Kaydet (Commit) 01d440fa authored tarafından Nicolas Delaby's avatar Nicolas Delaby Kaydeden (comit) Tim Graham

Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.

Thanks Anssi Kääriäinen for contributing to the patch.
üst 3f9d85d9
...@@ -348,7 +348,7 @@ class GenericRelation(ForeignObject): ...@@ -348,7 +348,7 @@ class GenericRelation(ForeignObject):
self.to_fields = [self.model._meta.pk.name] self.to_fields = [self.model._meta.pk.name]
return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
def _get_path_info_with_parent(self): def _get_path_info_with_parent(self, filtered_relation):
""" """
Return the path that joins the current model through any parent models. Return the path that joins the current model through any parent models.
The idea is that if you have a GFK defined on a parent model then we The idea is that if you have a GFK defined on a parent model then we
...@@ -365,7 +365,15 @@ class GenericRelation(ForeignObject): ...@@ -365,7 +365,15 @@ class GenericRelation(ForeignObject):
opts = self.remote_field.model._meta.concrete_model._meta opts = self.remote_field.model._meta.concrete_model._meta
parent_opts = opts.get_field(self.object_id_field_name).model._meta parent_opts = opts.get_field(self.object_id_field_name).model._meta
target = parent_opts.pk target = parent_opts.pk
path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False)) path.append(PathInfo(
from_opts=self.model._meta,
to_opts=parent_opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
))
# Collect joins needed for the parent -> child chain. This is easiest # Collect joins needed for the parent -> child chain. This is easiest
# to do if we collect joins for the child -> parent chain and then # to do if we collect joins for the child -> parent chain and then
# reverse the direction (call to reverse() and use of # reverse the direction (call to reverse() and use of
...@@ -380,19 +388,35 @@ class GenericRelation(ForeignObject): ...@@ -380,19 +388,35 @@ class GenericRelation(ForeignObject):
path.extend(field.remote_field.get_path_info()) path.extend(field.remote_field.get_path_info())
return path return path
def get_path_info(self): def get_path_info(self, filtered_relation=None):
opts = self.remote_field.model._meta opts = self.remote_field.model._meta
object_id_field = opts.get_field(self.object_id_field_name) object_id_field = opts.get_field(self.object_id_field_name)
if object_id_field.model != opts.model: if object_id_field.model != opts.model:
return self._get_path_info_with_parent() return self._get_path_info_with_parent(filtered_relation)
else: else:
target = opts.pk target = opts.pk
return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] return [PathInfo(
from_opts=self.model._meta,
def get_reverse_path_info(self): to_opts=opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self, filtered_relation=None):
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)] return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def value_to_string(self, obj): def value_to_string(self, obj):
qs = getattr(obj, self.name).all() qs = getattr(obj, self.name).all()
......
...@@ -20,6 +20,7 @@ from django.db.models.manager import Manager ...@@ -20,6 +20,7 @@ from django.db.models.manager import Manager
from django.db.models.query import ( from django.db.models.query import (
Prefetch, Q, QuerySet, prefetch_related_objects, Prefetch, Q, QuerySet, prefetch_related_objects,
) )
from django.db.models.query_utils import FilteredRelation
# Imports that would create circular imports if sorted # Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip from django.db.models.base import DEFERRED, Model # isort:skip
...@@ -69,6 +70,7 @@ __all__ += [ ...@@ -69,6 +70,7 @@ __all__ += [
'Window', 'WindowFrame', 'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink',
] ]
...@@ -697,18 +697,33 @@ class ForeignObject(RelatedField): ...@@ -697,18 +697,33 @@ class ForeignObject(RelatedField):
""" """
return None return None
def get_path_info(self): def get_path_info(self, filtered_relation=None):
"""Get path from this field to the related model.""" """Get path from this field to the related model."""
opts = self.remote_field.model._meta opts = self.remote_field.model._meta
from_opts = self.model._meta from_opts = self.model._meta
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] return [PathInfo(
from_opts=from_opts,
def get_reverse_path_info(self): to_opts=opts,
target_fields=self.foreign_related_fields,
join_field=self,
m2m=False,
direct=True,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model.""" """Get path from the related model to this field's model."""
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] return [PathInfo(
return pathinfos from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
@classmethod @classmethod
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -861,12 +876,19 @@ class ForeignKey(ForeignObject): ...@@ -861,12 +876,19 @@ class ForeignKey(ForeignObject):
def target_field(self): def target_field(self):
return self.foreign_related_fields[0] return self.foreign_related_fields[0]
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model.""" """Get path from the related model to this field's model."""
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] return [PathInfo(
return pathinfos from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def validate(self, value, model_instance): def validate(self, value, model_instance):
if self.remote_field.parent_link: if self.remote_field.parent_link:
...@@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField): ...@@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField):
) )
return name, path, args, kwargs return name, path, args, kwargs
def _get_path_info(self, direct=False): def _get_path_info(self, direct=False, filtered_relation=None):
"""Called by both direct and indirect m2m traversal.""" """Called by both direct and indirect m2m traversal."""
pathinfos = [] pathinfos = []
int_model = self.remote_field.through int_model = self.remote_field.through
...@@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField): ...@@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField):
linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name())
if direct: if direct:
join1infos = linkfield1.get_reverse_path_info() join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info() join2infos = linkfield2.get_path_info(filtered_relation)
else: else:
join1infos = linkfield2.get_reverse_path_info() join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info() join2infos = linkfield1.get_path_info(filtered_relation)
# Get join infos between the last model of join 1 and the first model # Get join infos between the last model of join 1 and the first model
# of join 2. Assume the only reason these may differ is due to model # of join 2. Assume the only reason these may differ is due to model
...@@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField): ...@@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField):
pathinfos.extend(join2infos) pathinfos.extend(join2infos)
return pathinfos return pathinfos
def get_path_info(self): def get_path_info(self, filtered_relation=None):
return self._get_path_info(direct=True) return self._get_path_info(direct=True, filtered_relation=filtered_relation)
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
return self._get_path_info(direct=False) return self._get_path_info(direct=False, filtered_relation=filtered_relation)
def _get_m2m_db_table(self, opts): def _get_m2m_db_table(self, opts):
""" """
......
...@@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin): ...@@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin):
return self.related_name return self.related_name
return opts.model_name + ('_set' if self.multiple else '') return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self): def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info() return self.field.get_reverse_path_info(filtered_relation)
def get_cache_name(self): def get_cache_name(self):
""" """
......
...@@ -632,7 +632,15 @@ class Options: ...@@ -632,7 +632,15 @@ class Options:
final_field = opts.parents[int_model] final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),) targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta opts = int_model._meta
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) path.append(PathInfo(
from_opts=final_field.model._meta,
to_opts=opts,
target_fields=targets,
join_field=final_field,
m2m=False,
direct=True,
filtered_relation=None,
))
return path return path
def get_path_from_parent(self, parent): def get_path_from_parent(self, parent):
......
...@@ -22,7 +22,7 @@ from django.db.models.deletion import Collector ...@@ -22,7 +22,7 @@ from django.db.models.deletion import Collector
from django.db.models.expressions import F from django.db.models.expressions import F
from django.db.models.fields import AutoField from django.db.models.fields import AutoField
from django.db.models.functions import Trunc from django.db.models.functions import Trunc
from django.db.models.query_utils import InvalidQuery, Q from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
from django.utils import timezone from django.utils import timezone
from django.utils.deprecation import RemovedInDjango30Warning from django.utils.deprecation import RemovedInDjango30Warning
...@@ -953,6 +953,12 @@ class QuerySet: ...@@ -953,6 +953,12 @@ class QuerySet:
if lookups == (None,): if lookups == (None,):
clone._prefetch_related_lookups = () clone._prefetch_related_lookups = ()
else: else:
for lookup in lookups:
if isinstance(lookup, Prefetch):
lookup = lookup.prefetch_to
lookup = lookup.split(LOOKUP_SEP, 1)[0]
if lookup in self.query._filtered_relations:
raise ValueError('prefetch_related() is not supported with FilteredRelation.')
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
return clone return clone
...@@ -984,7 +990,10 @@ class QuerySet: ...@@ -984,7 +990,10 @@ class QuerySet:
if alias in names: if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on " raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % alias) "the model." % alias)
clone.query.add_annotation(annotation, alias, is_summary=False) if isinstance(annotation, FilteredRelation):
clone.query.add_filtered_relation(annotation, alias)
else:
clone.query.add_annotation(annotation, alias, is_summary=False)
for alias, annotation in clone.query.annotations.items(): for alias, annotation in clone.query.annotations.items():
if alias in annotations and annotation.contains_aggregate: if alias in annotations and annotation.contains_aggregate:
...@@ -1060,6 +1069,10 @@ class QuerySet: ...@@ -1060,6 +1069,10 @@ class QuerySet:
# Can only pass None to defer(), not only(), as the rest option. # Can only pass None to defer(), not only(), as the rest option.
# That won't stop people trying to do this, so let's be explicit. # That won't stop people trying to do this, so let's be explicit.
raise TypeError("Cannot pass None as an argument to only().") raise TypeError("Cannot pass None as an argument to only().")
for field in fields:
field = field.split(LOOKUP_SEP, 1)[0]
if field in self.query._filtered_relations:
raise ValueError('only() is not supported with FilteredRelation.')
clone = self._chain() clone = self._chain()
clone.query.add_immediate_loading(fields) clone.query.add_immediate_loading(fields)
return clone return clone
...@@ -1730,9 +1743,9 @@ class RelatedPopulator: ...@@ -1730,9 +1743,9 @@ class RelatedPopulator:
# model's fields. # model's fields.
# - related_populators: a list of RelatedPopulator instances if # - related_populators: a list of RelatedPopulator instances if
# select_related() descends to related models from this model. # select_related() descends to related models from this model.
# - field, remote_field: the fields to use for populating the # - local_setter, remote_setter: Methods to set cached values on
# internal fields cache. If remote_field is set then we also # the object being populated and on the remote object. Usually
# set the reverse link. # these are Field.set_cached_value() methods.
select_fields = klass_info['select_fields'] select_fields = klass_info['select_fields']
from_parent = klass_info['from_parent'] from_parent = klass_info['from_parent']
if not from_parent: if not from_parent:
...@@ -1751,16 +1764,8 @@ class RelatedPopulator: ...@@ -1751,16 +1764,8 @@ class RelatedPopulator:
self.model_cls = klass_info['model'] self.model_cls = klass_info['model']
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
self.related_populators = get_related_populators(klass_info, select, self.db) self.related_populators = get_related_populators(klass_info, select, self.db)
reverse = klass_info['reverse'] self.local_setter = klass_info['local_setter']
field = klass_info['field'] self.remote_setter = klass_info['remote_setter']
self.remote_field = None
if reverse:
self.field = field.remote_field
self.remote_field = field
else:
self.field = field
if field.unique:
self.remote_field = field.remote_field
def populate(self, row, from_obj): def populate(self, row, from_obj):
if self.reorder_for_init: if self.reorder_for_init:
...@@ -1774,9 +1779,9 @@ class RelatedPopulator: ...@@ -1774,9 +1779,9 @@ class RelatedPopulator:
if self.related_populators: if self.related_populators:
for rel_iter in self.related_populators: for rel_iter in self.related_populators:
rel_iter.populate(row, obj) rel_iter.populate(row, obj)
if self.remote_field: self.local_setter(from_obj, obj)
self.remote_field.set_cached_value(obj, from_obj) if obj is not None:
self.field.set_cached_value(from_obj, obj) self.remote_setter(obj, from_obj)
def get_related_populators(klass_info, select, db): def get_related_populators(klass_info, select, db):
......
...@@ -16,7 +16,7 @@ from django.utils import tree ...@@ -16,7 +16,7 @@ from django.utils import tree
# PathInfo is used when converting lookups (fk__somecol). The contents # PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both # describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation. # sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct') PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
class InvalidQuery(Exception): class InvalidQuery(Exception):
...@@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field): ...@@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field):
check(target_opts) or check(target_opts) or
(getattr(field, 'primary_key', False) and check(field.model._meta)) (getattr(field, 'primary_key', False) and check(field.model._meta))
) )
class FilteredRelation:
"""Specify custom filtering in the ON clause of SQL joins."""
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
raise ValueError('relation_name cannot be empty.')
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
raise ValueError('condition argument must be a Q() instance.')
self.condition = condition
self.path = []
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition
)
def clone(self):
clone = FilteredRelation(self.relation_name, condition=self.condition)
clone.alias = self.alias
clone.path = self.path[:]
return clone
def resolve_expression(self, *args, **kwargs):
"""
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
query = compiler.query
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
return compiler.compile(where)
...@@ -702,7 +702,7 @@ class SQLCompiler: ...@@ -702,7 +702,7 @@ class SQLCompiler:
""" """
result = [] result = []
params = [] params = []
for alias in self.query.alias_map: for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias]: if not self.query.alias_refcount[alias]:
continue continue
try: try:
...@@ -737,7 +737,7 @@ class SQLCompiler: ...@@ -737,7 +737,7 @@ class SQLCompiler:
f.field.related_query_name() f.field.related_query_name()
for f in opts.related_objects if f.field.unique for f in opts.related_objects if f.field.unique
) )
return chain(direct_choices, reverse_choices) return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = [] related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth: if not restricted and cur_depth > self.query.max_depth:
...@@ -788,7 +788,8 @@ class SQLCompiler: ...@@ -788,7 +788,8 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': f.remote_field.model, 'model': f.remote_field.model,
'field': f, 'field': f,
'reverse': False, 'local_setter': f.set_cached_value,
'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
'from_parent': False, 'from_parent': False,
} }
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
...@@ -825,7 +826,8 @@ class SQLCompiler: ...@@ -825,7 +826,8 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': model, 'model': model,
'field': f, 'field': f,
'reverse': True, 'local_setter': f.remote_field.set_cached_value,
'remote_setter': f.set_cached_value,
'from_parent': from_parent, 'from_parent': from_parent,
} }
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
...@@ -842,6 +844,47 @@ class SQLCompiler: ...@@ -842,6 +844,47 @@ class SQLCompiler:
next, restricted) next, restricted)
get_related_klass_infos(klass_info, next_klass_infos) get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found) fields_not_found = set(requested).difference(fields_found)
for name in list(requested):
# Filtered relations work only on the topmost level.
if cur_depth > 1:
break
if name in self.query._filtered_relations:
fields_found.add(name)
f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
model = join_opts.model
alias = joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
def local_setter(obj, from_obj):
f.remote_field.set_cached_value(from_obj, obj)
def remote_setter(obj, from_obj):
setattr(from_obj, name, obj)
klass_info = {
'model': model,
'field': f,
'local_setter': local_setter,
'remote_setter': remote_setter,
'from_parent': from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
start_alias=alias, opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
klass_info['select_fields'] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
select, opts=model._meta, root_alias=alias,
cur_depth=cur_depth + 1, requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found)
if fields_not_found: if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found) invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError( raise FieldError(
......
...@@ -41,7 +41,7 @@ class Join: ...@@ -41,7 +41,7 @@ class Join:
- relabeled_clone() - relabeled_clone()
""" """
def __init__(self, table_name, parent_alias, table_alias, join_type, def __init__(self, table_name, parent_alias, table_alias, join_type,
join_field, nullable): join_field, nullable, filtered_relation=None):
# Join table # Join table
self.table_name = table_name self.table_name = table_name
self.parent_alias = parent_alias self.parent_alias = parent_alias
...@@ -56,6 +56,7 @@ class Join: ...@@ -56,6 +56,7 @@ class Join:
self.join_field = join_field self.join_field = join_field
# Is this join nullabled? # Is this join nullabled?
self.nullable = nullable self.nullable = nullable
self.filtered_relation = filtered_relation
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
""" """
...@@ -85,7 +86,11 @@ class Join: ...@@ -85,7 +86,11 @@ class Join:
extra_sql, extra_params = compiler.compile(extra_cond) extra_sql, extra_params = compiler.compile(extra_cond)
join_conditions.append('(%s)' % extra_sql) join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params) params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if not join_conditions: if not join_conditions:
# This might be a rel on the other end of an actual declared field. # This might be a rel on the other end of an actual declared field.
declared_field = getattr(self.join_field, 'field', self.join_field) declared_field = getattr(self.join_field, 'field', self.join_field)
...@@ -101,18 +106,27 @@ class Join: ...@@ -101,18 +106,27 @@ class Join:
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
new_table_alias = change_map.get(self.table_alias, self.table_alias) new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
else:
filtered_relation = None
return self.__class__( return self.__class__(
self.table_name, new_parent_alias, new_table_alias, self.join_type, self.table_name, new_parent_alias, new_table_alias, self.join_type,
self.join_field, self.nullable) self.join_field, self.nullable, filtered_relation=filtered_relation,
)
def equals(self, other, with_filtered_relation):
return (
isinstance(other, self.__class__) and
self.table_name == other.table_name and
self.parent_alias == other.parent_alias and
self.join_field == other.join_field and
(not with_filtered_relation or self.filtered_relation == other.filtered_relation)
)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__): return self.equals(other, with_filtered_relation=True)
return (
self.table_name == other.table_name and
self.parent_alias == other.parent_alias and
self.join_field == other.join_field
)
return False
def demote(self): def demote(self):
new = self.relabeled_clone({}) new = self.relabeled_clone({})
...@@ -134,6 +148,7 @@ class BaseTable: ...@@ -134,6 +148,7 @@ class BaseTable:
""" """
join_type = None join_type = None
parent_alias = None parent_alias = None
filtered_relation = None
def __init__(self, table_name, alias): def __init__(self, table_name, alias):
self.table_name = table_name self.table_name = table_name
...@@ -146,3 +161,10 @@ class BaseTable: ...@@ -146,3 +161,10 @@ class BaseTable:
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
def equals(self, other, with_filtered_relation):
return (
isinstance(self, other.__class__) and
self.table_name == other.table_name and
self.table_alias == other.table_alias
)
This diff is collapsed.
...@@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example:: ...@@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example::
>>> from django.db.models import prefetch_related_objects >>> from django.db.models import prefetch_related_objects
>>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants >>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants
>>> prefetch_related_objects(restaurants, 'pizzas__toppings') >>> prefetch_related_objects(restaurants, 'pizzas__toppings')
``FilteredRelation()`` objects
------------------------------
.. versionadded:: 2.0
.. class:: FilteredRelation(relation_name, *, condition=Q())
.. attribute:: FilteredRelation.relation_name
The name of the field on which you'd like to filter the relation.
.. attribute:: FilteredRelation.condition
A :class:`~django.db.models.Q` object to control the filtering.
``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an
``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default
relationship but on the annotation name (``pizzas_vegetarian`` in example
below).
For example, to find restaurants that have vegetarian pizzas with
``'mozzarella'`` in the name::
>>> from django.db.models import FilteredRelation, Q
>>> Restaurant.objects.annotate(
... pizzas_vegetarian=FilteredRelation(
... 'pizzas', condition=Q(pizzas__vegetarian=True),
... ),
... ).filter(pizzas_vegetarian__name__icontains='mozzarella')
If there are a large number of pizzas, this queryset performs better than::
>>> Restaurant.objects.filter(
... pizzas__vegetarian=True,
... pizzas__name__icontains='mozzarella',
... )
because the filtering in the ``WHERE`` clause of the first queryset will only
operate on vegetarian pizzas.
``FilteredRelation`` doesn't support:
* Conditions that span relational fields. For example::
>>> Restaurant.objects.annotate(
... pizzas_with_toppings_startswith_n=FilteredRelation(
... 'pizzas__toppings',
... condition=Q(pizzas__toppings__name__startswith='n'),
... ),
... )
Traceback (most recent call last):
...
ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith').
* :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`.
* A :class:`~django.contrib.contenttypes.fields.GenericForeignKey`
inherited from a parent model.
...@@ -354,6 +354,9 @@ Models ...@@ -354,6 +354,9 @@ Models
* The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching * The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching
results as named tuples. results as named tuples.
* The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to
querysets.
Pagination Pagination
~~~~~~~~~~ ~~~~~~~~~~
......
from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation,
)
from django.contrib.contenttypes.models import ContentType
from django.db import models
class Author(models.Model):
name = models.CharField(max_length=50, unique=True)
favorite_books = models.ManyToManyField(
'Book',
related_name='preferred_by_authors',
related_query_name='preferred_by_authors',
)
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
object_id = models.PositiveIntegerField(null=True)
content_object = GenericForeignKey()
def __str__(self):
return self.name
class Editor(models.Model):
name = models.CharField(max_length=255)
def __str__(self):
return self.name
class Book(models.Model):
AVAILABLE = 'available'
RESERVED = 'reserved'
RENTED = 'rented'
STATES = (
(AVAILABLE, 'Available'),
(RESERVED, 'reserved'),
(RENTED, 'Rented'),
)
title = models.CharField(max_length=255)
author = models.ForeignKey(
Author,
models.CASCADE,
related_name='books',
related_query_name='book',
)
editor = models.ForeignKey(Editor, models.CASCADE)
generic_author = GenericRelation(Author)
state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE)
def __str__(self):
return self.title
class Borrower(models.Model):
name = models.CharField(max_length=50, unique=True)
def __str__(self):
return self.name
class Reservation(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))
class RentalSession(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))
This diff is collapsed.
...@@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject): ...@@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject):
def get_joining_columns(self, reverse_join=False): def get_joining_columns(self, reverse_join=False):
return () return ()
def get_path_info(self): def get_path_info(self, filtered_relation=None):
to_opts = self.remote_field.model._meta to_opts = self.remote_field.model._meta
from_opts = self.model._meta from_opts = self.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)] return [PathInfo(
from_opts=from_opts,
def get_reverse_path_info(self): to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self, filtered_relation=None):
to_opts = self.model._meta to_opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)] return [PathInfo(
from_opts=from_opts,
to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self.remote_field,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def contribute_to_class(self, cls, name, private_only=False): def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only) super().contribute_to_class(cls, name, private_only)
......
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import FilteredRelation
from django.test import SimpleTestCase, TestCase from django.test import SimpleTestCase, TestCase
from .models import ( from .models import (
...@@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase): ...@@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase):
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
list(User.objects.select_related('username')) list(User.objects.select_related('username'))
def test_reverse_related_validation_with_filtered_relation(self):
fields = 'userprofile, userstat, relation'
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment