Kaydet (Commit) 37cc6a9d authored tarafından Jon Dufresne's avatar Jon Dufresne Kaydeden (comit) Tim Graham

[2.2.x] Fixed #30171 -- Fixed DatabaseError in servers tests.

Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.

The error appeared after 8c775391.
Backport of 76990cbb from master.
üst 07b44a25
import copy import copy
import threading
import time import time
import warnings import warnings
from collections import deque from collections import deque
...@@ -43,8 +44,7 @@ class BaseDatabaseWrapper: ...@@ -43,8 +44,7 @@ class BaseDatabaseWrapper:
queries_limit = 9000 queries_limit = 9000
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
allow_thread_sharing=False):
# Connection related attributes. # Connection related attributes.
# The underlying database connection. # The underlying database connection.
self.connection = None self.connection = None
...@@ -80,7 +80,8 @@ class BaseDatabaseWrapper: ...@@ -80,7 +80,8 @@ class BaseDatabaseWrapper:
self.errors_occurred = False self.errors_occurred = False
# Thread-safety related attributes. # Thread-safety related attributes.
self.allow_thread_sharing = allow_thread_sharing self._thread_sharing_lock = threading.Lock()
self._thread_sharing_count = 0
self._thread_ident = _thread.get_ident() self._thread_ident = _thread.get_ident()
# A list of no-argument functions to run when the transaction commits. # A list of no-argument functions to run when the transaction commits.
...@@ -515,12 +516,27 @@ class BaseDatabaseWrapper: ...@@ -515,12 +516,27 @@ class BaseDatabaseWrapper:
# ##### Thread safety handling ##### # ##### Thread safety handling #####
@property
def allow_thread_sharing(self):
with self._thread_sharing_lock:
return self._thread_sharing_count > 0
def inc_thread_sharing(self):
with self._thread_sharing_lock:
self._thread_sharing_count += 1
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
self._thread_sharing_count -= 1
def validate_thread_sharing(self): def validate_thread_sharing(self):
""" """
Validate that the connection isn't accessed by another thread than the Validate that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing` authorized to be shared between threads (via the `inc_thread_sharing()`
property). Raise an exception if the validation fails. method). Raise an exception if the validation fails.
""" """
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
raise DatabaseError( raise DatabaseError(
...@@ -589,11 +605,7 @@ class BaseDatabaseWrapper: ...@@ -589,11 +605,7 @@ class BaseDatabaseWrapper:
potential child threads while (or after) the test database is destroyed. potential child threads while (or after) the test database is destroyed.
Refs #10868, #17786, #16969. Refs #10868, #17786, #16969.
""" """
return self.__class__( return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
{**self.settings_dict, 'NAME': None},
alias=NO_DB_ALIAS,
allow_thread_sharing=False,
)
def schema_editor(self, *args, **kwargs): def schema_editor(self, *args, **kwargs):
""" """
...@@ -635,7 +647,7 @@ class BaseDatabaseWrapper: ...@@ -635,7 +647,7 @@ class BaseDatabaseWrapper:
finally: finally:
self.execute_wrappers.pop() self.execute_wrappers.pop()
def copy(self, alias=None, allow_thread_sharing=None): def copy(self, alias=None):
""" """
Return a copy of this connection. Return a copy of this connection.
...@@ -644,6 +656,4 @@ class BaseDatabaseWrapper: ...@@ -644,6 +656,4 @@ class BaseDatabaseWrapper:
settings_dict = copy.deepcopy(self.settings_dict) settings_dict = copy.deepcopy(self.settings_dict)
if alias is None: if alias is None:
alias = self.alias alias = self.alias
if allow_thread_sharing is None: return type(self)(settings_dict, alias)
allow_thread_sharing = self.allow_thread_sharing
return type(self)(settings_dict, alias, allow_thread_sharing)
...@@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self.__class__( return self.__class__(
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, {**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias, alias=self.alias,
allow_thread_sharing=False,
) )
return nodb_connection return nodb_connection
......
...@@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase): ...@@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase):
# the server thread. # the server thread.
if conn.vendor == 'sqlite' and conn.is_in_memory_db(): if conn.vendor == 'sqlite' and conn.is_in_memory_db():
# Explicitly enable thread-shareability for this connection # Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True conn.inc_thread_sharing()
connections_override[conn.alias] = conn connections_override[conn.alias] = conn
cls._live_server_modified_settings = modify_settings( cls._live_server_modified_settings = modify_settings(
...@@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase): ...@@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase):
# Terminate the live server's thread # Terminate the live server's thread
cls.server_thread.terminate() cls.server_thread.terminate()
# Restore sqlite in-memory database connections' non-shareability # Restore sqlite in-memory database connections' non-shareability.
for conn in connections.all(): for conn in cls.server_thread.connections_override.values():
if conn.vendor == 'sqlite' and conn.is_in_memory_db(): conn.dec_thread_sharing()
conn.allow_thread_sharing = False
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
......
...@@ -286,6 +286,9 @@ backends. ...@@ -286,6 +286,9 @@ backends.
* ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``)
* ``_create_check_sql()`` and ``_delete_check_sql()`` * ``_create_check_sql()`` and ``_delete_check_sql()``
* The third argument of ``DatabaseWrapper.__init__()``,
``allow_thread_sharing``, is removed.
Admin actions are no longer collected from base ``ModelAdmin`` classes Admin actions are no longer collected from base ``ModelAdmin`` classes
---------------------------------------------------------------------- ----------------------------------------------------------------------
......
...@@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase): ...@@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
connection = connections[DEFAULT_DB_ALIAS] connection = connections[DEFAULT_DB_ALIAS]
# Allow thread sharing so the connection can be closed by the # Allow thread sharing so the connection can be closed by the
# main thread. # main thread.
connection.allow_thread_sharing = True connection.inc_thread_sharing()
connection.cursor() connection.cursor()
connections_dict[id(connection)] = connection connections_dict[id(connection)] = connection
for x in range(2): try:
t = threading.Thread(target=runner) for x in range(2):
t.start() t = threading.Thread(target=runner)
t.join() t.start()
# Each created connection got different inner connection. t.join()
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) # Each created connection got different inner connection.
# Finish by closing the connections opened by the other threads (the self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
# connection opened in the main thread will automatically be closed on finally:
# teardown). # Finish by closing the connections opened by the other threads
for conn in connections_dict.values(): # (the connection opened in the main thread will automatically be
if conn is not connection: # closed on teardown).
conn.close() for conn in connections_dict.values():
if conn is not connection:
if conn.allow_thread_sharing:
conn.close()
conn.dec_thread_sharing()
def test_connections_thread_local(self): def test_connections_thread_local(self):
""" """
...@@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase): ...@@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
for conn in connections.all(): for conn in connections.all():
# Allow thread sharing so the connection can be closed by the # Allow thread sharing so the connection can be closed by the
# main thread. # main thread.
conn.allow_thread_sharing = True conn.inc_thread_sharing()
connections_dict[id(conn)] = conn connections_dict[id(conn)] = conn
for x in range(2): try:
t = threading.Thread(target=runner) for x in range(2):
t.start() t = threading.Thread(target=runner)
t.join() t.start()
self.assertEqual(len(connections_dict), 6) t.join()
# Finish by closing the connections opened by the other threads (the self.assertEqual(len(connections_dict), 6)
# connection opened in the main thread will automatically be closed on finally:
# teardown). # Finish by closing the connections opened by the other threads
for conn in connections_dict.values(): # (the connection opened in the main thread will automatically be
if conn is not connection: # closed on teardown).
conn.close() for conn in connections_dict.values():
if conn is not connection:
if conn.allow_thread_sharing:
conn.close()
conn.dec_thread_sharing()
def test_pass_connection_between_threads(self): def test_pass_connection_between_threads(self):
""" """
...@@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase): ...@@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
t.start() t.start()
t.join() t.join()
# Without touching allow_thread_sharing, which should be False by default. # Without touching thread sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to False
connections['default'].allow_thread_sharing = False
exceptions = [] exceptions = []
do_thread() do_thread()
# Forbidden! # Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError) self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to True # After calling inc_thread_sharing() on the connection.
connections['default'].allow_thread_sharing = True connections['default'].inc_thread_sharing()
exceptions = [] try:
do_thread() exceptions = []
# All good do_thread()
self.assertEqual(exceptions, []) # All good
self.assertEqual(exceptions, [])
finally:
connections['default'].dec_thread_sharing()
def test_closing_non_shared_connections(self): def test_closing_non_shared_connections(self):
""" """
...@@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase): ...@@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
except DatabaseError as e: except DatabaseError as e:
exceptions.add(e) exceptions.add(e)
# Enable thread sharing # Enable thread sharing
connections['default'].allow_thread_sharing = True connections['default'].inc_thread_sharing()
t2 = threading.Thread(target=runner2, args=[connections['default']]) try:
t2.start() t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.join() t2.start()
t2.join()
finally:
connections['default'].dec_thread_sharing()
t1 = threading.Thread(target=runner1) t1 = threading.Thread(target=runner1)
t1.start() t1.start()
t1.join() t1.join()
# No exception was raised # No exception was raised
self.assertEqual(len(exceptions), 0) self.assertEqual(len(exceptions), 0)
def test_thread_sharing_count(self):
self.assertIs(connection.allow_thread_sharing, False)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, False)
msg = 'Cannot decrement the thread sharing count below zero.'
with self.assertRaisesMessage(RuntimeError, msg):
connection.dec_thread_sharing()
class MySQLPKZeroTests(TestCase): class MySQLPKZeroTests(TestCase):
""" """
......
...@@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase): ...@@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase):
# Pass a connection to the thread to check they are being closed. # Pass a connection to the thread to check they are being closed.
connections_override = {DEFAULT_DB_ALIAS: conn} connections_override = {DEFAULT_DB_ALIAS: conn}
saved_sharing = conn.allow_thread_sharing conn.inc_thread_sharing()
try: try:
conn.allow_thread_sharing = True
self.assertTrue(conn.is_usable()) self.assertTrue(conn.is_usable())
self.run_live_server_thread(connections_override) self.run_live_server_thread(connections_override)
self.assertFalse(conn.is_usable()) self.assertFalse(conn.is_usable())
finally: finally:
conn.allow_thread_sharing = saved_sharing conn.dec_thread_sharing()
...@@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase): ...@@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase):
# app without having set the required STATIC_URL setting.") # app without having set the required STATIC_URL setting.")
pass pass
finally: finally:
# Use del to avoid decrementing the database thread sharing count a
# second time.
del cls.server_thread
super().tearDownClass() super().tearDownClass()
def test_test_test(self): def test_test_test(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment