Kaydet (Commit) eeeebcd8 authored tarafından Victor Stinner's avatar Victor Stinner

asyncio: Synchronize with Tulip

* Issue #159: Fix windows_utils.socketpair()

  - Use "127.0.0.1" (IPv4) or "::1" (IPv6) host instead of "localhost", because
    "localhost" may be a different IP address
  - Reject also invalid arguments: only AF_INET/AF_INET6 with SOCK_STREAM (and
    proto=0) are supported

* Reject add/remove reader/writer when event loop is closed.
* Fix ResourceWarning warnings
üst c5cc5011
...@@ -136,6 +136,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -136,6 +136,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def add_reader(self, fd, callback, *args): def add_reader(self, fd, callback, *args):
"""Add a reader callback.""" """Add a reader callback."""
if self._selector is None:
raise RuntimeError('Event loop is closed')
handle = events.Handle(callback, args, self) handle = events.Handle(callback, args, self)
try: try:
key = self._selector.get_key(fd) key = self._selector.get_key(fd)
...@@ -151,6 +153,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -151,6 +153,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def remove_reader(self, fd): def remove_reader(self, fd):
"""Remove a reader callback.""" """Remove a reader callback."""
if self._selector is None:
return False
try: try:
key = self._selector.get_key(fd) key = self._selector.get_key(fd)
except KeyError: except KeyError:
...@@ -171,6 +175,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -171,6 +175,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def add_writer(self, fd, callback, *args): def add_writer(self, fd, callback, *args):
"""Add a writer callback..""" """Add a writer callback.."""
if self._selector is None:
raise RuntimeError('Event loop is closed')
handle = events.Handle(callback, args, self) handle = events.Handle(callback, args, self)
try: try:
key = self._selector.get_key(fd) key = self._selector.get_key(fd)
...@@ -186,6 +192,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -186,6 +192,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def remove_writer(self, fd): def remove_writer(self, fd):
"""Remove a writer callback.""" """Remove a writer callback."""
if self._selector is None:
return False
try: try:
key = self._selector.get_key(fd) key = self._selector.get_key(fd)
except KeyError: except KeyError:
......
...@@ -36,12 +36,25 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): ...@@ -36,12 +36,25 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
""" """
if family == socket.AF_INET:
host = '127.0.0.1'
elif family == socket.AF_INET6:
host = '::1'
else:
raise ValueError("Ony AF_INET and AF_INET6 socket address families "
"are supported")
if type != socket.SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")
# We create a connected TCP socket. Note the trick with setblocking(0) # We create a connected TCP socket. Note the trick with setblocking(0)
# that prevents us from having to create a thread. # that prevents us from having to create a thread.
lsock = socket.socket(family, type, proto) lsock = socket.socket(family, type, proto)
lsock.bind(('localhost', 0)) lsock.bind((host, 0))
lsock.listen(1) lsock.listen(1)
addr, port = lsock.getsockname() # On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto) csock = socket.socket(family, type, proto)
csock.setblocking(False) csock.setblocking(False)
try: try:
......
...@@ -1326,6 +1326,30 @@ class EventLoopTestsMixin: ...@@ -1326,6 +1326,30 @@ class EventLoopTestsMixin:
self.assertIn('address must be resolved', self.assertIn('address must be resolved',
str(cm.exception)) str(cm.exception))
def test_remove_fds_after_closing(self):
loop = self.create_event_loop()
callback = lambda: None
r, w = test_utils.socketpair()
self.addCleanup(r.close)
self.addCleanup(w.close)
loop.add_reader(r, callback)
loop.add_writer(w, callback)
loop.close()
self.assertFalse(loop.remove_reader(r))
self.assertFalse(loop.remove_writer(w))
def test_add_fds_after_closing(self):
loop = self.create_event_loop()
callback = lambda: None
r, w = test_utils.socketpair()
self.addCleanup(r.close)
self.addCleanup(w.close)
loop.close()
with self.assertRaises(RuntimeError):
loop.add_reader(r, callback)
with self.assertRaises(RuntimeError):
loop.add_writer(w, callback)
class SubprocessTestsMixin: class SubprocessTestsMixin:
...@@ -1632,6 +1656,9 @@ if sys.platform == 'win32': ...@@ -1632,6 +1656,9 @@ if sys.platform == 'win32':
def test_create_datagram_endpoint(self): def test_create_datagram_endpoint(self):
raise unittest.SkipTest( raise unittest.SkipTest(
"IocpEventLoop does not have create_datagram_endpoint()") "IocpEventLoop does not have create_datagram_endpoint()")
def test_remove_fds_after_closing(self):
raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
else: else:
from asyncio import selectors from asyncio import selectors
......
"""Tests for window_utils""" """Tests for window_utils"""
import socket
import sys import sys
import test.support import test.support
import unittest import unittest
from test.support import IPV6_ENABLED
from unittest import mock from unittest import mock
if sys.platform != 'win32': if sys.platform != 'win32':
...@@ -16,23 +18,40 @@ from asyncio import _overlapped ...@@ -16,23 +18,40 @@ from asyncio import _overlapped
class WinsocketpairTests(unittest.TestCase): class WinsocketpairTests(unittest.TestCase):
def test_winsocketpair(self): def check_winsocketpair(self, ssock, csock):
ssock, csock = windows_utils.socketpair()
csock.send(b'xxx') csock.send(b'xxx')
self.assertEqual(b'xxx', ssock.recv(1024)) self.assertEqual(b'xxx', ssock.recv(1024))
csock.close() csock.close()
ssock.close() ssock.close()
def test_winsocketpair(self):
ssock, csock = windows_utils.socketpair()
self.check_winsocketpair(ssock, csock)
@unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
def test_winsocketpair_ipv6(self):
ssock, csock = windows_utils.socketpair(family=socket.AF_INET6)
self.check_winsocketpair(ssock, csock)
@mock.patch('asyncio.windows_utils.socket') @mock.patch('asyncio.windows_utils.socket')
def test_winsocketpair_exc(self, m_socket): def test_winsocketpair_exc(self, m_socket):
m_socket.AF_INET = socket.AF_INET
m_socket.SOCK_STREAM = socket.SOCK_STREAM
m_socket.socket.return_value.getsockname.return_value = ('', 12345) m_socket.socket.return_value.getsockname.return_value = ('', 12345)
m_socket.socket.return_value.accept.return_value = object(), object() m_socket.socket.return_value.accept.return_value = object(), object()
m_socket.socket.return_value.connect.side_effect = OSError() m_socket.socket.return_value.connect.side_effect = OSError()
self.assertRaises(OSError, windows_utils.socketpair) self.assertRaises(OSError, windows_utils.socketpair)
def test_winsocketpair_invalid_args(self):
self.assertRaises(ValueError,
windows_utils.socketpair, family=socket.AF_UNSPEC)
self.assertRaises(ValueError,
windows_utils.socketpair, type=socket.SOCK_DGRAM)
self.assertRaises(ValueError,
windows_utils.socketpair, proto=1)
class PipeTests(unittest.TestCase): class PipeTests(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment