Kaydet (Commit) 1aa094f7 authored tarafından Ilya Kulakov's avatar Ilya Kulakov Kaydeden (comit) Yury Selivanov

bpo-29302: Implement contextlib.AsyncExitStack. (#4790)

üst 6ab62920
...@@ -435,6 +435,44 @@ Functions and classes provided: ...@@ -435,6 +435,44 @@ Functions and classes provided:
callbacks registered, the arguments passed in will indicate that no callbacks registered, the arguments passed in will indicate that no
exception occurred. exception occurred.
.. class:: AsyncExitStack()
An :ref:`asynchronous context manager <async-context-managers>`, similar
to :class:`ExitStack`, that supports combining both synchronous and
asynchronous context managers, as well as having coroutines for
cleanup logic.
The :meth:`close` method is not implemented, :meth:`aclose` must be used
instead.
.. method:: enter_async_context(cm)
Similar to :meth:`enter_context` but expects an asynchronous context
manager.
.. method:: push_async_exit(exit)
Similar to :meth:`push` but expects either an asynchronous context manager
or a coroutine.
.. method:: push_async_callback(callback, *args, **kwds)
Similar to :meth:`callback` but expects a coroutine.
.. method:: aclose()
Similar to :meth:`close` but properly handles awaitables.
Continuing the example for :func:`asynccontextmanager`::
async with AsyncExitStack() as stack:
connections = [await stack.enter_async_context(get_connection())
for i in range(5)]
# All opened connections will automatically be released at the end of
# the async with statement, even if attempts to open a connection
# later in the list raise an exception.
.. versionadded:: 3.7
Examples and Recipes Examples and Recipes
-------------------- --------------------
......
...@@ -379,6 +379,9 @@ contextlib ...@@ -379,6 +379,9 @@ contextlib
:class:`~contextlib.AbstractAsyncContextManager` have been added. (Contributed :class:`~contextlib.AbstractAsyncContextManager` have been added. (Contributed
by Jelle Zijlstra in :issue:`29679` and :issue:`30241`.) by Jelle Zijlstra in :issue:`29679` and :issue:`30241`.)
:class:`contextlib.AsyncExitStack` has been added. (Contributed by
Alexander Mohr and Ilya Kulakov in :issue:`29302`.)
cProfile cProfile
-------- --------
......
This diff is collapsed.
"""Unit tests for contextlib.py, and other context managers.""" """Unit tests for contextlib.py, and other context managers."""
import asyncio
import io import io
import sys import sys
import tempfile import tempfile
...@@ -505,17 +506,18 @@ class TestContextDecorator(unittest.TestCase): ...@@ -505,17 +506,18 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999]) self.assertEqual(state, [1, 'something else', 999])
class TestExitStack(unittest.TestCase): class TestBaseExitStack:
exit_stack = None
@support.requires_docstrings @support.requires_docstrings
def test_instance_docs(self): def test_instance_docs(self):
# Issue 19330: ensure context manager instances have good docstrings # Issue 19330: ensure context manager instances have good docstrings
cm_docstring = ExitStack.__doc__ cm_docstring = self.exit_stack.__doc__
obj = ExitStack() obj = self.exit_stack()
self.assertEqual(obj.__doc__, cm_docstring) self.assertEqual(obj.__doc__, cm_docstring)
def test_no_resources(self): def test_no_resources(self):
with ExitStack(): with self.exit_stack():
pass pass
def test_callback(self): def test_callback(self):
...@@ -531,7 +533,7 @@ class TestExitStack(unittest.TestCase): ...@@ -531,7 +533,7 @@ class TestExitStack(unittest.TestCase):
def _exit(*args, **kwds): def _exit(*args, **kwds):
"""Test metadata propagation""" """Test metadata propagation"""
result.append((args, kwds)) result.append((args, kwds))
with ExitStack() as stack: with self.exit_stack() as stack:
for args, kwds in reversed(expected): for args, kwds in reversed(expected):
if args and kwds: if args and kwds:
f = stack.callback(_exit, *args, **kwds) f = stack.callback(_exit, *args, **kwds)
...@@ -543,9 +545,9 @@ class TestExitStack(unittest.TestCase): ...@@ -543,9 +545,9 @@ class TestExitStack(unittest.TestCase):
f = stack.callback(_exit) f = stack.callback(_exit)
self.assertIs(f, _exit) self.assertIs(f, _exit)
for wrapper in stack._exit_callbacks: for wrapper in stack._exit_callbacks:
self.assertIs(wrapper.__wrapped__, _exit) self.assertIs(wrapper[1].__wrapped__, _exit)
self.assertNotEqual(wrapper.__name__, _exit.__name__) self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
self.assertIsNone(wrapper.__doc__, _exit.__doc__) self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_push(self): def test_push(self):
...@@ -565,21 +567,21 @@ class TestExitStack(unittest.TestCase): ...@@ -565,21 +567,21 @@ class TestExitStack(unittest.TestCase):
self.fail("Should not be called!") self.fail("Should not be called!")
def __exit__(self, *exc_details): def __exit__(self, *exc_details):
self.check_exc(*exc_details) self.check_exc(*exc_details)
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(_expect_ok) stack.push(_expect_ok)
self.assertIs(stack._exit_callbacks[-1], _expect_ok) self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
cm = ExitCM(_expect_ok) cm = ExitCM(_expect_ok)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push(_suppress_exc) stack.push(_suppress_exc)
self.assertIs(stack._exit_callbacks[-1], _suppress_exc) self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
cm = ExitCM(_expect_exc) cm = ExitCM(_expect_exc)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push(_expect_exc) stack.push(_expect_exc)
self.assertIs(stack._exit_callbacks[-1], _expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
stack.push(_expect_exc) stack.push(_expect_exc)
self.assertIs(stack._exit_callbacks[-1], _expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
1/0 1/0
def test_enter_context(self): def test_enter_context(self):
...@@ -591,19 +593,19 @@ class TestExitStack(unittest.TestCase): ...@@ -591,19 +593,19 @@ class TestExitStack(unittest.TestCase):
result = [] result = []
cm = TestCM() cm = TestCM()
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback # Registered first => cleaned up last @stack.callback # Registered first => cleaned up last
def _exit(): def _exit():
result.append(4) result.append(4)
self.assertIsNotNone(_exit) self.assertIsNotNone(_exit)
stack.enter_context(cm) stack.enter_context(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
result.append(2) result.append(2)
self.assertEqual(result, [1, 2, 3, 4]) self.assertEqual(result, [1, 2, 3, 4])
def test_close(self): def test_close(self):
result = [] result = []
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback @stack.callback
def _exit(): def _exit():
result.append(1) result.append(1)
...@@ -614,7 +616,7 @@ class TestExitStack(unittest.TestCase): ...@@ -614,7 +616,7 @@ class TestExitStack(unittest.TestCase):
def test_pop_all(self): def test_pop_all(self):
result = [] result = []
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback @stack.callback
def _exit(): def _exit():
result.append(3) result.append(3)
...@@ -627,12 +629,12 @@ class TestExitStack(unittest.TestCase): ...@@ -627,12 +629,12 @@ class TestExitStack(unittest.TestCase):
def test_exit_raise(self): def test_exit_raise(self):
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: False) stack.push(lambda *exc: False)
1/0 1/0
def test_exit_suppress(self): def test_exit_suppress(self):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: True) stack.push(lambda *exc: True)
1/0 1/0
...@@ -696,7 +698,7 @@ class TestExitStack(unittest.TestCase): ...@@ -696,7 +698,7 @@ class TestExitStack(unittest.TestCase):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
stack.callback(raise_exc, KeyError) stack.callback(raise_exc, KeyError)
stack.callback(raise_exc, AttributeError) stack.callback(raise_exc, AttributeError)
...@@ -724,7 +726,7 @@ class TestExitStack(unittest.TestCase): ...@@ -724,7 +726,7 @@ class TestExitStack(unittest.TestCase):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(lambda: None) stack.callback(lambda: None)
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
except Exception as exc: except Exception as exc:
...@@ -733,7 +735,7 @@ class TestExitStack(unittest.TestCase): ...@@ -733,7 +735,7 @@ class TestExitStack(unittest.TestCase):
self.fail("Expected IndexError, but no exception was raised") self.fail("Expected IndexError, but no exception was raised")
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_exc, KeyError) stack.callback(raise_exc, KeyError)
stack.push(suppress_exc) stack.push(suppress_exc)
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
...@@ -760,7 +762,7 @@ class TestExitStack(unittest.TestCase): ...@@ -760,7 +762,7 @@ class TestExitStack(unittest.TestCase):
# fix, ExitStack would try to fix it *again* and get into an # fix, ExitStack would try to fix it *again* and get into an
# infinite self-referential loop # infinite self-referential loop
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.enter_context(gets_the_context_right(exc4)) stack.enter_context(gets_the_context_right(exc4))
stack.enter_context(gets_the_context_right(exc3)) stack.enter_context(gets_the_context_right(exc3))
stack.enter_context(gets_the_context_right(exc2)) stack.enter_context(gets_the_context_right(exc2))
...@@ -787,7 +789,7 @@ class TestExitStack(unittest.TestCase): ...@@ -787,7 +789,7 @@ class TestExitStack(unittest.TestCase):
exc4 = Exception(4) exc4 = Exception(4)
exc5 = Exception(5) exc5 = Exception(5)
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_nested, exc4, exc5) stack.callback(raise_nested, exc4, exc5)
stack.callback(raise_nested, exc2, exc3) stack.callback(raise_nested, exc2, exc3)
raise exc1 raise exc1
...@@ -801,27 +803,25 @@ class TestExitStack(unittest.TestCase): ...@@ -801,27 +803,25 @@ class TestExitStack(unittest.TestCase):
self.assertIsNone( self.assertIsNone(
exc.__context__.__context__.__context__.__context__.__context__) exc.__context__.__context__.__context__.__context__.__context__)
def test_body_exception_suppress(self): def test_body_exception_suppress(self):
def suppress_exc(*exc_details): def suppress_exc(*exc_details):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(suppress_exc) stack.push(suppress_exc)
1/0 1/0
except IndexError as exc: except IndexError as exc:
self.fail("Expected no exception, got IndexError") self.fail("Expected no exception, got IndexError")
def test_exit_exception_chaining_suppress(self): def test_exit_exception_chaining_suppress(self):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: True) stack.push(lambda *exc: True)
stack.push(lambda *exc: 1/0) stack.push(lambda *exc: 1/0)
stack.push(lambda *exc: {}[1]) stack.push(lambda *exc: {}[1])
def test_excessive_nesting(self): def test_excessive_nesting(self):
# The original implementation would die with RecursionError here # The original implementation would die with RecursionError here
with ExitStack() as stack: with self.exit_stack() as stack:
for i in range(10000): for i in range(10000):
stack.callback(int) stack.callback(int)
...@@ -829,10 +829,10 @@ class TestExitStack(unittest.TestCase): ...@@ -829,10 +829,10 @@ class TestExitStack(unittest.TestCase):
class Example(object): pass class Example(object): pass
cm = Example() cm = Example()
cm.__exit__ = object() cm.__exit__ = object()
stack = ExitStack() stack = self.exit_stack()
self.assertRaises(AttributeError, stack.enter_context, cm) self.assertRaises(AttributeError, stack.enter_context, cm)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1], cm) self.assertIs(stack._exit_callbacks[-1][1], cm)
def test_dont_reraise_RuntimeError(self): def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122 # https://bugs.python.org/issue27122
...@@ -856,7 +856,7 @@ class TestExitStack(unittest.TestCase): ...@@ -856,7 +856,7 @@ class TestExitStack(unittest.TestCase):
# The UniqueRuntimeError should be caught by second()'s exception # The UniqueRuntimeError should be caught by second()'s exception
# handler which chain raised a new UniqueException. # handler which chain raised a new UniqueException.
with self.assertRaises(UniqueException) as err_ctx: with self.assertRaises(UniqueException) as err_ctx:
with ExitStack() as es_ctx: with self.exit_stack() as es_ctx:
es_ctx.enter_context(second()) es_ctx.enter_context(second())
es_ctx.enter_context(first()) es_ctx.enter_context(first())
raise UniqueRuntimeError("please no infinite loop.") raise UniqueRuntimeError("please no infinite loop.")
...@@ -869,6 +869,10 @@ class TestExitStack(unittest.TestCase): ...@@ -869,6 +869,10 @@ class TestExitStack(unittest.TestCase):
self.assertIs(exc.__cause__, exc.__context__) self.assertIs(exc.__cause__, exc.__context__)
class TestExitStack(TestBaseExitStack, unittest.TestCase):
exit_stack = ExitStack
class TestRedirectStream: class TestRedirectStream:
redirect_stream = None redirect_stream = None
......
import asyncio import asyncio
from contextlib import asynccontextmanager, AbstractAsyncContextManager from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
import functools import functools
from test import support from test import support
import unittest import unittest
from .test_contextlib import TestBaseExitStack
def _async_test(func): def _async_test(func):
"""Decorator to turn an async function into a test case.""" """Decorator to turn an async function into a test case."""
...@@ -255,5 +257,168 @@ class AsyncContextManagerTestCase(unittest.TestCase): ...@@ -255,5 +257,168 @@ class AsyncContextManagerTestCase(unittest.TestCase):
self.assertEqual(target, (11, 22, 33, 44)) self.assertEqual(target, (11, 22, 33, 44))
class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
class SyncAsyncExitStack(AsyncExitStack):
@staticmethod
def run_coroutine(coro):
loop = asyncio.get_event_loop()
f = asyncio.ensure_future(coro)
f.add_done_callback(lambda f: loop.stop())
loop.run_forever()
exc = f.exception()
if not exc:
return f.result()
else:
context = exc.__context__
try:
raise exc
except:
exc.__context__ = context
raise exc
def close(self):
return self.run_coroutine(self.aclose())
def __enter__(self):
return self.run_coroutine(self.__aenter__())
def __exit__(self, *exc_details):
return self.run_coroutine(self.__aexit__(*exc_details))
exit_stack = SyncAsyncExitStack
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.addCleanup(self.loop.close)
@_async_test
async def test_async_callback(self):
expected = [
((), {}),
((1,), {}),
((1,2), {}),
((), dict(example=1)),
((1,), dict(example=1)),
((1,2), dict(example=1)),
]
result = []
async def _exit(*args, **kwds):
"""Test metadata propagation"""
result.append((args, kwds))
async with AsyncExitStack() as stack:
for args, kwds in reversed(expected):
if args and kwds:
f = stack.push_async_callback(_exit, *args, **kwds)
elif args:
f = stack.push_async_callback(_exit, *args)
elif kwds:
f = stack.push_async_callback(_exit, **kwds)
else:
f = stack.push_async_callback(_exit)
self.assertIs(f, _exit)
for wrapper in stack._exit_callbacks:
self.assertIs(wrapper[1].__wrapped__, _exit)
self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
self.assertEqual(result, expected)
@_async_test
async def test_async_push(self):
exc_raised = ZeroDivisionError
async def _expect_exc(exc_type, exc, exc_tb):
self.assertIs(exc_type, exc_raised)
async def _suppress_exc(*exc_details):
return True
async def _expect_ok(exc_type, exc, exc_tb):
self.assertIsNone(exc_type)
self.assertIsNone(exc)
self.assertIsNone(exc_tb)
class ExitCM(object):
def __init__(self, check_exc):
self.check_exc = check_exc
async def __aenter__(self):
self.fail("Should not be called!")
async def __aexit__(self, *exc_details):
await self.check_exc(*exc_details)
async with self.exit_stack() as stack:
stack.push_async_exit(_expect_ok)
self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
cm = ExitCM(_expect_ok)
stack.push_async_exit(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push_async_exit(_suppress_exc)
self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
cm = ExitCM(_expect_exc)
stack.push_async_exit(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push_async_exit(_expect_exc)
self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
stack.push_async_exit(_expect_exc)
self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
1/0
@_async_test
async def test_async_enter_context(self):
class TestCM(object):
async def __aenter__(self):
result.append(1)
async def __aexit__(self, *exc_details):
result.append(3)
result = []
cm = TestCM()
async with AsyncExitStack() as stack:
@stack.push_async_callback # Registered first => cleaned up last
async def _exit():
result.append(4)
self.assertIsNotNone(_exit)
await stack.enter_async_context(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
result.append(2)
self.assertEqual(result, [1, 2, 3, 4])
@_async_test
async def test_async_exit_exception_chaining(self):
# Ensure exception chaining matches the reference behaviour
async def raise_exc(exc):
raise exc
saved_details = None
async def suppress_exc(*exc_details):
nonlocal saved_details
saved_details = exc_details
return True
try:
async with self.exit_stack() as stack:
stack.push_async_callback(raise_exc, IndexError)
stack.push_async_callback(raise_exc, KeyError)
stack.push_async_callback(raise_exc, AttributeError)
stack.push_async_exit(suppress_exc)
stack.push_async_callback(raise_exc, ValueError)
1 / 0
except IndexError as exc:
self.assertIsInstance(exc.__context__, KeyError)
self.assertIsInstance(exc.__context__.__context__, AttributeError)
# Inner exceptions were suppressed
self.assertIsNone(exc.__context__.__context__.__context__)
else:
self.fail("Expected IndexError, but no exception was raised")
# Check the inner exceptions
inner_exc = saved_details[1]
self.assertIsInstance(inner_exc, ValueError)
self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Add contextlib.AsyncExitStack. Patch by Alexander Mohr and Ilya Kulakov.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment