Kaydet (Commit) 04240b23 authored tarafından acrefoot's avatar acrefoot Kaydeden (comit) Tim Graham

Refs #19527 -- Allowed QuerySet.bulk_create() to set the primary key of its objects.

PostgreSQL support only.

Thanks Vladislav Manchev and alesasnouski for working on the patch.
üst 60633ef3
...@@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object): ...@@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object):
can_use_chunked_reads = True can_use_chunked_reads = True
can_return_id_from_insert = False can_return_id_from_insert = False
can_return_ids_from_bulk_insert = False
has_bulk_insert = False has_bulk_insert = False
uses_savepoints = False uses_savepoints = False
can_release_savepoints = False can_release_savepoints = False
......
...@@ -5,6 +5,7 @@ from django.db.utils import InterfaceError ...@@ -5,6 +5,7 @@ from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True allows_group_by_selected_pks = True
can_return_id_from_insert = True can_return_id_from_insert = True
can_return_ids_from_bulk_insert = True
has_real_datatype = True has_real_datatype = True
has_native_uuid_field = True has_native_uuid_field = True
has_native_duration_field = True has_native_duration_field = True
......
...@@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations): ...@@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_ids(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the
list of newly created IDs.
"""
return [item[0] for item in cursor.fetchall()]
def lookup_cast(self, lookup_type, internal_type=None): def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s' lookup = '%s'
......
...@@ -411,17 +411,21 @@ class QuerySet(object): ...@@ -411,17 +411,21 @@ class QuerySet(object):
Inserts each of the instances into the database. This does *not* call Inserts each of the instances into the database. This does *not* call
save() on each of the instances, does not send any pre/post save save() on each of the instances, does not send any pre/post save
signals, and does not set the primary key attribute if it is an signals, and does not set the primary key attribute if it is an
autoincrement field. Multi-table models are not supported. autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
""" Multi-table models are not supported.
# So this case is fun. When you bulk insert you don't get the primary """
# keys back (if it's an autoincrement), so you can't insert into the # When you bulk insert you don't get the primary keys back (if it's an
# child tables which references this. There are two workarounds, 1) # autoincrement, except if can_return_ids_from_bulk_insert=True), so
# this could be implemented if you didn't have an autoincrement pk, # you can't insert into the child tables which references this. There
# and 2) you could do it by doing O(n) normal inserts into the parent # are two workarounds:
# tables to get the primary keys back, and then doing a single bulk # 1) This could be implemented if you didn't have an autoincrement pk
# insert into the childmost table. Some databases might allow doing # 2) You could do it by doing O(n) normal inserts into the parent
# this by using RETURNING clause for the insert query. We're punting # tables to get the primary keys back and then doing a single bulk
# on these for now because they are relatively rare cases. # insert into the childmost table.
# We currently set the primary keys on the objects when using
# PostgreSQL via the RETURNING ID clause. It should be possible for
# Oracle as well, but the semantics for extracting the primary keys is
# trickier so it's not done yet.
assert batch_size is None or batch_size > 0 assert batch_size is None or batch_size > 0
# Check that the parents share the same concrete model with the our # Check that the parents share the same concrete model with the our
# model to detect the inheritance pattern ConcreteGrandParent -> # model to detect the inheritance pattern ConcreteGrandParent ->
...@@ -447,7 +451,11 @@ class QuerySet(object): ...@@ -447,7 +451,11 @@ class QuerySet(object):
self._batched_insert(objs_with_pk, fields, batch_size) self._batched_insert(objs_with_pk, fields, batch_size)
if objs_without_pk: if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)] fields = [f for f in fields if not isinstance(f, AutoField)]
self._batched_insert(objs_without_pk, fields, batch_size) ids = self._batched_insert(objs_without_pk, fields, batch_size)
if connection.features.can_return_ids_from_bulk_insert:
assert len(ids) == len(objs_without_pk)
for i in range(len(ids)):
objs_without_pk[i].pk = ids[i]
return objs return objs
...@@ -1051,10 +1059,19 @@ class QuerySet(object): ...@@ -1051,10 +1059,19 @@ class QuerySet(object):
return return
ops = connections[self.db].ops ops = connections[self.db].ops
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
for batch in [objs[i:i + batch_size] inserted_ids = []
for i in range(0, len(objs), batch_size)]: for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
self.model._base_manager._insert(batch, fields=fields, if connections[self.db].features.can_return_ids_from_bulk_insert:
using=self.db) inserted_id = self.model._base_manager._insert(
item, fields=fields, using=self.db, return_id=True
)
if len(objs) > 1:
inserted_ids.extend(inserted_id)
if len(objs) == 1:
inserted_ids.append(inserted_id)
else:
self.model._base_manager._insert(item, fields=fields, using=self.db)
return inserted_ids
def _clone(self, **kwargs): def _clone(self, **kwargs):
query = self.query.clone() query = self.query.clone()
......
...@@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler): ...@@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler):
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
if self.return_id and self.connection.features.can_return_id_from_insert: if self.return_id and self.connection.features.can_return_id_from_insert:
params = param_rows[0] if self.connection.features.can_return_ids_from_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
params = param_rows
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
params = param_rows[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
r_fmt, r_params = self.connection.ops.return_insert_id() r_fmt, r_params = self.connection.ops.return_insert_id()
# Skip empty r_fmt to allow subclasses to customize behavior for # Skip empty r_fmt to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096. # 3rd party backends. Refs #19096.
if r_fmt: if r_fmt:
result.append(r_fmt % col) result.append(r_fmt % col)
params += r_params params += r_params
return [(" ".join(result), tuple(params))] return [(" ".join(result), tuple(chain.from_iterable(params)))]
if can_bulk: if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
...@@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler): ...@@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler):
] ]
def execute_sql(self, return_id=False): def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1) assert not (
return_id and len(self.query.objs) != 1 and
not self.connection.features.can_return_ids_from_bulk_insert
)
self.return_id = return_id self.return_id = return_id
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
for sql, params in self.as_sql(): for sql, params in self.as_sql():
cursor.execute(sql, params) cursor.execute(sql, params)
if not (return_id and cursor): if not (return_id and cursor):
return return
if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_ids(cursor)
if self.connection.features.can_return_id_from_insert: if self.connection.features.can_return_id_from_insert:
assert len(self.query.objs) == 1
return self.connection.ops.fetch_returned_insert_id(cursor) return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor, return self.connection.ops.last_insert_id(cursor,
self.query.get_meta().db_table, self.query.get_meta().pk.column) self.query.get_meta().db_table, self.query.get_meta().pk.column)
......
...@@ -1794,13 +1794,19 @@ This has a number of caveats though: ...@@ -1794,13 +1794,19 @@ This has a number of caveats though:
``post_save`` signals will not be sent. ``post_save`` signals will not be sent.
* It does not work with child models in a multi-table inheritance scenario. * It does not work with child models in a multi-table inheritance scenario.
* If the model's primary key is an :class:`~django.db.models.AutoField` it * If the model's primary key is an :class:`~django.db.models.AutoField` it
does not retrieve and set the primary key attribute, as ``save()`` does. does not retrieve and set the primary key attribute, as ``save()`` does,
unless the database backend supports it (currently PostgreSQL).
* It does not work with many-to-many relationships. * It does not work with many-to-many relationships.
.. versionchanged:: 1.9 .. versionchanged:: 1.9
Support for using ``bulk_create()`` with proxy models was added. Support for using ``bulk_create()`` with proxy models was added.
.. versionchanged:: 1.0
Support for setting primary keys on objects created using ``bulk_create()``
when using PostgreSQL was added.
The ``batch_size`` parameter controls how many objects are created in single The ``batch_size`` parameter controls how many objects are created in single
query. The default is to create all objects in one batch, except for SQLite query. The default is to create all objects in one batch, except for SQLite
where the default is such that at most 999 variables per query are used. where the default is such that at most 999 variables per query are used.
......
...@@ -203,6 +203,11 @@ Database backends ...@@ -203,6 +203,11 @@ Database backends
* Temporal data subtraction was unified on all backends. * Temporal data subtraction was unified on all backends.
* If the database supports it, backends can set
``DatabaseFeatures.can_return_ids_from_bulk_insert=True`` and implement
``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
on objects created using ``QuerySet.bulk_create()``.
Email Email
~~~~~ ~~~~~
...@@ -315,6 +320,9 @@ Models ...@@ -315,6 +320,9 @@ Models
* The :func:`~django.db.models.prefetch_related_objects` function is now a * The :func:`~django.db.models.prefetch_related_objects` function is now a
public API. public API.
* :meth:`QuerySet.bulk_create() <django.db.models.query.QuerySet.bulk_create>`
sets the primary key on objects when using PostgreSQL.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -198,3 +198,22 @@ class BulkCreateTests(TestCase): ...@@ -198,3 +198,22 @@ class BulkCreateTests(TestCase):
]) ])
bbb = Restaurant.objects.filter(name="betty's beetroot bar") bbb = Restaurant.objects.filter(name="betty's beetroot bar")
self.assertEqual(bbb.count(), 1) self.assertEqual(bbb.count(), 1)
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
def test_set_pk_and_insert_single_item(self):
countries = []
with self.assertNumQueries(1):
countries = Country.objects.bulk_create([self.data[0]])
self.assertEqual(len(countries), 1)
self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
def test_set_pk_and_query_efficiency(self):
countries = []
with self.assertNumQueries(1):
countries = Country.objects.bulk_create(self.data)
self.assertEqual(len(countries), 4)
self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1])
self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2])
self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment