Kaydet (Commit) 69597e5b authored tarafından Anssi Kääriäinen's avatar Anssi Kääriäinen

Fixed #10790 -- Refactored sql.Query.setup_joins()

This is a rather large refactoring. The "lookup traversal" code was
splitted out from the setup_joins. There is now names_to_path() method
which does the lookup traveling, the actual work of setup_joins() is
calling names_to_path() and then adding the joins found into the query.

As a side effect it was possible to remove the "process_extra"
functionality used by genric relations. This never worked for left
joins. Now the extra restriction is appended directly to the join
condition instead of the where clause.

To generate the extra condition we need to have the join field
available in the compiler. This has the side-effect that we need more
ugly code in Query.__getstate__ and __setstate__ as Field objects
aren't pickleable.

The join trimming code got a big change - now we trim all direct joins
and never trim reverse joins. This also fixes the problem in #10790
which was join trimming in null filter cases.
üst f8116497
......@@ -205,17 +205,16 @@ class GenericRelation(RelatedField, Field):
# same db_type as well.
return None
def extra_filters(self, pieces, pos, negate):
def get_content_type(self):
"""
Return an extra filter to the queryset so that the results are filtered
on the appropriate content type.
Returns the content type associated with this field's model.
"""
if negate:
return []
content_type = ContentType.objects.get_for_model(self.model)
prefix = "__".join(pieces[:pos + 1])
return [("%s__%s" % (prefix, self.content_type_field_name),
content_type)]
return ContentType.objects.get_for_model(self.model)
def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
contenttype = self.get_content_type().pk
return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
......@@ -246,9 +245,6 @@ class ReverseGenericRelatedObjectsDescriptor(object):
if instance is None:
return self
# This import is done here to avoid circular import importing this module
from django.contrib.contenttypes.models import ContentType
# Dynamically create a class that subclasses the related model's
# default manager.
rel_model = self.field.rel.to
......@@ -379,8 +375,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
def __init__(self, data=None, files=None, instance=None, save_as_new=None,
prefix=None, queryset=None):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
opts = self.model._meta
self.instance = instance
self.rel_name = '-'.join((
......@@ -409,8 +403,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
))
def save_new(self, form, commit=True):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
kwargs = {
self.ct_field.get_attname(): ContentType.objects.get_for_model(self.instance).pk,
self.ct_fk_field.get_attname(): self.instance.pk,
......@@ -432,8 +424,6 @@ def generic_inlineformset_factory(model, form=ModelForm,
defaults ``content_type`` and ``object_id`` respectively.
"""
opts = model._meta
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
# if there is no field called `ct_field` let the exception propagate
ct_field = opts.get_field(ct_field)
if not isinstance(ct_field, models.ForeignKey) or ct_field.rel.to != ContentType:
......
......@@ -274,7 +274,8 @@ class SQLCompiler(object):
except KeyError:
link_field = opts.get_ancestor_link(model)
alias = self.query.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
link_field.column, model._meta.pk.column),
join_field=link_field)
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
......@@ -448,8 +449,8 @@ class SQLCompiler(object):
"""
if not alias:
alias = self.query.get_initial_alias()
field, target, opts, joins, _, _ = self.query.setup_joins(pieces,
opts, alias, REUSE_ALL)
field, target, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias, REUSE_ALL)
# We will later on need to promote those joins that were added to the
# query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
......@@ -501,20 +502,27 @@ class SQLCompiler(object):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
first = True
from_params = []
for alias in self.query.tables:
if not self.query.alias_refcount[alias]:
continue
try:
name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias]
name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
result.append('%s %s%s ON (%s.%s = %s.%s)'
% (join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col)))
if join_field and hasattr(join_field, 'get_extra_join_sql'):
extra_cond, extra_params = join_field.get_extra_join_sql(
self.connection, qn, lhs, alias)
from_params.extend(extra_params)
else:
extra_cond = ""
result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
(join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col), extra_cond))
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
......@@ -528,7 +536,7 @@ class SQLCompiler(object):
connector = not first and ', ' or ''
result.append('%s%s' % (connector, qn(alias)))
first = False
return result, []
return result, from_params
def get_grouping(self, ordering_group_by):
"""
......@@ -638,7 +646,7 @@ class SQLCompiler(object):
alias = self.query.join((alias, table, f.column,
f.rel.get_related_field().column),
promote=promote)
promote=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
......@@ -685,7 +693,7 @@ class SQLCompiler(object):
alias_chain.append(alias)
alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column),
promote=True
promote=True, join_field=f
)
from_parent = (opts.model if issubclass(model, opts.model)
else None)
......
......@@ -18,12 +18,19 @@ QUERY_TERMS = set([
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
# Constants to make looking up tuple values clearer.
# Namedtuples for sql.* internal use.
# Join lists (indexes into the tuples that are values in the alias_map
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable')
'lhs_join_col rhs_join_col nullable join_field')
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the join in Model terms (model Options and Fields for both
# sides of the join. The rel_field is the field we are joining along.
PathInfo = namedtuple('PathInfo',
'from_field to_field from_opts to_opts join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')
......
......@@ -50,10 +50,10 @@ class SQLEvaluator(object):
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
field, source, opts, join_list, last, _ = query.setup_joins(
field, source, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(),
query.get_initial_alias(), self.reuse)
col, _, join_list = query.trim_joins(source, join_list, last, False)
col, _, join_list = query.trim_joins(source, join_list, path)
if self.reuse is not None and self.reuse != REUSE_ALL:
self.reuse.update(join_list)
self.cols.append((node, (join_list[-1], col)))
......
This diff is collapsed.
......@@ -978,3 +978,7 @@ class AggregationTests(TestCase):
('The Definitive Guide to Django: Web Development Done Right', 2)
]
)
def test_reverse_join_trimming(self):
qs = Author.objects.annotate(Count('book_contact_set__contact'))
self.assertIn(' JOIN ', str(qs.query))
......@@ -283,6 +283,7 @@ class SingleObject(models.Model):
class RelatedObject(models.Model):
single = models.ForeignKey(SingleObject, null=True)
f = models.IntegerField(null=True)
class Meta:
ordering = ['single']
......@@ -311,7 +312,7 @@ class Food(models.Model):
@python_2_unicode_compatible
class Eaten(models.Model):
food = models.ForeignKey(Food, to_field="name")
food = models.ForeignKey(Food, to_field="name", null=True)
meal = models.CharField(max_length=20)
def __str__(self):
......@@ -400,3 +401,23 @@ class ModelA(models.Model):
name = models.TextField()
b = models.ForeignKey(ModelB, null=True)
d = models.ForeignKey(ModelD)
@python_2_unicode_compatible
class Job(models.Model):
name = models.CharField(max_length=20, unique=True)
def __str__(self):
return self.name
class JobResponsibilities(models.Model):
job = models.ForeignKey(Job, to_field='name')
responsibility = models.ForeignKey('Responsibility', to_field='description')
@python_2_unicode_compatible
class Responsibility(models.Model):
description = models.CharField(max_length=20, unique=True)
jobs = models.ManyToManyField(Job, through=JobResponsibilities,
related_name='responsibilities')
def __str__(self):
return self.description
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment