Kaydet (Commit) a6074e89 authored tarafından Simon Charette's avatar Simon Charette

Fixed #26458 -- Based Avg's default output_field resolution on its source field type.

Thanks Tim for the review and Josh for the input.
üst 361cb7a8
...@@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions. ...@@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Star from django.db.models.expressions import Func, Star
from django.db.models.fields import FloatField, IntegerField from django.db.models.fields import DecimalField, FloatField, IntegerField
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
...@@ -41,9 +41,11 @@ class Avg(Aggregate): ...@@ -41,9 +41,11 @@ class Avg(Aggregate):
function = 'AVG' function = 'AVG'
name = 'Avg' name = 'Avg'
def __init__(self, expression, **extra): def _resolve_output_field(self):
output_field = extra.pop('output_field', FloatField()) source_field = self.get_source_fields()[0]
super(Avg, self).__init__(expression, output_field=output_field, **extra) if isinstance(source_field, (IntegerField, DecimalField)):
self._output_field = FloatField()
super(Avg, self)._resolve_output_field()
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':
......
...@@ -496,10 +496,16 @@ class AggregateTestCase(TestCase): ...@@ -496,10 +496,16 @@ class AggregateTestCase(TestCase):
self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)}) self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})
def test_avg_duration_field(self): def test_avg_duration_field(self):
# Explicit `output_field`.
self.assertEqual( self.assertEqual(
Publisher.objects.aggregate(Avg('duration', output_field=DurationField())), Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
{'duration__avg': datetime.timedelta(days=1, hours=12)} {'duration__avg': datetime.timedelta(days=1, hours=12)}
) )
# Implicit `output_field`.
self.assertEqual(
Publisher.objects.aggregate(Avg('duration')),
{'duration__avg': datetime.timedelta(days=1, hours=12)}
)
def test_sum_duration_field(self): def test_sum_duration_field(self):
self.assertEqual( self.assertEqual(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment