Kaydet (Commit) c4e2fc5d authored tarafından Mikhail Nacharov's avatar Mikhail Nacharov Kaydeden (comit) Tim Graham

Fixed #22669 -- Fixed QuerySet.bulk_create() with empty model fields on Oracle.

üst 965f678a
...@@ -9,7 +9,7 @@ from django.utils import timezone ...@@ -9,7 +9,7 @@ from django.utils import timezone
from django.utils.encoding import force_bytes, force_text from django.utils.encoding import force_bytes, force_text
from .base import Database from .base import Database
from .utils import InsertIdVar, Oracle_datetime from .utils import BulkInsertMapper, InsertIdVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
...@@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL) ...@@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL)
return truncate_name(trigger_name, name_length).upper() return truncate_name(trigger_name, name_length).upper()
def bulk_insert_sql(self, fields, placeholder_rows): def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join( query = []
"SELECT %s FROM DUAL" % ", ".join(row) for row in placeholder_rows:
for row in placeholder_rows select = []
) for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if not fields[i]:
select.append(placeholder)
else:
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
select.append(BulkInsertMapper.types.get(internal_type, '%s') % placeholder)
query.append('SELECT %s FROM DUAL' % ', '.join(select))
return ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs): def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField': if internal_type == 'DateField':
......
...@@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime): ...@@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime):
dt.year, dt.month, dt.day, dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond, dt.hour, dt.minute, dt.second, dt.microsecond,
) )
class BulkInsertMapper:
BLOB = 'TO_BLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'NullBooleanField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallIntegerField': NUMBER,
'TimeField': TIMESTAMP,
}
import datetime
import uuid
from decimal import Decimal
from django.db import models from django.db import models
from django.utils import timezone
class Country(models.Model): class Country(models.Model):
...@@ -51,3 +56,32 @@ class TwoFields(models.Model): ...@@ -51,3 +56,32 @@ class TwoFields(models.Model):
class NoFields(models.Model): class NoFields(models.Model):
pass pass
class NullableFields(models.Model):
# Fields in db.backends.oracle.BulkInsertMapper
big_int_filed = models.BigIntegerField(null=True, default=1)
binary_field = models.BinaryField(null=True, default=b'data')
date_field = models.DateField(null=True, default=timezone.now)
datetime_field = models.DateTimeField(null=True, default=timezone.now)
decimal_field = models.DecimalField(null=True, max_digits=2, decimal_places=1, default=Decimal('1.1'))
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
float_field = models.FloatField(null=True, default=3.2)
integer_field = models.IntegerField(null=True, default=2)
null_boolean_field = models.NullBooleanField(null=True, default=False)
positive_integer_field = models.PositiveIntegerField(null=True, default=3)
positive_small_integer_field = models.PositiveSmallIntegerField(null=True, default=4)
small_integer_field = models.SmallIntegerField(null=True, default=5)
time_field = models.TimeField(null=True, default=timezone.now)
# Fields not required in BulkInsertMapper
char_field = models.CharField(null=True, max_length=4, default='char')
email_field = models.EmailField(null=True, default='user@example.com')
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
file_field = models.FileField(null=True, default='file.txt')
file_path_field = models.FilePathField(path='/tmp', null=True, default='file.txt')
generic_ip_address_field = models.GenericIPAddressField(null=True, default='127.0.0.1')
image_field = models.ImageField(null=True, default='image.jpg')
slug_field = models.SlugField(null=True, default='slug')
text_field = models.TextField(null=True, default='text')
url_field = models.URLField(null=True, default='/')
uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
from operator import attrgetter from operator import attrgetter
from django.db import connection from django.db import connection
from django.db.models import Value from django.db.models import FileField, Value
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import ( from django.test import (
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
) )
from .models import ( from .models import (
Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry, Country, NoFields, NullableFields, Pizzeria, ProxyCountry,
ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant,
State, TwoFields,
) )
...@@ -204,6 +205,19 @@ class BulkCreateTests(TestCase): ...@@ -204,6 +205,19 @@ class BulkCreateTests(TestCase):
bbb = Restaurant.objects.filter(name="betty's beetroot bar") bbb = Restaurant.objects.filter(name="betty's beetroot bar")
self.assertEqual(bbb.count(), 1) self.assertEqual(bbb.count(), 1)
@skipUnlessDBFeature('has_bulk_insert')
def test_bulk_insert_nullable_fields(self):
# NULL can be mixed with other values in nullable fields
nullable_fields = [field for field in NullableFields._meta.get_fields() if field.name != 'id']
NullableFields.objects.bulk_create([
NullableFields(**{field.name: None}) for field in nullable_fields
])
self.assertEqual(NullableFields.objects.count(), len(nullable_fields))
for field in nullable_fields:
with self.subTest(field=field):
field_value = '' if isinstance(field, FileField) else None
self.assertEqual(NullableFields.objects.filter(**{field.name: field_value}).count(), 1)
@skipUnlessDBFeature('can_return_ids_from_bulk_insert') @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
def test_set_pk_and_insert_single_item(self): def test_set_pk_and_insert_single_item(self):
with self.assertNumQueries(1): with self.assertNumQueries(1):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment