Kaydet (Commit) 0eaa5ac9 authored tarafından Guido van Rossum's avatar Guido van Rossum

asyncio: Refactor SIGCHLD handling. By Anthony Baire.

üst ccea0846
"""Event loop and event loop policy.""" """Event loop and event loop policy."""
__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', __all__ = ['AbstractEventLoopPolicy',
'AbstractEventLoop', 'AbstractServer', 'AbstractEventLoop', 'AbstractServer',
'Handle', 'TimerHandle', 'Handle', 'TimerHandle',
'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop_policy', 'set_event_loop_policy',
'get_event_loop', 'set_event_loop', 'new_event_loop', 'get_event_loop', 'set_event_loop', 'new_event_loop',
'get_child_watcher', 'set_child_watcher',
] ]
import subprocess import subprocess
...@@ -318,8 +319,18 @@ class AbstractEventLoopPolicy: ...@@ -318,8 +319,18 @@ class AbstractEventLoopPolicy:
"""XXX""" """XXX"""
raise NotImplementedError raise NotImplementedError
# Child processes handling (Unix only).
class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): def get_child_watcher(self):
"""XXX"""
raise NotImplementedError
def set_child_watcher(self, watcher):
"""XXX"""
raise NotImplementedError
class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy):
"""Default policy implementation for accessing the event loop. """Default policy implementation for accessing the event loop.
In this policy, each thread has its own event loop. However, we In this policy, each thread has its own event loop. However, we
...@@ -332,28 +343,34 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): ...@@ -332,28 +343,34 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
associated). associated).
""" """
_loop = None _loop_factory = None
_set_called = False
class _Local(threading.local):
_loop = None
_set_called = False
def __init__(self):
self._local = self._Local()
def get_event_loop(self): def get_event_loop(self):
"""Get the event loop. """Get the event loop.
This may be None or an instance of EventLoop. This may be None or an instance of EventLoop.
""" """
if (self._loop is None and if (self._local._loop is None and
not self._set_called and not self._local._set_called and
isinstance(threading.current_thread(), threading._MainThread)): isinstance(threading.current_thread(), threading._MainThread)):
self._loop = self.new_event_loop() self._local._loop = self.new_event_loop()
assert self._loop is not None, \ assert self._local._loop is not None, \
('There is no current event loop in thread %r.' % ('There is no current event loop in thread %r.' %
threading.current_thread().name) threading.current_thread().name)
return self._loop return self._local._loop
def set_event_loop(self, loop): def set_event_loop(self, loop):
"""Set the event loop.""" """Set the event loop."""
self._set_called = True self._local._set_called = True
assert loop is None or isinstance(loop, AbstractEventLoop) assert loop is None or isinstance(loop, AbstractEventLoop)
self._loop = loop self._local._loop = loop
def new_event_loop(self): def new_event_loop(self):
"""Create a new event loop. """Create a new event loop.
...@@ -361,12 +378,7 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): ...@@ -361,12 +378,7 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
You must call set_event_loop() to make this the current event You must call set_event_loop() to make this the current event
loop. loop.
""" """
if sys.platform == 'win32': # pragma: no cover return self._loop_factory()
from . import windows_events
return windows_events.SelectorEventLoop()
else: # pragma: no cover
from . import unix_events
return unix_events.SelectorEventLoop()
# Event loop policy. The policy itself is always global, even if the # Event loop policy. The policy itself is always global, even if the
...@@ -375,12 +387,22 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): ...@@ -375,12 +387,22 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
# call to get_event_loop_policy(). # call to get_event_loop_policy().
_event_loop_policy = None _event_loop_policy = None
# Lock for protecting the on-the-fly creation of the event loop policy.
_lock = threading.Lock()
def _init_event_loop_policy():
global _event_loop_policy
with _lock:
if _event_loop_policy is None: # pragma: no branch
from . import DefaultEventLoopPolicy
_event_loop_policy = DefaultEventLoopPolicy()
def get_event_loop_policy(): def get_event_loop_policy():
"""XXX""" """XXX"""
global _event_loop_policy
if _event_loop_policy is None: if _event_loop_policy is None:
_event_loop_policy = DefaultEventLoopPolicy() _init_event_loop_policy()
return _event_loop_policy return _event_loop_policy
...@@ -404,3 +426,13 @@ def set_event_loop(loop): ...@@ -404,3 +426,13 @@ def set_event_loop(loop):
def new_event_loop(): def new_event_loop():
"""XXX""" """XXX"""
return get_event_loop_policy().new_event_loop() return get_event_loop_policy().new_event_loop()
def get_child_watcher():
"""XXX"""
return get_event_loop_policy().get_child_watcher()
def set_child_watcher(watcher):
"""XXX"""
return get_event_loop_policy().set_child_watcher(watcher)
This diff is collapsed.
...@@ -7,6 +7,7 @@ import weakref ...@@ -7,6 +7,7 @@ import weakref
import struct import struct
import _winapi import _winapi
from . import events
from . import base_subprocess from . import base_subprocess
from . import futures from . import futures
from . import proactor_events from . import proactor_events
...@@ -17,7 +18,9 @@ from .log import logger ...@@ -17,7 +18,9 @@ from .log import logger
from . import _overlapped from . import _overlapped
__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor',
'DefaultEventLoopPolicy',
]
NULL = 0 NULL = 0
...@@ -108,7 +111,7 @@ class PipeServer(object): ...@@ -108,7 +111,7 @@ class PipeServer(object):
__del__ = close __del__ = close
class SelectorEventLoop(selector_events.BaseSelectorEventLoop): class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Windows version of selector event loop.""" """Windows version of selector event loop."""
def _socketpair(self): def _socketpair(self):
...@@ -453,3 +456,13 @@ class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): ...@@ -453,3 +456,13 @@ class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport):
f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) f = self._loop._proactor.wait_for_handle(int(self._proc._handle))
f.add_done_callback(callback) f.add_done_callback(callback)
SelectorEventLoop = _WindowsSelectorEventLoop
class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
_loop_factory = SelectorEventLoop
DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy
...@@ -1308,8 +1308,17 @@ else: ...@@ -1308,8 +1308,17 @@ else:
from asyncio import selectors from asyncio import selectors
from asyncio import unix_events from asyncio import unix_events
class UnixEventLoopTestsMixin(EventLoopTestsMixin):
def setUp(self):
super().setUp()
events.set_child_watcher(unix_events.SafeChildWatcher(self.loop))
def tearDown(self):
events.set_child_watcher(None)
super().tearDown()
if hasattr(selectors, 'KqueueSelector'): if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(EventLoopTestsMixin, class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
...@@ -1318,7 +1327,7 @@ else: ...@@ -1318,7 +1327,7 @@ else:
selectors.KqueueSelector()) selectors.KqueueSelector())
if hasattr(selectors, 'EpollSelector'): if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(EventLoopTestsMixin, class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
...@@ -1326,7 +1335,7 @@ else: ...@@ -1326,7 +1335,7 @@ else:
return unix_events.SelectorEventLoop(selectors.EpollSelector()) return unix_events.SelectorEventLoop(selectors.EpollSelector())
if hasattr(selectors, 'PollSelector'): if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(EventLoopTestsMixin, class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
...@@ -1334,7 +1343,7 @@ else: ...@@ -1334,7 +1343,7 @@ else:
return unix_events.SelectorEventLoop(selectors.PollSelector()) return unix_events.SelectorEventLoop(selectors.PollSelector())
# Should always exist. # Should always exist.
class SelectEventLoopTests(EventLoopTestsMixin, class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
...@@ -1557,25 +1566,36 @@ class ProtocolsAbsTests(unittest.TestCase): ...@@ -1557,25 +1566,36 @@ class ProtocolsAbsTests(unittest.TestCase):
class PolicyTests(unittest.TestCase): class PolicyTests(unittest.TestCase):
def create_policy(self):
if sys.platform == "win32":
from asyncio import windows_events
return windows_events.DefaultEventLoopPolicy()
else:
from asyncio import unix_events
return unix_events.DefaultEventLoopPolicy()
def test_event_loop_policy(self): def test_event_loop_policy(self):
policy = events.AbstractEventLoopPolicy() policy = events.AbstractEventLoopPolicy()
self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.get_event_loop)
self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.set_event_loop, object())
self.assertRaises(NotImplementedError, policy.new_event_loop) self.assertRaises(NotImplementedError, policy.new_event_loop)
self.assertRaises(NotImplementedError, policy.get_child_watcher)
self.assertRaises(NotImplementedError, policy.set_child_watcher,
object())
def test_get_event_loop(self): def test_get_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
self.assertIsNone(policy._loop) self.assertIsNone(policy._local._loop)
loop = policy.get_event_loop() loop = policy.get_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, events.AbstractEventLoop)
self.assertIs(policy._loop, loop) self.assertIs(policy._local._loop, loop)
self.assertIs(loop, policy.get_event_loop()) self.assertIs(loop, policy.get_event_loop())
loop.close() loop.close()
def test_get_event_loop_after_set_none(self): def test_get_event_loop_after_set_none(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
policy.set_event_loop(None) policy.set_event_loop(None)
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
...@@ -1583,7 +1603,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1583,7 +1603,7 @@ class PolicyTests(unittest.TestCase):
def test_get_event_loop_thread(self, m_current_thread): def test_get_event_loop_thread(self, m_current_thread):
def f(): def f():
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
th = threading.Thread(target=f) th = threading.Thread(target=f)
...@@ -1591,14 +1611,14 @@ class PolicyTests(unittest.TestCase): ...@@ -1591,14 +1611,14 @@ class PolicyTests(unittest.TestCase):
th.join() th.join()
def test_new_event_loop(self): def test_new_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
loop = policy.new_event_loop() loop = policy.new_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, events.AbstractEventLoop)
loop.close() loop.close()
def test_set_event_loop(self): def test_set_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
old_loop = policy.get_event_loop() old_loop = policy.get_event_loop()
self.assertRaises(AssertionError, policy.set_event_loop, object()) self.assertRaises(AssertionError, policy.set_event_loop, object())
...@@ -1621,7 +1641,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1621,7 +1641,7 @@ class PolicyTests(unittest.TestCase):
old_policy = events.get_event_loop_policy() old_policy = events.get_event_loop_policy()
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
events.set_event_loop_policy(policy) events.set_event_loop_policy(policy)
self.assertIs(policy, events.get_event_loop_policy()) self.assertIs(policy, events.get_event_loop_policy())
self.assertIsNot(policy, old_policy) self.assertIsNot(policy, old_policy)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment