Kaydet (Commit) 5a36c81f authored tarafından Vinay Karanam's avatar Vinay Karanam Kaydeden (comit) Tim Graham

Fixed #29391 -- Made PostgresSimpleLookup respect Field.get_db_prep_value().

üst c492fdfd
from django.db.models import Lookup, Transform from django.db.models import Lookup, Transform
from django.db.models.lookups import Exact from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
from .search import SearchVector, SearchVectorExact, SearchVectorField from .search import SearchVector, SearchVectorExact, SearchVectorField
class PostgresSimpleLookup(Lookup): class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection)
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
run with a backend other than PostgreSQL. run with a backend other than PostgreSQL.
""" """
import enum
from django.db import models from django.db import models
try: try:
...@@ -40,3 +42,8 @@ except ImportError: ...@@ -40,3 +42,8 @@ except ImportError:
IntegerRangeField = models.Field IntegerRangeField = models.Field
JSONField = DummyJSONField JSONField = DummyJSONField
SearchVectorField = models.Field SearchVectorField = models.Field
class EnumField(models.CharField):
def get_prep_value(self, value):
return value.value if isinstance(value, enum.Enum) else value
...@@ -3,8 +3,8 @@ from django.db import migrations, models ...@@ -3,8 +3,8 @@ from django.db import migrations, models
from ..fields import ( from ..fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
IntegerRangeField, JSONField, SearchVectorField, HStoreField, IntegerRangeField, JSONField, SearchVectorField,
) )
from ..models import TagField from ..models import TagField
...@@ -249,4 +249,15 @@ class Migration(migrations.Migration): ...@@ -249,4 +249,15 @@ class Migration(migrations.Migration):
}, },
bases=(models.Model,), bases=(models.Model,),
), ),
migrations.CreateModel(
name='ArrayEnumModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
] ]
...@@ -3,8 +3,8 @@ from django.db import models ...@@ -3,8 +3,8 @@ from django.db import models
from .fields import ( from .fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
IntegerRangeField, JSONField, SearchVectorField, HStoreField, IntegerRangeField, JSONField, SearchVectorField,
) )
...@@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel): ...@@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
array_field = ArrayField(HStoreField(), null=True) array_field = ArrayField(HStoreField(), null=True)
class ArrayEnumModel(PostgreSQLModel):
array_of_enums = ArrayField(EnumField(max_length=20))
class CharFieldModel(models.Model): class CharFieldModel(models.Model):
field = models.CharField(max_length=16) field = models.CharField(max_length=16)
......
import decimal import decimal
import enum
import json import json
import unittest import unittest
import uuid import uuid
...@@ -16,9 +17,9 @@ from . import ( ...@@ -16,9 +17,9 @@ from . import (
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase, PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
) )
from .models import ( from .models import (
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
PostgreSQLModel, Tag, OtherTypesArrayModel, PostgreSQLModel, Tag,
) )
try: try:
...@@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase): ...@@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
[self.objs[3]] [self.objs[3]]
) )
def test_enum_lookup(self):
class TestEnum(enum.Enum):
VALUE_1 = 'value_1'
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
self.assertSequenceEqual(
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
[instance]
)
def test_unsupported_lookup(self): def test_unsupported_lookup(self):
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted." msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment