Kaydet (Commit) d58972c2 authored tarafından Serhiy Storchaka's avatar Serhiy Storchaka

Merge heads

...@@ -186,6 +186,11 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -186,6 +186,11 @@ class BaseEventLoop(events.AbstractEventLoop):
self.call_soon(_raise_stop_error) self.call_soon(_raise_stop_error)
def close(self): def close(self):
"""Close the event loop.
This clears the queues and shuts down the executor,
but does not wait for the executor to finish.
"""
self._ready.clear() self._ready.clear()
self._scheduled.clear() self._scheduled.clear()
executor = self._default_executor executor = self._default_executor
...@@ -275,8 +280,27 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -275,8 +280,27 @@ class BaseEventLoop(events.AbstractEventLoop):
@tasks.coroutine @tasks.coroutine
def create_connection(self, protocol_factory, host=None, port=None, *, def create_connection(self, protocol_factory, host=None, port=None, *,
ssl=None, family=0, proto=0, flags=0, sock=None, ssl=None, family=0, proto=0, flags=0, sock=None,
local_addr=None): local_addr=None, server_hostname=None):
"""XXX""" """XXX"""
if server_hostname is not None and not ssl:
raise ValueError('server_hostname is only meaningful with ssl')
if server_hostname is None and ssl:
# Use host as default for server_hostname. It is an error
# if host is empty or not set, e.g. when an
# already-connected socket was passed or when only a port
# is given. To avoid this error, you can pass
# server_hostname='' -- this will bypass the hostname
# check. (This also means that if host is a numeric
# IP/IPv6 address, we will attempt to verify that exact
# address; this will probably fail, but it is possible to
# create a certificate for a specific IP address, so we
# don't judge it here.)
if not host:
raise ValueError('You must set server_hostname '
'when using ssl without a host')
server_hostname = host
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
...@@ -357,7 +381,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -357,7 +381,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sslcontext = None if isinstance(ssl, bool) else ssl sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter, sock, protocol, sslcontext, waiter,
server_side=False, server_hostname=host) server_side=False, server_hostname=server_hostname)
else: else:
transport = self._make_socket_transport(sock, protocol, waiter) transport = self._make_socket_transport(sock, protocol, waiter)
...@@ -442,6 +466,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -442,6 +466,8 @@ class BaseEventLoop(events.AbstractEventLoop):
ssl=None, ssl=None,
reuse_address=None): reuse_address=None):
"""XXX""" """XXX"""
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
......
"""Constants.""" """Constants."""
# After the connection is lost, log warnings after this many write()s.
LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
# Seconds to wait before retrying accept().
ACCEPT_RETRY_DELAY = 1
...@@ -137,6 +137,17 @@ class AbstractEventLoop: ...@@ -137,6 +137,17 @@ class AbstractEventLoop:
"""Return whether the event loop is currently running.""" """Return whether the event loop is currently running."""
raise NotImplementedError raise NotImplementedError
def close(self):
"""Close the loop.
The loop should not be running.
This is idempotent and irreversible.
No other methods should be called after this one.
"""
raise NotImplementedError
# Methods scheduling callbacks. All these return Handles. # Methods scheduling callbacks. All these return Handles.
def call_soon(self, callback, *args): def call_soon(self, callback, *args):
...@@ -172,7 +183,7 @@ class AbstractEventLoop: ...@@ -172,7 +183,7 @@ class AbstractEventLoop:
def create_connection(self, protocol_factory, host=None, port=None, *, def create_connection(self, protocol_factory, host=None, port=None, *,
ssl=None, family=0, proto=0, flags=0, sock=None, ssl=None, family=0, proto=0, flags=0, sock=None,
local_addr=None): local_addr=None, server_hostname=None):
raise NotImplementedError raise NotImplementedError
def create_server(self, protocol_factory, host=None, port=None, *, def create_server(self, protocol_factory, host=None, port=None, *,
...@@ -214,6 +225,8 @@ class AbstractEventLoop: ...@@ -214,6 +225,8 @@ class AbstractEventLoop:
family=0, proto=0, flags=0): family=0, proto=0, flags=0):
raise NotImplementedError raise NotImplementedError
# Pipes and subprocesses.
def connect_read_pipe(self, protocol_factory, pipe): def connect_read_pipe(self, protocol_factory, pipe):
"""Register read pipe in eventloop. """Register read pipe in eventloop.
......
This diff is collapsed.
...@@ -62,8 +62,9 @@ class CoroWrapper: ...@@ -62,8 +62,9 @@ class CoroWrapper:
code = func.__code__ code = func.__code__
filename = code.co_filename filename = code.co_filename
lineno = code.co_firstlineno lineno = code.co_firstlineno
logger.error('Coroutine %r defined at %s:%s was never yielded from', logger.error(
func.__name__, filename, lineno) 'Coroutine %r defined at %s:%s was never yielded from',
func.__name__, filename, lineno)
def coroutine(func): def coroutine(func):
......
...@@ -138,6 +138,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): ...@@ -138,6 +138,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
@tasks.coroutine @tasks.coroutine
def start_serving_pipe(self, protocol_factory, address): def start_serving_pipe(self, protocol_factory, address):
server = PipeServer(address) server = PipeServer(address)
def loop(f=None): def loop(f=None):
pipe = None pipe = None
try: try:
...@@ -160,6 +161,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): ...@@ -160,6 +161,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
pipe.close() pipe.close()
else: else:
f.add_done_callback(loop) f.add_done_callback(loop)
self.call_soon(loop) self.call_soon(loop)
return [server] return [server]
...@@ -209,6 +211,7 @@ class IocpProactor: ...@@ -209,6 +211,7 @@ class IocpProactor:
ov.WSARecv(conn.fileno(), nbytes, flags) ov.WSARecv(conn.fileno(), nbytes, flags)
else: else:
ov.ReadFile(conn.fileno(), nbytes) ov.ReadFile(conn.fileno(), nbytes)
def finish(trans, key, ov): def finish(trans, key, ov):
try: try:
return ov.getresult() return ov.getresult()
...@@ -217,6 +220,7 @@ class IocpProactor: ...@@ -217,6 +220,7 @@ class IocpProactor:
raise ConnectionResetError(*exc.args) raise ConnectionResetError(*exc.args)
else: else:
raise raise
return self._register(ov, conn, finish) return self._register(ov, conn, finish)
def send(self, conn, buf, flags=0): def send(self, conn, buf, flags=0):
...@@ -226,6 +230,7 @@ class IocpProactor: ...@@ -226,6 +230,7 @@ class IocpProactor:
ov.WSASend(conn.fileno(), buf, flags) ov.WSASend(conn.fileno(), buf, flags)
else: else:
ov.WriteFile(conn.fileno(), buf) ov.WriteFile(conn.fileno(), buf)
def finish(trans, key, ov): def finish(trans, key, ov):
try: try:
return ov.getresult() return ov.getresult()
...@@ -234,6 +239,7 @@ class IocpProactor: ...@@ -234,6 +239,7 @@ class IocpProactor:
raise ConnectionResetError(*exc.args) raise ConnectionResetError(*exc.args)
else: else:
raise raise
return self._register(ov, conn, finish) return self._register(ov, conn, finish)
def accept(self, listener): def accept(self, listener):
...@@ -241,6 +247,7 @@ class IocpProactor: ...@@ -241,6 +247,7 @@ class IocpProactor:
conn = self._get_accept_socket(listener.family) conn = self._get_accept_socket(listener.family)
ov = _overlapped.Overlapped(NULL) ov = _overlapped.Overlapped(NULL)
ov.AcceptEx(listener.fileno(), conn.fileno()) ov.AcceptEx(listener.fileno(), conn.fileno())
def finish_accept(trans, key, ov): def finish_accept(trans, key, ov):
ov.getresult() ov.getresult()
# Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work.
...@@ -249,6 +256,7 @@ class IocpProactor: ...@@ -249,6 +256,7 @@ class IocpProactor:
_overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf)
conn.settimeout(listener.gettimeout()) conn.settimeout(listener.gettimeout())
return conn, conn.getpeername() return conn, conn.getpeername()
return self._register(ov, listener, finish_accept) return self._register(ov, listener, finish_accept)
def connect(self, conn, address): def connect(self, conn, address):
...@@ -264,26 +272,31 @@ class IocpProactor: ...@@ -264,26 +272,31 @@ class IocpProactor:
raise raise
ov = _overlapped.Overlapped(NULL) ov = _overlapped.Overlapped(NULL)
ov.ConnectEx(conn.fileno(), address) ov.ConnectEx(conn.fileno(), address)
def finish_connect(trans, key, ov): def finish_connect(trans, key, ov):
ov.getresult() ov.getresult()
# Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work.
conn.setsockopt(socket.SOL_SOCKET, conn.setsockopt(socket.SOL_SOCKET,
_overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0)
return conn return conn
return self._register(ov, conn, finish_connect) return self._register(ov, conn, finish_connect)
def accept_pipe(self, pipe): def accept_pipe(self, pipe):
self._register_with_iocp(pipe) self._register_with_iocp(pipe)
ov = _overlapped.Overlapped(NULL) ov = _overlapped.Overlapped(NULL)
ov.ConnectNamedPipe(pipe.fileno()) ov.ConnectNamedPipe(pipe.fileno())
def finish(trans, key, ov): def finish(trans, key, ov):
ov.getresult() ov.getresult()
return pipe return pipe
return self._register(ov, pipe, finish) return self._register(ov, pipe, finish)
def connect_pipe(self, address): def connect_pipe(self, address):
ov = _overlapped.Overlapped(NULL) ov = _overlapped.Overlapped(NULL)
ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address)
def finish(err, handle, ov): def finish(err, handle, ov):
# err, handle were arguments passed to PostQueuedCompletionStatus() # err, handle were arguments passed to PostQueuedCompletionStatus()
# in a function run in a thread pool. # in a function run in a thread pool.
...@@ -296,6 +309,7 @@ class IocpProactor: ...@@ -296,6 +309,7 @@ class IocpProactor:
raise OSError(0, msg, None, err) raise OSError(0, msg, None, err)
else: else:
return windows_utils.PipeHandle(handle) return windows_utils.PipeHandle(handle)
return self._register(ov, None, finish, wait_for_post=True) return self._register(ov, None, finish, wait_for_post=True)
def wait_for_handle(self, handle, timeout=None): def wait_for_handle(self, handle, timeout=None):
...@@ -432,8 +446,10 @@ class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): ...@@ -432,8 +446,10 @@ class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport):
self._proc = windows_utils.Popen( self._proc = windows_utils.Popen(
args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr,
bufsize=bufsize, **kwargs) bufsize=bufsize, **kwargs)
def callback(f): def callback(f):
returncode = self._proc.poll() returncode = self._proc.poll()
self._process_exited(returncode) self._process_exited(returncode)
f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) f = self._loop._proactor.wait_for_handle(int(self._proc._handle))
f.add_done_callback(callback) f.add_done_callback(callback)
...@@ -18,18 +18,18 @@ import _winapi ...@@ -18,18 +18,18 @@ import _winapi
__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle']
#
# Constants/globals # Constants/globals
#
BUFSIZE = 8192 BUFSIZE = 8192
PIPE = subprocess.PIPE PIPE = subprocess.PIPE
STDOUT = subprocess.STDOUT STDOUT = subprocess.STDOUT
_mmap_counter = itertools.count() _mmap_counter = itertools.count()
#
# Replacement for socket.socketpair() # Replacement for socket.socketpair()
#
def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
"""A socket pair usable as a self-pipe, for Windows. """A socket pair usable as a self-pipe, for Windows.
...@@ -57,9 +57,9 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): ...@@ -57,9 +57,9 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
lsock.close() lsock.close()
return (ssock, csock) return (ssock, csock)
#
# Replacement for os.pipe() using handles instead of fds # Replacement for os.pipe() using handles instead of fds
#
def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
"""Like os.pipe() but with overlapped support and using handles not fds.""" """Like os.pipe() but with overlapped support and using handles not fds."""
...@@ -105,9 +105,9 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): ...@@ -105,9 +105,9 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
_winapi.CloseHandle(h2) _winapi.CloseHandle(h2)
raise raise
#
# Wrapper for a pipe handle # Wrapper for a pipe handle
#
class PipeHandle: class PipeHandle:
"""Wrapper for an overlapped pipe handle which is vaguely file-object like. """Wrapper for an overlapped pipe handle which is vaguely file-object like.
...@@ -137,9 +137,9 @@ class PipeHandle: ...@@ -137,9 +137,9 @@ class PipeHandle:
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
self.close() self.close()
#
# Replacement for subprocess.Popen using overlapped pipe handles # Replacement for subprocess.Popen using overlapped pipe handles
#
class Popen(subprocess.Popen): class Popen(subprocess.Popen):
"""Replacement for subprocess.Popen using overlapped pipe handles. """Replacement for subprocess.Popen using overlapped pipe handles.
......
"""Tests for base_events.py""" """Tests for base_events.py"""
import errno
import logging import logging
import socket import socket
import time import time
...@@ -8,6 +9,7 @@ import unittest.mock ...@@ -8,6 +9,7 @@ import unittest.mock
from test.support import find_unused_port, IPV6_ENABLED from test.support import find_unused_port, IPV6_ENABLED
from asyncio import base_events from asyncio import base_events
from asyncio import constants
from asyncio import events from asyncio import events
from asyncio import futures from asyncio import futures
from asyncio import protocols from asyncio import protocols
...@@ -442,6 +444,71 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -442,6 +444,71 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
def test_create_connection_ssl_server_hostname_default(self):
self.loop.getaddrinfo = unittest.mock.Mock()
def mock_getaddrinfo(*args, **kwds):
f = futures.Future(loop=self.loop)
f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
socket.SOL_TCP, '', ('1.2.3.4', 80))])
return f
self.loop.getaddrinfo.side_effect = mock_getaddrinfo
self.loop.sock_connect = unittest.mock.Mock()
self.loop.sock_connect.return_value = ()
self.loop._make_ssl_transport = unittest.mock.Mock()
def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
**kwds):
waiter.set_result(None)
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = unittest.mock.ANY
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org')
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com')
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com')
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='')
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='')
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
coro = self.loop.create_connection(MyProto, 'python.org', 80,
server_hostname='')
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, 'python.org', 80,
server_hostname='python.org')
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_create_connection_ssl_server_hostname_errors(self):
# When using ssl, server_hostname may be None if host is non-empty.
coro = self.loop.create_connection(MyProto, '', 80, ssl=True)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, None, 80, ssl=True)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, None, None,
ssl=True, sock=socket.socket())
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_create_server_empty_host(self): def test_create_server_empty_host(self):
# if host is empty string use None instead # if host is empty string use None instead
host = object() host = object()
...@@ -585,11 +652,18 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -585,11 +652,18 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
def test_accept_connection_exception(self, m_log): def test_accept_connection_exception(self, m_log):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.accept.side_effect = OSError() sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files')
self.loop.remove_reader = unittest.mock.Mock()
self.loop.call_later = unittest.mock.Mock()
self.loop._accept_connection(MyProto, sock) self.loop._accept_connection(MyProto, sock)
self.assertTrue(sock.close.called)
self.assertTrue(m_log.exception.called) self.assertTrue(m_log.exception.called)
self.assertFalse(sock.close.called)
self.loop.remove_reader.assert_called_with(10)
self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
unittest.mock.ANY,
MyProto, sock, None, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1276,7 +1276,6 @@ if sys.platform == 'win32': ...@@ -1276,7 +1276,6 @@ if sys.platform == 'win32':
def create_event_loop(self): def create_event_loop(self):
return windows_events.SelectorEventLoop() return windows_events.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin, class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
...@@ -1471,6 +1470,8 @@ class AbstractEventLoopTests(unittest.TestCase): ...@@ -1471,6 +1470,8 @@ class AbstractEventLoopTests(unittest.TestCase):
NotImplementedError, loop.stop) NotImplementedError, loop.stop)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.is_running) NotImplementedError, loop.is_running)
self.assertRaises(
NotImplementedError, loop.close)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.call_later, None, None) NotImplementedError, loop.call_later, None, None)
self.assertRaises( self.assertRaises(
......
...@@ -77,7 +77,7 @@ class ProactorTests(unittest.TestCase): ...@@ -77,7 +77,7 @@ class ProactorTests(unittest.TestCase):
stream_reader = streams.StreamReader(loop=self.loop) stream_reader = streams.StreamReader(loop=self.loop)
protocol = streams.StreamReaderProtocol(stream_reader) protocol = streams.StreamReaderProtocol(stream_reader)
trans, proto = yield from self.loop.create_pipe_connection( trans, proto = yield from self.loop.create_pipe_connection(
lambda:protocol, ADDRESS) lambda: protocol, ADDRESS)
self.assertIsInstance(trans, transports.Transport) self.assertIsInstance(trans, transports.Transport)
self.assertEqual(protocol, proto) self.assertEqual(protocol, proto)
clients.append((stream_reader, trans)) clients.append((stream_reader, trans))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment