Kaydet (Commit) 5b3c66d8 authored tarafından Alberto Avila's avatar Alberto Avila Kaydeden (comit) Tim Graham

[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.

Partial backport of afe0bb7b from master.
üst e625859f
...@@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin): ...@@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin):
bilateral_transforms.append((self.__class__, self.init_lookups)) bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin): class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
...@@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin): ...@@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
raise NotImplementedError raise NotImplementedError
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class BuiltinLookup(Lookup): class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
......
...@@ -315,9 +315,9 @@ class WhereNode(tree.Node): ...@@ -315,9 +315,9 @@ class WhereNode(tree.Node):
@classmethod @classmethod
def _contains_aggregate(cls, obj): def _contains_aggregate(cls, obj):
if not isinstance(obj, tree.Node): if isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) return any(cls._contains_aggregate(c) for c in obj.children)
return any(cls._contains_aggregate(c) for c in obj.children) return obj.contains_aggregate
@cached_property @cached_property
def contains_aggregate(self): def contains_aggregate(self):
...@@ -336,6 +336,7 @@ class EverythingNode(object): ...@@ -336,6 +336,7 @@ class EverythingNode(object):
""" """
A node that matches everything. A node that matches everything.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
return '', [] return '', []
...@@ -345,11 +346,16 @@ class NothingNode(object): ...@@ -345,11 +346,16 @@ class NothingNode(object):
""" """
A node that matches nothing. A node that matches nothing.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
class ExtraWhere(object): class ExtraWhere(object):
# The contents are a black box - assume no aggregates are used.
contains_aggregate = False
def __init__(self, sqls, params): def __init__(self, sqls, params):
self.sqls = sqls self.sqls = sqls
self.params = params self.params = params
...@@ -410,6 +416,10 @@ class Constraint(object): ...@@ -410,6 +416,10 @@ class Constraint(object):
class SubqueryConstraint(object): class SubqueryConstraint(object):
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
contains_aggregate = False
def __init__(self, alias, columns, targets, query_object): def __init__(self, alias, columns, targets, query_object):
self.alias = alias self.alias = alias
self.columns = columns self.columns = columns
......
...@@ -23,3 +23,6 @@ Bugfixes ...@@ -23,3 +23,6 @@ Bugfixes
``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that
already had the other specified, or when removing one of them from a field already had the other specified, or when removing one of them from a field
that had both (:ticket:`26034`). that had both (:ticket:`26034`).
* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression
(:ticket:`26071`).
...@@ -8,7 +8,7 @@ from uuid import UUID ...@@ -8,7 +8,7 @@ from uuid import UUID
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, models from django.db import connection, models
from django.db.models import F, Q, Max, Min, Value from django.db.models import F, Q, Max, Min, Sum, Value
from django.db.models.expressions import Case, When from django.db.models.expressions import Case, When
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
...@@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase): ...@@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase):
transform=attrgetter('integer', 'join_test') transform=attrgetter('integer', 'join_test')
) )
def test_annotate_with_in_clause(self):
fk_rels = FKCaseTestModel.objects.filter(integer__in=[5])
self.assertQuerysetEqual(
CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case(
When(fk_rel__in=fk_rels, then=F('fk_rel__integer')),
default=Value(0),
))).order_by('pk'),
[(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)],
transform=attrgetter('integer', 'in_test')
)
def test_annotate_with_join_in_condition(self): def test_annotate_with_join_in_condition(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate(join_test=Case( CaseTestModel.objects.annotate(join_test=Case(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment