Kaydet (Commit) d3f00bd5 authored tarafından Anssi Kääriäinen's avatar Anssi Kääriäinen

Refactored qs.add_q() and utils/tree.py

The sql/query.py add_q method did a lot of where/having tree hacking to
get complex queries to work correctly. The logic was refactored so that
it should be simpler to understand. The new logic should also produce
leaner WHERE conditions.

The changes cascade somewhat, as some other parts of Django (like
add_filter() and WhereNode) expect boolean trees in certain format or
they fail to work. So to fix the add_q() one must fix utils/tree.py,
some things in add_filter(), WhereNode and so on.

This commit also fixed add_filter to see negate clauses up the path.
A query like .exclude(Q(reversefk__in=a_list)) didn't work similarly to
.filter(~Q(reversefk__in=a_list)). The reason for this is that only
the immediate parent negate clauses were seen by add_filter, and thus a
tree like AND: (NOT AND: (AND: condition)) will not be handled
correctly, as there is one intermediary AND node in the tree. The
example tree is generated by .exclude(~Q(reversefk__in=a_list)).

Still, aggregation lost connectors in OR cases, and F() objects and
aggregates in same filter clause caused GROUP BY problems on some
databases.

Fixed #17600, fixed #13198, fixed #17025, fixed #17000, fixed #11293.
üst d744c550
...@@ -32,13 +32,14 @@ class GeoWhereNode(WhereNode): ...@@ -32,13 +32,14 @@ class GeoWhereNode(WhereNode):
Used to represent the SQL where-clause for spatial databases -- Used to represent the SQL where-clause for spatial databases --
these are tied to the GeoQuery class that created it. these are tied to the GeoQuery class that created it.
""" """
def add(self, data, connector):
def _prepare_data(self, data):
if isinstance(data, (list, tuple)): if isinstance(data, (list, tuple)):
obj, lookup_type, value = data obj, lookup_type, value = data
if ( isinstance(obj, Constraint) and if ( isinstance(obj, Constraint) and
isinstance(obj.field, GeometryField) ): isinstance(obj.field, GeometryField) ):
data = (GeoConstraint(obj), lookup_type, value) data = (GeoConstraint(obj), lookup_type, value)
super(GeoWhereNode, self).add(data, connector) return super(GeoWhereNode, self)._prepare_data(data)
def make_atom(self, child, qn, connection): def make_atom(self, child, qn, connection):
lvalue, lookup_type, value_annot, params_or_value = child lvalue, lookup_type, value_annot, params_or_value = child
......
""" """
Classes to represent the definitions of aggregate functions. Classes to represent the definitions of aggregate functions.
""" """
from django.db.models.constants import LOOKUP_SEP
def refs_aggregate(lookup_parts, aggregates):
"""
A little helper method to check if the lookup_parts contains references
to the given aggregates set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts
for match.
"""
for i in range(len(lookup_parts) + 1):
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
return True
return False
class Aggregate(object): class Aggregate(object):
""" """
......
...@@ -4,4 +4,3 @@ Constants used across the ORM in general. ...@@ -4,4 +4,3 @@ Constants used across the ORM in general.
# Separator used to split filter strings apart. # Separator used to split filter strings apart.
LOOKUP_SEP = '__' LOOKUP_SEP = '__'
import datetime import datetime
from django.db.models.aggregates import refs_aggregate
from django.db.models.constants import LOOKUP_SEP
from django.utils import tree from django.utils import tree
class ExpressionNode(tree.Node): class ExpressionNode(tree.Node):
...@@ -37,6 +40,18 @@ class ExpressionNode(tree.Node): ...@@ -37,6 +40,18 @@ class ExpressionNode(tree.Node):
obj.add(other, connector) obj.add(other, connector)
return obj return obj
def contains_aggregate(self, existing_aggregates):
if self.children:
return any(child.contains_aggregate(existing_aggregates)
for child in self.children
if hasattr(child, 'contains_aggregate'))
else:
return refs_aggregate(self.name.split(LOOKUP_SEP),
existing_aggregates)
def prepare_database_save(self, unused):
return self
################### ###################
# VISITOR METHODS # # VISITOR METHODS #
################### ###################
...@@ -113,9 +128,6 @@ class ExpressionNode(tree.Node): ...@@ -113,9 +128,6 @@ class ExpressionNode(tree.Node):
"Use .bitand() and .bitor() for bitwise logical operations." "Use .bitand() and .bitor() for bitwise logical operations."
) )
def prepare_database_save(self, unused):
return self
class F(ExpressionNode): class F(ExpressionNode):
""" """
An expression representing the value of the given field. An expression representing the value of the given field.
......
...@@ -47,6 +47,7 @@ class Q(tree.Node): ...@@ -47,6 +47,7 @@ class Q(tree.Node):
if not isinstance(other, Q): if not isinstance(other, Q):
raise TypeError(other) raise TypeError(other)
obj = type(self)() obj = type(self)()
obj.connector = conn
obj.add(self, conn) obj.add(self, conn)
obj.add(other, conn) obj.add(other, conn)
return obj return obj
...@@ -63,6 +64,16 @@ class Q(tree.Node): ...@@ -63,6 +64,16 @@ class Q(tree.Node):
obj.negate() obj.negate()
return obj return obj
def clone(self):
clone = self.__class__._new_instance(
children=[], connector=self.connector, negated=self.negated)
for child in self.children:
if hasattr(child, 'clone'):
clone.children.append(child.clone())
else:
clone.children.append(child)
return clone
class DeferredAttribute(object): class DeferredAttribute(object):
""" """
A wrapper for a deferred-loading field. When the value is read from this A wrapper for a deferred-loading field. When the value is read from this
......
...@@ -87,6 +87,7 @@ class SQLCompiler(object): ...@@ -87,6 +87,7 @@ class SQLCompiler(object):
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
having_group_by = self.query.having.get_cols()
params = [] params = []
for val in six.itervalues(self.query.extra_select): for val in six.itervalues(self.query.extra_select):
params.extend(val[1]) params.extend(val[1])
...@@ -107,7 +108,7 @@ class SQLCompiler(object): ...@@ -107,7 +108,7 @@ class SQLCompiler(object):
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
params.extend(w_params) params.extend(w_params)
grouping, gb_params = self.get_grouping(ordering_group_by) grouping, gb_params = self.get_grouping(having_group_by, ordering_group_by)
if grouping: if grouping:
if distinct_fields: if distinct_fields:
raise NotImplementedError( raise NotImplementedError(
...@@ -534,7 +535,7 @@ class SQLCompiler(object): ...@@ -534,7 +535,7 @@ class SQLCompiler(object):
first = False first = False
return result, from_params return result, from_params
def get_grouping(self, ordering_group_by): def get_grouping(self, having_group_by, ordering_group_by):
""" """
Returns a tuple representing the SQL elements in the "group by" clause. Returns a tuple representing the SQL elements in the "group by" clause.
""" """
...@@ -551,7 +552,7 @@ class SQLCompiler(object): ...@@ -551,7 +552,7 @@ class SQLCompiler(object):
] ]
select_cols = [] select_cols = []
seen = set() seen = set()
cols = self.query.group_by + select_cols cols = self.query.group_by + having_group_by + select_cols
for col in cols: for col in cols:
col_params = () col_params = ()
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
......
...@@ -7,16 +7,14 @@ class SQLEvaluator(object): ...@@ -7,16 +7,14 @@ class SQLEvaluator(object):
def __init__(self, expression, query, allow_joins=True, reuse=None): def __init__(self, expression, query, allow_joins=True, reuse=None):
self.expression = expression self.expression = expression
self.opts = query.get_meta() self.opts = query.get_meta()
self.cols = []
self.contains_aggregate = False
self.reuse = reuse self.reuse = reuse
self.cols = []
self.expression.prepare(self, query, allow_joins) self.expression.prepare(self, query, allow_joins)
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
clone = copy.copy(self) clone = copy.copy(self)
clone.cols = [] clone.cols = []
for node, col in self.cols[:]: for node, col in self.cols:
if hasattr(col, 'relabeled_clone'): if hasattr(col, 'relabeled_clone'):
clone.cols.append((node, col.relabeled_clone(change_map))) clone.cols.append((node, col.relabeled_clone(change_map)))
else: else:
...@@ -24,6 +22,15 @@ class SQLEvaluator(object): ...@@ -24,6 +22,15 @@ class SQLEvaluator(object):
(change_map.get(col[0], col[0]), col[1]))) (change_map.get(col[0], col[0]), col[1])))
return clone return clone
def get_cols(self):
cols = []
for node, col in self.cols:
if hasattr(node, 'get_cols'):
cols.extend(node.get_cols())
elif isinstance(col, tuple):
cols.append(col)
return cols
def prepare(self): def prepare(self):
return self return self
...@@ -44,9 +51,7 @@ class SQLEvaluator(object): ...@@ -44,9 +51,7 @@ class SQLEvaluator(object):
raise FieldError("Joined field references are not permitted in this query") raise FieldError("Joined field references are not permitted in this query")
field_list = node.name.split(LOOKUP_SEP) field_list = node.name.split(LOOKUP_SEP)
if (len(field_list) == 1 and if node.name in query.aggregates:
node.name in query.aggregate_select.keys()):
self.contains_aggregate = True
self.cols.append((node, query.aggregate_select[node.name])) self.cols.append((node, query.aggregate_select[node.name]))
else: else:
try: try:
......
This diff is collapsed.
...@@ -45,7 +45,7 @@ class DeleteQuery(Query): ...@@ -45,7 +45,7 @@ class DeleteQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class() where = self.where_class()
where.add((Constraint(None, field.column, field), 'in', where.add((Constraint(None, field.column, field), 'in',
pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND) pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND)
self.do_query(self.model._meta.db_table, where, using=using) self.do_query(self.model._meta.db_table, where, using=using)
def delete_qs(self, query, using): def delete_qs(self, query, using):
...@@ -117,8 +117,8 @@ class UpdateQuery(Query): ...@@ -117,8 +117,8 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class() self.where = self.where_class()
self.where.add((Constraint(None, pk_field.column, pk_field), 'in', self.where.add((Constraint(None, pk_field.column, pk_field), 'in',
pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]),
AND) AND)
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(None)
def add_update_values(self, values): def add_update_values(self, values):
......
...@@ -46,18 +46,17 @@ class WhereNode(tree.Node): ...@@ -46,18 +46,17 @@ class WhereNode(tree.Node):
""" """
default = AND default = AND
def add(self, data, connector): def _prepare_data(self, data):
""" """
Add a node to the where-tree. If the data is a list or tuple, it is Prepare data for addition to the tree. If the data is a list or tuple,
expected to be of the form (obj, lookup_type, value), where obj is it is expected to be of the form (obj, lookup_type, value), where obj
a Constraint object, and is then slightly munged before being stored is a Constraint object, and is then slightly munged before being
(to avoid storing any reference to field objects). Otherwise, the 'data' stored (to avoid storing any reference to field objects). Otherwise,
is stored unchanged and can be any class with an 'as_sql()' method. the 'data' is stored unchanged and can be any class with an 'as_sql()'
method.
""" """
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
super(WhereNode, self).add(data, connector) return data
return
obj, lookup_type, value = data obj, lookup_type, value = data
if isinstance(value, collections.Iterator): if isinstance(value, collections.Iterator):
# Consume any generators immediately, so that we can determine # Consume any generators immediately, so that we can determine
...@@ -78,9 +77,7 @@ class WhereNode(tree.Node): ...@@ -78,9 +77,7 @@ class WhereNode(tree.Node):
if hasattr(obj, "prepare"): if hasattr(obj, "prepare"):
value = obj.prepare(lookup_type, value) value = obj.prepare(lookup_type, value)
return (obj, lookup_type, value_annotation, value)
super(WhereNode, self).add(
(obj, lookup_type, value_annotation, value), connector)
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
""" """
...@@ -154,6 +151,18 @@ class WhereNode(tree.Node): ...@@ -154,6 +151,18 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string sql_string = '(%s)' % sql_string
return sql_string, result_params return sql_string, result_params
def get_cols(self):
cols = []
for child in self.children:
if hasattr(child, 'get_cols'):
cols.extend(child.get_cols())
else:
if isinstance(child[0], Constraint):
cols.append((child[0].alias, child[0].col))
if hasattr(child[3], 'get_cols'):
cols.extend(child[3].get_cols())
return cols
def make_atom(self, child, qn, connection): def make_atom(self, child, qn, connection):
""" """
Turn a tuple (Constraint(table_alias, column_name, db_type), Turn a tuple (Constraint(table_alias, column_name, db_type),
...@@ -284,7 +293,6 @@ class WhereNode(tree.Node): ...@@ -284,7 +293,6 @@ class WhereNode(tree.Node):
with empty subtree_parents). Childs must be either (Contraint, lookup, with empty subtree_parents). Childs must be either (Contraint, lookup,
value) tuples, or objects supporting .clone(). value) tuples, or objects supporting .clone().
""" """
assert not self.subtree_parents
clone = self.__class__._new_instance( clone = self.__class__._new_instance(
children=[], connector=self.connector, negated=self.negated) children=[], connector=self.connector, negated=self.negated)
for child in self.children: for child in self.children:
......
...@@ -19,14 +19,9 @@ class Node(object): ...@@ -19,14 +19,9 @@ class Node(object):
""" """
Constructs a new Node. If no connector is given, the default will be Constructs a new Node. If no connector is given, the default will be
used. used.
Warning: You probably don't want to pass in the 'negated' parameter. It
is NOT the same as constructing a node and calling negate() on the
result.
""" """
self.children = children and children[:] or [] self.children = children and children[:] or []
self.connector = connector or self.default self.connector = connector or self.default
self.subtree_parents = []
self.negated = negated self.negated = negated
# We need this because of django.db.models.query_utils.Q. Q. __init__() is # We need this because of django.db.models.query_utils.Q. Q. __init__() is
...@@ -59,7 +54,6 @@ class Node(object): ...@@ -59,7 +54,6 @@ class Node(object):
obj = Node(connector=self.connector, negated=self.negated) obj = Node(connector=self.connector, negated=self.negated)
obj.__class__ = self.__class__ obj.__class__ = self.__class__
obj.children = copy.deepcopy(self.children, memodict) obj.children = copy.deepcopy(self.children, memodict)
obj.subtree_parents = copy.deepcopy(self.subtree_parents, memodict)
return obj return obj
def __len__(self): def __len__(self):
...@@ -83,74 +77,60 @@ class Node(object): ...@@ -83,74 +77,60 @@ class Node(object):
""" """
return other in self.children return other in self.children
def add(self, node, conn_type): def _prepare_data(self, data):
""" """
Adds a new node to the tree. If the conn_type is the same as the root's A subclass hook for doing subclass specific transformations of the
current connector type, the node is added to the first level. given data on combine() or add().
Otherwise, the whole tree is pushed down one level and a new root
connector is created, connecting the existing tree and the new node.
""" """
if node in self.children and conn_type == self.connector: return data
return
if len(self.children) < 2:
self.connector = conn_type
if self.connector == conn_type:
if isinstance(node, Node) and (node.connector == conn_type or
len(node) == 1):
self.children.extend(node.children)
else:
self.children.append(node)
else:
obj = self._new_instance(self.children, self.connector,
self.negated)
self.connector = conn_type
self.children = [obj, node]
def negate(self): def add(self, data, conn_type, squash=True):
""" """
Negate the sense of the root connector. This reorganises the children Combines this tree and the data represented by data using the
so that the current node has a single child: a negated node containing connector conn_type. The combine is done by squashing the node other
all the previous children. This slightly odd construction makes adding away if possible.
new children behave more intuitively.
Interpreting the meaning of this negate is up to client code. This This tree (self) will never be pushed to a child node of the
method is useful for implementing "not" arrangements. combined tree, nor will the connector or negated properties change.
"""
self.children = [self._new_instance(self.children, self.connector,
not self.negated)]
self.connector = self.default
def start_subtree(self, conn_type): The function returns a node which can be used in place of data
""" regardless if the node other got squashed or not.
Sets up internal state so that new nodes are added to a subtree of the
current node. The conn_type specifies how the sub-tree is joined to the If `squash` is False the data is prepared and added as a child to
existing children. this tree without further logic.
""" """
if len(self.children) == 1: if data in self.children:
self.connector = conn_type return data
elif self.connector != conn_type: data = self._prepare_data(data)
self.children = [self._new_instance(self.children, self.connector, if not squash:
self.negated)] self.children.append(data)
return data
if self.connector == conn_type:
# We can reuse self.children to append or squash the node other.
if (isinstance(data, Node) and not data.negated
and (data.connector == conn_type or len(data) == 1)):
# We can squash the other node's children directly into this
# node. We are just doing (AB)(CD) == (ABCD) here, with the
# addition that if the length of the other node is 1 the
# connector doesn't matter. However, for the len(self) == 1
# case we don't want to do the squashing, as it would alter
# self.connector.
self.children.extend(data.children)
return self
else:
# We could use perhaps additional logic here to see if some
# children could be used for pushdown here.
self.children.append(data)
return data
else:
obj = self._new_instance(self.children, self.connector,
self.negated)
self.connector = conn_type self.connector = conn_type
self.negated = False self.children = [obj, data]
return data
self.subtree_parents.append(self.__class__(self.children,
self.connector, self.negated))
self.connector = self.default
self.negated = False
self.children = []
def end_subtree(self): def negate(self):
""" """
Closes off the most recently unmatched start_subtree() call. Negate the sense of the root connector.
This puts the current state into a node of the parent tree and returns
the current instances state to be the parent.
""" """
obj = self.subtree_parents.pop() self.negated = not self.negated
node = self.__class__(self.children, self.connector)
self.connector = obj.connector
self.negated = obj.negated
self.children = obj.children
self.children.append(node)
...@@ -10,6 +10,7 @@ from django.contrib.contenttypes.models import ContentType ...@@ -10,6 +10,7 @@ from django.contrib.contenttypes.models import ContentType
from django.db.models import Count, Max, Avg, Sum, StdDev, Variance, F, Q from django.db.models import Count, Max, Avg, Sum, StdDev, Variance, F, Q
from django.test import TestCase, Approximate, skipUnlessDBFeature from django.test import TestCase, Approximate, skipUnlessDBFeature
from django.utils import six from django.utils import six
from django.utils.unittest import expectedFailure
from .models import (Author, Book, Publisher, Clues, Entries, HardbackBook, from .models import (Author, Book, Publisher, Clues, Entries, HardbackBook,
ItemTag, WithManualPK) ItemTag, WithManualPK)
...@@ -472,7 +473,7 @@ class AggregationTests(TestCase): ...@@ -472,7 +473,7 @@ class AggregationTests(TestCase):
# Regression for #15709 - Ensure each group_by field only exists once # Regression for #15709 - Ensure each group_by field only exists once
# per query # per query
qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by() qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by()
grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([]) grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], [])
self.assertEqual(len(grouping), 1) self.assertEqual(len(grouping), 1)
def test_duplicate_alias(self): def test_duplicate_alias(self):
...@@ -847,14 +848,14 @@ class AggregationTests(TestCase): ...@@ -847,14 +848,14 @@ class AggregationTests(TestCase):
# The name of the explicitly provided annotation name in this case # The name of the explicitly provided annotation name in this case
# poses no problem # poses no problem
qs = Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) qs = Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2).order_by('name')
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, qs,
['Peter Norvig'], ['Peter Norvig'],
lambda b: b.name lambda b: b.name
) )
# Neither in this case # Neither in this case
qs = Author.objects.annotate(book_count=Count('book')).filter(book_count=2) qs = Author.objects.annotate(book_count=Count('book')).filter(book_count=2).order_by('name')
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, qs,
['Peter Norvig'], ['Peter Norvig'],
...@@ -862,7 +863,7 @@ class AggregationTests(TestCase): ...@@ -862,7 +863,7 @@ class AggregationTests(TestCase):
) )
# This case used to fail because the ORM couldn't resolve the # This case used to fail because the ORM couldn't resolve the
# automatically generated annotation name `book__count` # automatically generated annotation name `book__count`
qs = Author.objects.annotate(Count('book')).filter(book__count=2) qs = Author.objects.annotate(Count('book')).filter(book__count=2).order_by('name')
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, qs,
['Peter Norvig'], ['Peter Norvig'],
...@@ -1020,3 +1021,83 @@ class AggregationTests(TestCase): ...@@ -1020,3 +1021,83 @@ class AggregationTests(TestCase):
('The Definitive Guide to Django: Web Development Done Right', 0) ('The Definitive Guide to Django: Web Development Done Right', 0)
] ]
) )
def test_negated_aggregation(self):
expected_results = Author.objects.exclude(
pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2)
).order_by('name')
expected_results = [a.name for a in expected_results]
qs = Author.objects.annotate(book_cnt=Count('book')).exclude(
Q(book_cnt=2), Q(book_cnt=2)).order_by('name')
self.assertQuerysetEqual(
qs,
expected_results,
lambda b: b.name
)
expected_results = Author.objects.exclude(
pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2)
).order_by('name')
expected_results = [a.name for a in expected_results]
qs = Author.objects.annotate(book_cnt=Count('book')).exclude(Q(book_cnt=2)|Q(book_cnt=2)).order_by('name')
self.assertQuerysetEqual(
qs,
expected_results,
lambda b: b.name
)
def test_name_filters(self):
qs = Author.objects.annotate(Count('book')).filter(
Q(book__count__exact=2)|Q(name='Adrian Holovaty')
).order_by('name')
self.assertQuerysetEqual(
qs,
['Adrian Holovaty', 'Peter Norvig'],
lambda b: b.name
)
def test_name_expressions(self):
# Test that aggregates are spotted corretly from F objects.
# Note that Adrian's age is 34 in the fixtures, and he has one book
# so both conditions match one author.
qs = Author.objects.annotate(Count('book')).filter(
Q(name='Peter Norvig')|Q(age=F('book__count') + 33)
).order_by('name')
self.assertQuerysetEqual(
qs,
['Adrian Holovaty', 'Peter Norvig'],
lambda b: b.name
)
def test_ticket_11293(self):
q1 = Q(price__gt=50)
q2 = Q(authors__count__gt=1)
query = Book.objects.annotate(Count('authors')).filter(
q1 | q2).order_by('pk')
self.assertQuerysetEqual(
query, [1, 4, 5, 6],
lambda b: b.pk)
def test_ticket_11293_q_immutable(self):
"""
Check that splitting a q object to parts for where/having doesn't alter
the original q-object.
"""
q1 = Q(isbn='')
q2 = Q(authors__count__gt=1)
query = Book.objects.annotate(Count('authors'))
query.filter(q1 | q2)
self.assertEqual(len(q2.children), 1)
def test_fobj_group_by(self):
"""
Check that an F() object referring to related column works correctly
in group by.
"""
qs = Book.objects.annotate(
acount=Count('authors')
).filter(
acount=F('publisher__num_awards')
)
self.assertQuerysetEqual(
qs, ['Sams Teach Yourself Django in 24 Hours'],
lambda b: b.name)
...@@ -475,3 +475,25 @@ class MyObject(models.Model): ...@@ -475,3 +475,25 @@ class MyObject(models.Model):
parent = models.ForeignKey('self', null=True, blank=True, related_name='children') parent = models.ForeignKey('self', null=True, blank=True, related_name='children')
data = models.CharField(max_length=100) data = models.CharField(max_length=100)
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
# Models for #17600 regressions
@python_2_unicode_compatible
class Order(models.Model):
id = models.IntegerField(primary_key=True)
class Meta:
ordering = ('pk', )
def __str__(self):
return '%s' % self.pk
@python_2_unicode_compatible
class OrderItem(models.Model):
order = models.ForeignKey(Order, related_name='items')
status = models.IntegerField()
class Meta:
ordering = ('pk', )
def __str__(self):
return '%s' % self.pk
...@@ -23,9 +23,9 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, ...@@ -23,9 +23,9 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten,
Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory,
SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, SingleObject, RelatedObject, ModelA, ModelB, ModelC, ModelD, Responsibility,
JobResponsibilities, BaseA, Identifier, Program, Channel, Page, Paragraph, Job, JobResponsibilities, BaseA, Identifier, Program, Channel, Page,
Chapter, Book, MyObject) Paragraph, Chapter, Book, MyObject, Order, OrderItem)
class BaseQuerysetTest(TestCase): class BaseQuerysetTest(TestCase):
...@@ -834,7 +834,6 @@ class Queries1Tests(BaseQuerysetTest): ...@@ -834,7 +834,6 @@ class Queries1Tests(BaseQuerysetTest):
Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)), Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)),
['<Note: n1>', '<Note: n3>'] ['<Note: n1>', '<Note: n3>']
) )
xx.delete()
q = Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)).query q = Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)).query
self.assertEqual( self.assertEqual(
len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]), len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]),
...@@ -880,7 +879,6 @@ class Queries1Tests(BaseQuerysetTest): ...@@ -880,7 +879,6 @@ class Queries1Tests(BaseQuerysetTest):
Item.objects.filter(Q(tags__name='t4')), Item.objects.filter(Q(tags__name='t4')),
[repr(i) for i in Item.objects.filter(~Q(~Q(tags__name='t4')))]) [repr(i) for i in Item.objects.filter(~Q(~Q(tags__name='t4')))])
@unittest.expectedFailure
def test_exclude_in(self): def test_exclude_in(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
Item.objects.exclude(Q(tags__name__in=['t4', 't3'])), Item.objects.exclude(Q(tags__name__in=['t4', 't3'])),
...@@ -2291,6 +2289,103 @@ class ExcludeTest(TestCase): ...@@ -2291,6 +2289,103 @@ class ExcludeTest(TestCase):
Responsibility.objects.exclude(jobs__name='Manager'), Responsibility.objects.exclude(jobs__name='Manager'),
['<Responsibility: Programming>']) ['<Responsibility: Programming>'])
class ExcludeTest17600(TestCase):
"""
Some regressiontests for ticket #17600. Some of these likely duplicate
other existing tests.
"""
def setUp(self):
# Create a few Orders.
self.o1 = Order.objects.create(pk=1)
self.o2 = Order.objects.create(pk=2)
self.o3 = Order.objects.create(pk=3)
# Create some OrderItems for the first order with homogeneous
# status_id values
self.oi1 = OrderItem.objects.create(order=self.o1, status=1)
self.oi2 = OrderItem.objects.create(order=self.o1, status=1)
self.oi3 = OrderItem.objects.create(order=self.o1, status=1)
# Create some OrderItems for the second order with heterogeneous
# status_id values
self.oi4 = OrderItem.objects.create(order=self.o2, status=1)
self.oi5 = OrderItem.objects.create(order=self.o2, status=2)
self.oi6 = OrderItem.objects.create(order=self.o2, status=3)
# Create some OrderItems for the second order with heterogeneous
# status_id values
self.oi7 = OrderItem.objects.create(order=self.o3, status=2)
self.oi8 = OrderItem.objects.create(order=self.o3, status=3)
self.oi9 = OrderItem.objects.create(order=self.o3, status=4)
def test_exclude_plain(self):
"""
This should exclude Orders which have some items with status 1
"""
self.assertQuerysetEqual(
Order.objects.exclude(items__status=1),
['<Order: 3>'])
def test_exclude_plain_distinct(self):
"""
This should exclude Orders which have some items with status 1
"""
self.assertQuerysetEqual(
Order.objects.exclude(items__status=1).distinct(),
['<Order: 3>'])
def test_exclude_with_q_object_distinct(self):
"""
This should exclude Orders which have some items with status 1
"""
self.assertQuerysetEqual(
Order.objects.exclude(Q(items__status=1)).distinct(),
['<Order: 3>'])
def test_exclude_with_q_object_no_distinct(self):
"""
This should exclude Orders which have some items with status 1
"""
self.assertQuerysetEqual(
Order.objects.exclude(Q(items__status=1)),
['<Order: 3>'])
def test_exclude_with_q_is_equal_to_plain_exclude(self):
"""
Using exclude(condition) and exclude(Q(condition)) should
yield the same QuerySet
"""
self.assertEqual(
list(Order.objects.exclude(items__status=1).distinct()),
list(Order.objects.exclude(Q(items__status=1)).distinct()))
def test_exclude_with_q_is_equal_to_plain_exclude_variation(self):
"""
Using exclude(condition) and exclude(Q(condition)) should
yield the same QuerySet
"""
self.assertEqual(
list(Order.objects.exclude(items__status=1)),
list(Order.objects.exclude(Q(items__status=1)).distinct()))
@unittest.expectedFailure
def test_only_orders_with_all_items_having_status_1(self):
"""
This should only return orders having ALL items set to status 1, or
those items not having any orders at all. The correct way to write
this query in SQL seems to be using two nested subqueries.
"""
self.assertQuerysetEqual(
Order.objects.exclude(~Q(items__status=1)).distinct(),
['<Order: 1>'])
class NullInExcludeTest(TestCase): class NullInExcludeTest(TestCase):
def setUp(self): def setUp(self):
NullableName.objects.create(name='i1') NullableName.objects.create(name='i1')
...@@ -2326,6 +2421,14 @@ class NullInExcludeTest(TestCase): ...@@ -2326,6 +2421,14 @@ class NullInExcludeTest(TestCase):
NullableName.objects.exclude(name__in=[None]), NullableName.objects.exclude(name__in=[None]),
['i1'], attrgetter('name')) ['i1'], attrgetter('name'))
def test_double_exclude(self):
self.assertEqual(
list(NullableName.objects.filter(~~Q(name='i1'))),
list(NullableName.objects.filter(Q(name='i1'))))
self.assertNotIn(
'IS NOT NULL',
str(NullableName.objects.filter(~~Q(name='i1')).query))
class EmptyStringsAsNullTest(TestCase): class EmptyStringsAsNullTest(TestCase):
""" """
Test that filtering on non-null character fields works as expected. Test that filtering on non-null character fields works as expected.
...@@ -2433,8 +2536,12 @@ class WhereNodeTest(TestCase): ...@@ -2433,8 +2536,12 @@ class WhereNodeTest(TestCase):
class NullJoinPromotionOrTest(TestCase): class NullJoinPromotionOrTest(TestCase):
def setUp(self): def setUp(self):
d = ModelD.objects.create(name='foo') self.d1 = ModelD.objects.create(name='foo')
ModelA.objects.create(name='bar', d=d) d2 = ModelD.objects.create(name='bar')
self.a1 = ModelA.objects.create(name='a1', d=self.d1)
c = ModelC.objects.create(name='c')
b = ModelB.objects.create(name='b', c=c)
self.a2 = ModelA.objects.create(name='a2', b=b, d=d2)
def test_ticket_17886(self): def test_ticket_17886(self):
# The first Q-object is generating the match, the rest of the filters # The first Q-object is generating the match, the rest of the filters
...@@ -2448,12 +2555,38 @@ class NullJoinPromotionOrTest(TestCase): ...@@ -2448,12 +2555,38 @@ class NullJoinPromotionOrTest(TestCase):
Q(b__c__name='foo') Q(b__c__name='foo')
) )
qset = ModelA.objects.filter(q_obj) qset = ModelA.objects.filter(q_obj)
self.assertEqual(len(qset), 1) self.assertEqual(list(qset), [self.a1])
# We generate one INNER JOIN to D. The join is direct and not nullable # We generate one INNER JOIN to D. The join is direct and not nullable
# so we can use INNER JOIN for it. However, we can NOT use INNER JOIN # so we can use INNER JOIN for it. However, we can NOT use INNER JOIN
# for the b->c join, as a->b is nullable. # for the b->c join, as a->b is nullable.
self.assertEqual(str(qset.query).count('INNER JOIN'), 1) self.assertEqual(str(qset.query).count('INNER JOIN'), 1)
def test_isnull_filter_promotion(self):
qs = ModelA.objects.filter(Q(b__name__isnull=True))
self.assertEqual(str(qs.query).count('LEFT OUTER'), 1)
self.assertEqual(list(qs), [self.a1])
qs = ModelA.objects.filter(~Q(b__name__isnull=True))
self.assertEqual(str(qs.query).count('INNER JOIN'), 1)
self.assertEqual(list(qs), [self.a2])
qs = ModelA.objects.filter(~~Q(b__name__isnull=True))
self.assertEqual(str(qs.query).count('LEFT OUTER'), 1)
self.assertEqual(list(qs), [self.a1])
qs = ModelA.objects.filter(Q(b__name__isnull=False))
self.assertEqual(str(qs.query).count('INNER JOIN'), 1)
self.assertEqual(list(qs), [self.a2])
qs = ModelA.objects.filter(~Q(b__name__isnull=False))
self.assertEqual(str(qs.query).count('LEFT OUTER'), 1)
self.assertEqual(list(qs), [self.a1])
qs = ModelA.objects.filter(~~Q(b__name__isnull=False))
self.assertEqual(str(qs.query).count('INNER JOIN'), 1)
self.assertEqual(list(qs), [self.a2])
class ReverseJoinTrimmingTest(TestCase): class ReverseJoinTrimmingTest(TestCase):
def test_reverse_trimming(self): def test_reverse_trimming(self):
# Check that we don't accidentally trim reverse joins - we can't know # Check that we don't accidentally trim reverse joins - we can't know
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment