Kaydet (Commit) 082332ce authored tarafından Yury Selivanov's avatar Yury Selivanov

Issue 24342: Let wrapper set by sys.set_coroutine_wrapper fail gracefully

(Merge 3.5)
......@@ -1085,6 +1085,20 @@ always available.
If called twice, the new wrapper replaces the previous one. The function
is thread-specific.
The *wrapper* callable cannot define new coroutines directly or indirectly::
def wrapper(coro):
async def wrap(coro):
return await coro
return wrap(coro)
sys.set_coroutine_wrapper(wrapper)
async def foo(): pass
# The following line will fail with a RuntimeError, because
# `wrapper` creates a `wrap(coro)` coroutine:
foo()
See also :func:`get_coroutine_wrapper`.
.. versionadded:: 3.5
......
......@@ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj,
#ifndef Py_LIMITED_API
PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *);
PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *);
PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper);
PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *);
PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void);
PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *);
#endif
struct _frame; /* Avoid including frameobject.h */
......
......@@ -135,6 +135,7 @@ typedef struct _ts {
void *on_delete_data;
PyObject *coroutine_wrapper;
int in_coroutine_wrapper;
/* XXX signal handlers should also be here */
......
......@@ -995,6 +995,26 @@ class SysSetCoroWrapperTest(unittest.TestCase):
sys.set_coroutine_wrapper(1)
self.assertIsNone(sys.get_coroutine_wrapper())
def test_set_wrapper_3(self):
async def foo():
return 'spam'
def wrapper(coro):
async def wrap(coro):
return await coro
return wrap(coro)
sys.set_coroutine_wrapper(wrapper)
try:
with self.assertRaisesRegex(
RuntimeError,
"coroutine wrapper.*\.wrapper at 0x.*attempted to "
"recursively wrap <coroutine.*\.wrap"):
foo()
finally:
sys.set_coroutine_wrapper(None)
class CAPITest(unittest.TestCase):
......
......@@ -3921,7 +3921,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
if (co->co_flags & CO_GENERATOR) {
PyObject *gen;
PyObject *coroutine_wrapper;
/* Don't need to keep the reference to f_back, it will be set
* when the generator is resumed. */
......@@ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
if (gen == NULL)
return NULL;
if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) {
coroutine_wrapper = _PyEval_GetCoroutineWrapper();
if (coroutine_wrapper != NULL) {
PyObject *wrapped =
PyObject_CallFunction(coroutine_wrapper, "N", gen);
gen = wrapped;
}
}
if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))
return _PyEval_ApplyCoroutineWrapper(gen);
return gen;
}
......@@ -4407,6 +4401,33 @@ _PyEval_GetCoroutineWrapper(void)
return tstate->coroutine_wrapper;
}
PyObject *
_PyEval_ApplyCoroutineWrapper(PyObject *gen)
{
PyObject *wrapped;
PyThreadState *tstate = PyThreadState_GET();
PyObject *wrapper = tstate->coroutine_wrapper;
if (tstate->in_coroutine_wrapper) {
assert(wrapper != NULL);
PyErr_Format(PyExc_RuntimeError,
"coroutine wrapper %.150R attempted "
"to recursively wrap %.150R",
wrapper,
gen);
return NULL;
}
if (wrapper == NULL) {
return gen;
}
tstate->in_coroutine_wrapper = 1;
wrapped = PyObject_CallFunction(wrapper, "N", gen);
tstate->in_coroutine_wrapper = 0;
return wrapped;
}
PyObject *
PyEval_GetBuiltins(void)
{
......
......@@ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init)
tstate->on_delete_data = NULL;
tstate->coroutine_wrapper = NULL;
tstate->in_coroutine_wrapper = 0;
if (init)
_PyThreadState_Init(tstate);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment