Unverified Kaydet (Commit) 74387926 authored tarafından Andrew Svetlov's avatar Andrew Svetlov Kaydeden (comit) GitHub

bpo-30064: Refactor sock_* asyncio API (#10419)

üst 9404e773
......@@ -358,26 +358,29 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recv(n)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
self._sock_recv(fut, None, sock, n)
fd = sock.fileno()
self.add_reader(fd, self._sock_recv, fut, sock, n)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
return await fut
def _sock_recv(self, fut, registered_fd, sock, n):
def _sock_read_done(self, fd, fut):
self.remove_reader(fd)
def _sock_recv(self, fut, sock, n):
# _sock_recv() can add itself as an I/O callback if the operation can't
# be done immediately. Don't use it directly, call sock_recv().
if registered_fd is not None:
# Remove the callback early. It should be rare that the
# selector says the fd is ready but the call still returns
# EAGAIN, and I am willing to take a hit in that case in
# order to simplify the common case.
self.remove_reader(registered_fd)
if fut.cancelled():
if fut.done():
return
try:
data = sock.recv(n)
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
self.add_reader(fd, self._sock_recv, fut, fd, sock, n)
return # try again next time
except Exception as exc:
fut.set_exception(exc)
else:
......@@ -391,27 +394,27 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recv_into(buf)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
self._sock_recv_into(fut, None, sock, buf)
fd = sock.fileno()
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
return await fut
def _sock_recv_into(self, fut, registered_fd, sock, buf):
def _sock_recv_into(self, fut, sock, buf):
# _sock_recv_into() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recv_into().
if registered_fd is not None:
# Remove the callback early. It should be rare that the
# selector says the FD is ready but the call still returns
# EAGAIN, and I am willing to take a hit in that case in
# order to simplify the common case.
self.remove_reader(registered_fd)
if fut.cancelled():
if fut.done():
return
try:
nbytes = sock.recv_into(buf)
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
self.add_reader(fd, self._sock_recv_into, fut, fd, sock, buf)
return # try again next time
except Exception as exc:
fut.set_exception(exc)
else:
......@@ -428,23 +431,32 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
fut = self.create_future()
if data:
self._sock_sendall(fut, None, sock, data)
try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
n = 0
if n == len(data):
# all data sent
return
else:
fut.set_result(None)
data = bytearray(memoryview(data)[n:])
fut = self.create_future()
fd = sock.fileno()
fut.add_done_callback(
functools.partial(self._sock_write_done, fd))
self.add_writer(fd, self._sock_sendall, fut, sock, data)
return await fut
def _sock_sendall(self, fut, registered_fd, sock, data):
if registered_fd is not None:
self.remove_writer(registered_fd)
if fut.cancelled():
def _sock_sendall(self, fut, sock, data):
if fut.done():
# Future cancellation can be scheduled on previous loop iteration
return
try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
n = 0
return
except Exception as exc:
fut.set_exception(exc)
return
......@@ -452,10 +464,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = sock.fileno()
self.add_writer(fd, self._sock_sendall, fut, fd, sock, data)
del data[:n]
async def sock_connect(self, sock, address):
"""Connect to a remote socket at address.
......@@ -484,18 +493,18 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
# becomes writable to be notified when the connection succeed or
# fails.
fut.add_done_callback(
functools.partial(self._sock_connect_done, fd))
functools.partial(self._sock_write_done, fd))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(None)
def _sock_connect_done(self, fd, fut):
def _sock_write_done(self, fd, fut):
self.remove_writer(fd)
def _sock_connect_cb(self, fut, sock, address):
if fut.cancelled():
if fut.done():
return
try:
......@@ -529,7 +538,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
fd = sock.fileno()
if registered:
self.remove_reader(fd)
if fut.cancelled():
if fut.done():
return
try:
conn, address = sock.accept()
......
......@@ -190,411 +190,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
self.loop._csock.send.side_effect = RuntimeError()
self.assertRaises(RuntimeError, self.loop._write_to_self)
def test_sock_recv(self):
sock = test_utils.mock_nonblocking_socket()
self.loop._sock_recv = mock.Mock()
f = self.loop.create_task(self.loop.sock_recv(sock, 1024))
self.loop.run_until_complete(asyncio.sleep(0.01))
self.assertEqual(self.loop._sock_recv.call_args[0][1:],
(None, sock, 1024))
f.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(f)
def test_sock_recv_reconnection(self):
sock = mock.Mock()
sock.fileno.return_value = 10
sock.recv.side_effect = BlockingIOError
sock.gettimeout.return_value = 0.0
self.loop.add_reader = mock.Mock()
self.loop.remove_reader = mock.Mock()
fut = self.loop.create_task(
self.loop.sock_recv(sock, 1024))
self.loop.run_until_complete(asyncio.sleep(0.01))
callback = self.loop.add_reader.call_args[0][1]
params = self.loop.add_reader.call_args[0][2:]
# emulate the old socket has closed, but the new one has
# the same fileno, so callback is called with old (closed) socket
sock.fileno.return_value = -1
sock.recv.side_effect = OSError(9)
callback(*params)
self.loop.run_until_complete(asyncio.sleep(0.01))
self.assertIsInstance(fut.exception(), OSError)
self.assertEqual((10,), self.loop.remove_reader.call_args[0])
def test__sock_recv_canceled_fut(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop._sock_recv(f, None, sock, 1024)
self.assertFalse(sock.recv.called)
def test__sock_recv_unregister(self):
sock = mock.Mock()
sock.fileno.return_value = 10
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop.remove_reader = mock.Mock()
self.loop._sock_recv(f, 10, sock, 1024)
self.assertEqual((10,), self.loop.remove_reader.call_args[0])
def test__sock_recv_tryagain(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.recv.side_effect = BlockingIOError
self.loop.add_reader = mock.Mock()
self.loop._sock_recv(f, None, sock, 1024)
self.assertEqual((10, self.loop._sock_recv, f, 10, sock, 1024),
self.loop.add_reader.call_args[0])
def test__sock_recv_exception(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
err = sock.recv.side_effect = OSError()
self.loop._sock_recv(f, None, sock, 1024)
self.assertIs(err, f.exception())
def test_sock_sendall(self):
sock = test_utils.mock_nonblocking_socket()
self.loop._sock_sendall = mock.Mock()
f = self.loop.create_task(
self.loop.sock_sendall(sock, b'data'))
self.loop.run_until_complete(asyncio.sleep(0.01))
self.assertEqual(
(None, sock, b'data'),
self.loop._sock_sendall.call_args[0][1:])
f.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(f)
def test_sock_sendall_nodata(self):
sock = test_utils.mock_nonblocking_socket()
self.loop._sock_sendall = mock.Mock()
f = self.loop.create_task(self.loop.sock_sendall(sock, b''))
self.loop.run_until_complete(asyncio.sleep(0))
self.assertTrue(f.done())
self.assertIsNone(f.result())
self.assertFalse(self.loop._sock_sendall.called)
def test_sock_sendall_reconnection(self):
sock = mock.Mock()
sock.fileno.return_value = 10
sock.send.side_effect = BlockingIOError
sock.gettimeout.return_value = 0.0
self.loop.add_writer = mock.Mock()
self.loop.remove_writer = mock.Mock()
fut = self.loop.create_task(self.loop.sock_sendall(sock, b'data'))
self.loop.run_until_complete(asyncio.sleep(0.01))
callback = self.loop.add_writer.call_args[0][1]
params = self.loop.add_writer.call_args[0][2:]
# emulate the old socket has closed, but the new one has
# the same fileno, so callback is called with old (closed) socket
sock.fileno.return_value = -1
sock.send.side_effect = OSError(9)
callback(*params)
self.loop.run_until_complete(asyncio.sleep(0.01))
self.assertIsInstance(fut.exception(), OSError)
self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_sendall_canceled_fut(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertFalse(sock.send.called)
def test__sock_sendall_unregister(self):
sock = mock.Mock()
sock.fileno.return_value = 10
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop.remove_writer = mock.Mock()
self.loop._sock_sendall(f, 10, sock, b'data')
self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_sendall_tryagain(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.send.side_effect = BlockingIOError
self.loop.add_writer = mock.Mock()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertEqual(
(10, self.loop._sock_sendall, f, 10, sock, b'data'),
self.loop.add_writer.call_args[0])
def test__sock_sendall_interrupted(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.send.side_effect = InterruptedError
self.loop.add_writer = mock.Mock()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertEqual(
(10, self.loop._sock_sendall, f, 10, sock, b'data'),
self.loop.add_writer.call_args[0])
def test__sock_sendall_exception(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
err = sock.send.side_effect = OSError()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertIs(f.exception(), err)
def test__sock_sendall(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10
sock.send.return_value = 4
self.loop._sock_sendall(f, None, sock, b'data')
self.assertTrue(f.done())
self.assertIsNone(f.result())
def test__sock_sendall_partial(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10
sock.send.return_value = 2
self.loop.add_writer = mock.Mock()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertFalse(f.done())
self.assertEqual(
(10, self.loop._sock_sendall, f, 10, sock, b'ta'),
self.loop.add_writer.call_args[0])
def test__sock_sendall_none(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10
sock.send.return_value = 0
self.loop.add_writer = mock.Mock()
self.loop._sock_sendall(f, None, sock, b'data')
self.assertFalse(f.done())
self.assertEqual(
(10, self.loop._sock_sendall, f, 10, sock, b'data'),
self.loop.add_writer.call_args[0])
def test_sock_connect_timeout(self):
# asyncio issue #205: sock_connect() must unregister the socket on
# timeout error
# prepare mocks
self.loop.add_writer = mock.Mock()
self.loop.remove_writer = mock.Mock()
sock = test_utils.mock_nonblocking_socket()
sock.connect.side_effect = BlockingIOError
# first call to sock_connect() registers the socket
fut = self.loop.create_task(
self.loop.sock_connect(sock, ('127.0.0.1', 80)))
self.loop._run_once()
self.assertTrue(sock.connect.called)
self.assertTrue(self.loop.add_writer.called)
# on timeout, the socket must be unregistered
sock.connect.reset_mock()
fut.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(fut)
self.assertTrue(self.loop.remove_writer.called)
@mock.patch('socket.getaddrinfo')
def test_sock_connect_resolve_using_socket_params(self, m_gai):
addr = ('need-resolution.com', 8080)
sock = test_utils.mock_nonblocking_socket()
m_gai.side_effect = \
lambda *args: [(None, None, None, None, ('127.0.0.1', 0))]
con = self.loop.create_task(self.loop.sock_connect(sock, addr))
self.loop.run_until_complete(con)
m_gai.assert_called_with(
addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
self.loop.run_until_complete(con)
sock.connect.assert_called_with(('127.0.0.1', 0))
def test__sock_connect(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
resolved = self.loop.create_future()
resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
socket.IPPROTO_TCP, '', ('127.0.0.1', 8080))])
self.loop._sock_connect(f, sock, resolved)
self.assertTrue(f.done())
self.assertIsNone(f.result())
self.assertTrue(sock.connect.called)
def test__sock_connect_cb_cancelled_fut(self):
sock = mock.Mock()
self.loop.remove_writer = mock.Mock()
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
self.assertFalse(sock.getsockopt.called)
def test__sock_connect_writer(self):
# check that the fd is registered and then unregistered
self.loop._process_events = mock.Mock()
self.loop.add_writer = mock.Mock()
self.loop.remove_writer = mock.Mock()
sock = mock.Mock()
sock.fileno.return_value = 10
sock.connect.side_effect = BlockingIOError
sock.getsockopt.return_value = 0
address = ('127.0.0.1', 8080)
resolved = self.loop.create_future()
resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
socket.IPPROTO_TCP, '', address)])
f = asyncio.Future(loop=self.loop)
self.loop._sock_connect(f, sock, resolved)
self.loop._run_once()
self.assertTrue(self.loop.add_writer.called)
self.assertEqual(10, self.loop.add_writer.call_args[0][0])
self.loop._sock_connect_cb(f, sock, address)
# need to run the event loop to execute _sock_connect_done() callback
self.loop.run_until_complete(f)
self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_connect_cb_tryagain(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.EAGAIN
# check that the exception is handled
self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
def test__sock_connect_cb_exception(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.ENOTCONN
self.loop.remove_writer = mock.Mock()
self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
self.assertIsInstance(f.exception(), OSError)
def test_sock_accept(self):
sock = test_utils.mock_nonblocking_socket()
self.loop._sock_accept = mock.Mock()
f = self.loop.create_task(self.loop.sock_accept(sock))
self.loop.run_until_complete(asyncio.sleep(0.01))
self.assertFalse(self.loop._sock_accept.call_args[0][1])
self.assertIs(self.loop._sock_accept.call_args[0][2], sock)
f.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(f)
def test__sock_accept(self):
f = asyncio.Future(loop=self.loop)
conn = mock.Mock()
sock = mock.Mock()
sock.fileno.return_value = 10
sock.accept.return_value = conn, ('127.0.0.1', 1000)
self.loop._sock_accept(f, False, sock)
self.assertTrue(f.done())
self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
self.assertEqual((False,), conn.setblocking.call_args[0])
def test__sock_accept_canceled_fut(self):
sock = mock.Mock()
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop._sock_accept(f, False, sock)
self.assertFalse(sock.accept.called)
def test__sock_accept_unregister(self):
sock = mock.Mock()
sock.fileno.return_value = 10
f = asyncio.Future(loop=self.loop)
f.cancel()
self.loop.remove_reader = mock.Mock()
self.loop._sock_accept(f, True, sock)
self.assertEqual((10,), self.loop.remove_reader.call_args[0])
def test__sock_accept_tryagain(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
sock.accept.side_effect = BlockingIOError
self.loop.add_reader = mock.Mock()
self.loop._sock_accept(f, False, sock)
self.assertEqual(
(10, self.loop._sock_accept, f, True, sock),
self.loop.add_reader.call_args[0])
def test__sock_accept_exception(self):
f = asyncio.Future(loop=self.loop)
sock = mock.Mock()
sock.fileno.return_value = 10
err = sock.accept.side_effect = OSError()
self.loop._sock_accept(f, False, sock)
self.assertIs(err, f.exception())
def test_add_reader(self):
self.loop._selector.get_key.side_effect = KeyError
cb = lambda: True
......
......@@ -2,6 +2,7 @@ import socket
import asyncio
import sys
from asyncio import proactor_events
from itertools import cycle, islice
from test.test_asyncio import utils as test_utils
from test import support
......@@ -120,6 +121,110 @@ class BaseSockTestsMixin:
sock = socket.socket()
self._basetest_sock_recv_into(httpd, sock)
async def _basetest_huge_content(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
data = await self.loop.sock_recv(sock, DATA_SIZE)
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
data += await self.loop.sock_recv(sock, DATA_SIZE)
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
data = await self.loop.sock_recv(sock, DATA_SIZE)
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content(httpd.address))
async def _basetest_huge_content_recvinto(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
array = bytearray(DATA_SIZE)
buf = memoryview(array)
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = buf[:nbytes]
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content_recvinto(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content_recvinto(httpd.address))
@support.skip_unless_bind_unix_socket
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
......
......@@ -180,11 +180,21 @@ class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def loop(environ):
size = int(environ['CONTENT_LENGTH'])
while size:
data = environ['wsgi.input'].read(min(size, 0x10000))
yield data
size -= len(data)
def app(environ, start_response):
status = '200 OK'
headers = [('Content-type', 'text/plain')]
start_response(status, headers)
return [b'Test message']
if environ['PATH_INFO'] == '/loop':
return loop(environ)
else:
return [b'Test message']
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
......
Use add_done_callback() in sock_* asyncio API to unsubscribe reader/writer
early on calcellation.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment