Unverified Kaydet (Commit) 74387926 authored tarafından Andrew Svetlov's avatar Andrew Svetlov Kaydeden (comit) GitHub

bpo-30064: Refactor sock_* asyncio API (#10419)

üst 9404e773
......@@ -358,26 +358,29 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recv(n)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
self._sock_recv(fut, None, sock, n)
fd = sock.fileno()
self.add_reader(fd, self._sock_recv, fut, sock, n)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
return await fut
def _sock_recv(self, fut, registered_fd, sock, n):
def _sock_read_done(self, fd, fut):
self.remove_reader(fd)
def _sock_recv(self, fut, sock, n):
# _sock_recv() can add itself as an I/O callback if the operation can't
# be done immediately. Don't use it directly, call sock_recv().
if registered_fd is not None:
# Remove the callback early. It should be rare that the
# selector says the fd is ready but the call still returns
# EAGAIN, and I am willing to take a hit in that case in
# order to simplify the common case.
self.remove_reader(registered_fd)
if fut.cancelled():
if fut.done():
return
try:
data = sock.recv(n)
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
self.add_reader(fd, self._sock_recv, fut, fd, sock, n)
return # try again next time
except Exception as exc:
fut.set_exception(exc)
else:
......@@ -391,27 +394,27 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recv_into(buf)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
self._sock_recv_into(fut, None, sock, buf)
fd = sock.fileno()
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
return await fut
def _sock_recv_into(self, fut, registered_fd, sock, buf):
def _sock_recv_into(self, fut, sock, buf):
# _sock_recv_into() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recv_into().
if registered_fd is not None:
# Remove the callback early. It should be rare that the
# selector says the FD is ready but the call still returns
# EAGAIN, and I am willing to take a hit in that case in
# order to simplify the common case.
self.remove_reader(registered_fd)
if fut.cancelled():
if fut.done():
return
try:
nbytes = sock.recv_into(buf)
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
self.add_reader(fd, self._sock_recv_into, fut, fd, sock, buf)
return # try again next time
except Exception as exc:
fut.set_exception(exc)
else:
......@@ -428,23 +431,32 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
fut = self.create_future()
if data:
self._sock_sendall(fut, None, sock, data)
try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
n = 0
if n == len(data):
# all data sent
return
else:
fut.set_result(None)
data = bytearray(memoryview(data)[n:])
fut = self.create_future()
fd = sock.fileno()
fut.add_done_callback(
functools.partial(self._sock_write_done, fd))
self.add_writer(fd, self._sock_sendall, fut, sock, data)
return await fut
def _sock_sendall(self, fut, registered_fd, sock, data):
if registered_fd is not None:
self.remove_writer(registered_fd)
if fut.cancelled():
def _sock_sendall(self, fut, sock, data):
if fut.done():
# Future cancellation can be scheduled on previous loop iteration
return
try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
n = 0
return
except Exception as exc:
fut.set_exception(exc)
return
......@@ -452,10 +464,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = sock.fileno()
self.add_writer(fd, self._sock_sendall, fut, fd, sock, data)
del data[:n]
async def sock_connect(self, sock, address):
"""Connect to a remote socket at address.
......@@ -484,18 +493,18 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
# becomes writable to be notified when the connection succeed or
# fails.
fut.add_done_callback(
functools.partial(self._sock_connect_done, fd))
functools.partial(self._sock_write_done, fd))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(None)
def _sock_connect_done(self, fd, fut):
def _sock_write_done(self, fd, fut):
self.remove_writer(fd)
def _sock_connect_cb(self, fut, sock, address):
if fut.cancelled():
if fut.done():
return
try:
......@@ -529,7 +538,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
fd = sock.fileno()
if registered:
self.remove_reader(fd)
if fut.cancelled():
if fut.done():
return
try:
conn, address = sock.accept()
......
......@@ -2,6 +2,7 @@ import socket
import asyncio
import sys
from asyncio import proactor_events
from itertools import cycle, islice
from test.test_asyncio import utils as test_utils
from test import support
......@@ -120,6 +121,110 @@ class BaseSockTestsMixin:
sock = socket.socket()
self._basetest_sock_recv_into(httpd, sock)
async def _basetest_huge_content(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
data = await self.loop.sock_recv(sock, DATA_SIZE)
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
data += await self.loop.sock_recv(sock, DATA_SIZE)
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
data = await self.loop.sock_recv(sock, DATA_SIZE)
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content(httpd.address))
async def _basetest_huge_content_recvinto(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
array = bytearray(DATA_SIZE)
buf = memoryview(array)
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = buf[:nbytes]
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content_recvinto(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content_recvinto(httpd.address))
@support.skip_unless_bind_unix_socket
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
......
......@@ -180,11 +180,21 @@ class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def loop(environ):
size = int(environ['CONTENT_LENGTH'])
while size:
data = environ['wsgi.input'].read(min(size, 0x10000))
yield data
size -= len(data)
def app(environ, start_response):
status = '200 OK'
headers = [('Content-type', 'text/plain')]
start_response(status, headers)
return [b'Test message']
if environ['PATH_INFO'] == '/loop':
return loop(environ)
else:
return [b'Test message']
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
......
Use add_done_callback() in sock_* asyncio API to unsubscribe reader/writer
early on calcellation.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment