Kaydet (Commit) 88a5bf0b authored tarafından Yury Selivanov's avatar Yury Selivanov

asyncio: Add support for UNIX Domain Sockets.

üst c36e504c
......@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname):
protocol = protocol_factory()
waiter = futures.Future(loop=self)
if ssl:
......
......@@ -220,6 +220,32 @@ class AbstractEventLoop:
"""
raise NotImplementedError
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_server(self, protocol_factory, path, *,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
The return valud is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
......
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'IncompleteReadError',
'open_connection', 'start_server',
'open_unix_connection', 'start_unix_server',
'IncompleteReadError',
]
import socket
from . import events
from . import futures
from . import protocols
......@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds))
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
@tasks.coroutine
def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@tasks.coroutine
def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_unix_server(factory, path, **kwds))
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
......
......@@ -4,12 +4,18 @@ import collections
import contextlib
import io
import os
import socket
import socketserver
import sys
import tempfile
import threading
import time
import unittest
import unittest.mock
from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
try:
import ssl
except ImportError: # pragma: no cover
......@@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever()
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SilentWSGIRequestHandler(WSGIRequestHandler):
class SilentWSGIRequestHandler(WSGIRequestHandler):
def get_stderr(self):
return io.StringIO()
def get_stderr(self):
return io.StringIO()
def log_message(self, format, *args):
pass
def log_message(self, format, *args):
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
pass
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
......@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
httpd = make_server(host, port, app,
server_class, SilentWSGIRequestHandler)
server_class = server_ssl_cls if use_ssl else server_cls
httpd = server_class(address, SilentWSGIRequestHandler)
httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
......@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file:
return file.name
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def make_test_protocol(base):
dct = {}
for name in dir(base):
......@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)
......@@ -11,6 +11,7 @@ import sys
import threading
from . import base_events
from . import base_subprocess
from . import constants
from . import events
......@@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Unix event loop
"""Unix event loop.
Adds signal handling to SelectorEventLoop
Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
"""
def __init__(self, selector=None):
......@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
@tasks.coroutine
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
raise ValueError(
'you have to pass server_hostname when using ssl')
else:
if server_hostname is not None:
raise ValueError('server_hostname is only meaningful with ssl')
if path is not None:
if sock is not None:
raise ValueError(
'path and sock can not be specified at the same time')
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.setblocking(False)
yield from self.sock_connect(sock, path)
except OSError:
if sock is not None:
sock.close()
raise
else:
if sock is None:
raise ValueError('no path and sock were specified')
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if path is not None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.bind(path)
except OSError as exc:
if exc.errno == errno.EADDRINUSE:
# Let's improve the error message by adding
# with what exact address it occurs.
msg = 'Address {!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg) from None
else:
raise
else:
if sock is None:
raise ValueError(
'path was not specified, and no sock specified')
if sock.family != socket.AF_UNIX:
raise ValueError(
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
server = base_events.Server(self, [sock])
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
return server
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
......
......@@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
idx = -1
data = [10.0, 10.0, 10.3, 13.0]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
......
......@@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock()
waiter = asyncio.Future(loop=self.loop)
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None)
......
......@@ -7,8 +7,10 @@ import io
import os
import pprint
import signal
import socket
import stat
import sys
import tempfile
import threading
import unittest
import unittest.mock
......@@ -24,7 +26,7 @@ from asyncio import unix_events
@unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopTests(unittest.TestCase):
class SelectorEventLoopSignalTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
......@@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
m_signal.set_wakeup_fd.assert_called_once_with(-1)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
sock = socket.socket(socket.AF_UNIX)
sock.bind(path)
coro = self.loop.create_unix_server(lambda: None, path)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_existing_path_nonsock(self):
with tempfile.NamedTemporaryFile() as file:
coro = self.loop.create_unix_server(lambda: None, file.name)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_ssl_bool(self):
coro = self.loop.create_unix_server(lambda: None, path='spam',
ssl=True)
with self.assertRaisesRegex(TypeError,
'ssl argument must be an SSLContext'):
self.loop.run_until_complete(coro)
def test_create_unix_server_nopath_nosock(self):
coro = self.loop.create_unix_server(lambda: None, path=None)
with self.assertRaisesRegex(ValueError,
'path was not specified, and no sock'):
self.loop.run_until_complete(coro)
def test_create_unix_server_path_inetsock(self):
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=socket.socket())
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Socket was expected'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_path_sock(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', sock=object())
with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nopath_nosock(self):
coro = self.loop.create_unix_connection(
lambda: None, None)
with self.assertRaisesRegex(ValueError,
'no path and sock were specified'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nossl_serverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', server_hostname='spam')
with self.assertRaisesRegex(ValueError,
'server_hostname is only meaningful'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_ssl_noserverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', ssl=True)
with self.assertRaisesRegexp(
ValueError, 'you have to pass server_hostname when using ssl'):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment