test_utils.py 13.9 KB
Newer Older
1 2 3 4 5
"""Utilities shared by tests."""

import collections
import contextlib
import io
6
import logging
7
import os
8
import re
9 10
import socket
import socketserver
11
import sys
12
import tempfile
13
import threading
14
import time
15
import unittest
16 17
import weakref

18
from unittest import mock
19 20

from http.server import HTTPServer
21
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
22

23 24 25 26 27 28
try:
    import ssl
except ImportError:  # pragma: no cover
    ssl = None

from . import base_events
29
from . import compat
30
from . import events
31
from . import futures
32
from . import selectors
33
from . import tasks
34
from .coroutines import coroutine
35
from .log import logger
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51


if sys.platform == 'win32':  # pragma: no cover
    from .windows_utils import socketpair
else:
    from socket import socketpair  # pragma: no cover


def dummy_ssl_context():
    if ssl is None:
        return None
    else:
        return ssl.SSLContext(ssl.PROTOCOL_SSLv23)


def run_briefly(loop):
52
    @coroutine
53 54 55
    def once():
        pass
    gen = once()
56
    t = loop.create_task(gen)
57 58 59
    # Don't log a warning if the task is not done after run_until_complete().
    # It occurs if the loop is stopped or if a task raises a BaseException.
    t._log_destroy_pending = False
60 61 62 63 64 65
    try:
        loop.run_until_complete(t)
    finally:
        gen.close()


66 67
def run_until(loop, pred, timeout=30):
    deadline = time.time() + timeout
68 69 70 71
    while not pred():
        if timeout is not None:
            timeout = deadline - time.time()
            if timeout <= 0:
72 73
                raise futures.TimeoutError()
        loop.run_until_complete(tasks.sleep(0.001, loop=loop))
74 75


76
def run_once(loop):
77 78 79 80 81
    """Legacy API to run once through the event loop.

    This is the recommended pattern for test code.  It will poll the
    selector once and run all callbacks scheduled in response to I/O
    events.
82
    """
83
    loop.call_soon(loop.stop)
84 85 86
    loop.run_forever()


87
class SilentWSGIRequestHandler(WSGIRequestHandler):
88

89 90
    def get_stderr(self):
        return io.StringIO()
91

92 93
    def log_message(self, format, *args):
        pass
94

95 96 97

class SilentWSGIServer(WSGIServer):

98 99 100 101 102 103 104
    request_timeout = 2

    def get_request(self):
        request, client_addr = super().get_request()
        request.settimeout(self.request_timeout)
        return request, client_addr

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    def handle_error(self, request, client_address):
        pass


class SSLWSGIServerMixin:

    def finish_request(self, request, client_address):
        # The relative location of our test directory (which
        # contains the ssl key and certificate files) differs
        # between the stdlib and stand-alone asyncio.
        # Prefer our own if we can find it.
        here = os.path.join(os.path.dirname(__file__), '..', 'tests')
        if not os.path.isdir(here):
            here = os.path.join(os.path.dirname(os.__file__),
                                'test', 'test_asyncio')
        keyfile = os.path.join(here, 'ssl_key.pem')
        certfile = os.path.join(here, 'ssl_cert.pem')
122 123 124 125
        context = ssl.SSLContext()
        context.load_cert_chain(certfile, keyfile)

        ssock = context.wrap_socket(request, server_side=True)
126 127 128 129 130
        try:
            self.RequestHandlerClass(ssock, client_address, self)
            ssock.close()
        except OSError:
            # maybe socket has been closed by peer
131 132
            pass

133 134 135 136 137 138

class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
    pass


def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
139 140 141 142 143 144 145 146 147

    def app(environ, start_response):
        status = '200 OK'
        headers = [('Content-type', 'text/plain')]
        start_response(status, headers)
        return [b'Test message']

    # Run the test WSGI server in a separate thread in order not to
    # interfere with event handling in the main thread
148 149 150
    server_class = server_ssl_cls if use_ssl else server_cls
    httpd = server_class(address, SilentWSGIRequestHandler)
    httpd.set_app(app)
151
    httpd.address = httpd.server_address
152 153
    server_thread = threading.Thread(
        target=lambda: httpd.serve_forever(poll_interval=0.05))
154 155 156 157 158
    server_thread.start()
    try:
        yield httpd
    finally:
        httpd.shutdown()
159
        httpd.server_close()
160 161 162
        server_thread.join()


163 164 165 166 167 168 169 170 171 172 173 174
if hasattr(socket, 'AF_UNIX'):

    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):

        def server_bind(self):
            socketserver.UnixStreamServer.server_bind(self)
            self.server_name = '127.0.0.1'
            self.server_port = 80


    class UnixWSGIServer(UnixHTTPServer, WSGIServer):

175 176
        request_timeout = 2

177 178 179 180 181 182
        def server_bind(self):
            UnixHTTPServer.server_bind(self)
            self.setup_environ()

        def get_request(self):
            request, client_addr = super().get_request()
183
            request.settimeout(self.request_timeout)
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
            # Code in the stdlib expects that get_request
            # will return a socket and a tuple (host, port).
            # However, this isn't true for UNIX sockets,
            # as the second return value will be a path;
            # hence we return some fake data sufficient
            # to get the tests going
            return request, ('127.0.0.1', '')


    class SilentUnixWSGIServer(UnixWSGIServer):

        def handle_error(self, request, client_address):
            pass


    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
        pass


    def gen_unix_socket_path():
        with tempfile.NamedTemporaryFile() as file:
            return file.name


    @contextlib.contextmanager
    def unix_socket_path():
        path = gen_unix_socket_path()
        try:
            yield path
        finally:
            try:
                os.unlink(path)
            except OSError:
                pass


    @contextlib.contextmanager
    def run_test_unix_server(*, use_ssl=False):
        with unix_socket_path() as path:
            yield from _run_test_server(address=path, use_ssl=use_ssl,
                                        server_cls=SilentUnixWSGIServer,
                                        server_ssl_cls=UnixSSLWSGIServer)


@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
                                server_cls=SilentWSGIServer,
                                server_ssl_cls=SSLWSGIServer)


235 236 237 238 239 240
def make_test_protocol(base):
    dct = {}
    for name in dir(base):
        if name.startswith('__') and name.endswith('__'):
            # skip magic names
            continue
241
        dct[name] = MockCallback(return_value=None)
242 243 244 245 246
    return type('TestProtocol', (base,) + base.__bases__, dct)()


class TestSelector(selectors.BaseSelector):

247 248 249 250 251 252 253 254 255 256 257
    def __init__(self):
        self.keys = {}

    def register(self, fileobj, events, data=None):
        key = selectors.SelectorKey(fileobj, 0, events, data)
        self.keys[fileobj] = key
        return key

    def unregister(self, fileobj):
        return self.keys.pop(fileobj)

258 259 260
    def select(self, timeout):
        return []

261 262 263
    def get_map(self):
        return self.keys

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279

class TestLoop(base_events.BaseEventLoop):
    """Loop for unittests.

    It manages self time directly.
    If something scheduled to be executed later then
    on next loop iteration after all ready handlers done
    generator passed to __init__ is calling.

    Generator should be like this:

        def gen():
            ...
            when = yield ...
            ... = yield time_advance

280
    Value returned by yield is absolute time of next scheduled handler.
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
    Value passed to yield is time advance to move loop's time forward.
    """

    def __init__(self, gen=None):
        super().__init__()

        if gen is None:
            def gen():
                yield
            self._check_on_close = False
        else:
            self._check_on_close = True

        self._gen = gen()
        next(self._gen)
        self._time = 0
297
        self._clock_resolution = 1e-9
298 299 300 301 302 303 304
        self._timers = []
        self._selector = TestSelector()

        self.readers = {}
        self.writers = {}
        self.reset_counters()

305 306
        self._transports = weakref.WeakValueDictionary()

307 308 309 310 311 312 313 314 315
    def time(self):
        return self._time

    def advance_time(self, advance):
        """Move test time forward."""
        if advance:
            self._time += advance

    def close(self):
316
        super().close()
317 318 319 320 321 322 323 324
        if self._check_on_close:
            try:
                self._gen.send(0)
            except StopIteration:
                pass
            else:  # pragma: no cover
                raise AssertionError("Time generator is not finished")

325
    def _add_reader(self, fd, callback, *args):
326
        self.readers[fd] = events.Handle(callback, args, self)
327

328
    def _remove_reader(self, fd):
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
        self.remove_reader_count[fd] += 1
        if fd in self.readers:
            del self.readers[fd]
            return True
        else:
            return False

    def assert_reader(self, fd, callback, *args):
        assert fd in self.readers, 'fd {} is not registered'.format(fd)
        handle = self.readers[fd]
        assert handle._callback == callback, '{!r} != {!r}'.format(
            handle._callback, callback)
        assert handle._args == args, '{!r} != {!r}'.format(
            handle._args, args)

344
    def _add_writer(self, fd, callback, *args):
345
        self.writers[fd] = events.Handle(callback, args, self)
346

347
    def _remove_writer(self, fd):
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
        self.remove_writer_count[fd] += 1
        if fd in self.writers:
            del self.writers[fd]
            return True
        else:
            return False

    def assert_writer(self, fd, callback, *args):
        assert fd in self.writers, 'fd {} is not registered'.format(fd)
        handle = self.writers[fd]
        assert handle._callback == callback, '{!r} != {!r}'.format(
            handle._callback, callback)
        assert handle._args == args, '{!r} != {!r}'.format(
            handle._args, args)

363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    def _ensure_fd_no_transport(self, fd):
        try:
            transport = self._transports[fd]
        except KeyError:
            pass
        else:
            raise RuntimeError(
                'File descriptor {!r} is used by transport {!r}'.format(
                    fd, transport))

    def add_reader(self, fd, callback, *args):
        """Add a reader callback."""
        self._ensure_fd_no_transport(fd)
        return self._add_reader(fd, callback, *args)

    def remove_reader(self, fd):
        """Remove a reader callback."""
        self._ensure_fd_no_transport(fd)
        return self._remove_reader(fd)

    def add_writer(self, fd, callback, *args):
        """Add a writer callback.."""
        self._ensure_fd_no_transport(fd)
        return self._add_writer(fd, callback, *args)

    def remove_writer(self, fd):
        """Remove a writer callback."""
        self._ensure_fd_no_transport(fd)
        return self._remove_writer(fd)

393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    def reset_counters(self):
        self.remove_reader_count = collections.defaultdict(int)
        self.remove_writer_count = collections.defaultdict(int)

    def _run_once(self):
        super()._run_once()
        for when in self._timers:
            advance = self._gen.send(when)
            self.advance_time(advance)
        self._timers = []

    def call_at(self, when, callback, *args):
        self._timers.append(when)
        return super().call_at(when, callback, *args)

    def _process_events(self, event_list):
        return

    def _write_to_self(self):
        pass
413

414

415
def MockCallback(**kwargs):
416
    return mock.Mock(spec=['__call__'], **kwargs)
417 418 419 420 421 422


class MockPattern(str):
    """A regex based str with a fuzzy __eq__.

    Use this helper with 'mock.assert_called_with', or anywhere
423
    where a regex comparison between strings is needed.
424 425 426 427 428 429

    For instance:
       mock_call.assert_called_with(MockPattern('spam.*ham'))
    """
    def __eq__(self, other):
        return bool(re.search(str(self), other, re.S))
430 431 432 433 434 435 436


def get_function_source(func):
    source = events._get_function_source(func)
    if source is None:
        raise ValueError("unable to get the source of %r" % (func,))
    return source
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451


class TestCase(unittest.TestCase):
    def set_event_loop(self, loop, *, cleanup=True):
        assert loop is not None
        # ensure that the event loop is passed explicitly in asyncio
        events.set_event_loop(None)
        if cleanup:
            self.addCleanup(loop.close)

    def new_test_loop(self, gen=None):
        loop = TestLoop(gen)
        self.set_event_loop(loop)
        return loop

452 453 454 455
    def setUp(self):
        self._get_running_loop = events._get_running_loop
        events._get_running_loop = lambda: None

456
    def tearDown(self):
457 458
        events._get_running_loop = self._get_running_loop

459
        events.set_event_loop(None)
460

461 462 463 464
        # Detect CPython bug #23353: ensure that yield/yield-from is not used
        # in an except block of a generator
        self.assertEqual(sys.exc_info(), (None, None, None))

465 466 467 468 469 470 471 472 473 474
    if not compat.PY34:
        # Python 3.3 compatibility
        def subTest(self, *args, **kwargs):
            class EmptyCM:
                def __enter__(self):
                    pass
                def __exit__(self, *exc):
                    pass
            return EmptyCM()

475 476 477 478 479 480 481 482 483 484 485 486 487

@contextlib.contextmanager
def disable_logger():
    """Context manager to disable asyncio logger.

    For example, it can be used to ignore warnings in debug mode.
    """
    old_level = logger.level
    try:
        logger.setLevel(logging.CRITICAL+1)
        yield
    finally:
        logger.setLevel(old_level)
488

489 490 491

def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
                            family=socket.AF_INET):
492
    """Create a mock of a non-blocking socket."""
493 494 495 496
    sock = mock.MagicMock(socket.socket)
    sock.proto = proto
    sock.type = type
    sock.family = family
497 498
    sock.gettimeout.return_value = 0.0
    return sock
499 500 501 502 503


def force_legacy_ssl_support():
    return mock.patch('asyncio.sslproto._is_sslproto_available',
                      return_value=False)