streams.py 23.8 KB
Newer Older
1 2
"""Stream-related things."""

3
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
4 5
           'open_connection', 'start_server',
           'IncompleteReadError',
6
           'LimitOverrunError',
7
           ]
8

9 10
import socket

11 12 13
if hasattr(socket, 'AF_UNIX'):
    __all__.extend(['open_unix_connection', 'start_unix_server'])

14
from . import coroutines
15
from . import compat
16 17
from . import events
from . import protocols
18
from .coroutines import coroutine
19
from .log import logger
20 21


22
_DEFAULT_LIMIT = 2 ** 16
23

24

25 26 27 28 29
class IncompleteReadError(EOFError):
    """
    Incomplete read error. Attributes:

    - partial: read bytes string before the end of stream was reached
30
    - expected: total number of expected bytes (or None if unknown)
31 32
    """
    def __init__(self, partial, expected):
33 34
        super().__init__("%d bytes read on a total of %r expected bytes"
                         % (len(partial), expected))
35 36 37
        self.partial = partial
        self.expected = expected

38

39
class LimitOverrunError(Exception):
40
    """Reached the buffer limit while looking for a separator.
41 42

    Attributes:
43
    - consumed: total number of to be consumed bytes.
44 45 46 47 48 49
    """
    def __init__(self, message, consumed):
        super().__init__(message)
        self.consumed = consumed


50
@coroutine
51 52 53 54 55
def open_connection(host=None, port=None, *,
                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
    """A wrapper for create_connection() returning a (reader, writer) pair.

    The reader returned is a StreamReader instance; the writer is a
56
    StreamWriter instance.
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

    The arguments are all the usual arguments to create_connection()
    except protocol_factory; most common are positional host and port,
    with various optional keyword arguments following.

    Additional optional keyword arguments are loop (to set the event loop
    instance to use) and limit (to set the buffer limit passed to the
    StreamReader).

    (If you want to customize the StreamReader and/or
    StreamReaderProtocol classes, just copy the code -- there's
    really nothing special here except some convenience.)
    """
    if loop is None:
        loop = events.get_event_loop()
    reader = StreamReader(limit=limit, loop=loop)
73
    protocol = StreamReaderProtocol(reader, loop=loop)
74 75
    transport, _ = yield from loop.create_connection(
        lambda: protocol, host, port, **kwds)
76 77
    writer = StreamWriter(transport, protocol, reader, loop)
    return reader, writer
78 79


80
@coroutine
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
def start_server(client_connected_cb, host=None, port=None, *,
                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
    """Start a socket server, call back for each client connected.

    The first parameter, `client_connected_cb`, takes two parameters:
    client_reader, client_writer.  client_reader is a StreamReader
    object, while client_writer is a StreamWriter object.  This
    parameter can either be a plain callback function or a coroutine;
    if it is a coroutine, it will be automatically converted into a
    Task.

    The rest of the arguments are all the usual arguments to
    loop.create_server() except protocol_factory; most common are
    positional host and port, with various optional keyword arguments
    following.  The return value is the same as loop.create_server().

    Additional optional keyword arguments are loop (to set the event loop
    instance to use) and limit (to set the buffer limit passed to the
    StreamReader).

    The return value is the same as loop.create_server(), i.e. a
    Server object which can be used to stop the service.
    """
    if loop is None:
        loop = events.get_event_loop()

    def factory():
        reader = StreamReader(limit=limit, loop=loop)
        protocol = StreamReaderProtocol(reader, client_connected_cb,
                                        loop=loop)
        return protocol

    return (yield from loop.create_server(factory, host, port, **kwds))


116 117 118
if hasattr(socket, 'AF_UNIX'):
    # UNIX Domain Sockets are supported on this platform

119
    @coroutine
120 121 122 123 124 125 126 127 128 129 130 131
    def open_unix_connection(path=None, *,
                             loop=None, limit=_DEFAULT_LIMIT, **kwds):
        """Similar to `open_connection` but works with UNIX Domain Sockets."""
        if loop is None:
            loop = events.get_event_loop()
        reader = StreamReader(limit=limit, loop=loop)
        protocol = StreamReaderProtocol(reader, loop=loop)
        transport, _ = yield from loop.create_unix_connection(
            lambda: protocol, path, **kwds)
        writer = StreamWriter(transport, protocol, reader, loop)
        return reader, writer

132
    @coroutine
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def start_unix_server(client_connected_cb, path=None, *,
                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
        """Similar to `start_server` but works with UNIX Domain Sockets."""
        if loop is None:
            loop = events.get_event_loop()

        def factory():
            reader = StreamReader(limit=limit, loop=loop)
            protocol = StreamReaderProtocol(reader, client_connected_cb,
                                            loop=loop)
            return protocol

        return (yield from loop.create_unix_server(factory, path, **kwds))


148 149 150 151 152 153 154
class FlowControlMixin(protocols.Protocol):
    """Reusable flow control logic for StreamWriter.drain().

    This implements the protocol methods pause_writing(),
    resume_reading() and connection_lost().  If the subclass overrides
    these it must call the super methods.

155
    StreamWriter.drain() must wait for _drain_helper() coroutine.
156 157 158
    """

    def __init__(self, loop=None):
159 160 161 162
        if loop is None:
            self._loop = events.get_event_loop()
        else:
            self._loop = loop
163 164
        self._paused = False
        self._drain_waiter = None
165
        self._connection_lost = False
166 167 168 169

    def pause_writing(self):
        assert not self._paused
        self._paused = True
170 171
        if self._loop.get_debug():
            logger.debug("%r pauses writing", self)
172 173 174 175

    def resume_writing(self):
        assert self._paused
        self._paused = False
176 177 178
        if self._loop.get_debug():
            logger.debug("%r resumes writing", self)

179 180 181 182 183 184 185
        waiter = self._drain_waiter
        if waiter is not None:
            self._drain_waiter = None
            if not waiter.done():
                waiter.set_result(None)

    def connection_lost(self, exc):
186
        self._connection_lost = True
187 188 189 190 191 192 193 194 195 196 197 198 199 200
        # Wake up the writer if currently paused.
        if not self._paused:
            return
        waiter = self._drain_waiter
        if waiter is None:
            return
        self._drain_waiter = None
        if waiter.done():
            return
        if exc is None:
            waiter.set_result(None)
        else:
            waiter.set_exception(exc)

201 202 203 204
    @coroutine
    def _drain_helper(self):
        if self._connection_lost:
            raise ConnectionResetError('Connection lost')
205
        if not self._paused:
206
            return
207 208
        waiter = self._drain_waiter
        assert waiter is None or waiter.cancelled()
209
        waiter = self._loop.create_future()
210
        self._drain_waiter = waiter
211
        yield from waiter
212 213 214 215


class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
    """Helper class to adapt between Protocol and StreamReader.
216 217 218 219 220 221 222

    (This is a helper class instead of making StreamReader itself a
    Protocol subclass, because the StreamReader has other potential
    uses, and to prevent the user of the StreamReader to accidentally
    call inappropriate methods of the protocol.)
    """

223
    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
224
        super().__init__(loop=loop)
225
        self._stream_reader = stream_reader
226 227
        self._stream_writer = None
        self._client_connected_cb = client_connected_cb
228
        self._over_ssl = False
229 230

    def connection_made(self, transport):
231
        self._stream_reader.set_transport(transport)
232
        self._over_ssl = transport.get_extra_info('sslcontext') is not None
233 234 235 236 237 238
        if self._client_connected_cb is not None:
            self._stream_writer = StreamWriter(transport, self,
                                               self._stream_reader,
                                               self._loop)
            res = self._client_connected_cb(self._stream_reader,
                                            self._stream_writer)
239
            if coroutines.iscoroutine(res):
240
                self._loop.create_task(res)
241 242

    def connection_lost(self, exc):
243 244 245 246 247
        if self._stream_reader is not None:
            if exc is None:
                self._stream_reader.feed_eof()
            else:
                self._stream_reader.set_exception(exc)
248
        super().connection_lost(exc)
249 250
        self._stream_reader = None
        self._stream_writer = None
251 252

    def data_received(self, data):
253
        self._stream_reader.feed_data(data)
254 255

    def eof_received(self):
256
        self._stream_reader.feed_eof()
257 258 259 260 261
        if self._over_ssl:
            # Prevent a warning in SSLProtocol.eof_received:
            # "returning true from eof_received()
            # has no effect when using ssl"
            return False
262
        return True
263 264 265 266 267 268 269 270


class StreamWriter:
    """Wraps a Transport.

    This exposes write(), writelines(), [can_]write_eof(),
    get_extra_info() and close().  It adds drain() which returns an
    optional Future on which you can wait for flow control.  It also
271
    adds a transport property which references the Transport
272 273 274 275 276 277
    directly.
    """

    def __init__(self, transport, protocol, reader, loop):
        self._transport = transport
        self._protocol = protocol
278
        # drain() expects that the reader has an exception() method
279
        assert reader is None or isinstance(reader, StreamReader)
280 281 282
        self._reader = reader
        self._loop = loop

283
    def __repr__(self):
284
        info = [self.__class__.__name__, 'transport=%r' % self._transport]
285 286 287 288
        if self._reader is not None:
            info.append('reader=%r' % self._reader)
        return '<%s>' % ' '.join(info)

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
    @property
    def transport(self):
        return self._transport

    def write(self, data):
        self._transport.write(data)

    def writelines(self, data):
        self._transport.writelines(data)

    def write_eof(self):
        return self._transport.write_eof()

    def can_write_eof(self):
        return self._transport.can_write_eof()

305 306 307
    def close(self):
        return self._transport.close()

308 309 310
    def get_extra_info(self, name, default=None):
        return self._transport.get_extra_info(name, default)

311
    @coroutine
312
    def drain(self):
313
        """Flush the write buffer.
314 315 316 317 318 319

        The intended use is to write

          w.write(data)
          yield from w.drain()
        """
320 321 322 323
        if self._reader is not None:
            exc = self._reader.exception()
            if exc is not None:
                raise exc
324
        if self._transport is not None:
325
            if self._transport.is_closing():
326 327 328 329 330 331 332
                # Yield to the event loop so connection_lost() may be
                # called.  Without this, _drain_helper() would return
                # immediately, and code that calls
                #     write(...); yield from drain()
                # in a loop would never call connection_lost(), so it
                # would not see an error when the socket is closed.
                yield
333
        yield from self._protocol._drain_helper()
334 335 336 337 338 339 340


class StreamReader:

    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
        # The line length limit is  a security feature;
        # it also doubles as half the buffer limit.
341 342 343 344

        if limit <= 0:
            raise ValueError('Limit cannot be <= 0')

345
        self._limit = limit
346
        if loop is None:
347 348 349
            self._loop = events.get_event_loop()
        else:
            self._loop = loop
350
        self._buffer = bytearray()
351 352
        self._eof = False    # Whether we're done.
        self._waiter = None  # A future used by _wait_for_data()
353 354 355 356
        self._exception = None
        self._transport = None
        self._paused = False

357 358 359
    def __repr__(self):
        info = ['StreamReader']
        if self._buffer:
360
            info.append('%d bytes' % len(self._buffer))
361 362 363 364 365 366 367 368 369 370 371 372 373 374
        if self._eof:
            info.append('eof')
        if self._limit != _DEFAULT_LIMIT:
            info.append('l=%d' % self._limit)
        if self._waiter:
            info.append('w=%r' % self._waiter)
        if self._exception:
            info.append('e=%r' % self._exception)
        if self._transport:
            info.append('t=%r' % self._transport)
        if self._paused:
            info.append('paused')
        return '<%s>' % ' '.join(info)

375 376 377 378 379 380
    def exception(self):
        return self._exception

    def set_exception(self, exc):
        self._exception = exc

381
        waiter = self._waiter
382
        if waiter is not None:
383
            self._waiter = None
384 385 386
            if not waiter.cancelled():
                waiter.set_exception(exc)

387
    def _wakeup_waiter(self):
388
        """Wakeup read*() functions waiting for data or EOF."""
389 390 391 392 393 394
        waiter = self._waiter
        if waiter is not None:
            self._waiter = None
            if not waiter.cancelled():
                waiter.set_result(None)

395 396 397 398 399
    def set_transport(self, transport):
        assert self._transport is None, 'Transport already set'
        self._transport = transport

    def _maybe_resume_transport(self):
400
        if self._paused and len(self._buffer) <= self._limit:
401
            self._paused = False
402
            self._transport.resume_reading()
403 404

    def feed_eof(self):
405
        self._eof = True
406
        self._wakeup_waiter()
407

408 409 410 411
    def at_eof(self):
        """Return True if the buffer is empty and 'feed_eof' was called."""
        return self._eof and not self._buffer

412
    def feed_data(self, data):
413 414
        assert not self._eof, 'feed_data after feed_eof'

415 416 417
        if not data:
            return

418
        self._buffer.extend(data)
419
        self._wakeup_waiter()
420 421

        if (self._transport is not None and
422 423
                not self._paused and
                len(self._buffer) > 2 * self._limit):
424
            try:
425
                self._transport.pause_reading()
426 427 428 429 430 431 432 433
            except NotImplementedError:
                # The transport can't be paused.
                # We'll just have to buffer all data.
                # Forget the transport so we don't keep trying.
                self._transport = None
            else:
                self._paused = True

434
    @coroutine
435
    def _wait_for_data(self, func_name):
436 437 438 439
        """Wait until feed_data() or feed_eof() is called.

        If stream was paused, automatically resume it.
        """
440 441 442 443 444 445 446
        # StreamReader uses a future to link the protocol feed_data() method
        # to a read coroutine. Running two read coroutines at the same time
        # would have an unexpected behaviour. It would not possible to know
        # which coroutine would get the next data.
        if self._waiter is not None:
            raise RuntimeError('%s() called while another coroutine is '
                               'already waiting for incoming data' % func_name)
447

448 449 450 451 452 453 454
        assert not self._eof, '_wait_for_data after EOF'

        # Waiting for data while paused will make deadlock, so prevent it.
        if self._paused:
            self._paused = False
            self._transport.resume_reading()

455
        self._waiter = self._loop.create_future()
456 457 458 459
        try:
            yield from self._waiter
        finally:
            self._waiter = None
460

461
    @coroutine
462
    def readline(self):
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
        """Read chunk of data from the stream until newline (b'\n') is found.

        On success, return chunk that ends with newline. If only partial
        line can be read due to EOF, return incomplete line without
        terminating newline. When EOF was reached while no bytes read, empty
        bytes object is returned.

        If limit is reached, ValueError will be raised. In that case, if
        newline was found, complete line including newline will be removed
        from internal buffer. Else, internal buffer will be cleared. Limit is
        compared against part of the line without newline.

        If stream was paused, this function will automatically resume it if
        needed.
        """
        sep = b'\n'
        seplen = len(sep)
        try:
            line = yield from self.readuntil(sep)
        except IncompleteReadError as e:
            return e.partial
        except LimitOverrunError as e:
            if self._buffer.startswith(sep, e.consumed):
                del self._buffer[:e.consumed + seplen]
            else:
                self._buffer.clear()
            self._maybe_resume_transport()
            raise ValueError(e.args[0])
        return line

    @coroutine
    def readuntil(self, separator=b'\n'):
495
        """Read data from the stream until ``separator`` is found.
496

497 498 499
        On success, the data and separator will be removed from the
        internal buffer (consumed). Returned data will include the
        separator at the end.
500

501 502 503
        Configured stream limit is used to check result. Limit sets the
        maximal length of data that can be returned, not counting the
        separator.
504

505 506 507 508
        If an EOF occurs and the complete separator is still not found,
        an IncompleteReadError exception will be raised, and the internal
        buffer will be reset.  The IncompleteReadError.partial attribute
        may contain the separator partially.
509

510 511 512
        If the data cannot be read because of over limit, a
        LimitOverrunError exception  will be raised, and the data
        will be left in the internal buffer, so it can be read again.
513 514 515 516 517
        """
        seplen = len(separator)
        if seplen == 0:
            raise ValueError('Separator should be at least one-byte string')

518 519 520
        if self._exception is not None:
            raise self._exception

521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
        # Consume whole buffer except last bytes, which length is
        # one less than seplen. Let's check corner cases with
        # separator='SEPARATOR':
        # * we have received almost complete separator (without last
        #   byte). i.e buffer='some textSEPARATO'. In this case we
        #   can safely consume len(separator) - 1 bytes.
        # * last byte of buffer is first byte of separator, i.e.
        #   buffer='abcdefghijklmnopqrS'. We may safely consume
        #   everything except that last byte, but this require to
        #   analyze bytes of buffer that match partial separator.
        #   This is slow and/or require FSM. For this case our
        #   implementation is not optimal, since require rescanning
        #   of data that is known to not belong to separator. In
        #   real world, separator will not be so long to notice
        #   performance problems. Even when reading MIME-encoded
        #   messages :)

538 539
        # `offset` is the number of bytes from the beginning of the buffer
        # where there is no occurrence of `separator`.
540 541 542 543 544 545 546 547 548 549 550 551 552
        offset = 0

        # Loop until we find `separator` in the buffer, exceed the buffer size,
        # or an EOF has happened.
        while True:
            buflen = len(self._buffer)

            # Check if we now have enough data in the buffer for `separator` to
            # fit.
            if buflen - offset >= seplen:
                isep = self._buffer.find(separator, offset)

                if isep != -1:
553 554
                    # `separator` is in the buffer. `isep` will be used later
                    # to retrieve the data.
555 556 557 558 559
                    break

                # see upper comment for explanation.
                offset = buflen + 1 - seplen
                if offset > self._limit:
560 561 562
                    raise LimitOverrunError(
                        'Separator is not found, and chunk exceed the limit',
                        offset)
563

564 565 566 567
            # Complete message (with full separator) may be present in buffer
            # even when EOF flag is set. This may happen when the last chunk
            # adds data which makes separator be found. That's why we check for
            # EOF *ater* inspecting the buffer.
568
            if self._eof:
569 570 571 572 573 574
                chunk = bytes(self._buffer)
                self._buffer.clear()
                raise IncompleteReadError(chunk, None)

            # _wait_for_data() will resume reading if stream was paused.
            yield from self._wait_for_data('readuntil')
575

576
        if isep > self._limit:
577 578
            raise LimitOverrunError(
                'Separator is found, but chunk is longer than limit', isep)
579

580 581
        chunk = self._buffer[:isep + seplen]
        del self._buffer[:isep + seplen]
582
        self._maybe_resume_transport()
583
        return bytes(chunk)
584

585
    @coroutine
586
    def read(self, n=-1):
587 588 589 590 591 592
        """Read up to `n` bytes from the stream.

        If n is not provided, or set to -1, read until EOF and return all read
        bytes. If the EOF was received and the internal buffer is empty, return
        an empty bytes object.

593
        If n is zero, return empty bytes object immediately.
594 595 596 597 598 599

        If n is positive, this function try to read `n` bytes, and may return
        less or equal bytes than requested, but at least one byte. If EOF was
        received before any byte is read, this function returns empty byte
        object.

600 601
        Returned value is not limited with limit, configured at stream
        creation.
602 603 604 605 606

        If stream was paused, this function will automatically resume it if
        needed.
        """

607 608 609
        if self._exception is not None:
            raise self._exception

610
        if n == 0:
611 612 613
            return b''

        if n < 0:
614 615 616 617 618 619 620 621 622 623 624
            # This used to just loop creating a new waiter hoping to
            # collect everything in self._buffer, but that would
            # deadlock if the subprocess sends more than self.limit
            # bytes.  So just call self.read(self._limit) until EOF.
            blocks = []
            while True:
                block = yield from self.read(self._limit)
                if not block:
                    break
                blocks.append(block)
            return b''.join(blocks)
625

626 627 628 629 630 631
        if not self._buffer and not self._eof:
            yield from self._wait_for_data('read')

        # This will work right even if buffer is less than n bytes
        data = bytes(self._buffer[:n])
        del self._buffer[:n]
632 633 634

        self._maybe_resume_transport()
        return data
635

636
    @coroutine
637
    def readexactly(self, n):
638 639
        """Read exactly `n` bytes.

640 641
        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
        read. The IncompleteReadError.partial attribute of the exception will
642 643 644 645
        contain the partial read bytes.

        if n is zero, return empty bytes object.

646 647
        Returned value is not limited with limit, configured at stream
        creation.
648 649 650 651

        If stream was paused, this function will automatically resume it if
        needed.
        """
652 653 654
        if n < 0:
            raise ValueError('readexactly size can not be less than zero')

655 656 657
        if self._exception is not None:
            raise self._exception

658 659 660
        if n == 0:
            return b''

661 662 663 664 665 666 667 668 669 670 671
        # There used to be "optimized" code here.  It created its own
        # Future and waited until self._buffer had at least the n
        # bytes, then called read(n).  Unfortunately, this could pause
        # the transport if the argument was larger than the pause
        # limit (which is twice self._limit).  So now we just read()
        # into a local buffer.

        blocks = []
        while n > 0:
            block = yield from self.read(n)
            if not block:
672 673
                partial = b''.join(blocks)
                raise IncompleteReadError(partial, len(partial) + n)
674 675
            blocks.append(block)
            n -= len(block)
676

677 678
        assert n == 0

679
        return b''.join(blocks)
680

681
    if compat.PY35:
682 683 684 685 686 687 688 689 690 691
        @coroutine
        def __aiter__(self):
            return self

        @coroutine
        def __anext__(self):
            val = yield from self.readline()
            if val == b'':
                raise StopAsyncIteration
            return val
692 693 694 695 696 697

    if compat.PY352:
        # In Python 3.5.2 and greater, __aiter__ should return
        # the asynchronous iterator directly.
        def __aiter__(self):
            return self