Kaydet (Commit) 7a6fbf36 authored tarafından Jon Dufresne's avatar Jon Dufresne Kaydeden (comit) Tim Graham

Fixed #28853 -- Updated connection.cursor() uses to use a context manager.

üst 33080858
...@@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection): ...@@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField' data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# In order to get the specific geometry type of the field, # In order to get the specific geometry type of the field,
# we introspect on the table definition using `DESCRIBE`. # we introspect on the table definition using `DESCRIBE`.
cursor.execute('DESCRIBE %s' % cursor.execute('DESCRIBE %s' %
...@@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection): ...@@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
field_type = OGRGeomType(typ).django field_type = OGRGeomType(typ).django
field_params = {} field_params = {}
break break
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def supports_spatial_index(self, cursor, table_name): def supports_spatial_index(self, cursor, table_name):
......
...@@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection): ...@@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField' data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information. # Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
try: try:
cursor.execute( cursor.execute(
...@@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection): ...@@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
dim = dim.size() dim = dim.size()
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params
...@@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection): ...@@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
# to query the PostgreSQL pg_type table corresponding to the # to query the PostgreSQL pg_type table corresponding to the
# PostGIS custom data types. # PostGIS custom data types.
oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s' oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
for field_type in field_types: for field_type in field_types:
cursor.execute(oid_sql, (field_type[0],)) cursor.execute(oid_sql, (field_type[0],))
for result in cursor.fetchall(): for result in cursor.fetchall():
postgis_types[result[0]] = field_type[1] postgis_types[result[0]] = field_type[1]
finally:
cursor.close()
return postgis_types return postgis_types
def get_field_type(self, data_type, description): def get_field_type(self, data_type, description):
...@@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection): ...@@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
PointField or a PolygonField). Thus, this routine queries the PostGIS PointField or a PolygonField). Thus, this routine queries the PostGIS
metadata tables to determine the geometry type. metadata tables to determine the geometry type.
""" """
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
try: try:
# First seeing if this geometry column is in the `geometry_columns` # First seeing if this geometry column is in the `geometry_columns`
cursor.execute('SELECT "coord_dimension", "srid", "type" ' cursor.execute('SELECT "coord_dimension", "srid", "type" '
...@@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection): ...@@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params
...@@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): ...@@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
data_types_reverse = GeoFlexibleFieldLookupDict() data_types_reverse = GeoFlexibleFieldLookupDict()
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying the `geometry_columns` table to get additional metadata. # Querying the `geometry_columns` table to get additional metadata.
cursor.execute('SELECT coord_dimension, srid, geometry_type ' cursor.execute('SELECT coord_dimension, srid, geometry_type '
'FROM geometry_columns ' 'FROM geometry_columns '
...@@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): ...@@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if (isinstance(dim, str) and 'Z' in dim) or dim == 3: if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
field_params['dim'] = 3 field_params['dim'] = 3
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def get_constraints(self, cursor, table_name): def get_constraints(self, cursor, table_name):
......
...@@ -573,11 +573,10 @@ class BaseDatabaseWrapper: ...@@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
Provide a cursor: with self.temporary_connection() as cursor: ... Provide a cursor: with self.temporary_connection() as cursor: ...
""" """
must_close = self.connection is None must_close = self.connection is None
cursor = self.cursor()
try: try:
yield cursor with self.cursor() as cursor:
yield cursor
finally: finally:
cursor.close()
if must_close: if must_close:
self.close() self.close()
......
...@@ -116,21 +116,20 @@ class BaseDatabaseIntrospection: ...@@ -116,21 +116,20 @@ class BaseDatabaseIntrospection:
from django.db import router from django.db import router
sequence_list = [] sequence_list = []
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
for app_config in apps.get_app_configs():
for app_config in apps.get_app_configs(): for model in router.get_migratable_models(app_config, self.connection.alias):
for model in router.get_migratable_models(app_config, self.connection.alias): if not model._meta.managed:
if not model._meta.managed: continue
continue if model._meta.swapped:
if model._meta.swapped: continue
continue sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) for f in model._meta.local_many_to_many:
for f in model._meta.local_many_to_many: # If this is an m2m using an intermediate table,
# If this is an m2m using an intermediate table, # we don't need to reset the sequence.
# we don't need to reset the sequence. if f.remote_field.through is None:
if f.remote_field.through is None: sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence = self.get_sequences(cursor, f.m2m_db_table()) sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()): def get_sequences(self, cursor, table_name, table_fields=()):
......
...@@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name: if not primary_key_column_name:
continue continue
key_columns = self.introspection.get_key_columns(cursor, table_name) key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns: for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute( cursor.execute(
""" """
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`) ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" % ( """ % (
primary_key_column_name, column_name, table_name, primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name, referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name, column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
) )
) )
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self): def is_usable(self):
try: try:
......
...@@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation): ...@@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation):
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False): def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_database_create(): if self._test_database_create():
try: try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e: except Exception as e:
if 'ORA-01543' not in str(e): if 'ORA-01543' not in str(e):
# All errors except "tablespace already exists" cancel tests # All errors except "tablespace already exists" cancel tests
sys.stderr.write("Got an error creating the test database: %s\n" % e) sys.stderr.write("Got an error creating the test database: %s\n" % e)
sys.exit(2) sys.exit(2)
if not autoclobber: if not autoclobber:
confirm = input( confirm = input(
"It appears the test database, %s, already exists. " "It appears the test database, %s, already exists. "
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user']) "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes': if autoclobber or confirm == 'yes':
if verbosity >= 1: if verbosity >= 1:
print("Destroying old test database for alias '%s'..." % self.connection.alias) print("Destroying old test database for alias '%s'..." % self.connection.alias)
try: try:
self._execute_test_db_destruction(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
except DatabaseError as e: except DatabaseError as e:
if 'ORA-29857' in str(e): if 'ORA-29857' in str(e):
self._handle_objects_preventing_db_destruction(cursor, parameters, self._handle_objects_preventing_db_destruction(cursor, parameters,
verbosity, autoclobber) verbosity, autoclobber)
else: else:
# Ran into a database error that isn't about leftover objects in the tablespace # Ran into a database error that isn't about leftover objects in the tablespace
sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
sys.exit(2)
except Exception as e:
sys.stderr.write("Got an error destroying the old test database: %s\n" % e) sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
sys.exit(2) sys.exit(2)
except Exception as e: try:
sys.stderr.write("Got an error destroying the old test database: %s\n" % e) self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
sys.exit(2) except Exception as e:
try: sys.stderr.write("Got an error recreating the test database: %s\n" % e)
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) sys.exit(2)
except Exception as e: else:
sys.stderr.write("Got an error recreating the test database: %s\n" % e) print("Tests cancelled.")
sys.exit(2) sys.exit(1)
else:
print("Tests cancelled.")
sys.exit(1)
if self._test_user_create(): if self._test_user_create():
if verbosity >= 1: if verbosity >= 1:
print("Creating test user...") print("Creating test user...")
try: try:
self._create_test_user(cursor, parameters, verbosity, keepdb) self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e: except Exception as e:
if 'ORA-01920' not in str(e): if 'ORA-01920' not in str(e):
# All errors except "user already exists" cancel tests # All errors except "user already exists" cancel tests
sys.stderr.write("Got an error creating the test user: %s\n" % e) sys.stderr.write("Got an error creating the test user: %s\n" % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print("Creating test user...")
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2) sys.exit(2)
else: if not autoclobber:
print("Tests cancelled.") confirm = input(
sys.exit(1) "It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
# Cursor must be closed before closing connection. if autoclobber or confirm == 'yes':
cursor.close() try:
if verbosity >= 1:
print("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print("Creating test user...")
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2)
else:
print("Tests cancelled.")
sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters) self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME'] return self.connection.settings_dict['NAME']
...@@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation): ...@@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation):
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close() self.connection.close()
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_user_create(): if self._test_user_create():
if verbosity >= 1: if verbosity >= 1:
print('Destroying test user...') print('Destroying test user...')
self._destroy_test_user(cursor, parameters, verbosity) self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create(): if self._test_database_create():
if verbosity >= 1: if verbosity >= 1:
print('Destroying test database tables...') print('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close() self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
......
...@@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name: if not primary_key_column_name:
continue continue
key_columns = self.introspection.get_key_columns(cursor, table_name) key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns: for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute( cursor.execute(
""" """
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`) ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" """
% ( % (
primary_key_column_name, column_name, table_name, primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name, referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name, column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
) )
) )
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self): def is_usable(self):
return True return True
......
...@@ -322,7 +322,8 @@ class MigrationExecutor: ...@@ -322,7 +322,8 @@ class MigrationExecutor:
apps = after_state.apps apps = after_state.apps
found_create_model_migration = False found_create_model_migration = False
found_add_field_migration = False found_add_field_migration = False
existing_table_names = self.connection.introspection.table_names(self.connection.cursor()) with self.connection.cursor() as cursor:
existing_table_names = self.connection.introspection.table_names(cursor)
# Make sure all create model and add field operations are done # Make sure all create model and add field operations are done
for operation in migration.operations: for operation in migration.operations:
if isinstance(operation, migrations.CreateModel): if isinstance(operation, migrations.CreateModel):
......
...@@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase): ...@@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase):
no_style(), conn.introspection.sequence_list()) no_style(), conn.introspection.sequence_list())
if sql_list: if sql_list:
with transaction.atomic(using=db_name): with transaction.atomic(using=db_name):
cursor = conn.cursor() with conn.cursor() as cursor:
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)
def _fixture_setup(self): def _fixture_setup(self):
for db_name in self._databases_names(include_mirrors=False): for db_name in self._databases_names(include_mirrors=False):
......
...@@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its ...@@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
...
Limitations of multiple databases Limitations of multiple databases
================================= =================================
......
...@@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its ...@@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
# Your code here... # Your code here...
By default, the Python DB API will return results without their field names, By default, the Python DB API will return results without their field names,
which means you end up with a ``list`` of values, rather than a ``dict``. At a which means you end up with a ``list`` of values, rather than a ``dict``. At a
......
...@@ -9,15 +9,15 @@ from ..models import Person ...@@ -9,15 +9,15 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase): class DatabaseSequenceTests(TestCase):
def test_get_sequences(self): def test_get_sequences(self):
cursor = connection.cursor() with connection.cursor() as cursor:
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual( self.assertEqual(
seqs, seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
) )
cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual( self.assertEqual(
seqs, seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
) )
...@@ -44,10 +44,10 @@ class Tests(TestCase): ...@@ -44,10 +44,10 @@ class Tests(TestCase):
# Ensure the database default time zone is different than # Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can # the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show. # get the default time zone by reset & show.
cursor = new_connection.cursor() with new_connection.cursor() as cursor:
cursor.execute("RESET TIMEZONE") cursor.execute("RESET TIMEZONE")
cursor.execute("SHOW TIMEZONE") cursor.execute("SHOW TIMEZONE")
db_default_tz = cursor.fetchone()[0] db_default_tz = cursor.fetchone()[0]
new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC' new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
new_connection.close() new_connection.close()
...@@ -59,12 +59,12 @@ class Tests(TestCase): ...@@ -59,12 +59,12 @@ class Tests(TestCase):
# time zone, run a query and rollback. # time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz): with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False) new_connection.set_autocommit(False)
cursor = new_connection.cursor()
new_connection.rollback() new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE. # Now let's see if the rollback rolled back the SET TIME ZONE.
cursor.execute("SHOW TIMEZONE") with new_connection.cursor() as cursor:
tz = cursor.fetchone()[0] cursor.execute("SHOW TIMEZONE")
tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz) self.assertEqual(new_tz, tz)
finally: finally:
......
...@@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase): ...@@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063). # can hit the SQLITE_MAX_COLUMN limit (#26063).
cursor = connection.cursor() with connection.cursor() as cursor:
sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
params = list(range(2001)) params = list(range(2001))
# This should not raise an exception. # This should not raise an exception.
cursor.db.ops.last_executed_query(cursor.cursor, sql, params) cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
...@@ -97,9 +97,9 @@ class EscapingChecks(TestCase): ...@@ -97,9 +97,9 @@ class EscapingChecks(TestCase):
""" """
def test_parameter_escaping(self): def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648). # '%s' escaping support for sqlite3 (#13648).
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("select strftime('%s', date('now'))") cursor.execute("select strftime('%s', date('now'))")
response = cursor.fetchall()[0][0] response = cursor.fetchall()[0][0]
# response should be an non-zero integer # response should be an non-zero integer
self.assertTrue(int(response)) self.assertTrue(int(response))
......
...@@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase): ...@@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous last_executed_query should not raise an exception even if no previous
query has been run. query has been run.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
connection.ops.last_executed_query(cursor, '', ()) connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self): def test_debug_sql(self):
list(Reporter.objects.filter(first_name="test")) list(Reporter.objects.filter(first_name="test"))
...@@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase): ...@@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self): def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
cursor = connection.cursor() with connection.cursor() as cursor:
query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
connection.introspection.table_name_converter('backends_square'), connection.introspection.table_name_converter('backends_square'),
connection.ops.quote_name('root'), connection.ops.quote_name('root'),
connection.ops.quote_name('square') connection.ops.quote_name('square')
)) ))
with self.assertRaises(Exception): with self.assertRaises(Exception):
cursor.executemany(query, [(1, 2, 3)]) cursor.executemany(query, [(1, 2, 3)])
with self.assertRaises(Exception): with self.assertRaises(Exception):
cursor.executemany(query, [(1,)]) cursor.executemany(query, [(1,)])
class LongNameTest(TransactionTestCase): class LongNameTest(TransactionTestCase):
...@@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase): ...@@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table 'table': VLM._meta.db_table
}, },
] ]
cursor = connection.cursor() sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
for statement in connection.ops.sql_flush(no_style(), tables, sequences): with connection.cursor() as cursor:
cursor.execute(statement) for statement in sql_list:
cursor.execute(statement)
class SequenceResetTest(TestCase): class SequenceResetTest(TestCase):
...@@ -146,10 +147,10 @@ class SequenceResetTest(TestCase): ...@@ -146,10 +147,10 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world') Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database # Reset the sequences for the database
cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
for sql in commands: with connection.cursor() as cursor:
cursor.execute(sql) for sql in commands:
cursor.execute(sql)
# If we create a new object now, it should have a PK greater # If we create a new object now, it should have a PK greater
# than the PK we specified manually. # than the PK we specified manually.
...@@ -192,14 +193,14 @@ class EscapingChecks(TestCase): ...@@ -192,14 +193,14 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self): def test_paramless_no_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%s'" + self.bare_select_suffix) cursor.execute("SELECT '%s'" + self.bare_select_suffix)
self.assertEqual(cursor.fetchall()[0][0], '%s') self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self): def test_parameter_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
self.assertEqual(cursor.fetchall()[0], ('%', '%d')) self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@override_settings(DEBUG=True) @override_settings(DEBUG=True)
...@@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase): ...@@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True) self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple): def create_squares(self, args, paramstyle, multiple):
cursor = connection.cursor()
opts = Square._meta opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table) tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column) f1 = connection.ops.quote_name(opts.get_field('root').column)
...@@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase): ...@@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else: else:
raise ValueError("unsupported paramstyle in test") raise ValueError("unsupported paramstyle in test")
if multiple: with connection.cursor() as cursor:
cursor.executemany(query, args) if multiple:
else: cursor.executemany(query, args)
cursor.execute(query, args) else:
cursor.execute(query, args)
def test_cursor_executemany(self): def test_cursor_executemany(self):
# Test cursor.executemany #4896 # Test cursor.executemany #4896
...@@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase): ...@@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save() Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute( cursor.execute(
'SELECT %s, %s FROM %s ORDER BY %s' % ( 'SELECT %s, %s FROM %s ORDER BY %s' % (
qn(f3.column), qn(f3.column),
qn(f4.column), qn(f4.column),
connection.introspection.table_name_converter(opts2.db_table), connection.introspection.table_name_converter(opts2.db_table),
qn(f3.column), qn(f3.column),
)
) )
) self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
def test_unicode_password(self): def test_unicode_password(self):
old_password = connection.settings_dict['PASSWORD'] old_password = connection.settings_dict['PASSWORD']
...@@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase): ...@@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self): def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """ """ Creating an existing table returns a DatabaseError """
cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
with self.assertRaises(DatabaseError): with connection.cursor() as cursor:
cursor.execute(query) with self.assertRaises(DatabaseError):
cursor.execute(query)
def test_cursor_contextmanager(self): def test_cursor_contextmanager(self):
""" """
......
...@@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase): ...@@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test') @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self): def test_reraising_backend_specific_database_exception(self):
cursor = connection.cursor() with connection.cursor() as cursor:
msg = 'table "X" does not exist' msg = 'table "X" does not exist'
with self.assertRaisesMessage(ProgrammingError, msg) as cm: with self.assertRaisesMessage(ProgrammingError, msg) as cm:
cursor.execute('DROP TABLE "X"') cursor.execute('DROP TABLE "X"')
self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__)) self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__))
self.assertIsNotNone(cm.exception.__cause__) self.assertIsNotNone(cm.exception.__cause__)
self.assertIsNotNone(cm.exception.__cause__.pgcode) self.assertIsNotNone(cm.exception.__cause__.pgcode)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment