Kaydet (Commit) b78d100f authored tarafından Tom's avatar Tom Kaydeden (comit) Tim Graham

Fixed #27849 -- Added filtering support to aggregates.

üst 489421b0
...@@ -8,10 +8,10 @@ __all__ = [ ...@@ -8,10 +8,10 @@ __all__ = [
class StatAggregate(Aggregate): class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=FloatField()): def __init__(self, y, x, output_field=FloatField(), filter=None):
if not x or not y: if not x or not y:
raise ValueError('Both y and x must be provided.') raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field) super().__init__(y, x, output_field=output_field, filter=filter)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return super().resolve_expression(query, allow_joins, reuse, summarize) return super().resolve_expression(query, allow_joins, reuse, summarize)
...@@ -22,9 +22,9 @@ class Corr(StatAggregate): ...@@ -22,9 +22,9 @@ class Corr(StatAggregate):
class CovarPop(StatAggregate): class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False): def __init__(self, y, x, sample=False, filter=None):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super().__init__(y, x) super().__init__(y, x, filter=filter)
class RegrAvgX(StatAggregate): class RegrAvgX(StatAggregate):
...@@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate): ...@@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate):
class RegrCount(StatAggregate): class RegrCount(StatAggregate):
function = 'REGR_COUNT' function = 'REGR_COUNT'
def __init__(self, y, x): def __init__(self, y, x, filter=None):
super().__init__(y=y, x=x, output_field=IntegerField()) super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if value is None: if value is None:
......
...@@ -229,6 +229,10 @@ class BaseDatabaseFeatures: ...@@ -229,6 +229,10 @@ class BaseDatabaseFeatures:
supports_select_difference = True supports_select_difference = True
supports_slicing_ordering_in_compound = False supports_slicing_ordering_in_compound = False
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
# expressions?
supports_aggregate_filter_clause = False
# Does the backend support indexing a TextField? # Does the backend support indexing a TextField?
supports_index_on_text_field = True supports_index_on_text_field = True
......
...@@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): ...@@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
END; END;
$$ LANGUAGE plpgsql;""" $$ LANGUAGE plpgsql;"""
@cached_property
def supports_aggregate_filter_clause(self):
return self.connection.pg_version >= 90400
@cached_property @cached_property
def has_select_for_update_skip_locked(self): def has_select_for_update_skip_locked(self):
return self.connection.pg_version >= 90500 return self.connection.pg_version >= 90500
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
Classes to represent the definitions of aggregate functions. Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Star from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import DecimalField, FloatField, IntegerField from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.query_utils import Q
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
...@@ -13,12 +14,36 @@ __all__ = [ ...@@ -13,12 +14,36 @@ __all__ = [
class Aggregate(Func): class Aggregate(Func):
contains_aggregate = True contains_aggregate = True
name = None name = None
filter_template = '%s FILTER (WHERE %%(filter)s)'
def __init__(self, *args, filter=None, **kwargs):
self.filter = filter
super().__init__(*args, **kwargs)
def get_source_fields(self):
# Don't return the filter expression since it's not a source field.
return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
source_expressions += [self.filter]
return source_expressions
def set_source_expressions(self, exprs):
if self.filter:
self.filter = exprs.pop()
return super().set_source_expressions(exprs)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# Aggregates are not allowed in UPDATE queries, so ignore for_save # Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize) c = super().resolve_expression(query, allow_joins, reuse, summarize)
if c.filter:
c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if not summarize: if not summarize:
expressions = c.get_source_expressions() # Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
expressions = super(Aggregate, c).get_source_expressions()
for index, expr in enumerate(expressions): for index, expr in enumerate(expressions):
if expr.contains_aggregate: if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index] before_resolved = self.get_source_expressions()[index]
...@@ -36,6 +61,29 @@ class Aggregate(Func): ...@@ -36,6 +61,29 @@ class Aggregate(Func):
def get_group_by_cols(self): def get_group_by_cols(self):
return [] return []
def as_sql(self, compiler, connection, **extra_context):
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
return sql, params + filter_params
else:
copy = self.copy()
copy.filter = None
condition = When(Q())
source_expressions = copy.get_source_expressions()
condition.set_source_expressions([self.filter, source_expressions[0]])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.filter:
options.update({'filter': self.filter})
return options
class Avg(Aggregate): class Avg(Aggregate):
function = 'AVG' function = 'AVG'
...@@ -52,7 +100,7 @@ class Avg(Aggregate): ...@@ -52,7 +100,7 @@ class Avg(Aggregate):
expression = self.get_source_expressions()[0] expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile( return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression))) SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
) )
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection)
...@@ -62,16 +110,19 @@ class Count(Aggregate): ...@@ -62,16 +110,19 @@ class Count(Aggregate):
name = 'Count' name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)' template = '%(function)s(%(distinct)s%(expressions)s)'
def __init__(self, expression, distinct=False, **extra): def __init__(self, expression, distinct=False, filter=None, **extra):
if expression == '*': if expression == '*':
expression = Star() expression = Star()
if isinstance(expression, Star) and filter is not None:
raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__( super().__init__(
expression, distinct='DISTINCT ' if distinct else '', expression, distinct='DISTINCT ' if distinct else '',
output_field=IntegerField(), **extra output_field=IntegerField(), filter=filter, **extra
) )
def _get_repr_options(self): def _get_repr_options(self):
return {'distinct': self.extra['distinct'] != ''} options = super()._get_repr_options()
return dict(options, distinct=self.extra['distinct'] != '')
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if value is None: if value is None:
...@@ -97,7 +148,8 @@ class StdDev(Aggregate): ...@@ -97,7 +148,8 @@ class StdDev(Aggregate):
super().__init__(expression, output_field=FloatField(), **extra) super().__init__(expression, output_field=FloatField(), **extra)
def _get_repr_options(self): def _get_repr_options(self):
return {'sample': self.function == 'STDDEV_SAMP'} options = super()._get_repr_options()
return dict(options, sample=self.function == 'STDDEV_SAMP')
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if value is None: if value is None:
...@@ -127,7 +179,8 @@ class Variance(Aggregate): ...@@ -127,7 +179,8 @@ class Variance(Aggregate):
super().__init__(expression, output_field=FloatField(), **extra) super().__init__(expression, output_field=FloatField(), **extra)
def _get_repr_options(self): def _get_repr_options(self):
return {'sample': self.function == 'VAR_SAMP'} options = super()._get_repr_options()
return dict(options, sample=self.function == 'VAR_SAMP')
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if value is None: if value is None:
......
...@@ -22,7 +22,7 @@ General-purpose aggregation functions ...@@ -22,7 +22,7 @@ General-purpose aggregation functions
``ArrayAgg`` ``ArrayAgg``
------------ ------------
.. class:: ArrayAgg(expression, distinct=False, **extra) .. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
Returns a list of values, including nulls, concatenated into an array. Returns a list of values, including nulls, concatenated into an array.
...@@ -36,7 +36,7 @@ General-purpose aggregation functions ...@@ -36,7 +36,7 @@ General-purpose aggregation functions
``BitAnd`` ``BitAnd``
---------- ----------
.. class:: BitAnd(expression, **extra) .. class:: BitAnd(expression, filter=None, **extra)
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
``None`` if all values are null. ``None`` if all values are null.
...@@ -44,7 +44,7 @@ General-purpose aggregation functions ...@@ -44,7 +44,7 @@ General-purpose aggregation functions
``BitOr`` ``BitOr``
--------- ---------
.. class:: BitOr(expression, **extra) .. class:: BitOr(expression, filter=None, **extra)
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
``None`` if all values are null. ``None`` if all values are null.
...@@ -52,7 +52,7 @@ General-purpose aggregation functions ...@@ -52,7 +52,7 @@ General-purpose aggregation functions
``BoolAnd`` ``BoolAnd``
----------- -----------
.. class:: BoolAnd(expression, **extra) .. class:: BoolAnd(expression, filter=None, **extra)
Returns ``True``, if all input values are true, ``None`` if all values are Returns ``True``, if all input values are true, ``None`` if all values are
null or if there are no values, otherwise ``False`` . null or if there are no values, otherwise ``False`` .
...@@ -60,7 +60,7 @@ General-purpose aggregation functions ...@@ -60,7 +60,7 @@ General-purpose aggregation functions
``BoolOr`` ``BoolOr``
---------- ----------
.. class:: BoolOr(expression, **extra) .. class:: BoolOr(expression, filter=None, **extra)
Returns ``True`` if at least one input value is true, ``None`` if all Returns ``True`` if at least one input value is true, ``None`` if all
values are null or if there are no values, otherwise ``False``. values are null or if there are no values, otherwise ``False``.
...@@ -68,7 +68,7 @@ General-purpose aggregation functions ...@@ -68,7 +68,7 @@ General-purpose aggregation functions
``JSONBAgg`` ``JSONBAgg``
------------ ------------
.. class:: JSONBAgg(expressions, **extra) .. class:: JSONBAgg(expressions, filter=None, **extra)
.. versionadded:: 1.11 .. versionadded:: 1.11
...@@ -77,7 +77,7 @@ General-purpose aggregation functions ...@@ -77,7 +77,7 @@ General-purpose aggregation functions
``StringAgg`` ``StringAgg``
------------- -------------
.. class:: StringAgg(expression, delimiter, distinct=False) .. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
Returns the input values concatenated into a string, separated by Returns the input values concatenated into a string, separated by
the ``delimiter`` string. the ``delimiter`` string.
...@@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required.
``Corr`` ``Corr``
-------- --------
.. class:: Corr(y, x) .. class:: Corr(y, x, filter=None)
Returns the correlation coefficient as a ``float``, or ``None`` if there Returns the correlation coefficient as a ``float``, or ``None`` if there
aren't any matching rows. aren't any matching rows.
...@@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required.
``CovarPop`` ``CovarPop``
------------ ------------
.. class:: CovarPop(y, x, sample=False) .. class:: CovarPop(y, x, sample=False, filter=None)
Returns the population covariance as a ``float``, or ``None`` if there Returns the population covariance as a ``float``, or ``None`` if there
aren't any matching rows. aren't any matching rows.
...@@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgX`` ``RegrAvgX``
------------ ------------
.. class:: RegrAvgX(y, x) .. class:: RegrAvgX(y, x, filter=None)
Returns the average of the independent variable (``sum(x)/N``) as a Returns the average of the independent variable (``sum(x)/N``) as a
``float``, or ``None`` if there aren't any matching rows. ``float``, or ``None`` if there aren't any matching rows.
...@@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgY`` ``RegrAvgY``
------------ ------------
.. class:: RegrAvgY(y, x) .. class:: RegrAvgY(y, x, filter=None)
Returns the average of the dependent variable (``sum(y)/N``) as a Returns the average of the dependent variable (``sum(y)/N``) as a
``float``, or ``None`` if there aren't any matching rows. ``float``, or ``None`` if there aren't any matching rows.
...@@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required.
``RegrCount`` ``RegrCount``
------------- -------------
.. class:: RegrCount(y, x) .. class:: RegrCount(y, x, filter=None)
Returns an ``int`` of the number of input rows in which both expressions Returns an ``int`` of the number of input rows in which both expressions
are not null. are not null.
...@@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required.
``RegrIntercept`` ``RegrIntercept``
----------------- -----------------
.. class:: RegrIntercept(y, x) .. class:: RegrIntercept(y, x, filter=None)
Returns the y-intercept of the least-squares-fit linear equation determined Returns the y-intercept of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
...@@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required.
``RegrR2`` ``RegrR2``
---------- ----------
.. class:: RegrR2(y, x) .. class:: RegrR2(y, x, filter=None)
Returns the square of the correlation coefficient as a ``float``, or Returns the square of the correlation coefficient as a ``float``, or
``None`` if there aren't any matching rows. ``None`` if there aren't any matching rows.
...@@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSlope`` ``RegrSlope``
------------- -------------
.. class:: RegrSlope(y, x) .. class:: RegrSlope(y, x, filter=None)
Returns the slope of the least-squares-fit linear equation determined Returns the slope of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
...@@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSXX`` ``RegrSXX``
----------- -----------
.. class:: RegrSXX(y, x) .. class:: RegrSXX(y, x, filter=None)
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
variable) as a ``float``, or ``None`` if there aren't any matching rows. variable) as a ``float``, or ``None`` if there aren't any matching rows.
...@@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSXY`` ``RegrSXY``
----------- -----------
.. class:: RegrSXY(y, x) .. class:: RegrSXY(y, x, filter=None)
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
times dependent variable) as a ``float``, or ``None`` if there aren't any times dependent variable) as a ``float``, or ``None`` if there aren't any
...@@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required. ...@@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSYY`` ``RegrSYY``
----------- -----------
.. class:: RegrSYY(y, x) .. class:: RegrSYY(y, x, filter=None)
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
variable) as a ``float``, or ``None`` if there aren't any matching rows. variable) as a ``float``, or ``None`` if there aren't any matching rows.
......
...@@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the ...@@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the
>>> Client.objects.values_list('name', 'account_type') >>> Client.objects.values_list('name', 'account_type')
<QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]> <QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]>
.. _conditional-aggregation:
Conditional aggregation Conditional aggregation
----------------------- -----------------------
What if we want to find out how many clients there are for each What if we want to find out how many clients there are for each
``account_type``? We can nest conditional expression within ``account_type``? We can use the ``filter`` argument of :ref:`aggregate
:ref:`aggregate functions <aggregation-functions>` to achieve this:: functions <aggregation-functions>` to achieve this::
>>> # Create some more Clients first so we can have something to count >>> # Create some more Clients first so we can have something to count
>>> Client.objects.create( >>> Client.objects.create(
...@@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each ...@@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each
>>> # Get counts for each value of account_type >>> # Get counts for each value of account_type
>>> from django.db.models import IntegerField, Sum >>> from django.db.models import IntegerField, Sum
>>> Client.objects.aggregate( >>> Client.objects.aggregate(
... regular=Sum( ... regular=Count('pk', filter=Q(account_type=Client.REGULAR)),
... Case(When(account_type=Client.REGULAR, then=1), ... gold=Count('pk', filter=Q(account_type=Client.GOLD)),
... output_field=IntegerField()) ... platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)),
... ),
... gold=Sum(
... Case(When(account_type=Client.GOLD, then=1),
... output_field=IntegerField())
... ),
... platinum=Sum(
... Case(When(account_type=Client.PLATINUM, then=1),
... output_field=IntegerField())
... )
... ) ... )
{'regular': 2, 'gold': 1, 'platinum': 3} {'regular': 2, 'gold': 1, 'platinum': 3}
This aggregate produces a query with the SQL 2003 ``FILTER WHERE`` syntax
on databases that support it:
.. code-block:: sql
SELECT count('id') FILTER (WHERE account_type=1) as regular,
count('id') FILTER (WHERE account_type=2) as gold,
count('id') FILTER (WHERE account_type=3) as platinum
FROM clients;
On other databases, this is emulated using a ``CASE`` statement:
.. code-block:: sql
SELECT count(CASE WHEN account_type=1 THEN id ELSE null) as regular,
count(CASE WHEN account_type=2 THEN id ELSE null) as gold,
count(CASE WHEN account_type=3 THEN id ELSE null) as platinum
FROM clients;
The two SQL statements are functionally equivalent but the more explicit
``FILTER`` may perform better.
...@@ -339,7 +339,7 @@ some complex computations:: ...@@ -339,7 +339,7 @@ some complex computations::
The ``Aggregate`` API is as follows: The ``Aggregate`` API is as follows:
.. class:: Aggregate(expression, output_field=None, **extra) .. class:: Aggregate(expression, output_field=None, filter=None, **extra)
.. attribute:: template .. attribute:: template
...@@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an ...@@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have ``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined. ``output_field=FloatField()`` defined.
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
.. versionchanged:: 2.0
The ``filter`` argument was added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------
......
...@@ -3085,6 +3085,17 @@ of the return value ...@@ -3085,6 +3085,17 @@ of the return value
``output_field`` if all fields are of the same type. Otherwise, you ``output_field`` if all fields are of the same type. Otherwise, you
must provide the ``output_field`` yourself. must provide the ``output_field`` yourself.
``filter``
~~~~~~~~~~
.. versionadded:: 2.0
An optional :class:`Q object <django.db.models.Q>` that's used to filter the
rows that are aggregated.
See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
example usage.
``**extra`` ``**extra``
~~~~~~~~~~~ ~~~~~~~~~~~
...@@ -3094,7 +3105,7 @@ by the aggregate. ...@@ -3094,7 +3105,7 @@ by the aggregate.
``Avg`` ``Avg``
~~~~~~~ ~~~~~~~
.. class:: Avg(expression, output_field=FloatField(), **extra) .. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
Returns the mean value of the given expression, which must be numeric Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``. unless you specify a different ``output_field``.
...@@ -3106,7 +3117,7 @@ by the aggregate. ...@@ -3106,7 +3117,7 @@ by the aggregate.
``Count`` ``Count``
~~~~~~~~~ ~~~~~~~~~
.. class:: Count(expression, distinct=False, **extra) .. class:: Count(expression, distinct=False, filter=None, **extra)
Returns the number of objects that are related through the provided Returns the number of objects that are related through the provided
expression. expression.
...@@ -3125,7 +3136,7 @@ by the aggregate. ...@@ -3125,7 +3136,7 @@ by the aggregate.
``Max`` ``Max``
~~~~~~~ ~~~~~~~
.. class:: Max(expression, output_field=None, **extra) .. class:: Max(expression, output_field=None, filter=None, **extra)
Returns the maximum value of the given expression. Returns the maximum value of the given expression.
...@@ -3135,7 +3146,7 @@ by the aggregate. ...@@ -3135,7 +3146,7 @@ by the aggregate.
``Min`` ``Min``
~~~~~~~ ~~~~~~~
.. class:: Min(expression, output_field=None, **extra) .. class:: Min(expression, output_field=None, filter=None, **extra)
Returns the minimum value of the given expression. Returns the minimum value of the given expression.
...@@ -3145,7 +3156,7 @@ by the aggregate. ...@@ -3145,7 +3156,7 @@ by the aggregate.
``StdDev`` ``StdDev``
~~~~~~~~~~ ~~~~~~~~~~
.. class:: StdDev(expression, sample=False, **extra) .. class:: StdDev(expression, sample=False, filter=None, **extra)
Returns the standard deviation of the data in the provided expression. Returns the standard deviation of the data in the provided expression.
...@@ -3169,7 +3180,7 @@ by the aggregate. ...@@ -3169,7 +3180,7 @@ by the aggregate.
``Sum`` ``Sum``
~~~~~~~ ~~~~~~~
.. class:: Sum(expression, output_field=None, **extra) .. class:: Sum(expression, output_field=None, filter=None, **extra)
Computes the sum of all values of the given expression. Computes the sum of all values of the given expression.
...@@ -3179,7 +3190,7 @@ by the aggregate. ...@@ -3179,7 +3190,7 @@ by the aggregate.
``Variance`` ``Variance``
~~~~~~~~~~~~ ~~~~~~~~~~~~
.. class:: Variance(expression, sample=False, **extra) .. class:: Variance(expression, sample=False, filter=None, **extra)
Returns the variance of the data in the provided expression. Returns the variance of the data in the provided expression.
......
...@@ -273,6 +273,10 @@ Models ...@@ -273,6 +273,10 @@ Models
parameters, if the backend supports this feature. Of Django's built-in parameters, if the backend supports this feature. Of Django's built-in
backends, only Oracle supports it. backends, only Oracle supports it.
* The new ``filter`` argument for built-in aggregates allows :ref:`adding
different conditionals <conditional-aggregation>` to multiple aggregations
over the same fields or relations.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above ...@@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
>>> pubs[0].num_books >>> pubs[0].num_books
73 73
# Each publisher, with a separate count of books with a rating above and below 5
>>> from django.db.models import Q
>>> above_5 = Count('book', filter=Q(book__rating__gt=5))
>>> below_5 = Count('book', filter=Q(book__rating__lte=5))
>>> pubs = Publisher.objects.annotate(below_5=below_5).annotate(above_5=above_5)
>>> pubs[0].above_5
23
>>> pubs[0].below_5
12
# The top 5 publishers, in order by number of books. # The top 5 publishers, in order by number of books.
>>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5] >>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5]
>>> pubs[0].num_books >>> pubs[0].num_books
...@@ -324,6 +334,8 @@ title that starts with "Django" using the query:: ...@@ -324,6 +334,8 @@ title that starts with "Django" using the query::
>>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price')) >>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price'))
.. _filtering-on-annotations:
Filtering on annotations Filtering on annotations
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -339,6 +351,27 @@ you can issue the query:: ...@@ -339,6 +351,27 @@ you can issue the query::
This query generates an annotated result set, and then generates a filter This query generates an annotated result set, and then generates a filter
based upon that annotation. based upon that annotation.
If you need two annotations with two separate filters you can use the
``filter`` argument with any aggregate. For example, to generate a list of
authors with a count of highly rated books::
>>> highly_rated = Count('books', filter=Q(books__rating__gte=7))
>>> Author.objects.annotate(num_books=Count('books'), highly_rated_books=highly_rated)
Each ``Author`` in the result set will have the ``num_books`` and
``highly_rated_books`` attributes.
.. admonition:: Choosing between ``filter`` and ``QuerySet.filter()``
Avoid using the ``filter`` argument with a single annotation or
aggregation. It's more efficient to use ``QuerySet.filter()`` to exclude
rows. The aggregation ``filter`` argument is only useful when using two or
more aggregations over the same relations with different conditionals.
.. versionchanged:: 2.0
The ``filter`` argument was added to aggregates.
Order of ``annotate()`` and ``filter()`` clauses Order of ``annotate()`` and ``filter()`` clauses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
import datetime
from decimal import Decimal
from django.db.models import Case, Count, F, Q, Sum, When
from django.test import TestCase
from .models import Author, Book, Publisher
class FilteredAggregateTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.a1 = Author.objects.create(name='test', age=40)
cls.a2 = Author.objects.create(name='test2', age=60)
cls.a3 = Author.objects.create(name='test3', age=100)
cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1))
cls.b1 = Book.objects.create(
isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right',
pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1,
pubdate=datetime.date(2007, 12, 6),
)
cls.b2 = Book.objects.create(
isbn='067232959', name='Sams Teach Yourself Django in 24 Hours',
pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1,
pubdate=datetime.date(2008, 3, 3),
)
cls.b3 = Book.objects.create(
isbn='159059996', name='Practical Django Projects',
pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1,
pubdate=datetime.date(2008, 6, 23),
)
cls.a1.friends.add(cls.a2)
cls.a1.friends.add(cls.a3)
cls.b1.authors.add(cls.a1)
cls.b1.authors.add(cls.a3)
cls.b2.authors.add(cls.a2)
cls.b3.authors.add(cls.a3)
def test_filtered_aggregates(self):
agg = Sum('age', filter=Q(name__startswith='test'))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200)
def test_double_filtered_aggregates(self):
agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test')))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60)
def test_excluded_aggregates(self):
agg = Sum('age', filter=~Q(name='test2'))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140)
def test_related_aggregates_m2m(self):
agg = Sum('friends__age', filter=~Q(friends__name='test'))
self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160)
def test_related_aggregates_m2m_and_fk(self):
q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3')
agg = Sum('friends__book__pages', filter=q)
self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528)
def test_plain_annotate(self):
agg = Sum('book__pages', filter=Q(book__rating__gt=3))
qs = Author.objects.annotate(pages=agg).order_by('pk')
self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047])
def test_filtered_aggregate_on_annotate(self):
pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3))
age_agg = Sum('age', filter=Q(total_pages__gte=400))
aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg)
self.assertEqual(aggregated, {'summed_age': 140})
def test_case_aggregate(self):
agg = Sum(
Case(When(friends__age=40, then=F('friends__age'))),
filter=Q(friends__name__startswith='test'),
)
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80)
def test_sum_star_exception(self):
msg = 'Star cannot be used with filter. Please specify a field.'
with self.assertRaisesMessage(ValueError, msg):
Count('*', filter=Q(age=40))
...@@ -5,7 +5,7 @@ from copy import deepcopy ...@@ -5,7 +5,7 @@ from copy import deepcopy
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models, transaction from django.db import DatabaseError, connection, models, transaction
from django.db.models import CharField, TimeField, UUIDField from django.db.models import CharField, Q, TimeField, UUIDField
from django.db.models.aggregates import ( from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
) )
...@@ -1369,3 +1369,16 @@ class ReprTests(TestCase): ...@@ -1369,3 +1369,16 @@ class ReprTests(TestCase):
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))") self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
def test_filtered_aggregates(self):
filter = Q(a=1)
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(
repr(Variance('a', sample=True, filter=filter)),
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
)
...@@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase): ...@@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase):
account_type=Client.PLATINUM, account_type=Client.PLATINUM,
registered_on=date.today(), registered_on=date.today(),
) )
self.assertEqual(
Client.objects.aggregate(
regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)),
gold=models.Count('pk', filter=Q(account_type=Client.GOLD)),
platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)),
),
{'regular': 2, 'gold': 1, 'platinum': 3}
)
# This was the example before the filter argument was added.
self.assertEqual( self.assertEqual(
Client.objects.aggregate( Client.objects.aggregate(
regular=models.Sum(Case( regular=models.Sum(Case(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment