Kaydet (Commit) 37d06cfc authored tarafından Claude Paroz's avatar Claude Paroz

Fixed #25499 -- Added the ability to pass an expression in distance lookups

Thanks Bibhas Debnath for the report and Tim Graham for the review.
üst 4a7b5821
...@@ -195,7 +195,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): ...@@ -195,7 +195,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
""" """
return 'MDSYS.SDO_GEOMETRY' return 'MDSYS.SDO_GEOMETRY'
def get_distance(self, f, value, lookup_type): def get_distance(self, f, value, lookup_type, **kwargs):
""" """
Returns the distance parameters given the value and the lookup type. Returns the distance parameters given the value and the lookup type.
On Oracle, geometry columns with a geodetic coordinate system behave On Oracle, geometry columns with a geodetic coordinate system behave
......
...@@ -34,14 +34,17 @@ class PostGISOperator(SpatialOperator): ...@@ -34,14 +34,17 @@ class PostGISOperator(SpatialOperator):
class PostGISDistanceOperator(PostGISOperator): class PostGISDistanceOperator(PostGISOperator):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def as_sql(self, connection, lookup, template_params, sql_params): def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
sql_template = self.sql_template sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'}) template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
# Using distance_spheroid requires the spheroid of the field as
# a parameter.
sql_params.insert(1, lookup.lhs.output_field._spheroid)
else: else:
template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'}) template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'})
return sql_template % template_params, sql_params return sql_template % template_params, sql_params
...@@ -226,7 +229,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): ...@@ -226,7 +229,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
geom_type = f.geom_type geom_type = f.geom_type
return 'geometry(%s,%d)' % (geom_type, f.srid) return 'geometry(%s,%d)' % (geom_type, f.srid)
def get_distance(self, f, dist_val, lookup_type): def get_distance(self, f, dist_val, lookup_type, handle_spheroid=True):
""" """
Retrieve the distance parameters for the given geometry field, Retrieve the distance parameters for the given geometry field,
distance lookup value, and the distance lookup type. distance lookup value, and the distance lookup type.
...@@ -236,11 +239,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): ...@@ -236,11 +239,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
projected geometry columns. In addition, it has to take into account projected geometry columns. In addition, it has to take into account
the geography column type. the geography column type.
""" """
# Getting the distance parameter and any options. # Getting the distance parameter
if len(dist_val) == 1: value = dist_val[0]
value, option = dist_val[0], None
else:
value, option = dist_val
# Shorthand boolean flags. # Shorthand boolean flags.
geodetic = f.geodetic(self.connection) geodetic = f.geodetic(self.connection)
...@@ -260,13 +260,17 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): ...@@ -260,13 +260,17 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
# Assuming the distance is in the units of the field. # Assuming the distance is in the units of the field.
dist_param = value dist_param = value
params = [dist_param]
# handle_spheroid *might* be dropped in Django 2.0 as PostGISDistanceOperator
# also handles it (#25524).
if handle_spheroid and len(dist_val) > 1:
option = dist_val[1]
if (not geography and geodetic and lookup_type != 'dwithin' if (not geography and geodetic and lookup_type != 'dwithin'
and option == 'spheroid'): and option == 'spheroid'):
# using distance_spheroid requires the spheroid of the field as # using distance_spheroid requires the spheroid of the field as
# a parameter. # a parameter.
return [f._spheroid, dist_param] params.insert(0, f._spheroid)
else: return params
return [dist_param]
def get_geom_placeholder(self, f, value, compiler): def get_geom_placeholder(self, f, value, compiler):
""" """
......
...@@ -175,7 +175,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): ...@@ -175,7 +175,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
""" """
return None return None
def get_distance(self, f, value, lookup_type): def get_distance(self, f, value, lookup_type, **kwargs):
""" """
Returns the distance parameters for the given geometry field, Returns the distance parameters for the given geometry field,
lookup value, and lookup type. SpatiaLite only supports regular lookup value, and lookup type. SpatiaLite only supports regular
......
...@@ -16,6 +16,10 @@ class GISLookup(Lookup): ...@@ -16,6 +16,10 @@ class GISLookup(Lookup):
transform_func = None transform_func = None
distance = False distance = False
def __init__(self, *args, **kwargs):
super(GISLookup, self).__init__(*args, **kwargs)
self.template_params = {}
@classmethod @classmethod
def _check_geo_field(cls, opts, lookup): def _check_geo_field(cls, opts, lookup):
""" """
...@@ -98,7 +102,8 @@ class GISLookup(Lookup): ...@@ -98,7 +102,8 @@ class GISLookup(Lookup):
rhs_sql, rhs_params = self.process_rhs(compiler, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
sql_params.extend(rhs_params) sql_params.extend(rhs_params)
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql} template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s'}
template_params.update(self.template_params)
rhs_op = self.get_rhs_op(connection, rhs_sql) rhs_op = self.get_rhs_op(connection, rhs_sql)
return rhs_op.as_sql(connection, self, template_params, sql_params) return rhs_op.as_sql(connection, self, template_params, sql_params)
...@@ -302,18 +307,26 @@ gis_lookups['within'] = WithinLookup ...@@ -302,18 +307,26 @@ gis_lookups['within'] = WithinLookup
class DistanceLookupBase(GISLookup): class DistanceLookupBase(GISLookup):
distance = True distance = True
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def get_db_prep_lookup(self, value, connection): def process_rhs(self, compiler, connection):
if isinstance(value, (tuple, list)): if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
if not 2 <= len(value) <= 3:
raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name) raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
params = [connection.ops.Adapter(value[0])] params = [connection.ops.Adapter(self.rhs[0])]
# Getting the distance parameter in the units of the field. # Getting the distance parameter in the units of the field.
params += connection.ops.get_distance(self.lhs.output_field, value[1:], self.lookup_name) dist_param = self.rhs[1]
return ('%s', params) if hasattr(dist_param, 'resolve_expression'):
dist_param = dist_param.resolve_expression(compiler.query)
sql, expr_params = compiler.compile(dist_param)
self.template_params['value'] = sql
params.extend(expr_params)
else: else:
return super(DistanceLookupBase, self).get_db_prep_lookup(value, connection) params += connection.ops.get_distance(
self.lhs.output_field, (dist_param,) + self.rhs[2:],
self.lookup_name, handle_spheroid=False
)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
return (rhs, params)
class DWithinLookup(DistanceLookupBase): class DWithinLookup(DistanceLookupBase):
......
...@@ -515,14 +515,20 @@ Distance lookups take the following form:: ...@@ -515,14 +515,20 @@ Distance lookups take the following form::
The value passed into a distance lookup is a tuple; the first two The value passed into a distance lookup is a tuple; the first two
values are mandatory, and are the geometry to calculate distances to, values are mandatory, and are the geometry to calculate distances to,
and a distance value (either a number in units of the field or a and a distance value (either a number in units of the field, a
:class:`~django.contrib.gis.measure.Distance` object). On every :class:`~django.contrib.gis.measure.Distance` object, or a `query expression
distance lookup but :lookup:`dwithin`, an optional <ref/models/expressions>`).
With PostGIS, on every distance lookup but :lookup:`dwithin`, an optional
third element, ``'spheroid'``, may be included to tell GeoDjango third element, ``'spheroid'``, may be included to tell GeoDjango
to use the more accurate spheroid distance calculation functions on to use the more accurate spheroid distance calculation functions on
fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid`` fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid``
would be used instead of ``ST_Distance_Sphere``). would be used instead of ``ST_Distance_Sphere``).
.. versionadded:: 1.10
The ability to pass an expression as the distance value was added.
.. fieldlookup:: distance_gt .. fieldlookup:: distance_gt
distance_gt distance_gt
......
...@@ -59,7 +59,8 @@ Minor features ...@@ -59,7 +59,8 @@ Minor features
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
* ... * :ref:`Distance lookups <distance-lookups>` now accept expressions as the
distance value parameter.
:mod:`django.contrib.messages` :mod:`django.contrib.messages`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -21,6 +21,7 @@ class NamedModel(models.Model): ...@@ -21,6 +21,7 @@ class NamedModel(models.Model):
class SouthTexasCity(NamedModel): class SouthTexasCity(NamedModel):
"City model on projected coordinate system for South Texas." "City model on projected coordinate system for South Texas."
point = models.PointField(srid=32140) point = models.PointField(srid=32140)
radius = models.IntegerField(default=10000)
class SouthTexasCityFt(NamedModel): class SouthTexasCityFt(NamedModel):
...@@ -31,6 +32,7 @@ class SouthTexasCityFt(NamedModel): ...@@ -31,6 +32,7 @@ class SouthTexasCityFt(NamedModel):
class AustraliaCity(NamedModel): class AustraliaCity(NamedModel):
"City model for Australia, using WGS84." "City model for Australia, using WGS84."
point = models.PointField() point = models.PointField()
radius = models.IntegerField(default=10000)
class CensusZipcode(NamedModel): class CensusZipcode(NamedModel):
......
...@@ -6,7 +6,7 @@ from django.contrib.gis.db.models.functions import ( ...@@ -6,7 +6,7 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance from django.contrib.gis.measure import D # alias for Distance
from django.db import connection from django.db import connection
from django.db.models import Q from django.db.models import F, Q
from django.test import TestCase, ignore_warnings, skipUnlessDBFeature from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
...@@ -323,6 +323,31 @@ class DistanceTest(TestCase): ...@@ -323,6 +323,31 @@ class DistanceTest(TestCase):
cities = self.get_names(qs) cities = self.get_names(qs)
self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul']) self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul'])
@skipUnlessDBFeature("supports_distances_lookups")
def test_distance_lookups_with_expression_rhs(self):
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius')),
).order_by('name')
self.assertEqual(
self.get_names(qs),
['Bellaire', 'Downtown Houston', 'Southside Place', 'West University Place']
)
# With a combined expression
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius') * 2),
).order_by('name')
self.assertEqual(len(qs), 5)
self.assertIn('Pearland', self.get_names(qs))
# With spheroid param
if connection.features.supports_distance_geodetic:
hobart = AustraliaCity.objects.get(name='Hobart')
qs = AustraliaCity.objects.filter(
point__distance_lte=(hobart.point, F('radius') * 70, 'spheroid'),
).order_by('name')
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
@skipUnlessDBFeature("has_area_method") @skipUnlessDBFeature("has_area_method")
@ignore_warnings(category=RemovedInDjango20Warning) @ignore_warnings(category=RemovedInDjango20Warning)
def test_area(self): def test_area(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment