Kaydet (Commit) bc05547c authored tarafından Simon Charette's avatar Simon Charette Kaydeden (comit) Tim Graham

Fixed #28658 -- Added DISTINCT handling to the Aggregate class.

üst 222caab6
...@@ -11,14 +11,12 @@ __all__ = [ ...@@ -11,14 +11,12 @@ __all__ = [
class ArrayAgg(OrderableAggMixin, Aggregate): class ArrayAgg(OrderableAggMixin, Aggregate):
function = 'ARRAY_AGG' function = 'ARRAY_AGG'
template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
allow_distinct = True
@property @property
def output_field(self): def output_field(self):
return ArrayField(self.source_expressions[0].output_field) return ArrayField(self.source_expressions[0].output_field)
def __init__(self, expression, distinct=False, **extra):
super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if not value:
return [] return []
...@@ -54,10 +52,10 @@ class JSONBAgg(Aggregate): ...@@ -54,10 +52,10 @@ class JSONBAgg(Aggregate):
class StringAgg(OrderableAggMixin, Aggregate): class StringAgg(OrderableAggMixin, Aggregate):
function = 'STRING_AGG' function = 'STRING_AGG'
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)" template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
allow_distinct = True
def __init__(self, expression, delimiter, distinct=False, **extra): def __init__(self, expression, delimiter, **extra):
distinct = 'DISTINCT ' if distinct else '' super().__init__(expression, delimiter=delimiter, **extra)
super().__init__(expression, delimiter=delimiter, distinct=distinct, **extra)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if not value:
......
...@@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations): ...@@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations):
'aggregations on date/time fields in sqlite3 ' 'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.' 'since date/time is saved as text.'
) )
if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
raise utils.NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
"accepting multiple arguments."
)
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
""" """
......
...@@ -11,14 +11,19 @@ __all__ = [ ...@@ -11,14 +11,19 @@ __all__ = [
class Aggregate(Func): class Aggregate(Func):
template = '%(function)s(%(distinct)s%(expressions)s)'
contains_aggregate = True contains_aggregate = True
name = None name = None
filter_template = '%s FILTER (WHERE %%(filter)s)' filter_template = '%s FILTER (WHERE %%(filter)s)'
window_compatible = True window_compatible = True
allow_distinct = False
def __init__(self, *args, filter=None, **kwargs): def __init__(self, *expressions, distinct=False, filter=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
self.distinct = distinct
self.filter = filter self.filter = filter
super().__init__(*args, **kwargs) super().__init__(*expressions, **extra)
def get_source_fields(self): def get_source_fields(self):
# Don't return the filter expression since it's not a source field. # Don't return the filter expression since it's not a source field.
...@@ -60,6 +65,7 @@ class Aggregate(Func): ...@@ -60,6 +65,7 @@ class Aggregate(Func):
return [] return []
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):
extra_context['distinct'] = 'DISTINCT' if self.distinct else ''
if self.filter: if self.filter:
if connection.features.supports_aggregate_filter_clause: if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection) filter_sql, filter_params = self.filter.as_sql(compiler, connection)
...@@ -80,8 +86,10 @@ class Aggregate(Func): ...@@ -80,8 +86,10 @@ class Aggregate(Func):
def _get_repr_options(self): def _get_repr_options(self):
options = super()._get_repr_options() options = super()._get_repr_options()
if self.distinct:
options['distinct'] = self.distinct
if self.filter: if self.filter:
options.update({'filter': self.filter}) options['filter'] = self.filter
return options return options
...@@ -114,21 +122,15 @@ class Avg(Aggregate): ...@@ -114,21 +122,15 @@ class Avg(Aggregate):
class Count(Aggregate): class Count(Aggregate):
function = 'COUNT' function = 'COUNT'
name = 'Count' name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)'
output_field = IntegerField() output_field = IntegerField()
allow_distinct = True
def __init__(self, expression, distinct=False, filter=None, **extra): def __init__(self, expression, filter=None, **extra):
if expression == '*': if expression == '*':
expression = Star() expression = Star()
if isinstance(expression, Star) and filter is not None: if isinstance(expression, Star) and filter is not None:
raise ValueError('Star cannot be used with filter. Please specify a field.') raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__( super().__init__(expression, filter=filter, **extra)
expression, distinct='DISTINCT ' if distinct else '',
filter=filter, **extra
)
def _get_repr_options(self):
return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
return 0 if value is None else value return 0 if value is None else value
......
...@@ -373,7 +373,7 @@ some complex computations:: ...@@ -373,7 +373,7 @@ some complex computations::
The ``Aggregate`` API is as follows: The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, filter=None, **extra) .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra)
.. attribute:: template .. attribute:: template
...@@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows: ...@@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows:
Defaults to ``True`` since most aggregate functions can be used as the Defaults to ``True`` since most aggregate functions can be used as the
source expression in :class:`~django.db.models.expressions.Window`. source expression in :class:`~django.db.models.expressions.Window`.
.. attribute:: allow_distinct
.. versionadded:: 2.2
A class attribute determining whether or not this aggregate function
allows passing a ``distinct`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
The ``expressions`` positional arguments can include expressions or the names The ``expressions`` positional arguments can include expressions or the names
of model fields. They will be converted to a string and used as the of model fields. They will be converted to a string and used as the
``expressions`` placeholder within the ``template``. ``expressions`` placeholder within the ``template``.
...@@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an ...@@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have ``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined. ``output_field=FloatField()`` defined.
The ``distinct`` argument determines whether or not the aggregate function
should be invoked for each distinct value of ``expressions`` (or set of
values, for multiple ``expressions``). The argument is only supported on
aggregates that have :attr:`~Aggregate.allow_distinct` set to ``True``.
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation` used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage. and :ref:`filtering-on-annotations` for example usage.
...@@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage. ...@@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
.. versionadded:: 2.2
The ``allow_distinct`` attribute and ``distinct`` argument were added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------
......
...@@ -239,6 +239,13 @@ Models ...@@ -239,6 +239,13 @@ Models
* Added SQLite support for the :class:`~django.db.models.StdDev` and * Added SQLite support for the :class:`~django.db.models.StdDev` and
:class:`~django.db.models.Variance` functions. :class:`~django.db.models.Variance` functions.
* The handling of ``DISTINCT`` aggregation is added to the
:class:`~django.db.models.Aggregate` class. Adding :attr:`allow_distinct =
True <django.db.models.Aggregate.allow_distinct>` as a class attribute on
``Aggregate`` subclasses allows a ``distinct`` keyword argument to be
specified on initialization to ensure that the aggregate function is only
called for each distinct value of ``expressions``.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase): ...@@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase):
# test completely changing how the output is rendered # test completely changing how the output is rendered
def lower_case_function_override(self, compiler, connection): def lower_case_function_override(self, compiler, connection):
sql, params = compiler.compile(self.source_expressions[0]) sql, params = compiler.compile(self.source_expressions[0])
substitutions = {'function': self.function.lower(), 'expressions': sql} substitutions = {'function': self.function.lower(), 'expressions': sql, 'distinct': ''}
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, params return self.template % substitutions, params
setattr(MySum, 'as_' + connection.vendor, lower_case_function_override) setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)
...@@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase): ...@@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template # test overriding all parts of the template
def be_evil(self, compiler, connection): def be_evil(self, compiler, connection):
substitutions = {'function': 'MAX', 'expressions': '2'} substitutions = {'function': 'MAX', 'expressions': '2', 'distinct': ''}
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, () return self.template % substitutions, ()
setattr(MySum, 'as_' + connection.vendor, be_evil) setattr(MySum, 'as_' + connection.vendor, be_evil)
......
...@@ -11,6 +11,7 @@ from django.db.models import ( ...@@ -11,6 +11,7 @@ from django.db.models import (
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
Value, Variance, When, Value, Variance, When,
) )
from django.db.models.aggregates import Aggregate
from django.test import ( from django.test import (
TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature, TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature,
) )
...@@ -1496,6 +1497,16 @@ class AggregationTests(TestCase): ...@@ -1496,6 +1497,16 @@ class AggregationTests(TestCase):
qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1) qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1)
self.assertSequenceEqual(qs, [29]) self.assertSequenceEqual(qs, [29])
def test_allow_distinct(self):
class MyAggregate(Aggregate):
pass
with self.assertRaisesMessage(TypeError, 'MyAggregate does not allow distinct'):
MyAggregate('foo', distinct=True)
class DistinctAggregate(Aggregate):
allow_distinct = True
DistinctAggregate('foo', distinct=True)
class JoinPromotionTests(TestCase): class JoinPromotionTests(TestCase):
def test_ticket_21150(self): def test_ticket_21150(self):
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import Avg, StdDev, Sum, Variance from django.db.models import Avg, StdDev, Sum, Variance
from django.db.models.aggregates import Aggregate
from django.db.models.fields import CharField from django.db.models.fields import CharField
from django.db.utils import NotSupportedError from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
...@@ -34,6 +35,17 @@ class Tests(TestCase): ...@@ -34,6 +35,17 @@ class Tests(TestCase):
**{'complex': aggregate('last_modified') + aggregate('last_modified')} **{'complex': aggregate('last_modified') + aggregate('last_modified')}
) )
def test_distinct_aggregation(self):
class DistinctAggregate(Aggregate):
allow_distinct = True
aggregate = DistinctAggregate('first', 'second', distinct=True)
msg = (
"SQLite doesn't support DISTINCT on aggregate functions accepting "
"multiple arguments."
)
with self.assertRaisesMessage(NotSupportedError, msg):
connection.ops.check_expression_support(aggregate)
def test_memory_db_test_name(self): def test_memory_db_test_name(self):
"""A named in-memory db should be allowed where supported.""" """A named in-memory db should be allowed where supported."""
from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.backends.sqlite3.base import DatabaseWrapper
......
...@@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase): ...@@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase):
def test_aggregates(self): def test_aggregates(self):
self.assertEqual(repr(Avg('a')), "Avg(F(a))") self.assertEqual(repr(Avg('a')), "Avg(F(a))")
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)") self.assertEqual(repr(Count('a')), "Count(F(a))")
self.assertEqual(repr(Count('*')), "Count('*', distinct=False)") self.assertEqual(repr(Count('*')), "Count('*')")
self.assertEqual(repr(Max('a')), "Max(F(a))") self.assertEqual(repr(Max('a')), "Max(F(a))")
self.assertEqual(repr(Min('a')), "Min(F(a))") self.assertEqual(repr(Min('a')), "Min(F(a))")
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))") self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
def test_distinct_aggregates(self):
self.assertEqual(repr(Count('a', distinct=True)), "Count(F(a), distinct=True)")
self.assertEqual(repr(Count('*', distinct=True)), "Count('*', distinct=True)")
def test_filtered_aggregates(self): def test_filtered_aggregates(self):
filter = Q(a=1) filter = Q(a=1)
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))") self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)") self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
...@@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase): ...@@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase):
repr(Variance('a', sample=True, filter=filter)), repr(Variance('a', sample=True, filter=filter)),
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)" "Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
) )
self.assertEqual(
repr(Count('a', filter=filter, distinct=True)), "Count(F(a), distinct=True, filter=(AND: ('a', 1)))"
)
class CombinableTests(SimpleTestCase): class CombinableTests(SimpleTestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment