Kaydet (Commit) bc7e288c authored tarafından Simon Charette's avatar Simon Charette Kaydeden (comit) Tim Graham

Fixed #29745 -- Based Expression equality on detailed initialization signature.

The old implementation considered objects initialized with an equivalent
signature different if some arguments were provided positionally instead of
as keyword arguments.

Refs #11964, #26167.
üst e4df8e6d
import copy import copy
import datetime import datetime
import inspect
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
...@@ -137,6 +138,16 @@ class Combinable: ...@@ -137,6 +138,16 @@ class Combinable:
) )
def make_hashable(value):
if isinstance(value, list):
return tuple(map(make_hashable, value))
if isinstance(value, dict):
return tuple([
(key, make_hashable(nested_value)) for key, nested_value in value.items()
])
return value
@deconstructible @deconstructible
class BaseExpression: class BaseExpression:
"""Base class for all query expressions.""" """Base class for all query expressions."""
...@@ -360,28 +371,27 @@ class BaseExpression: ...@@ -360,28 +371,27 @@ class BaseExpression:
if expr: if expr:
yield from expr.flatten() yield from expr.flatten()
@cached_property
def identity(self):
constructor_signature = inspect.signature(self.__init__)
args, kwargs = self._constructor_args
signature = constructor_signature.bind_partial(*args, **kwargs)
signature.apply_defaults()
arguments = signature.arguments.items()
identity = [self.__class__]
for arg, value in arguments:
if isinstance(value, fields.Field):
value = type(value)
else:
value = make_hashable(value)
identity.append((arg, value))
return tuple(identity)
def __eq__(self, other): def __eq__(self, other):
if self.__class__ != other.__class__: return isinstance(other, BaseExpression) and other.identity == self.identity
return False
path, args, kwargs = self.deconstruct()
other_path, other_args, other_kwargs = other.deconstruct()
if (path, args) == (other_path, other_args):
kwargs = kwargs.copy()
other_kwargs = other_kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
other_output_field = type(other_kwargs.pop('output_field', None))
if output_field == other_output_field:
return kwargs == other_kwargs
return False
def __hash__(self): def __hash__(self):
path, args, kwargs = self.deconstruct() return hash(self.identity)
kwargs = kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
return hash((path, output_field) + args + tuple([
(key, tuple(value)) if isinstance(value, list) else (key, value)
for key, value in kwargs.items()
]))
class Expression(BaseExpression, Combinable): class Expression(BaseExpression, Combinable):
...@@ -695,9 +705,6 @@ class RawSQL(Expression): ...@@ -695,9 +705,6 @@ class RawSQL(Expression):
def get_group_by_cols(self): def get_group_by_cols(self):
return [self] return [self]
def __hash__(self):
return hash((self.sql, self.output_field) + tuple(self.params))
class Star(Expression): class Star(Expression):
def __repr__(self): def __repr__(self):
......
...@@ -11,8 +11,9 @@ from django.db.models.aggregates import ( ...@@ -11,8 +11,9 @@ from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
) )
from django.db.models.expressions import ( from django.db.models.expressions import (
Case, Col, Combinable, Exists, ExpressionList, ExpressionWrapper, F, Func, Case, Col, Combinable, Exists, Expression, ExpressionList,
OrderBy, OuterRef, Random, RawSQL, Ref, Subquery, Value, When, ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, RawSQL, Ref,
Subquery, Value, When,
) )
from django.db.models.functions import ( from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper, Coalesce, Concat, Length, Lower, Substr, Upper,
...@@ -822,6 +823,31 @@ class ExpressionsTests(TestCase): ...@@ -822,6 +823,31 @@ class ExpressionsTests(TestCase):
) )
class SimpleExpressionTests(SimpleTestCase):
def test_equal(self):
self.assertEqual(Expression(), Expression())
self.assertEqual(
Expression(models.IntegerField()),
Expression(output_field=models.IntegerField())
)
self.assertNotEqual(
Expression(models.IntegerField()),
Expression(models.CharField())
)
def test_hash(self):
self.assertEqual(hash(Expression()), hash(Expression()))
self.assertEqual(
hash(Expression(models.IntegerField())),
hash(Expression(output_field=models.IntegerField()))
)
self.assertNotEqual(
hash(Expression(models.IntegerField())),
hash(Expression(models.CharField())),
)
class ExpressionsNumericTests(TestCase): class ExpressionsNumericTests(TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment