Kaydet (Commit) 7bec480f authored tarafından Alex Hill's avatar Alex Hill Kaydeden (comit) Tim Graham

Fixed #24201 -- Added order_with_respect_to support to GenericForeignKey.

üst e1427cc6
...@@ -7,9 +7,10 @@ from django.core import checks ...@@ -7,9 +7,10 @@ from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
from django.db.models import DO_NOTHING, signals from django.db.models import DO_NOTHING, signals
from django.db.models.base import ModelBase from django.db.models.base import ModelBase, make_foreign_order_accessors
from django.db.models.fields.related import ( from django.db.models.fields.related import (
ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor, ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor,
lazy_related_operation,
) )
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.encoding import python_2_unicode_compatible, smart_text
...@@ -61,6 +62,20 @@ class GenericForeignKey(object): ...@@ -61,6 +62,20 @@ class GenericForeignKey(object):
setattr(cls, name, self) setattr(cls, name, self)
def get_filter_kwargs_for_object(self, obj):
"""See corresponding method on Field"""
return {
self.fk_field: getattr(obj, self.fk_field),
self.ct_field: getattr(obj, self.ct_field),
}
def get_forward_related_filter(self, obj):
"""See corresponding method on RelatedField"""
return {
self.fk_field: obj.pk,
self.ct_field: ContentType.objects.get_for_model(obj).pk,
}
def __str__(self): def __str__(self):
model = self.model model = self.model
app = model._meta.app_label app = model._meta.app_label
...@@ -368,6 +383,21 @@ class GenericRelation(ForeignObject): ...@@ -368,6 +383,21 @@ class GenericRelation(ForeignObject):
self.model = cls self.model = cls
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.remote_field)) setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.remote_field))
# Add get_RELATED_order() and set_RELATED_order() methods if the model
# on the other end of this relation is ordered with respect to this.
def matching_gfk(field):
return (
isinstance(field, GenericForeignKey) and
self.content_type_field_name == field.ct_field and
self.object_id_field_name == field.fk_field
)
def make_generic_foreign_order_accessors(related_model, model):
if matching_gfk(model._meta.order_with_respect_to):
make_foreign_order_accessors(model, related_model)
lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model)
def set_attributes_from_rel(self): def set_attributes_from_rel(self):
pass pass
......
...@@ -311,21 +311,15 @@ class ModelBase(type): ...@@ -311,21 +311,15 @@ class ModelBase(type):
cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True) cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True)
cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False) cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False)
# defer creating accessors on the foreign class until we are # Defer creating accessors on the foreign class until it has been
# certain it has been created # created and registered. If remote_field is None, we're ordering
def make_foreign_order_accessors(cls, model, field): # with respect to a GenericForeignKey and don't know what the
setattr( # foreign class is - we'll add those accessors later in
field.remote_field.model, # contribute_to_class().
'get_%s_order' % cls.__name__.lower(), if opts.order_with_respect_to.remote_field:
curry(method_get_order, cls) wrt = opts.order_with_respect_to
) remote = wrt.remote_field.model
setattr( lazy_related_operation(make_foreign_order_accessors, cls, remote)
field.remote_field.model,
'set_%s_order' % cls.__name__.lower(),
curry(method_set_order, cls)
)
wrt = opts.order_with_respect_to
lazy_related_operation(make_foreign_order_accessors, cls, wrt.remote_field.model, field=wrt)
# Give the class a docstring -- its definition. # Give the class a docstring -- its definition.
if cls.__doc__ is None: if cls.__doc__ is None:
...@@ -803,8 +797,8 @@ class Model(six.with_metaclass(ModelBase)): ...@@ -803,8 +797,8 @@ class Model(six.with_metaclass(ModelBase)):
# If this is a model with an order_with_respect_to # If this is a model with an order_with_respect_to
# autopopulate the _order field # autopopulate the _order field
field = meta.order_with_respect_to field = meta.order_with_respect_to
order_value = cls._base_manager.using(using).filter( filter_args = field.get_filter_kwargs_for_object(self)
**{field.name: getattr(self, field.attname)}).count() order_value = cls._base_manager.using(using).filter(**filter_args).count()
self._order = order_value self._order = order_value
fields = meta.local_concrete_fields fields = meta.local_concrete_fields
...@@ -892,9 +886,8 @@ class Model(six.with_metaclass(ModelBase)): ...@@ -892,9 +886,8 @@ class Model(six.with_metaclass(ModelBase)):
op = 'gt' if is_next else 'lt' op = 'gt' if is_next else 'lt'
order = '_order' if is_next else '-_order' order = '_order' if is_next else '-_order'
order_field = self._meta.order_with_respect_to order_field = self._meta.order_with_respect_to
obj = self._default_manager.filter(**{ filter_args = order_field.get_filter_kwargs_for_object(self)
order_field.name: getattr(self, order_field.attname) obj = self._default_manager.filter(**filter_args).filter(**{
}).filter(**{
'_order__%s' % op: self._default_manager.values('_order').filter(**{ '_order__%s' % op: self._default_manager.values('_order').filter(**{
self._meta.pk.name: self.pk self._meta.pk.name: self.pk
}) })
...@@ -1653,22 +1646,33 @@ class Model(six.with_metaclass(ModelBase)): ...@@ -1653,22 +1646,33 @@ class Model(six.with_metaclass(ModelBase)):
def method_set_order(ordered_obj, self, id_list, using=None): def method_set_order(ordered_obj, self, id_list, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name) order_wrt = ordered_obj._meta.order_with_respect_to
order_name = ordered_obj._meta.order_with_respect_to.name filter_args = order_wrt.get_forward_related_filter(self)
# FIXME: It would be nice if there was an "update many" version of update # FIXME: It would be nice if there was an "update many" version of update
# for situations like this. # for situations like this.
with transaction.atomic(using=using, savepoint=False): with transaction.atomic(using=using, savepoint=False):
for i, j in enumerate(id_list): for i, j in enumerate(id_list):
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i) ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i)
def method_get_order(ordered_obj, self): def method_get_order(ordered_obj, self):
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name) order_wrt = ordered_obj._meta.order_with_respect_to
order_name = ordered_obj._meta.order_with_respect_to.name filter_args = order_wrt.get_forward_related_filter(self)
pk_name = ordered_obj._meta.pk.name pk_name = ordered_obj._meta.pk.name
return [r[pk_name] for r in return ordered_obj.objects.filter(**filter_args).values_list(pk_name, flat=True)
ordered_obj.objects.filter(**{order_name: rel_val}).values(pk_name)]
def make_foreign_order_accessors(model, related_model):
setattr(
related_model,
'get_%s_order' % model.__name__.lower(),
curry(method_get_order, model)
)
setattr(
related_model,
'set_%s_order' % model.__name__.lower(),
curry(method_set_order, model)
)
######## ########
# MISC # # MISC #
......
...@@ -678,6 +678,13 @@ class Field(RegisterLookupMixin): ...@@ -678,6 +678,13 @@ class Field(RegisterLookupMixin):
setattr(cls, 'get_%s_display' % self.name, setattr(cls, 'get_%s_display' % self.name,
curry(cls._get_FIELD_display, field=self)) curry(cls._get_FIELD_display, field=self))
def get_filter_kwargs_for_object(self, obj):
"""
Return a dict that when passed as kwargs to self.model.filter(), would
yield all instances having the same value for this field as obj has.
"""
return {self.name: getattr(obj, self.attname)}
def get_attname(self): def get_attname(self):
return self.name return self.name
......
...@@ -303,6 +303,33 @@ class RelatedField(Field): ...@@ -303,6 +303,33 @@ class RelatedField(Field):
field.do_related_class(related, model) field.do_related_class(related, model)
lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self) lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self)
def get_forward_related_filter(self, obj):
"""
Return the keyword arguments that when supplied to
self.model.object.filter(), would select all instances related through
this field to the remote obj. This is used to build the querysets
returned by related descriptors. obj is an instance of
self.related_field.model.
"""
return {
'%s__%s' % (self.name, rh_field.name): getattr(obj, rh_field.attname)
for _, rh_field in self.related_fields
}
def get_reverse_related_filter(self, obj):
"""
Complement to get_forward_related_filter(). Return the keyword
arguments that when passed to self.related_field.model.object.filter()
select all instances of self.related_field.model related through
this field to obj. obj is an instance of self.model.
"""
base_filter = {
rh_field.attname: getattr(obj, lh_field.attname)
for lh_field, rh_field in self.related_fields
}
base_filter.update(self.get_extra_descriptor_filter(obj) or {})
return base_filter
@property @property
def swappable_setting(self): def swappable_setting(self):
""" """
...@@ -453,11 +480,9 @@ class SingleRelatedObjectDescriptor(object): ...@@ -453,11 +480,9 @@ class SingleRelatedObjectDescriptor(object):
if related_pk is None: if related_pk is None:
rel_obj = None rel_obj = None
else: else:
params = {} filter_args = self.related.field.get_forward_related_filter(instance)
for lh_field, rh_field in self.related.field.related_fields:
params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
try: try:
rel_obj = self.get_queryset(instance=instance).get(**params) rel_obj = self.get_queryset(instance=instance).get(**filter_args)
except self.related.related_model.DoesNotExist: except self.related.related_model.DoesNotExist:
rel_obj = None rel_obj = None
else: else:
...@@ -603,16 +628,8 @@ class ReverseSingleRelatedObjectDescriptor(object): ...@@ -603,16 +628,8 @@ class ReverseSingleRelatedObjectDescriptor(object):
if None in val: if None in val:
rel_obj = None rel_obj = None
else: else:
params = {
rh_field.attname: getattr(instance, lh_field.attname)
for lh_field, rh_field in self.field.related_fields}
qs = self.get_queryset(instance=instance) qs = self.get_queryset(instance=instance)
extra_filter = self.field.get_extra_descriptor_filter(instance) qs = qs.filter(**self.field.get_reverse_related_filter(instance))
if isinstance(extra_filter, dict):
params.update(extra_filter)
qs = qs.filter(**params)
else:
qs = qs.filter(extra_filter, **params)
# Assuming the database enforces foreign keys, this won't fail. # Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get() rel_obj = qs.get()
if not self.field.remote_field.multiple: if not self.field.remote_field.multiple:
......
...@@ -187,6 +187,13 @@ Minor features ...@@ -187,6 +187,13 @@ Minor features
makes it possible to use ``REMOTE_USER`` for setups where the header is only makes it possible to use ``REMOTE_USER`` for setups where the header is only
populated on login pages instead of every request in the session. populated on login pages instead of every request in the session.
:mod:`django.contrib.contenttypes`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* It's now possible to use
:attr:`~django.db.models.Options.order_with_respect_to` with a
``GenericForeignKey``.
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation,
)
from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
from django.utils.http import urlquote from django.utils.http import urlquote
...@@ -76,3 +80,38 @@ class FooWithBrokenAbsoluteUrl(FooWithoutUrl): ...@@ -76,3 +80,38 @@ class FooWithBrokenAbsoluteUrl(FooWithoutUrl):
def get_absolute_url(self): def get_absolute_url(self):
return "/users/%s/" % self.unknown_field return "/users/%s/" % self.unknown_field
class Question(models.Model):
text = models.CharField(max_length=200)
answer_set = GenericRelation('Answer')
@python_2_unicode_compatible
class Answer(models.Model):
text = models.CharField(max_length=200)
content_type = models.ForeignKey(ContentType, models.CASCADE)
object_id = models.PositiveIntegerField()
question = GenericForeignKey()
class Meta:
order_with_respect_to = 'question'
def __str__(self):
return self.text
@python_2_unicode_compatible
class Post(models.Model):
"""An ordered tag on an item."""
title = models.CharField(max_length=200)
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
object_id = models.PositiveIntegerField(null=True)
parent = GenericForeignKey()
children = GenericRelation('Post')
class Meta:
order_with_respect_to = 'parent'
def __str__(self):
return self.title
from order_with_respect_to.tests import (
OrderWithRespectToTests, OrderWithRespectToTests2,
)
from .models import Answer, Post, Question
class OrderWithRespectToGFKTests(OrderWithRespectToTests):
Answer = Answer
Question = Question
del OrderWithRespectToTests
class OrderWithRespectToGFKTests2(OrderWithRespectToTests2):
Post = Post
del OrderWithRespectToTests2
""" """
Tests for the order_with_respect_to Meta attribute. Tests for the order_with_respect_to Meta attribute.
We explicitly declare app_label on these models, because they are reused by
contenttypes_tests. When those tests are run in isolation, these models need
app_label because order_with_respect_to isn't in INSTALLED_APPS.
""" """
from django.db import models from django.db import models
...@@ -10,6 +14,9 @@ from django.utils.encoding import python_2_unicode_compatible ...@@ -10,6 +14,9 @@ from django.utils.encoding import python_2_unicode_compatible
class Question(models.Model): class Question(models.Model):
text = models.CharField(max_length=200) text = models.CharField(max_length=200)
class Meta:
app_label = 'order_with_respect_to'
@python_2_unicode_compatible @python_2_unicode_compatible
class Answer(models.Model): class Answer(models.Model):
...@@ -18,6 +25,7 @@ class Answer(models.Model): ...@@ -18,6 +25,7 @@ class Answer(models.Model):
class Meta: class Meta:
order_with_respect_to = 'question' order_with_respect_to = 'question'
app_label = 'order_with_respect_to'
def __str__(self): def __str__(self):
return six.text_type(self.text) return six.text_type(self.text)
...@@ -30,6 +38,7 @@ class Post(models.Model): ...@@ -30,6 +38,7 @@ class Post(models.Model):
class Meta: class Meta:
order_with_respect_to = "parent" order_with_respect_to = "parent"
app_label = 'order_with_respect_to'
def __str__(self): def __str__(self):
return self.title return self.title
...@@ -10,13 +10,17 @@ from .models import Answer, Post, Question ...@@ -10,13 +10,17 @@ from .models import Answer, Post, Question
class OrderWithRespectToTests(TestCase): class OrderWithRespectToTests(TestCase):
# Hook to allow subclasses to run these tests with alternate models.
Answer = Answer
Question = Question
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
cls.q1 = Question.objects.create(text="Which Beatle starts with the letter 'R'?") cls.q1 = cls.Question.objects.create(text="Which Beatle starts with the letter 'R'?")
Answer.objects.create(text="John", question=cls.q1) cls.Answer.objects.create(text="John", question=cls.q1)
Answer.objects.create(text="Paul", question=cls.q1) cls.Answer.objects.create(text="Paul", question=cls.q1)
Answer.objects.create(text="George", question=cls.q1) cls.Answer.objects.create(text="George", question=cls.q1)
Answer.objects.create(text="Ringo", question=cls.q1) cls.Answer.objects.create(text="Ringo", question=cls.q1)
def test_default_to_insertion_order(self): def test_default_to_insertion_order(self):
# Answers will always be ordered in the order they were inserted. # Answers will always be ordered in the order they were inserted.
...@@ -30,30 +34,30 @@ class OrderWithRespectToTests(TestCase): ...@@ -30,30 +34,30 @@ class OrderWithRespectToTests(TestCase):
def test_previous_and_next_in_order(self): def test_previous_and_next_in_order(self):
# We can retrieve the answers related to a particular object, in the # We can retrieve the answers related to a particular object, in the
# order they were created, once we have a particular object. # order they were created, once we have a particular object.
a1 = Answer.objects.filter(question=self.q1)[0] a1 = self.q1.answer_set.all()[0]
self.assertEqual(a1.text, "John") self.assertEqual(a1.text, "John")
self.assertEqual(a1.get_next_in_order().text, "Paul") self.assertEqual(a1.get_next_in_order().text, "Paul")
a2 = list(Answer.objects.filter(question=self.q1))[-1] a2 = list(self.q1.answer_set.all())[-1]
self.assertEqual(a2.text, "Ringo") self.assertEqual(a2.text, "Ringo")
self.assertEqual(a2.get_previous_in_order().text, "George") self.assertEqual(a2.get_previous_in_order().text, "George")
def test_item_ordering(self): def test_item_ordering(self):
# We can retrieve the ordering of the queryset from a particular item. # We can retrieve the ordering of the queryset from a particular item.
a1 = Answer.objects.filter(question=self.q1)[1] a1 = self.q1.answer_set.all()[1]
id_list = [o.pk for o in self.q1.answer_set.all()] id_list = [o.pk for o in self.q1.answer_set.all()]
self.assertEqual(a1.question.get_answer_order(), id_list) self.assertSequenceEqual(a1.question.get_answer_order(), id_list)
# It doesn't matter which answer we use to check the order, it will # It doesn't matter which answer we use to check the order, it will
# always be the same. # always be the same.
a2 = Answer.objects.create(text="Number five", question=self.q1) a2 = self.Answer.objects.create(text="Number five", question=self.q1)
self.assertEqual( self.assertListEqual(
a1.question.get_answer_order(), a2.question.get_answer_order() list(a1.question.get_answer_order()), list(a2.question.get_answer_order())
) )
def test_change_ordering(self): def test_change_ordering(self):
# The ordering can be altered # The ordering can be altered
a = Answer.objects.create(text="Number five", question=self.q1) a = self.Answer.objects.create(text="Number five", question=self.q1)
# Swap the last two items in the order list # Swap the last two items in the order list
id_list = [o.pk for o in self.q1.answer_set.all()] id_list = [o.pk for o in self.q1.answer_set.all()]
...@@ -61,7 +65,7 @@ class OrderWithRespectToTests(TestCase): ...@@ -61,7 +65,7 @@ class OrderWithRespectToTests(TestCase):
id_list.insert(-1, x) id_list.insert(-1, x)
# By default, the ordering is different from the swapped version # By default, the ordering is different from the swapped version
self.assertNotEqual(a.question.get_answer_order(), id_list) self.assertNotEqual(list(a.question.get_answer_order()), id_list)
# Change the ordering to the swapped version - # Change the ordering to the swapped version -
# this changes the ordering of the queryset. # this changes the ordering of the queryset.
...@@ -76,19 +80,25 @@ class OrderWithRespectToTests(TestCase): ...@@ -76,19 +80,25 @@ class OrderWithRespectToTests(TestCase):
class OrderWithRespectToTests2(TestCase): class OrderWithRespectToTests2(TestCase):
# Provide the Post model as a class attribute so that we can subclass this
# test case in contenttypes_tests.test_order_with_respect_to and run these
# tests with alternative implementations of Post.
Post = Post
def test_recursive_ordering(self): def test_recursive_ordering(self):
p1 = Post.objects.create(title='1') p1 = self.Post.objects.create(title="1")
p2 = Post.objects.create(title='2') p2 = self.Post.objects.create(title="2")
p1_1 = Post.objects.create(title="1.1", parent=p1) p1_1 = self.Post.objects.create(title="1.1", parent=p1)
p1_2 = Post.objects.create(title="1.2", parent=p1) p1_2 = self.Post.objects.create(title="1.2", parent=p1)
Post.objects.create(title="2.1", parent=p2) self.Post.objects.create(title="2.1", parent=p2)
p1_3 = Post.objects.create(title="1.3", parent=p1) p1_3 = self.Post.objects.create(title="1.3", parent=p1)
self.assertEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk]) self.assertSequenceEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk])
def test_duplicate_order_field(self): def test_duplicate_order_field(self):
class Bar(models.Model): class Bar(models.Model):
pass class Meta:
app_label = 'order_with_respect_to'
class Foo(models.Model): class Foo(models.Model):
bar = models.ForeignKey(Bar, models.CASCADE) bar = models.ForeignKey(Bar, models.CASCADE)
...@@ -96,6 +106,7 @@ class OrderWithRespectToTests2(TestCase): ...@@ -96,6 +106,7 @@ class OrderWithRespectToTests2(TestCase):
class Meta: class Meta:
order_with_respect_to = 'bar' order_with_respect_to = 'bar'
app_label = 'order_with_respect_to'
count = 0 count = 0
for field in Foo._meta.local_fields: for field in Foo._meta.local_fields:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment