Kaydet (Commit) f7686c1f authored tarafından Neil Aspinall's avatar Neil Aspinall Kaydeden (comit) Andrew Svetlov

bpo-29970: Add timeout for SSL handshake in asyncio

10 seconds by default.
üst 4b965930
...@@ -261,7 +261,7 @@ Tasks ...@@ -261,7 +261,7 @@ Tasks
Creating connections Creating connections
-------------------- --------------------
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None) .. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0)
Create a streaming transport connection to a given Internet *host* and Create a streaming transport connection to a given Internet *host* and
*port*: socket family :py:data:`~socket.AF_INET` or *port*: socket family :py:data:`~socket.AF_INET` or
...@@ -325,6 +325,13 @@ Creating connections ...@@ -325,6 +325,13 @@ Creating connections
to bind the socket to locally. The *local_host* and *local_port* to bind the socket to locally. The *local_host* and *local_port*
are looked up using getaddrinfo(), similarly to *host* and *port*. are looked up using getaddrinfo(), similarly to *host* and *port*.
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
to wait for the SSL handshake to complete before aborting the connection.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported. On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
...@@ -386,7 +393,7 @@ Creating connections ...@@ -386,7 +393,7 @@ Creating connections
:ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples. :ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None) .. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0)
Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket
type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket
...@@ -404,6 +411,10 @@ Creating connections ...@@ -404,6 +411,10 @@ Creating connections
Availability: UNIX. Availability: UNIX.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.7 .. versionchanged:: 3.7
The *path* parameter can now be a :class:`~pathlib.Path` object. The *path* parameter can now be a :class:`~pathlib.Path` object.
...@@ -412,7 +423,7 @@ Creating connections ...@@ -412,7 +423,7 @@ Creating connections
Creating listening connections Creating listening connections
------------------------------ ------------------------------
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None) .. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0)
Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
*host* and *port*. *host* and *port*.
...@@ -456,6 +467,13 @@ Creating listening connections ...@@ -456,6 +467,13 @@ Creating listening connections
set this flag when being created. This option is not supported on set this flag when being created. This option is not supported on
Windows. Windows.
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
for the SSL handshake to complete before aborting the connection.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported. On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
...@@ -470,7 +488,7 @@ Creating listening connections ...@@ -470,7 +488,7 @@ Creating listening connections
The *host* parameter can now be a sequence of strings. The *host* parameter can now be a sequence of strings.
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None) .. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0)
Similar to :meth:`AbstractEventLoop.create_server`, but specific to the Similar to :meth:`AbstractEventLoop.create_server`, but specific to the
socket family :py:data:`~socket.AF_UNIX`. socket family :py:data:`~socket.AF_UNIX`.
...@@ -481,11 +499,15 @@ Creating listening connections ...@@ -481,11 +499,15 @@ Creating listening connections
Availability: UNIX. Availability: UNIX.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.7 .. versionchanged:: 3.7
The *path* parameter can now be a :class:`~pathlib.Path` object. The *path* parameter can now be a :class:`~pathlib.Path` object.
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None) .. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0)
Handle an accepted connection. Handle an accepted connection.
...@@ -500,8 +522,15 @@ Creating listening connections ...@@ -500,8 +522,15 @@ Creating listening connections
* *ssl* can be set to an :class:`~ssl.SSLContext` to enable SSL over the * *ssl* can be set to an :class:`~ssl.SSLContext` to enable SSL over the
accepted connections. accepted connections.
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection.
When completed it returns a ``(transport, protocol)`` pair. When completed it returns a ``(transport, protocol)`` pair.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionadded:: 3.5.3 .. versionadded:: 3.5.3
......
...@@ -29,6 +29,7 @@ import sys ...@@ -29,6 +29,7 @@ import sys
import warnings import warnings
import weakref import weakref
from . import constants
from . import coroutines from . import coroutines
from . import events from . import events
from . import futures from . import futures
...@@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop):
"""Create socket transport.""" """Create socket transport."""
raise NotImplementedError raise NotImplementedError
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None): extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Create SSL transport.""" """Create SSL transport."""
raise NotImplementedError raise NotImplementedError
...@@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop):
return await self.run_in_executor( return await self.run_in_executor(
None, socket.getnameinfo, sockaddr, flags) None, socket.getnameinfo, sockaddr, flags)
async def create_connection(self, protocol_factory, host=None, port=None, async def create_connection(
self, protocol_factory, host=None, port=None,
*, ssl=None, family=0, *, ssl=None, family=0,
proto=0, flags=0, sock=None, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None): local_addr=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Connect to a TCP server. """Connect to a TCP server.
Create a streaming transport connection to a given Internet host and Create a streaming transport connection to a given Internet host and
...@@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop):
f'A Stream Socket was expected, got {sock!r}') f'A Stream Socket was expected, got {sock!r}')
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname) sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug: if self._debug:
# Get the socket from the transport because SSL transport closes # Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket # the old socket and creates a new SSL socket
...@@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop):
sock, host, port, transport, protocol) sock, host, port, transport, protocol)
return transport, protocol return transport, protocol
async def _create_connection_transport(self, sock, protocol_factory, ssl, async def _create_connection_transport(
server_hostname, server_side=False): self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
sock.setblocking(False) sock.setblocking(False)
...@@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sslcontext = None if isinstance(ssl, bool) else ssl sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter, sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname) server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
else: else:
transport = self._make_socket_transport(sock, protocol, waiter) transport = self._make_socket_transport(sock, protocol, waiter)
...@@ -929,7 +938,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -929,7 +938,8 @@ class BaseEventLoop(events.AbstractEventLoop):
raise OSError(f'getaddrinfo({host!r}) returned empty list') raise OSError(f'getaddrinfo({host!r}) returned empty list')
return infos return infos
async def create_server(self, protocol_factory, host=None, port=None, async def create_server(
self, protocol_factory, host=None, port=None,
*, *,
family=socket.AF_UNSPEC, family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE, flags=socket.AI_PASSIVE,
...@@ -937,7 +947,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -937,7 +947,8 @@ class BaseEventLoop(events.AbstractEventLoop):
backlog=100, backlog=100,
ssl=None, ssl=None,
reuse_address=None, reuse_address=None,
reuse_port=None): reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Create a TCP server. """Create a TCP server.
The host parameter can be a string, in that case the TCP server is The host parameter can be a string, in that case the TCP server is
...@@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop):
for sock in sockets: for sock in sockets:
sock.listen(backlog) sock.listen(backlog)
sock.setblocking(False) sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server, backlog) self._start_serving(protocol_factory, sock, ssl, server, backlog,
ssl_handshake_timeout)
if self._debug: if self._debug:
logger.info("%r is serving", server) logger.info("%r is serving", server)
return server return server
async def connect_accepted_socket(self, protocol_factory, sock, async def connect_accepted_socket(
*, ssl=None): self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Handle an accepted connection. """Handle an accepted connection.
This is used by servers that accept connections outside of This is used by servers that accept connections outside of
...@@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(f'A Stream Socket was expected, got {sock!r}') raise ValueError(f'A Stream Socket was expected, got {sock!r}')
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True) sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug: if self._debug:
# Get the socket from the transport because SSL transport closes # Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket # the old socket and creates a new SSL socket
......
...@@ -8,3 +8,6 @@ ACCEPT_RETRY_DELAY = 1 ...@@ -8,3 +8,6 @@ ACCEPT_RETRY_DELAY = 1
# The larger the number, the slower the operation in debug mode # The larger the number, the slower the operation in debug mode
# (see extract_stack() in format_helpers.py). # (see extract_stack() in format_helpers.py).
DEBUG_STACK_DEPTH = 10 DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0
...@@ -250,16 +250,20 @@ class AbstractEventLoop: ...@@ -250,16 +250,20 @@ class AbstractEventLoop:
async def getnameinfo(self, sockaddr, flags=0): async def getnameinfo(self, sockaddr, flags=0):
raise NotImplementedError raise NotImplementedError
async def create_connection(self, protocol_factory, host=None, port=None, async def create_connection(
self, protocol_factory, host=None, port=None,
*, ssl=None, family=0, proto=0, *, ssl=None, family=0, proto=0,
flags=0, sock=None, local_addr=None, flags=0, sock=None, local_addr=None,
server_hostname=None): server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
raise NotImplementedError raise NotImplementedError
async def create_server(self, protocol_factory, host=None, port=None, async def create_server(
self, protocol_factory, host=None, port=None,
*, family=socket.AF_UNSPEC, *, family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE, sock=None, backlog=100, flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None): ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""A coroutine which creates a TCP server bound to host and port. """A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop The return value is a Server object which can be used to stop
...@@ -294,16 +298,25 @@ class AbstractEventLoop: ...@@ -294,16 +298,25 @@ class AbstractEventLoop:
the same port as other existing endpoints are bound to, so long as the same port as other existing endpoints are bound to, so long as
they all set this flag when being created. This option is not they all set this flag when being created. This option is not
supported on Windows. supported on Windows.
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for completion of the SSL handshake before aborting the
connection. Default is 10s, longer timeouts may increase vulnerability
to DoS attacks (see https://support.f5.com/csp/article/K13834)
""" """
raise NotImplementedError raise NotImplementedError
async def create_unix_connection(self, protocol_factory, path=None, *, async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None): server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
raise NotImplementedError raise NotImplementedError
async def create_unix_server(self, protocol_factory, path=None, *, async def create_unix_server(
sock=None, backlog=100, ssl=None): self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""A coroutine which creates a UNIX Domain Socket server. """A coroutine which creates a UNIX Domain Socket server.
The return value is a Server object, which can be used to stop The return value is a Server object, which can be used to stop
...@@ -320,6 +333,9 @@ class AbstractEventLoop: ...@@ -320,6 +333,9 @@ class AbstractEventLoop:
ssl can be set to an SSLContext to enable SSL over the ssl can be set to an SSLContext to enable SSL over the
accepted connections. accepted connections.
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 10s).
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -389,11 +389,15 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -389,11 +389,15 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
return _ProactorSocketTransport(self, sock, protocol, waiter, return _ProactorSocketTransport(self, sock, protocol, waiter,
extra, server) extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None): extra=None, server=None,
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
server_side, server_hostname) ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol, _ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server) extra=extra, server=server)
return ssl_protocol._app_transport return ssl_protocol._app_transport
...@@ -486,7 +490,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -486,7 +490,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._csock.send(b'\0') self._csock.send(b'\0')
def _start_serving(self, protocol_factory, sock, def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100): sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
def loop(f=None): def loop(f=None):
try: try:
...@@ -499,7 +504,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -499,7 +504,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
if sslcontext is not None: if sslcontext is not None:
self._make_ssl_transport( self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True, conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server) extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
else: else:
self._make_socket_transport( self._make_socket_transport(
conn, protocol, conn, protocol,
......
...@@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
return _SelectorSocketTransport(self, sock, protocol, waiter, return _SelectorSocketTransport(self, sock, protocol, waiter,
extra, server) extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None): extra=None, server=None,
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
server_side, server_hostname) ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
_SelectorSocketTransport(self, rawsock, ssl_protocol, _SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server) extra=extra, server=server)
return ssl_protocol._app_transport return ssl_protocol._app_transport
...@@ -143,12 +147,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -143,12 +147,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
exc_info=True) exc_info=True)
def _start_serving(self, protocol_factory, sock, def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100): sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection, self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog) protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout)
def _accept_connection(self, protocol_factory, sock, def _accept_connection(
sslcontext=None, server=None, backlog=100): self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
# This method is only called once for each event loop tick where the # This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple # listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop. # connections waiting for an .accept() so it is called in a loop.
...@@ -179,17 +187,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -179,17 +187,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.call_later(constants.ACCEPT_RETRY_DELAY, self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving, self._start_serving,
protocol_factory, sock, sslcontext, server, protocol_factory, sock, sslcontext, server,
backlog) backlog, ssl_handshake_timeout)
else: else:
raise # The event loop will catch, log and ignore it. raise # The event loop will catch, log and ignore it.
else: else:
extra = {'peername': addr} extra = {'peername': addr}
accept = self._accept_connection2( accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server) protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout)
self.create_task(accept) self.create_task(accept)
async def _accept_connection2(self, protocol_factory, conn, extra, async def _accept_connection2(
sslcontext=None, server=None): self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
protocol = None protocol = None
transport = None transport = None
try: try:
...@@ -198,7 +209,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -198,7 +209,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if sslcontext: if sslcontext:
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter, conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server) server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
else: else:
transport = self._make_socket_transport( transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra, conn, protocol, waiter=waiter, extra=extra,
......
...@@ -6,6 +6,7 @@ except ImportError: # pragma: no cover ...@@ -6,6 +6,7 @@ except ImportError: # pragma: no cover
ssl = None ssl = None
from . import base_events from . import base_events
from . import constants
from . import protocols from . import protocols
from . import transports from . import transports
from .log import logger from .log import logger
...@@ -400,7 +401,8 @@ class SSLProtocol(protocols.Protocol): ...@@ -400,7 +401,8 @@ class SSLProtocol(protocols.Protocol):
def __init__(self, loop, app_protocol, sslcontext, waiter, def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None, server_side=False, server_hostname=None,
call_connection_made=True): call_connection_made=True,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
if ssl is None: if ssl is None:
raise RuntimeError('stdlib ssl module not available') raise RuntimeError('stdlib ssl module not available')
...@@ -434,6 +436,7 @@ class SSLProtocol(protocols.Protocol): ...@@ -434,6 +436,7 @@ class SSLProtocol(protocols.Protocol):
# transport, ex: SelectorSocketTransport # transport, ex: SelectorSocketTransport
self._transport = None self._transport = None
self._call_connection_made = call_connection_made self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
def _wakeup_waiter(self, exc=None): def _wakeup_waiter(self, exc=None):
if self._waiter is None: if self._waiter is None:
...@@ -561,9 +564,18 @@ class SSLProtocol(protocols.Protocol): ...@@ -561,9 +564,18 @@ class SSLProtocol(protocols.Protocol):
# the SSL handshake # the SSL handshake
self._write_backlog.append((b'', 1)) self._write_backlog.append((b'', 1))
self._loop.call_soon(self._process_write_backlog) self._loop.call_soon(self._process_write_backlog)
self._handshake_timeout_handle = \
self._loop.call_later(self._ssl_handshake_timeout,
self._check_handshake_timeout)
def _check_handshake_timeout(self):
if self._in_handshake is True:
logger.warning("%r stalled during handshake", self)
self._abort()
def _on_handshake_complete(self, handshake_exc): def _on_handshake_complete(self, handshake_exc):
self._in_handshake = False self._in_handshake = False
self._handshake_timeout_handle.cancel()
sslobj = self._sslpipe.ssl_object sslobj = self._sslpipe.ssl_object
try: try:
......
...@@ -192,9 +192,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -192,9 +192,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp): def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode) self.call_soon_threadsafe(transp._process_exited, returncode)
async def create_unix_connection(self, protocol_factory, path=None, *, async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None): server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
assert server_hostname is None or isinstance(server_hostname, str) assert server_hostname is None or isinstance(server_hostname, str)
if ssl: if ssl:
if server_hostname is None: if server_hostname is None:
...@@ -228,11 +230,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -228,11 +230,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock.setblocking(False) sock.setblocking(False)
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname) sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
return transport, protocol return transport, protocol
async def create_unix_server(self, protocol_factory, path=None, *, async def create_unix_server(
sock=None, backlog=100, ssl=None): self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
if isinstance(ssl, bool): if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None') raise TypeError('ssl argument must be an SSLContext or None')
...@@ -283,7 +288,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -283,7 +288,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
server = base_events.Server(self, [sock]) server = base_events.Server(self, [sock])
sock.listen(backlog) sock.listen(backlog)
sock.setblocking(False) sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server) self._start_serving(protocol_factory, sock, ssl, server,
ssl_handshake_timeout=ssl_handshake_timeout)
return server return server
......
...@@ -1301,34 +1301,45 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1301,34 +1301,45 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY ANY = mock.ANY
handshake_timeout = object()
# First try the default server_hostname. # First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with( self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='python.org') server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout)
# Next try an explicit server_hostname. # Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, coro = self.loop.create_connection(
server_hostname='perl.com') MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with( self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='perl.com') server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
# Finally try an explicit empty server_hostname. # Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, coro = self.loop.create_connection(
server_hostname='') MyProto, 'python.org', 80, ssl=True,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='') server_hostname='',
ssl_handshake_timeout=handshake_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self): def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None. # When not using ssl, server_hostname must be None.
...@@ -1687,7 +1698,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1687,7 +1698,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY, constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving # self.loop._start_serving
mock.ANY, mock.ANY,
MyProto, sock, None, None, mock.ANY) MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self): def test_call_coroutine(self):
@asyncio.coroutine @asyncio.coroutine
......
...@@ -11,6 +11,7 @@ except ImportError: ...@@ -11,6 +11,7 @@ except ImportError:
import asyncio import asyncio
from asyncio import log from asyncio import log
from asyncio import sslproto from asyncio import sslproto
from asyncio import tasks
from test.test_asyncio import utils as test_utils from test.test_asyncio import utils as test_utils
...@@ -25,7 +26,8 @@ class SslProtoHandshakeTests(test_utils.TestCase): ...@@ -25,7 +26,8 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def ssl_protocol(self, waiter=None): def ssl_protocol(self, waiter=None):
sslcontext = test_utils.dummy_ssl_context() sslcontext = test_utils.dummy_ssl_context()
app_proto = asyncio.Protocol() app_proto = asyncio.Protocol()
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=0.1)
self.assertIs(proto._app_transport.get_protocol(), app_proto) self.assertIs(proto._app_transport.get_protocol(), app_proto)
self.addCleanup(proto._app_transport.close) self.addCleanup(proto._app_transport.close)
return proto return proto
...@@ -63,6 +65,16 @@ class SslProtoHandshakeTests(test_utils.TestCase): ...@@ -63,6 +65,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
with test_utils.disable_logger(): with test_utils.disable_logger():
self.loop.run_until_complete(handshake_fut) self.loop.run_until_complete(handshake_fut)
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
ssl_proto = self.ssl_protocol()
transport = self.connection_made(ssl_proto)
with test_utils.disable_logger():
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
self.assertTrue(transport.abort.called)
def test_eof_received_waiter(self): def test_eof_received_waiter(self):
waiter = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter) ssl_proto = self.ssl_protocol(waiter)
......
...@@ -63,6 +63,7 @@ Jeffrey Armstrong ...@@ -63,6 +63,7 @@ Jeffrey Armstrong
Jason Asbahr Jason Asbahr
David Ascher David Ascher
Ammar Askar Ammar Askar
Neil Aspinall
Chris AtLee Chris AtLee
Aymeric Augustin Aymeric Augustin
Cathy Avery Cathy Avery
......
Abort asyncio SSLProtocol connection if handshake not complete within 10s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment