Kaydet (Commit) a08d2463 authored tarafından Daniel Wiesmann's avatar Daniel Wiesmann Kaydeden (comit) Claude Paroz

Fixed #26112 -- Error when computing aggregate of GIS areas.

Thanks Simon Charette and Claude Paroz for the reviews.
üst 16baec5c
...@@ -117,24 +117,23 @@ class OracleToleranceMixin(object): ...@@ -117,24 +117,23 @@ class OracleToleranceMixin(object):
class Area(OracleToleranceMixin, GeoFunc): class Area(OracleToleranceMixin, GeoFunc):
output_field_class = AreaField
arity = 1 arity = 1
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if connection.ops.geography: if connection.ops.geography:
# Geography fields support area calculation, returns square meters. self.output_field.area_att = 'sq_m'
self.output_field = AreaField('sq_m')
elif not self.output_field.geodetic(connection):
# Getting the area units of the geographic field.
units = self.output_field.units_name(connection)
if units:
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection))
)
else:
self.output_field = FloatField()
else: else:
# TODO: Do we want to support raw number areas for geodetic fields? # Getting the area units of the geographic field.
raise NotImplementedError('Area on geodetic coordinate systems not supported.') source_fields = self.get_source_fields()
if len(source_fields):
source_field = source_fields[0]
if source_field.geodetic(connection):
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
units_name = source_field.units_name(connection)
if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super(Area, self).as_sql(compiler, connection) return super(Area, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
......
...@@ -21,13 +21,14 @@ class BaseField(object): ...@@ -21,13 +21,14 @@ class BaseField(object):
class AreaField(BaseField): class AreaField(BaseField):
"Wrapper for Area values." "Wrapper for Area values."
def __init__(self, area_att): def __init__(self, area_att=None):
self.area_att = area_att self.area_att = area_att
def from_db_value(self, value, expression, connection, context): def from_db_value(self, value, expression, connection, context):
if connection.features.interprets_empty_strings_as_nulls and value == '': if connection.features.interprets_empty_strings_as_nulls and value == '':
value = None value = None
if value is not None: # If the units are known, convert value into area measure.
if value is not None and self.area_att:
value = Area(**{self.area_att: value}) value = Area(**{self.area_att: value})
return value return value
......
...@@ -22,6 +22,10 @@ class Country(NamedModel): ...@@ -22,6 +22,10 @@ class Country(NamedModel):
mpoly = models.MultiPolygonField() # SRID, by default, is 4326 mpoly = models.MultiPolygonField() # SRID, by default, is 4326
class CountryWebMercator(NamedModel):
mpoly = models.MultiPolygonField(srid=3857)
class City(NamedModel): class City(NamedModel):
point = models.PointField() point = models.PointField()
......
...@@ -5,12 +5,14 @@ from decimal import Decimal ...@@ -5,12 +5,14 @@ from decimal import Decimal
from django.contrib.gis.db.models import functions from django.contrib.gis.db.models import functions
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
from django.contrib.gis.measure import Area
from django.db import connection from django.db import connection
from django.db.models import Sum
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils import six from django.utils import six
from ..utils import mysql, oracle, postgis, spatialite from ..utils import mysql, oracle, postgis, spatialite
from .models import City, Country, State, Track from .models import City, Country, CountryWebMercator, State, Track
@skipUnlessDBFeature("gis_enabled") @skipUnlessDBFeature("gis_enabled")
...@@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase): ...@@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase):
expected = c.mpoly.intersection(geom) expected = c.mpoly.intersection(geom)
self.assertEqual(c.inter, expected) self.assertEqual(c.inter, expected)
@skipUnlessDBFeature("has_Area_function")
def test_area_with_regular_aggregate(self):
# Create projected country objects, for this test to work on all backends.
for c in Country.objects.all():
CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly)
# Test in projected coordinate system
qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly')))
for c in qs:
result = c.area_sum
# If the result is a measure object, get value.
if isinstance(result, Area):
result = result.sq_m
self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0)
@skipUnlessDBFeature("has_MemSize_function") @skipUnlessDBFeature("has_MemSize_function")
def test_memsize(self): def test_memsize(self):
ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo') ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment