streams.py 15.8 KB
Newer Older
1 2
"""Stream-related things."""

3
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
4 5
           'open_connection', 'start_server',
           'IncompleteReadError',
6
           ]
7

8 9
import socket

10 11 12
if hasattr(socket, 'AF_UNIX'):
    __all__.extend(['open_unix_connection', 'start_unix_server'])

13
from . import coroutines
14 15 16
from . import events
from . import futures
from . import protocols
17
from .coroutines import coroutine
18
from .log import logger
19 20 21 22


_DEFAULT_LIMIT = 2**16

23

24 25 26 27 28 29 30 31 32 33 34 35 36
class IncompleteReadError(EOFError):
    """
    Incomplete read error. Attributes:

    - partial: read bytes string before the end of stream was reached
    - expected: total number of expected bytes
    """
    def __init__(self, partial, expected):
        EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
                                % (len(partial), expected))
        self.partial = partial
        self.expected = expected

37

38
@coroutine
39 40 41 42 43
def open_connection(host=None, port=None, *,
                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
    """A wrapper for create_connection() returning a (reader, writer) pair.

    The reader returned is a StreamReader instance; the writer is a
44
    StreamWriter instance.
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    The arguments are all the usual arguments to create_connection()
    except protocol_factory; most common are positional host and port,
    with various optional keyword arguments following.

    Additional optional keyword arguments are loop (to set the event loop
    instance to use) and limit (to set the buffer limit passed to the
    StreamReader).

    (If you want to customize the StreamReader and/or
    StreamReaderProtocol classes, just copy the code -- there's
    really nothing special here except some convenience.)
    """
    if loop is None:
        loop = events.get_event_loop()
    reader = StreamReader(limit=limit, loop=loop)
61
    protocol = StreamReaderProtocol(reader, loop=loop)
62 63
    transport, _ = yield from loop.create_connection(
        lambda: protocol, host, port, **kwds)
64 65
    writer = StreamWriter(transport, protocol, reader, loop)
    return reader, writer
66 67


68
@coroutine
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
def start_server(client_connected_cb, host=None, port=None, *,
                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
    """Start a socket server, call back for each client connected.

    The first parameter, `client_connected_cb`, takes two parameters:
    client_reader, client_writer.  client_reader is a StreamReader
    object, while client_writer is a StreamWriter object.  This
    parameter can either be a plain callback function or a coroutine;
    if it is a coroutine, it will be automatically converted into a
    Task.

    The rest of the arguments are all the usual arguments to
    loop.create_server() except protocol_factory; most common are
    positional host and port, with various optional keyword arguments
    following.  The return value is the same as loop.create_server().

    Additional optional keyword arguments are loop (to set the event loop
    instance to use) and limit (to set the buffer limit passed to the
    StreamReader).

    The return value is the same as loop.create_server(), i.e. a
    Server object which can be used to stop the service.
    """
    if loop is None:
        loop = events.get_event_loop()

    def factory():
        reader = StreamReader(limit=limit, loop=loop)
        protocol = StreamReaderProtocol(reader, client_connected_cb,
                                        loop=loop)
        return protocol

    return (yield from loop.create_server(factory, host, port, **kwds))


104 105 106
if hasattr(socket, 'AF_UNIX'):
    # UNIX Domain Sockets are supported on this platform

107
    @coroutine
108 109 110 111 112 113 114 115 116 117 118 119 120
    def open_unix_connection(path=None, *,
                             loop=None, limit=_DEFAULT_LIMIT, **kwds):
        """Similar to `open_connection` but works with UNIX Domain Sockets."""
        if loop is None:
            loop = events.get_event_loop()
        reader = StreamReader(limit=limit, loop=loop)
        protocol = StreamReaderProtocol(reader, loop=loop)
        transport, _ = yield from loop.create_unix_connection(
            lambda: protocol, path, **kwds)
        writer = StreamWriter(transport, protocol, reader, loop)
        return reader, writer


121
    @coroutine
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    def start_unix_server(client_connected_cb, path=None, *,
                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
        """Similar to `start_server` but works with UNIX Domain Sockets."""
        if loop is None:
            loop = events.get_event_loop()

        def factory():
            reader = StreamReader(limit=limit, loop=loop)
            protocol = StreamReaderProtocol(reader, client_connected_cb,
                                            loop=loop)
            return protocol

        return (yield from loop.create_unix_server(factory, path, **kwds))


137 138 139 140 141 142 143
class FlowControlMixin(protocols.Protocol):
    """Reusable flow control logic for StreamWriter.drain().

    This implements the protocol methods pause_writing(),
    resume_reading() and connection_lost().  If the subclass overrides
    these it must call the super methods.

144
    StreamWriter.drain() must wait for _drain_helper() coroutine.
145 146 147 148 149 150
    """

    def __init__(self, loop=None):
        self._loop = loop  # May be None; we may never need it.
        self._paused = False
        self._drain_waiter = None
151
        self._connection_lost = False
152 153 154 155

    def pause_writing(self):
        assert not self._paused
        self._paused = True
156 157
        if self._loop.get_debug():
            logger.debug("%r pauses writing", self)
158 159 160 161

    def resume_writing(self):
        assert self._paused
        self._paused = False
162 163 164
        if self._loop.get_debug():
            logger.debug("%r resumes writing", self)

165 166 167 168 169 170 171
        waiter = self._drain_waiter
        if waiter is not None:
            self._drain_waiter = None
            if not waiter.done():
                waiter.set_result(None)

    def connection_lost(self, exc):
172
        self._connection_lost = True
173 174 175 176 177 178 179 180 181 182 183 184 185 186
        # Wake up the writer if currently paused.
        if not self._paused:
            return
        waiter = self._drain_waiter
        if waiter is None:
            return
        self._drain_waiter = None
        if waiter.done():
            return
        if exc is None:
            waiter.set_result(None)
        else:
            waiter.set_exception(exc)

187 188 189 190
    @coroutine
    def _drain_helper(self):
        if self._connection_lost:
            raise ConnectionResetError('Connection lost')
191
        if not self._paused:
192
            return
193 194 195 196
        waiter = self._drain_waiter
        assert waiter is None or waiter.cancelled()
        waiter = futures.Future(loop=self._loop)
        self._drain_waiter = waiter
197
        yield from waiter
198 199 200 201


class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
    """Helper class to adapt between Protocol and StreamReader.
202 203 204 205 206 207 208

    (This is a helper class instead of making StreamReader itself a
    Protocol subclass, because the StreamReader has other potential
    uses, and to prevent the user of the StreamReader to accidentally
    call inappropriate methods of the protocol.)
    """

209
    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
210
        super().__init__(loop=loop)
211
        self._stream_reader = stream_reader
212 213
        self._stream_writer = None
        self._client_connected_cb = client_connected_cb
214 215

    def connection_made(self, transport):
216
        self._stream_reader.set_transport(transport)
217 218 219 220 221 222
        if self._client_connected_cb is not None:
            self._stream_writer = StreamWriter(transport, self,
                                               self._stream_reader,
                                               self._loop)
            res = self._client_connected_cb(self._stream_reader,
                                            self._stream_writer)
223
            if coroutines.iscoroutine(res):
224
                self._loop.create_task(res)
225 226 227

    def connection_lost(self, exc):
        if exc is None:
228
            self._stream_reader.feed_eof()
229
        else:
230
            self._stream_reader.set_exception(exc)
231
        super().connection_lost(exc)
232 233

    def data_received(self, data):
234
        self._stream_reader.feed_data(data)
235 236

    def eof_received(self):
237 238 239 240 241 242 243 244 245
        self._stream_reader.feed_eof()


class StreamWriter:
    """Wraps a Transport.

    This exposes write(), writelines(), [can_]write_eof(),
    get_extra_info() and close().  It adds drain() which returns an
    optional Future on which you can wait for flow control.  It also
246
    adds a transport property which references the Transport
247 248 249 250 251 252
    directly.
    """

    def __init__(self, transport, protocol, reader, loop):
        self._transport = transport
        self._protocol = protocol
253 254
        # drain() expects that the reader has a exception() method
        assert reader is None or isinstance(reader, StreamReader)
255 256 257
        self._reader = reader
        self._loop = loop

258 259 260 261 262 263
    def __repr__(self):
        info = [self.__class__.__name__, 'transport=%r' % self._transport]
        if self._reader is not None:
            info.append('reader=%r' % self._reader)
        return '<%s>' % ' '.join(info)

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    @property
    def transport(self):
        return self._transport

    def write(self, data):
        self._transport.write(data)

    def writelines(self, data):
        self._transport.writelines(data)

    def write_eof(self):
        return self._transport.write_eof()

    def can_write_eof(self):
        return self._transport.can_write_eof()

    def close(self):
        return self._transport.close()

    def get_extra_info(self, name, default=None):
        return self._transport.get_extra_info(name, default)

286
    @coroutine
287
    def drain(self):
288
        """Flush the write buffer.
289 290 291 292 293 294

        The intended use is to write

          w.write(data)
          yield from w.drain()
        """
295 296 297 298 299
        if self._reader is not None:
            exc = self._reader.exception()
            if exc is not None:
                raise exc
        yield from self._protocol._drain_helper()
300 301 302 303 304 305 306


class StreamReader:

    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
        # The line length limit is  a security feature;
        # it also doubles as half the buffer limit.
307
        self._limit = limit
308 309
        if loop is None:
            loop = events.get_event_loop()
310
        self._loop = loop
311
        self._buffer = bytearray()
312 313
        self._eof = False  # Whether we're done.
        self._waiter = None  # A future.
314 315 316 317 318 319 320 321 322 323
        self._exception = None
        self._transport = None
        self._paused = False

    def exception(self):
        return self._exception

    def set_exception(self, exc):
        self._exception = exc

324
        waiter = self._waiter
325
        if waiter is not None:
326
            self._waiter = None
327 328 329 330 331 332 333 334
            if not waiter.cancelled():
                waiter.set_exception(exc)

    def set_transport(self, transport):
        assert self._transport is None, 'Transport already set'
        self._transport = transport

    def _maybe_resume_transport(self):
335
        if self._paused and len(self._buffer) <= self._limit:
336
            self._paused = False
337
            self._transport.resume_reading()
338 339

    def feed_eof(self):
340 341
        self._eof = True
        waiter = self._waiter
342
        if waiter is not None:
343
            self._waiter = None
344 345 346
            if not waiter.cancelled():
                waiter.set_result(True)

347 348 349 350
    def at_eof(self):
        """Return True if the buffer is empty and 'feed_eof' was called."""
        return self._eof and not self._buffer

351
    def feed_data(self, data):
352 353
        assert not self._eof, 'feed_data after feed_eof'

354 355 356
        if not data:
            return

357
        self._buffer.extend(data)
358

359
        waiter = self._waiter
360
        if waiter is not None:
361
            self._waiter = None
362 363 364 365 366
            if not waiter.cancelled():
                waiter.set_result(False)

        if (self._transport is not None and
            not self._paused and
367
            len(self._buffer) > 2*self._limit):
368
            try:
369
                self._transport.pause_reading()
370 371 372 373 374 375 376 377
            except NotImplementedError:
                # The transport can't be paused.
                # We'll just have to buffer all data.
                # Forget the transport so we don't keep trying.
                self._transport = None
            else:
                self._paused = True

378 379 380 381 382 383 384 385 386 387
    def _create_waiter(self, func_name):
        # StreamReader uses a future to link the protocol feed_data() method
        # to a read coroutine. Running two read coroutines at the same time
        # would have an unexpected behaviour. It would not possible to know
        # which coroutine would get the next data.
        if self._waiter is not None:
            raise RuntimeError('%s() called while another coroutine is '
                               'already waiting for incoming data' % func_name)
        return futures.Future(loop=self._loop)

388
    @coroutine
389 390 391 392
    def readline(self):
        if self._exception is not None:
            raise self._exception

393
        line = bytearray()
394 395 396
        not_enough = True

        while not_enough:
397
            while self._buffer and not_enough:
398
                ichar = self._buffer.find(b'\n')
399
                if ichar < 0:
400 401
                    line.extend(self._buffer)
                    self._buffer.clear()
402 403
                else:
                    ichar += 1
404 405
                    line.extend(self._buffer[:ichar])
                    del self._buffer[:ichar]
406 407
                    not_enough = False

408
                if len(line) > self._limit:
409 410 411
                    self._maybe_resume_transport()
                    raise ValueError('Line is too long')

412
            if self._eof:
413 414 415
                break

            if not_enough:
416
                self._waiter = self._create_waiter('readline')
417
                try:
418
                    yield from self._waiter
419
                finally:
420
                    self._waiter = None
421 422

        self._maybe_resume_transport()
423
        return bytes(line)
424

425
    @coroutine
426 427 428 429 430 431 432 433
    def read(self, n=-1):
        if self._exception is not None:
            raise self._exception

        if not n:
            return b''

        if n < 0:
434 435 436 437 438 439 440 441 442 443 444
            # This used to just loop creating a new waiter hoping to
            # collect everything in self._buffer, but that would
            # deadlock if the subprocess sends more than self.limit
            # bytes.  So just call self.read(self._limit) until EOF.
            blocks = []
            while True:
                block = yield from self.read(self._limit)
                if not block:
                    break
                blocks.append(block)
            return b''.join(blocks)
445
        else:
446
            if not self._buffer and not self._eof:
447
                self._waiter = self._create_waiter('read')
448
                try:
449
                    yield from self._waiter
450
                finally:
451
                    self._waiter = None
452

453 454
        if n < 0 or len(self._buffer) <= n:
            data = bytes(self._buffer)
455
            self._buffer.clear()
456 457 458 459 460 461 462
        else:
            # n > 0 and len(self._buffer) > n
            data = bytes(self._buffer[:n])
            del self._buffer[:n]

        self._maybe_resume_transport()
        return data
463

464
    @coroutine
465 466 467 468
    def readexactly(self, n):
        if self._exception is not None:
            raise self._exception

469 470 471 472 473 474 475 476 477 478 479
        # There used to be "optimized" code here.  It created its own
        # Future and waited until self._buffer had at least the n
        # bytes, then called read(n).  Unfortunately, this could pause
        # the transport if the argument was larger than the pause
        # limit (which is twice self._limit).  So now we just read()
        # into a local buffer.

        blocks = []
        while n > 0:
            block = yield from self.read(n)
            if not block:
480 481
                partial = b''.join(blocks)
                raise IncompleteReadError(partial, len(partial) + n)
482 483
            blocks.append(block)
            n -= len(block)
484

485
        return b''.join(blocks)