Kaydet (Commit) 263b3d2b authored tarafından Dmitry Dygalo's avatar Dmitry Dygalo Kaydeden (comit) Tim Graham

Fixed #25666 -- Fixed the exact lookup of ArrayField.

üst b8f78823
...@@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField ...@@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions from django.core import checks, exceptions
from django.db.models import Field, IntegerField, Transform from django.db.models import Field, IntegerField, Transform
from django.db.models.lookups import Exact
from django.utils import six from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _ from django.utils.translation import string_concat, ugettext_lazy as _
...@@ -166,7 +167,7 @@ class ArrayField(Field): ...@@ -166,7 +167,7 @@ class ArrayField(Field):
class ArrayContains(lookups.DataContains): class ArrayContains(lookups.DataContains):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
sql, params = super(ArrayContains, self).as_sql(qn, connection) sql, params = super(ArrayContains, self).as_sql(qn, connection)
sql += '::%s' % self.lhs.output_field.db_type(connection) sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params return sql, params
...@@ -174,7 +175,15 @@ class ArrayContains(lookups.DataContains): ...@@ -174,7 +175,15 @@ class ArrayContains(lookups.DataContains):
class ArrayContainedBy(lookups.ContainedBy): class ArrayContainedBy(lookups.ContainedBy):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
sql, params = super(ArrayContainedBy, self).as_sql(qn, connection) sql, params = super(ArrayContainedBy, self).as_sql(qn, connection)
sql += '::%s' % self.lhs.output_field.db_type(connection) sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params
@ArrayField.register_lookup
class ArrayExact(Exact):
def as_sql(self, qn, connection):
sql, params = super(ArrayExact, self).as_sql(qn, connection)
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params return sql, params
...@@ -182,7 +191,7 @@ class ArrayContainedBy(lookups.ContainedBy): ...@@ -182,7 +191,7 @@ class ArrayContainedBy(lookups.ContainedBy):
class ArrayOverlap(lookups.Overlap): class ArrayOverlap(lookups.Overlap):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
sql, params = super(ArrayOverlap, self).as_sql(qn, connection) sql, params = super(ArrayOverlap, self).as_sql(qn, connection)
sql += '::%s' % self.lhs.output_field.db_type(connection) sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params return sql, params
......
...@@ -34,3 +34,5 @@ Bugfixes ...@@ -34,3 +34,5 @@ Bugfixes
* Fixed serialization of * Fixed serialization of
:class:`~django.contrib.postgres.fields.DateRangeField` and :class:`~django.contrib.postgres.fields.DateRangeField` and
:class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`). :class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`).
* Fixed the exact lookup of ``ArrayField`` (:ticket:`25666`).
...@@ -122,6 +122,20 @@ class TestQuerying(PostgreSQLTestCase): ...@@ -122,6 +122,20 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:1] self.objs[:1]
) )
def test_exact_charfield(self):
instance = CharArrayModel.objects.create(field=['text'])
self.assertSequenceEqual(
CharArrayModel.objects.filter(field=['text']),
[instance]
)
def test_exact_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]),
[instance]
)
def test_isnull(self): def test_isnull(self):
self.assertSequenceEqual( self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__isnull=True), NullableIntegerArrayModel.objects.filter(field__isnull=True),
...@@ -244,6 +258,73 @@ class TestQuerying(PostgreSQLTestCase): ...@@ -244,6 +258,73 @@ class TestQuerying(PostgreSQLTestCase):
) )
class TestDateTimeExactQuerying(PostgreSQLTestCase):
def setUp(self):
now = timezone.now()
self.datetimes = [now]
self.dates = [now.date()]
self.times = [now.time()]
self.objs = [
DateTimeArrayModel.objects.create(
datetimes=self.datetimes,
dates=self.dates,
times=self.times,
)
]
def test_exact_datetimes(self):
self.assertSequenceEqual(
DateTimeArrayModel.objects.filter(datetimes=self.datetimes),
self.objs
)
def test_exact_dates(self):
self.assertSequenceEqual(
DateTimeArrayModel.objects.filter(dates=self.dates),
self.objs
)
def test_exact_times(self):
self.assertSequenceEqual(
DateTimeArrayModel.objects.filter(times=self.times),
self.objs
)
class TestOtherTypesExactQuerying(PostgreSQLTestCase):
def setUp(self):
self.ips = ['192.168.0.1', '::1']
self.uuids = [uuid.uuid4()]
self.decimals = [decimal.Decimal(1.25), 1.75]
self.objs = [
OtherTypesArrayModel.objects.create(
ips=self.ips,
uuids=self.uuids,
decimals=self.decimals,
)
]
def test_exact_ip_addresses(self):
self.assertSequenceEqual(
OtherTypesArrayModel.objects.filter(ips=self.ips),
self.objs
)
def test_exact_uuids(self):
self.assertSequenceEqual(
OtherTypesArrayModel.objects.filter(uuids=self.uuids),
self.objs
)
def test_exact_decimals(self):
self.assertSequenceEqual(
OtherTypesArrayModel.objects.filter(decimals=self.decimals),
self.objs
)
class TestChecks(PostgreSQLTestCase): class TestChecks(PostgreSQLTestCase):
def test_field_checks(self): def test_field_checks(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment