Kaydet (Commit) 5b3c66d8 authored tarafından Alberto Avila's avatar Alberto Avila Kaydeden (comit) Tim Graham

[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.

Partial backport of afe0bb7b from master.
üst e625859f
......@@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin):
bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin):
lookup_name = None
......@@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin):
def as_sql(self, compiler, connection):
raise NotImplementedError
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None):
......
......@@ -315,9 +315,9 @@ class WhereNode(tree.Node):
@classmethod
def _contains_aggregate(cls, obj):
if not isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False)
return any(cls._contains_aggregate(c) for c in obj.children)
if isinstance(obj, tree.Node):
return any(cls._contains_aggregate(c) for c in obj.children)
return obj.contains_aggregate
@cached_property
def contains_aggregate(self):
......@@ -336,6 +336,7 @@ class EverythingNode(object):
"""
A node that matches everything.
"""
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
return '', []
......@@ -345,11 +346,16 @@ class NothingNode(object):
"""
A node that matches nothing.
"""
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet
class ExtraWhere(object):
# The contents are a black box - assume no aggregates are used.
contains_aggregate = False
def __init__(self, sqls, params):
self.sqls = sqls
self.params = params
......@@ -410,6 +416,10 @@ class Constraint(object):
class SubqueryConstraint(object):
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
contains_aggregate = False
def __init__(self, alias, columns, targets, query_object):
self.alias = alias
self.columns = columns
......
......@@ -23,3 +23,6 @@ Bugfixes
``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that
already had the other specified, or when removing one of them from a field
that had both (:ticket:`26034`).
* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression
(:ticket:`26071`).
......@@ -8,7 +8,7 @@ from uuid import UUID
from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models import F, Q, Max, Min, Value
from django.db.models import F, Q, Max, Min, Sum, Value
from django.db.models.expressions import Case, When
from django.test import TestCase
from django.utils import six
......@@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase):
transform=attrgetter('integer', 'join_test')
)
def test_annotate_with_in_clause(self):
fk_rels = FKCaseTestModel.objects.filter(integer__in=[5])
self.assertQuerysetEqual(
CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case(
When(fk_rel__in=fk_rels, then=F('fk_rel__integer')),
default=Value(0),
))).order_by('pk'),
[(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)],
transform=attrgetter('integer', 'in_test')
)
def test_annotate_with_join_in_condition(self):
self.assertQuerysetEqual(
CaseTestModel.objects.annotate(join_test=Case(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment