Kaydet (Commit) ab3c8898 authored tarafından Guido van Rossum's avatar Guido van Rossum

asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.

üst ab27a9fc
...@@ -9,6 +9,36 @@ from . import futures ...@@ -9,6 +9,36 @@ from . import futures
from . import tasks from . import tasks
class _ContextManager:
"""Context manager.
This enables the following idiom for acquiring and releasing a
lock around a block:
with (yield from lock):
<block>
while failing loudly when accidentally using:
with lock:
<block>
"""
def __init__(self, lock):
self._lock = lock
def __enter__(self):
# We have no use for the "as ..." clause in the with
# statement for locks.
return None
def __exit__(self, *args):
try:
self._lock.release()
finally:
self._lock = None # Crudely prevent reuse.
class Lock: class Lock:
"""Primitive lock objects. """Primitive lock objects.
...@@ -124,17 +154,29 @@ class Lock: ...@@ -124,17 +154,29 @@ class Lock:
raise RuntimeError('Lock is not acquired.') raise RuntimeError('Lock is not acquired.')
def __enter__(self): def __enter__(self):
if not self._locked: raise RuntimeError(
raise RuntimeError( '"yield from" should be used as context manager expression')
'"yield from" should be used as context manager expression')
return True
def __exit__(self, *args): def __exit__(self, *args):
self.release() # This must exist because __enter__ exists, even though that
# always raises; that's how the with-statement works.
pass
def __iter__(self): def __iter__(self):
# This is not a coroutine. It is meant to enable the idiom:
#
# with (yield from lock):
# <block>
#
# as an alternative to:
#
# yield from lock.acquire()
# try:
# <block>
# finally:
# lock.release()
yield from self.acquire() yield from self.acquire()
return self return _ContextManager(self)
class Event: class Event:
...@@ -311,14 +353,16 @@ class Condition: ...@@ -311,14 +353,16 @@ class Condition:
self.notify(len(self._waiters)) self.notify(len(self._waiters))
def __enter__(self): def __enter__(self):
return self._lock.__enter__() raise RuntimeError(
'"yield from" should be used as context manager expression')
def __exit__(self, *args): def __exit__(self, *args):
return self._lock.__exit__(*args) pass
def __iter__(self): def __iter__(self):
# See comment in Lock.__iter__().
yield from self.acquire() yield from self.acquire()
return self return _ContextManager(self)
class Semaphore: class Semaphore:
...@@ -341,7 +385,6 @@ class Semaphore: ...@@ -341,7 +385,6 @@ class Semaphore:
raise ValueError("Semaphore initial value must be >= 0") raise ValueError("Semaphore initial value must be >= 0")
self._value = value self._value = value
self._waiters = collections.deque() self._waiters = collections.deque()
self._locked = (value == 0)
if loop is not None: if loop is not None:
self._loop = loop self._loop = loop
else: else:
...@@ -349,7 +392,7 @@ class Semaphore: ...@@ -349,7 +392,7 @@ class Semaphore:
def __repr__(self): def __repr__(self):
res = super().__repr__() res = super().__repr__()
extra = 'locked' if self._locked else 'unlocked,value:{}'.format( extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
self._value) self._value)
if self._waiters: if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters)) extra = '{},waiters:{}'.format(extra, len(self._waiters))
...@@ -357,7 +400,7 @@ class Semaphore: ...@@ -357,7 +400,7 @@ class Semaphore:
def locked(self): def locked(self):
"""Returns True if semaphore can not be acquired immediately.""" """Returns True if semaphore can not be acquired immediately."""
return self._locked return self._value == 0
@tasks.coroutine @tasks.coroutine
def acquire(self): def acquire(self):
...@@ -371,8 +414,6 @@ class Semaphore: ...@@ -371,8 +414,6 @@ class Semaphore:
""" """
if not self._waiters and self._value > 0: if not self._waiters and self._value > 0:
self._value -= 1 self._value -= 1
if self._value == 0:
self._locked = True
return True return True
fut = futures.Future(loop=self._loop) fut = futures.Future(loop=self._loop)
...@@ -380,8 +421,6 @@ class Semaphore: ...@@ -380,8 +421,6 @@ class Semaphore:
try: try:
yield from fut yield from fut
self._value -= 1 self._value -= 1
if self._value == 0:
self._locked = True
return True return True
finally: finally:
self._waiters.remove(fut) self._waiters.remove(fut)
...@@ -392,23 +431,22 @@ class Semaphore: ...@@ -392,23 +431,22 @@ class Semaphore:
become larger than zero again, wake up that coroutine. become larger than zero again, wake up that coroutine.
""" """
self._value += 1 self._value += 1
self._locked = False
for waiter in self._waiters: for waiter in self._waiters:
if not waiter.done(): if not waiter.done():
waiter.set_result(True) waiter.set_result(True)
break break
def __enter__(self): def __enter__(self):
# TODO: This is questionable. How do we know the user actually raise RuntimeError(
# wrote "with (yield from sema)" instead of "with sema"? '"yield from" should be used as context manager expression')
return True
def __exit__(self, *args): def __exit__(self, *args):
self.release() pass
def __iter__(self): def __iter__(self):
# See comment in Lock.__iter__().
yield from self.acquire() yield from self.acquire()
return self return _ContextManager(self)
class BoundedSemaphore(Semaphore): class BoundedSemaphore(Semaphore):
......
...@@ -208,6 +208,24 @@ class LockTests(unittest.TestCase): ...@@ -208,6 +208,24 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
def test_context_manager_cant_reuse(self):
lock = asyncio.Lock(loop=self.loop)
@asyncio.coroutine
def acquire_lock():
return (yield from lock)
# This spells "yield from lock" outside a generator.
cm = self.loop.run_until_complete(acquire_lock())
with cm:
self.assertTrue(lock.locked())
self.assertFalse(lock.locked())
with self.assertRaises(AttributeError):
with cm:
pass
def test_context_manager_no_yield(self): def test_context_manager_no_yield(self):
lock = asyncio.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
...@@ -219,6 +237,8 @@ class LockTests(unittest.TestCase): ...@@ -219,6 +237,8 @@ class LockTests(unittest.TestCase):
str(err), str(err),
'"yield from" should be used as context manager expression') '"yield from" should be used as context manager expression')
self.assertFalse(lock.locked())
class EventTests(unittest.TestCase): class EventTests(unittest.TestCase):
...@@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase): ...@@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase):
str(err), str(err),
'"yield from" should be used as context manager expression') '"yield from" should be used as context manager expression')
self.assertFalse(cond.locked())
class SemaphoreTests(unittest.TestCase): class SemaphoreTests(unittest.TestCase):
...@@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase): ...@@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase):
self.assertEqual(2, sem._value) self.assertEqual(2, sem._value)
def test_context_manager_no_yield(self):
sem = asyncio.Semaphore(2, loop=self.loop)
try:
with sem:
self.fail('RuntimeError is not raised in with expression')
except RuntimeError as err:
self.assertEqual(
str(err),
'"yield from" should be used as context manager expression')
self.assertEqual(2, sem._value)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment