Kaydet (Commit) a924fc7a authored tarafından Charles-François Natali's avatar Charles-François Natali

Issue #21565: multiprocessing: use contex-manager protocol for synchronization

primitives.
üst 1691e359
...@@ -1320,6 +1320,9 @@ processes. ...@@ -1320,6 +1320,9 @@ processes.
Note that accessing the ctypes object through the wrapper can be a lot slower Note that accessing the ctypes object through the wrapper can be a lot slower
than accessing the raw ctypes object. than accessing the raw ctypes object.
.. versionchanged:: 3.5
Synchronized objects support the :term:`context manager` protocol.
The table below compares the syntax for creating shared ctypes objects from The table below compares the syntax for creating shared ctypes objects from
shared memory with the normal ctypes syntax. (In the table ``MyStruct`` is some shared memory with the normal ctypes syntax. (In the table ``MyStruct`` is some
......
...@@ -59,9 +59,8 @@ class Connection(object): ...@@ -59,9 +59,8 @@ class Connection(object):
return True return True
if timeout <= 0.0: if timeout <= 0.0:
return False return False
self._in.not_empty.acquire() with self._in.not_empty:
self._in.not_empty.wait(timeout) self._in.not_empty.wait(timeout)
self._in.not_empty.release()
return self._in.qsize() > 0 return self._in.qsize() > 0
def close(self): def close(self):
......
...@@ -216,9 +216,8 @@ class Heap(object): ...@@ -216,9 +216,8 @@ class Heap(object):
assert 0 <= size < sys.maxsize assert 0 <= size < sys.maxsize
if os.getpid() != self._lastpid: if os.getpid() != self._lastpid:
self.__init__() # reinitialize after fork self.__init__() # reinitialize after fork
self._lock.acquire() with self._lock:
self._free_pending_blocks() self._free_pending_blocks()
try:
size = self._roundup(max(size,1), self._alignment) size = self._roundup(max(size,1), self._alignment)
(arena, start, stop) = self._malloc(size) (arena, start, stop) = self._malloc(size)
new_stop = start + size new_stop = start + size
...@@ -227,8 +226,6 @@ class Heap(object): ...@@ -227,8 +226,6 @@ class Heap(object):
block = (arena, start, new_stop) block = (arena, start, new_stop)
self._allocated_blocks.add(block) self._allocated_blocks.add(block)
return block return block
finally:
self._lock.release()
# #
# Class representing a chunk of an mmap -- can be inherited by child process # Class representing a chunk of an mmap -- can be inherited by child process
......
...@@ -306,8 +306,7 @@ class Server(object): ...@@ -306,8 +306,7 @@ class Server(object):
''' '''
Return some info --- useful to spot problems with refcounting Return some info --- useful to spot problems with refcounting
''' '''
self.mutex.acquire() with self.mutex:
try:
result = [] result = []
keys = list(self.id_to_obj.keys()) keys = list(self.id_to_obj.keys())
keys.sort() keys.sort()
...@@ -317,8 +316,6 @@ class Server(object): ...@@ -317,8 +316,6 @@ class Server(object):
(ident, self.id_to_refcount[ident], (ident, self.id_to_refcount[ident],
str(self.id_to_obj[ident][0])[:75])) str(self.id_to_obj[ident][0])[:75]))
return '\n'.join(result) return '\n'.join(result)
finally:
self.mutex.release()
def number_of_objects(self, c): def number_of_objects(self, c):
''' '''
...@@ -343,8 +340,7 @@ class Server(object): ...@@ -343,8 +340,7 @@ class Server(object):
''' '''
Create a new shared object and return its id Create a new shared object and return its id
''' '''
self.mutex.acquire() with self.mutex:
try:
callable, exposed, method_to_typeid, proxytype = \ callable, exposed, method_to_typeid, proxytype = \
self.registry[typeid] self.registry[typeid]
...@@ -374,8 +370,6 @@ class Server(object): ...@@ -374,8 +370,6 @@ class Server(object):
# has been created. # has been created.
self.incref(c, ident) self.incref(c, ident)
return ident, tuple(exposed) return ident, tuple(exposed)
finally:
self.mutex.release()
def get_methods(self, c, token): def get_methods(self, c, token):
''' '''
...@@ -392,22 +386,16 @@ class Server(object): ...@@ -392,22 +386,16 @@ class Server(object):
self.serve_client(c) self.serve_client(c)
def incref(self, c, ident): def incref(self, c, ident):
self.mutex.acquire() with self.mutex:
try:
self.id_to_refcount[ident] += 1 self.id_to_refcount[ident] += 1
finally:
self.mutex.release()
def decref(self, c, ident): def decref(self, c, ident):
self.mutex.acquire() with self.mutex:
try:
assert self.id_to_refcount[ident] >= 1 assert self.id_to_refcount[ident] >= 1
self.id_to_refcount[ident] -= 1 self.id_to_refcount[ident] -= 1
if self.id_to_refcount[ident] == 0: if self.id_to_refcount[ident] == 0:
del self.id_to_obj[ident], self.id_to_refcount[ident] del self.id_to_obj[ident], self.id_to_refcount[ident]
util.debug('disposing of obj with id %r', ident) util.debug('disposing of obj with id %r', ident)
finally:
self.mutex.release()
# #
# Class to represent state of a manager # Class to represent state of a manager
...@@ -671,14 +659,11 @@ class BaseProxy(object): ...@@ -671,14 +659,11 @@ class BaseProxy(object):
def __init__(self, token, serializer, manager=None, def __init__(self, token, serializer, manager=None,
authkey=None, exposed=None, incref=True): authkey=None, exposed=None, incref=True):
BaseProxy._mutex.acquire() with BaseProxy._mutex:
try:
tls_idset = BaseProxy._address_to_local.get(token.address, None) tls_idset = BaseProxy._address_to_local.get(token.address, None)
if tls_idset is None: if tls_idset is None:
tls_idset = util.ForkAwareLocal(), ProcessLocalSet() tls_idset = util.ForkAwareLocal(), ProcessLocalSet()
BaseProxy._address_to_local[token.address] = tls_idset BaseProxy._address_to_local[token.address] = tls_idset
finally:
BaseProxy._mutex.release()
# self._tls is used to record the connection used by this # self._tls is used to record the connection used by this
# thread to communicate with the manager at token.address # thread to communicate with the manager at token.address
......
...@@ -666,8 +666,7 @@ class IMapIterator(object): ...@@ -666,8 +666,7 @@ class IMapIterator(object):
return self return self
def next(self, timeout=None): def next(self, timeout=None):
self._cond.acquire() with self._cond:
try:
try: try:
item = self._items.popleft() item = self._items.popleft()
except IndexError: except IndexError:
...@@ -680,8 +679,6 @@ class IMapIterator(object): ...@@ -680,8 +679,6 @@ class IMapIterator(object):
if self._index == self._length: if self._index == self._length:
raise StopIteration raise StopIteration
raise TimeoutError raise TimeoutError
finally:
self._cond.release()
success, value = item success, value = item
if success: if success:
...@@ -691,8 +688,7 @@ class IMapIterator(object): ...@@ -691,8 +688,7 @@ class IMapIterator(object):
__next__ = next # XXX __next__ = next # XXX
def _set(self, i, obj): def _set(self, i, obj):
self._cond.acquire() with self._cond:
try:
if self._index == i: if self._index == i:
self._items.append(obj) self._items.append(obj)
self._index += 1 self._index += 1
...@@ -706,18 +702,13 @@ class IMapIterator(object): ...@@ -706,18 +702,13 @@ class IMapIterator(object):
if self._index == self._length: if self._index == self._length:
del self._cache[self._job] del self._cache[self._job]
finally:
self._cond.release()
def _set_length(self, length): def _set_length(self, length):
self._cond.acquire() with self._cond:
try:
self._length = length self._length = length
if self._index == self._length: if self._index == self._length:
self._cond.notify() self._cond.notify()
del self._cache[self._job] del self._cache[self._job]
finally:
self._cond.release()
# #
# Class whose instances are returned by `Pool.imap_unordered()` # Class whose instances are returned by `Pool.imap_unordered()`
...@@ -726,15 +717,12 @@ class IMapIterator(object): ...@@ -726,15 +717,12 @@ class IMapIterator(object):
class IMapUnorderedIterator(IMapIterator): class IMapUnorderedIterator(IMapIterator):
def _set(self, i, obj): def _set(self, i, obj):
self._cond.acquire() with self._cond:
try:
self._items.append(obj) self._items.append(obj)
self._index += 1 self._index += 1
self._cond.notify() self._cond.notify()
if self._index == self._length: if self._index == self._length:
del self._cache[self._job] del self._cache[self._job]
finally:
self._cond.release()
# #
# #
...@@ -760,10 +748,7 @@ class ThreadPool(Pool): ...@@ -760,10 +748,7 @@ class ThreadPool(Pool):
@staticmethod @staticmethod
def _help_stuff_finish(inqueue, task_handler, size): def _help_stuff_finish(inqueue, task_handler, size):
# put sentinels at head of inqueue to make workers finish # put sentinels at head of inqueue to make workers finish
inqueue.not_empty.acquire() with inqueue.not_empty:
try:
inqueue.queue.clear() inqueue.queue.clear()
inqueue.queue.extend([None] * size) inqueue.queue.extend([None] * size)
inqueue.not_empty.notify_all() inqueue.not_empty.notify_all()
finally:
inqueue.not_empty.release()
...@@ -81,14 +81,11 @@ class Queue(object): ...@@ -81,14 +81,11 @@ class Queue(object):
if not self._sem.acquire(block, timeout): if not self._sem.acquire(block, timeout):
raise Full raise Full
self._notempty.acquire() with self._notempty:
try:
if self._thread is None: if self._thread is None:
self._start_thread() self._start_thread()
self._buffer.append(obj) self._buffer.append(obj)
self._notempty.notify() self._notempty.notify()
finally:
self._notempty.release()
def get(self, block=True, timeout=None): def get(self, block=True, timeout=None):
if block and timeout is None: if block and timeout is None:
...@@ -201,12 +198,9 @@ class Queue(object): ...@@ -201,12 +198,9 @@ class Queue(object):
@staticmethod @staticmethod
def _finalize_close(buffer, notempty): def _finalize_close(buffer, notempty):
debug('telling queue thread to quit') debug('telling queue thread to quit')
notempty.acquire() with notempty:
try:
buffer.append(_sentinel) buffer.append(_sentinel)
notempty.notify() notempty.notify()
finally:
notempty.release()
@staticmethod @staticmethod
def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe):
...@@ -295,35 +289,24 @@ class JoinableQueue(Queue): ...@@ -295,35 +289,24 @@ class JoinableQueue(Queue):
if not self._sem.acquire(block, timeout): if not self._sem.acquire(block, timeout):
raise Full raise Full
self._notempty.acquire() with self._notempty, self._cond:
self._cond.acquire()
try:
if self._thread is None: if self._thread is None:
self._start_thread() self._start_thread()
self._buffer.append(obj) self._buffer.append(obj)
self._unfinished_tasks.release() self._unfinished_tasks.release()
self._notempty.notify() self._notempty.notify()
finally:
self._cond.release()
self._notempty.release()
def task_done(self): def task_done(self):
self._cond.acquire() with self._cond:
try:
if not self._unfinished_tasks.acquire(False): if not self._unfinished_tasks.acquire(False):
raise ValueError('task_done() called too many times') raise ValueError('task_done() called too many times')
if self._unfinished_tasks._semlock._is_zero(): if self._unfinished_tasks._semlock._is_zero():
self._cond.notify_all() self._cond.notify_all()
finally:
self._cond.release()
def join(self): def join(self):
self._cond.acquire() with self._cond:
try:
if not self._unfinished_tasks._semlock._is_zero(): if not self._unfinished_tasks._semlock._is_zero():
self._cond.wait() self._cond.wait()
finally:
self._cond.release()
# #
# Simplified Queue type -- really just a locked pipe # Simplified Queue type -- really just a locked pipe
......
...@@ -188,6 +188,12 @@ class SynchronizedBase(object): ...@@ -188,6 +188,12 @@ class SynchronizedBase(object):
self.acquire = self._lock.acquire self.acquire = self._lock.acquire
self.release = self._lock.release self.release = self._lock.release
def __enter__(self):
return self._lock.__enter__()
def __exit__(self, *args):
return self._lock.__exit__(*args)
def __reduce__(self): def __reduce__(self):
assert_spawning(self) assert_spawning(self)
return synchronized, (self._obj, self._lock) return synchronized, (self._obj, self._lock)
...@@ -212,32 +218,20 @@ class SynchronizedArray(SynchronizedBase): ...@@ -212,32 +218,20 @@ class SynchronizedArray(SynchronizedBase):
return len(self._obj) return len(self._obj)
def __getitem__(self, i): def __getitem__(self, i):
self.acquire() with self:
try:
return self._obj[i] return self._obj[i]
finally:
self.release()
def __setitem__(self, i, value): def __setitem__(self, i, value):
self.acquire() with self:
try:
self._obj[i] = value self._obj[i] = value
finally:
self.release()
def __getslice__(self, start, stop): def __getslice__(self, start, stop):
self.acquire() with self:
try:
return self._obj[start:stop] return self._obj[start:stop]
finally:
self.release()
def __setslice__(self, start, stop, values): def __setslice__(self, start, stop, values):
self.acquire() with self:
try:
self._obj[start:stop] = values self._obj[start:stop] = values
finally:
self.release()
class SynchronizedString(SynchronizedArray): class SynchronizedString(SynchronizedArray):
......
...@@ -337,34 +337,24 @@ class Event(object): ...@@ -337,34 +337,24 @@ class Event(object):
self._flag = ctx.Semaphore(0) self._flag = ctx.Semaphore(0)
def is_set(self): def is_set(self):
self._cond.acquire() with self._cond:
try:
if self._flag.acquire(False): if self._flag.acquire(False):
self._flag.release() self._flag.release()
return True return True
return False return False
finally:
self._cond.release()
def set(self): def set(self):
self._cond.acquire() with self._cond:
try:
self._flag.acquire(False) self._flag.acquire(False)
self._flag.release() self._flag.release()
self._cond.notify_all() self._cond.notify_all()
finally:
self._cond.release()
def clear(self): def clear(self):
self._cond.acquire() with self._cond:
try:
self._flag.acquire(False) self._flag.acquire(False)
finally:
self._cond.release()
def wait(self, timeout=None): def wait(self, timeout=None):
self._cond.acquire() with self._cond:
try:
if self._flag.acquire(False): if self._flag.acquire(False):
self._flag.release() self._flag.release()
else: else:
...@@ -374,8 +364,6 @@ class Event(object): ...@@ -374,8 +364,6 @@ class Event(object):
self._flag.release() self._flag.release()
return True return True
return False return False
finally:
self._cond.release()
# #
# Barrier # Barrier
......
...@@ -327,6 +327,13 @@ class ForkAwareThreadLock(object): ...@@ -327,6 +327,13 @@ class ForkAwareThreadLock(object):
self.acquire = self._lock.acquire self.acquire = self._lock.acquire
self.release = self._lock.release self.release = self._lock.release
def __enter__(self):
return self._lock.__enter__()
def __exit__(self, *args):
return self._lock.__exit__(*args)
class ForkAwareLocal(threading.local): class ForkAwareLocal(threading.local):
def __init__(self): def __init__(self):
register_after_fork(self, lambda obj : obj.__dict__.clear()) register_after_fork(self, lambda obj : obj.__dict__.clear())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment