test_selector_events.py 53.4 KB
Newer Older
1 2 3
"""Tests for selector_events.py"""

import errno
4
import selectors
5 6
import socket
import unittest
7
from unittest import mock
8 9 10 11 12
try:
    import ssl
except ImportError:
    ssl = None

13
import asyncio
14
from asyncio.selector_events import BaseSelectorEventLoop
15 16 17
from asyncio.selector_events import _SelectorTransport
from asyncio.selector_events import _SelectorSocketTransport
from asyncio.selector_events import _SelectorDatagramTransport
18
from asyncio.selector_events import _set_nodelay
19
from test.test_asyncio import utils as test_utils
20 21


22
MOCK_ANY = mock.ANY
23 24


25
class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
26 27

    def _make_self_pipe(self):
28 29
        self._ssock = mock.Mock()
        self._csock = mock.Mock()
30 31
        self._internal_fds += 1

32 33 34
    def _close_self_pipe(self):
        pass

35

36 37 38 39
def list_to_buffer(l=()):
    return bytearray().join(l)


40 41 42 43 44 45 46 47 48
def close_transport(transport):
    # Don't call transport.close() because the event loop and the selector
    # are mocked
    if transport._sock is None:
        return
    transport._sock.close()
    transport._sock = None


49
class BaseSelectorEventLoopTests(test_utils.TestCase):
50 51

    def setUp(self):
52
        super().setUp()
53 54 55
        self.selector = mock.Mock()
        self.selector.select.return_value = []
        self.loop = TestBaseSelectorEventLoop(self.selector)
56
        self.set_event_loop(self.loop)
57 58

    def test_make_socket_transport(self):
59 60
        m = mock.Mock()
        self.loop.add_reader = mock.Mock()
61
        self.loop.add_reader._is_coroutine = False
62 63
        transport = self.loop._make_socket_transport(m, asyncio.Protocol())
        self.assertIsInstance(transport, _SelectorSocketTransport)
64 65 66 67 68

        # Calling repr() must not fail when the event loop is closed
        self.loop.close()
        repr(transport)

69
        close_transport(transport)
70

71
    @unittest.skipIf(ssl is None, 'No ssl module')
72
    def test_make_ssl_transport(self):
73
        m = mock.Mock()
74 75 76 77 78
        self.loop._add_reader = mock.Mock()
        self.loop._add_reader._is_coroutine = False
        self.loop._add_writer = mock.Mock()
        self.loop._remove_reader = mock.Mock()
        self.loop._remove_writer = mock.Mock()
79
        waiter = asyncio.Future(loop=self.loop)
80 81 82
        with test_utils.disable_logger():
            transport = self.loop._make_ssl_transport(
                m, asyncio.Protocol(), m, waiter)
83 84 85 86 87

            with self.assertRaisesRegex(RuntimeError,
                                        r'SSL transport.*not.*initialized'):
                transport.is_reading()

88 89 90 91
            # execute the handshake while the logger is disabled
            # to ignore SSL handshake failure
            test_utils.run_briefly(self.loop)

92 93 94 95 96 97 98 99
        self.assertTrue(transport.is_reading())
        transport.pause_reading()
        transport.pause_reading()
        self.assertFalse(transport.is_reading())
        transport.resume_reading()
        transport.resume_reading()
        self.assertTrue(transport.is_reading())

100 101 102 103
        # Sanity check
        class_name = transport.__class__.__name__
        self.assertIn("ssl", class_name.lower())
        self.assertIn("transport", class_name.lower())
104

105 106 107 108
        transport.close()
        # execute pending callbacks to close the socket transport
        test_utils.run_briefly(self.loop)

109
    @mock.patch('asyncio.selector_events.ssl', None)
110
    @mock.patch('asyncio.sslproto.ssl', None)
111
    def test_make_ssl_transport_without_ssl_error(self):
112 113 114 115 116
        m = mock.Mock()
        self.loop.add_reader = mock.Mock()
        self.loop.add_writer = mock.Mock()
        self.loop.remove_reader = mock.Mock()
        self.loop.remove_writer = mock.Mock()
117 118 119
        with self.assertRaises(RuntimeError):
            self.loop._make_ssl_transport(m, m, m, m)

120
    def test_close(self):
121 122 123 124 125 126 127 128 129
        class EventLoop(BaseSelectorEventLoop):
            def _make_self_pipe(self):
                self._ssock = mock.Mock()
                self._csock = mock.Mock()
                self._internal_fds += 1

        self.loop = EventLoop(self.selector)
        self.set_event_loop(self.loop)

130 131 132 133
        ssock = self.loop._ssock
        ssock.fileno.return_value = 7
        csock = self.loop._csock
        csock.fileno.return_value = 1
134
        remove_reader = self.loop._remove_reader = mock.Mock()
135 136

        self.loop._selector.close()
137
        self.loop._selector = selector = mock.Mock()
138 139
        self.assertFalse(self.loop.is_closed())

140
        self.loop.close()
141
        self.assertTrue(self.loop.is_closed())
142 143 144 145 146 147 148 149
        self.assertIsNone(self.loop._selector)
        self.assertIsNone(self.loop._csock)
        self.assertIsNone(self.loop._ssock)
        selector.close.assert_called_with()
        ssock.close.assert_called_with()
        csock.close.assert_called_with()
        remove_reader.assert_called_with(7)

150
        # it should be possible to call close() more than once
151 152 153
        self.loop.close()
        self.loop.close()

154 155 156 157 158 159 160 161 162 163
        # operation blocked when the loop is closed
        f = asyncio.Future(loop=self.loop)
        self.assertRaises(RuntimeError, self.loop.run_forever)
        self.assertRaises(RuntimeError, self.loop.run_until_complete, f)
        fd = 0
        def callback():
            pass
        self.assertRaises(RuntimeError, self.loop.add_reader, fd, callback)
        self.assertRaises(RuntimeError, self.loop.add_writer, fd, callback)

164
    def test_close_no_selector(self):
165
        self.loop.remove_reader = mock.Mock()
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        self.loop._selector.close()
        self.loop._selector = None
        self.loop.close()
        self.assertIsNone(self.loop._selector)

    def test_read_from_self_tryagain(self):
        self.loop._ssock.recv.side_effect = BlockingIOError
        self.assertIsNone(self.loop._read_from_self())

    def test_read_from_self_exception(self):
        self.loop._ssock.recv.side_effect = OSError
        self.assertRaises(OSError, self.loop._read_from_self)

    def test_write_to_self_tryagain(self):
        self.loop._csock.send.side_effect = BlockingIOError
181 182
        with test_utils.disable_logger():
            self.assertIsNone(self.loop._write_to_self())
183 184

    def test_write_to_self_exception(self):
185 186 187
        # _write_to_self() swallows OSError
        self.loop._csock.send.side_effect = RuntimeError()
        self.assertRaises(RuntimeError, self.loop._write_to_self)
188 189

    def test_sock_recv(self):
190
        sock = test_utils.mock_nonblocking_socket()
191
        self.loop._sock_recv = mock.Mock()
192

193 194 195 196 197 198 199 200 201
        f = self.loop.create_task(self.loop.sock_recv(sock, 1024))
        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

        self.assertEqual(self.loop._sock_recv.call_args[0][1:],
                         (None, sock, 1024))

        f.cancel()
        with self.assertRaises(asyncio.CancelledError):
            self.loop.run_until_complete(f)
202 203 204 205 206

    def test_sock_recv_reconnection(self):
        sock = mock.Mock()
        sock.fileno.return_value = 10
        sock.recv.side_effect = BlockingIOError
207
        sock.gettimeout.return_value = 0.0
208 209 210

        self.loop.add_reader = mock.Mock()
        self.loop.remove_reader = mock.Mock()
211 212 213 214 215
        fut = self.loop.create_task(
            self.loop.sock_recv(sock, 1024))

        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

216 217 218 219 220 221 222 223 224
        callback = self.loop.add_reader.call_args[0][1]
        params = self.loop.add_reader.call_args[0][2:]

        # emulate the old socket has closed, but the new one has
        # the same fileno, so callback is called with old (closed) socket
        sock.fileno.return_value = -1
        sock.recv.side_effect = OSError(9)
        callback(*params)

225 226
        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

227 228
        self.assertIsInstance(fut.exception(), OSError)
        self.assertEqual((10,), self.loop.remove_reader.call_args[0])
229 230

    def test__sock_recv_canceled_fut(self):
231
        sock = mock.Mock()
232

233
        f = asyncio.Future(loop=self.loop)
234 235
        f.cancel()

236
        self.loop._sock_recv(f, None, sock, 1024)
237 238 239
        self.assertFalse(sock.recv.called)

    def test__sock_recv_unregister(self):
240
        sock = mock.Mock()
241 242
        sock.fileno.return_value = 10

243
        f = asyncio.Future(loop=self.loop)
244 245
        f.cancel()

246
        self.loop.remove_reader = mock.Mock()
247
        self.loop._sock_recv(f, 10, sock, 1024)
248 249 250
        self.assertEqual((10,), self.loop.remove_reader.call_args[0])

    def test__sock_recv_tryagain(self):
251
        f = asyncio.Future(loop=self.loop)
252
        sock = mock.Mock()
253 254 255
        sock.fileno.return_value = 10
        sock.recv.side_effect = BlockingIOError

256
        self.loop.add_reader = mock.Mock()
257 258
        self.loop._sock_recv(f, None, sock, 1024)
        self.assertEqual((10, self.loop._sock_recv, f, 10, sock, 1024),
259 260 261
                         self.loop.add_reader.call_args[0])

    def test__sock_recv_exception(self):
262
        f = asyncio.Future(loop=self.loop)
263
        sock = mock.Mock()
264 265 266
        sock.fileno.return_value = 10
        err = sock.recv.side_effect = OSError()

267
        self.loop._sock_recv(f, None, sock, 1024)
268 269 270
        self.assertIs(err, f.exception())

    def test_sock_sendall(self):
271
        sock = test_utils.mock_nonblocking_socket()
272
        self.loop._sock_sendall = mock.Mock()
273

274 275 276 277 278
        f = self.loop.create_task(
            self.loop.sock_sendall(sock, b'data'))

        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

279
        self.assertEqual(
280 281 282 283 284 285
            (None, sock, b'data'),
            self.loop._sock_sendall.call_args[0][1:])

        f.cancel()
        with self.assertRaises(asyncio.CancelledError):
            self.loop.run_until_complete(f)
286 287

    def test_sock_sendall_nodata(self):
288
        sock = test_utils.mock_nonblocking_socket()
289
        self.loop._sock_sendall = mock.Mock()
290

291 292 293
        f = self.loop.create_task(self.loop.sock_sendall(sock, b''))
        self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop))

294 295 296 297
        self.assertTrue(f.done())
        self.assertIsNone(f.result())
        self.assertFalse(self.loop._sock_sendall.called)

298 299 300 301
    def test_sock_sendall_reconnection(self):
        sock = mock.Mock()
        sock.fileno.return_value = 10
        sock.send.side_effect = BlockingIOError
302
        sock.gettimeout.return_value = 0.0
303 304 305

        self.loop.add_writer = mock.Mock()
        self.loop.remove_writer = mock.Mock()
306 307 308 309
        fut = self.loop.create_task(self.loop.sock_sendall(sock, b'data'))

        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

310 311 312 313 314 315 316 317 318
        callback = self.loop.add_writer.call_args[0][1]
        params = self.loop.add_writer.call_args[0][2:]

        # emulate the old socket has closed, but the new one has
        # the same fileno, so callback is called with old (closed) socket
        sock.fileno.return_value = -1
        sock.send.side_effect = OSError(9)
        callback(*params)

319 320
        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

321 322 323
        self.assertIsInstance(fut.exception(), OSError)
        self.assertEqual((10,), self.loop.remove_writer.call_args[0])

324
    def test__sock_sendall_canceled_fut(self):
325
        sock = mock.Mock()
326

327
        f = asyncio.Future(loop=self.loop)
328 329
        f.cancel()

330
        self.loop._sock_sendall(f, None, sock, b'data')
331 332 333
        self.assertFalse(sock.send.called)

    def test__sock_sendall_unregister(self):
334
        sock = mock.Mock()
335 336
        sock.fileno.return_value = 10

337
        f = asyncio.Future(loop=self.loop)
338 339
        f.cancel()

340
        self.loop.remove_writer = mock.Mock()
341
        self.loop._sock_sendall(f, 10, sock, b'data')
342 343 344
        self.assertEqual((10,), self.loop.remove_writer.call_args[0])

    def test__sock_sendall_tryagain(self):
345
        f = asyncio.Future(loop=self.loop)
346
        sock = mock.Mock()
347 348 349
        sock.fileno.return_value = 10
        sock.send.side_effect = BlockingIOError

350
        self.loop.add_writer = mock.Mock()
351
        self.loop._sock_sendall(f, None, sock, b'data')
352
        self.assertEqual(
353
            (10, self.loop._sock_sendall, f, 10, sock, b'data'),
354 355 356
            self.loop.add_writer.call_args[0])

    def test__sock_sendall_interrupted(self):
357
        f = asyncio.Future(loop=self.loop)
358
        sock = mock.Mock()
359 360 361
        sock.fileno.return_value = 10
        sock.send.side_effect = InterruptedError

362
        self.loop.add_writer = mock.Mock()
363
        self.loop._sock_sendall(f, None, sock, b'data')
364
        self.assertEqual(
365
            (10, self.loop._sock_sendall, f, 10, sock, b'data'),
366 367 368
            self.loop.add_writer.call_args[0])

    def test__sock_sendall_exception(self):
369
        f = asyncio.Future(loop=self.loop)
370
        sock = mock.Mock()
371 372 373
        sock.fileno.return_value = 10
        err = sock.send.side_effect = OSError()

374
        self.loop._sock_sendall(f, None, sock, b'data')
375 376 377
        self.assertIs(f.exception(), err)

    def test__sock_sendall(self):
378
        sock = mock.Mock()
379

380
        f = asyncio.Future(loop=self.loop)
381 382 383
        sock.fileno.return_value = 10
        sock.send.return_value = 4

384
        self.loop._sock_sendall(f, None, sock, b'data')
385 386 387 388
        self.assertTrue(f.done())
        self.assertIsNone(f.result())

    def test__sock_sendall_partial(self):
389
        sock = mock.Mock()
390

391
        f = asyncio.Future(loop=self.loop)
392 393 394
        sock.fileno.return_value = 10
        sock.send.return_value = 2

395
        self.loop.add_writer = mock.Mock()
396
        self.loop._sock_sendall(f, None, sock, b'data')
397 398
        self.assertFalse(f.done())
        self.assertEqual(
399
            (10, self.loop._sock_sendall, f, 10, sock, b'ta'),
400 401 402
            self.loop.add_writer.call_args[0])

    def test__sock_sendall_none(self):
403
        sock = mock.Mock()
404

405
        f = asyncio.Future(loop=self.loop)
406 407 408
        sock.fileno.return_value = 10
        sock.send.return_value = 0

409
        self.loop.add_writer = mock.Mock()
410
        self.loop._sock_sendall(f, None, sock, b'data')
411 412
        self.assertFalse(f.done())
        self.assertEqual(
413
            (10, self.loop._sock_sendall, f, 10, sock, b'data'),
414 415
            self.loop.add_writer.call_args[0])

416
    def test_sock_connect_timeout(self):
417
        # asyncio issue #205: sock_connect() must unregister the socket on
418 419 420 421 422 423 424 425 426
        # timeout error

        # prepare mocks
        self.loop.add_writer = mock.Mock()
        self.loop.remove_writer = mock.Mock()
        sock = test_utils.mock_nonblocking_socket()
        sock.connect.side_effect = BlockingIOError

        # first call to sock_connect() registers the socket
427 428
        fut = self.loop.create_task(
            self.loop.sock_connect(sock, ('127.0.0.1', 80)))
429
        self.loop._run_once()
430 431 432 433 434
        self.assertTrue(sock.connect.called)
        self.assertTrue(self.loop.add_writer.called)

        # on timeout, the socket must be unregistered
        sock.connect.reset_mock()
435 436
        fut.cancel()
        with self.assertRaises(asyncio.CancelledError):
437 438 439
            self.loop.run_until_complete(fut)
        self.assertTrue(self.loop.remove_writer.called)

440 441
    @mock.patch('socket.getaddrinfo')
    def test_sock_connect_resolve_using_socket_params(self, m_gai):
442 443
        addr = ('need-resolution.com', 8080)
        sock = test_utils.mock_nonblocking_socket()
444 445 446 447

        m_gai.side_effect = \
            lambda *args: [(None, None, None, None, ('127.0.0.1', 0))]

448
        con = self.loop.create_task(self.loop.sock_connect(sock, addr))
449
        self.loop.run_until_complete(con)
450 451 452
        m_gai.assert_called_with(
            addr[0], addr[1], sock.family, sock.type, sock.proto, 0)

453 454
        self.loop.run_until_complete(con)
        sock.connect.assert_called_with(('127.0.0.1', 0))
455

456
    def test__sock_connect(self):
457
        f = asyncio.Future(loop=self.loop)
458

459
        sock = mock.Mock()
460 461
        sock.fileno.return_value = 10

462 463 464 465
        resolved = self.loop.create_future()
        resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
                              socket.IPPROTO_TCP, '', ('127.0.0.1', 8080))])
        self.loop._sock_connect(f, sock, resolved)
466 467 468 469
        self.assertTrue(f.done())
        self.assertIsNone(f.result())
        self.assertTrue(sock.connect.called)

470
    def test__sock_connect_cb_cancelled_fut(self):
471
        sock = mock.Mock()
472
        self.loop.remove_writer = mock.Mock()
473

474
        f = asyncio.Future(loop=self.loop)
475 476
        f.cancel()

477 478 479 480 481 482 483 484
        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
        self.assertFalse(sock.getsockopt.called)

    def test__sock_connect_writer(self):
        # check that the fd is registered and then unregistered
        self.loop._process_events = mock.Mock()
        self.loop.add_writer = mock.Mock()
        self.loop.remove_writer = mock.Mock()
485

486
        sock = mock.Mock()
487
        sock.fileno.return_value = 10
488 489 490
        sock.connect.side_effect = BlockingIOError
        sock.getsockopt.return_value = 0
        address = ('127.0.0.1', 8080)
491 492 493
        resolved = self.loop.create_future()
        resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
                              socket.IPPROTO_TCP, '', address)])
494

495
        f = asyncio.Future(loop=self.loop)
496 497
        self.loop._sock_connect(f, sock, resolved)
        self.loop._run_once()
498 499
        self.assertTrue(self.loop.add_writer.called)
        self.assertEqual(10, self.loop.add_writer.call_args[0][0])
500

501 502 503
        self.loop._sock_connect_cb(f, sock, address)
        # need to run the event loop to execute _sock_connect_done() callback
        self.loop.run_until_complete(f)
504 505
        self.assertEqual((10,), self.loop.remove_writer.call_args[0])

506
    def test__sock_connect_cb_tryagain(self):
507
        f = asyncio.Future(loop=self.loop)
508
        sock = mock.Mock()
509 510 511
        sock.fileno.return_value = 10
        sock.getsockopt.return_value = errno.EAGAIN

512 513
        # check that the exception is handled
        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
514

515
    def test__sock_connect_cb_exception(self):
516
        f = asyncio.Future(loop=self.loop)
517
        sock = mock.Mock()
518 519 520
        sock.fileno.return_value = 10
        sock.getsockopt.return_value = errno.ENOTCONN

521
        self.loop.remove_writer = mock.Mock()
522
        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
523 524 525
        self.assertIsInstance(f.exception(), OSError)

    def test_sock_accept(self):
526
        sock = test_utils.mock_nonblocking_socket()
527
        self.loop._sock_accept = mock.Mock()
528

529 530 531 532 533 534 535 536 537
        f = self.loop.create_task(self.loop.sock_accept(sock))
        self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

        self.assertFalse(self.loop._sock_accept.call_args[0][1])
        self.assertIs(self.loop._sock_accept.call_args[0][2], sock)

        f.cancel()
        with self.assertRaises(asyncio.CancelledError):
            self.loop.run_until_complete(f)
538 539

    def test__sock_accept(self):
540
        f = asyncio.Future(loop=self.loop)
541

542
        conn = mock.Mock()
543

544
        sock = mock.Mock()
545 546 547 548 549 550 551 552 553
        sock.fileno.return_value = 10
        sock.accept.return_value = conn, ('127.0.0.1', 1000)

        self.loop._sock_accept(f, False, sock)
        self.assertTrue(f.done())
        self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
        self.assertEqual((False,), conn.setblocking.call_args[0])

    def test__sock_accept_canceled_fut(self):
554
        sock = mock.Mock()
555

556
        f = asyncio.Future(loop=self.loop)
557 558 559 560 561 562
        f.cancel()

        self.loop._sock_accept(f, False, sock)
        self.assertFalse(sock.accept.called)

    def test__sock_accept_unregister(self):
563
        sock = mock.Mock()
564 565
        sock.fileno.return_value = 10

566
        f = asyncio.Future(loop=self.loop)
567 568
        f.cancel()

569
        self.loop.remove_reader = mock.Mock()
570 571 572 573
        self.loop._sock_accept(f, True, sock)
        self.assertEqual((10,), self.loop.remove_reader.call_args[0])

    def test__sock_accept_tryagain(self):
574
        f = asyncio.Future(loop=self.loop)
575
        sock = mock.Mock()
576 577 578
        sock.fileno.return_value = 10
        sock.accept.side_effect = BlockingIOError

579
        self.loop.add_reader = mock.Mock()
580 581 582 583 584 585
        self.loop._sock_accept(f, False, sock)
        self.assertEqual(
            (10, self.loop._sock_accept, f, True, sock),
            self.loop.add_reader.call_args[0])

    def test__sock_accept_exception(self):
586
        f = asyncio.Future(loop=self.loop)
587
        sock = mock.Mock()
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
        sock.fileno.return_value = 10
        err = sock.accept.side_effect = OSError()

        self.loop._sock_accept(f, False, sock)
        self.assertIs(err, f.exception())

    def test_add_reader(self):
        self.loop._selector.get_key.side_effect = KeyError
        cb = lambda: True
        self.loop.add_reader(1, cb)

        self.assertTrue(self.loop._selector.register.called)
        fd, mask, (r, w) = self.loop._selector.register.call_args[0]
        self.assertEqual(1, fd)
        self.assertEqual(selectors.EVENT_READ, mask)
        self.assertEqual(cb, r._callback)
        self.assertIsNone(w)

    def test_add_reader_existing(self):
607 608
        reader = mock.Mock()
        writer = mock.Mock()
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_WRITE, (reader, writer))
        cb = lambda: True
        self.loop.add_reader(1, cb)

        self.assertTrue(reader.cancel.called)
        self.assertFalse(self.loop._selector.register.called)
        self.assertTrue(self.loop._selector.modify.called)
        fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
        self.assertEqual(1, fd)
        self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
        self.assertEqual(cb, r._callback)
        self.assertEqual(writer, w)

    def test_add_reader_existing_writer(self):
624
        writer = mock.Mock()
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_WRITE, (None, writer))
        cb = lambda: True
        self.loop.add_reader(1, cb)

        self.assertFalse(self.loop._selector.register.called)
        self.assertTrue(self.loop._selector.modify.called)
        fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
        self.assertEqual(1, fd)
        self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
        self.assertEqual(cb, r._callback)
        self.assertEqual(writer, w)

    def test_remove_reader(self):
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_READ, (None, None))
        self.assertFalse(self.loop.remove_reader(1))

        self.assertTrue(self.loop._selector.unregister.called)

    def test_remove_reader_read_write(self):
646 647
        reader = mock.Mock()
        writer = mock.Mock()
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
            (reader, writer))
        self.assertTrue(
            self.loop.remove_reader(1))

        self.assertFalse(self.loop._selector.unregister.called)
        self.assertEqual(
            (1, selectors.EVENT_WRITE, (None, writer)),
            self.loop._selector.modify.call_args[0])

    def test_remove_reader_unknown(self):
        self.loop._selector.get_key.side_effect = KeyError
        self.assertFalse(
            self.loop.remove_reader(1))

    def test_add_writer(self):
        self.loop._selector.get_key.side_effect = KeyError
        cb = lambda: True
        self.loop.add_writer(1, cb)

        self.assertTrue(self.loop._selector.register.called)
        fd, mask, (r, w) = self.loop._selector.register.call_args[0]
        self.assertEqual(1, fd)
        self.assertEqual(selectors.EVENT_WRITE, mask)
        self.assertIsNone(r)
        self.assertEqual(cb, w._callback)

    def test_add_writer_existing(self):
677 678
        reader = mock.Mock()
        writer = mock.Mock()
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_READ, (reader, writer))
        cb = lambda: True
        self.loop.add_writer(1, cb)

        self.assertTrue(writer.cancel.called)
        self.assertFalse(self.loop._selector.register.called)
        self.assertTrue(self.loop._selector.modify.called)
        fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
        self.assertEqual(1, fd)
        self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
        self.assertEqual(reader, r)
        self.assertEqual(cb, w._callback)

    def test_remove_writer(self):
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_WRITE, (None, None))
        self.assertFalse(self.loop.remove_writer(1))

        self.assertTrue(self.loop._selector.unregister.called)

    def test_remove_writer_read_write(self):
701 702
        reader = mock.Mock()
        writer = mock.Mock()
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
        self.loop._selector.get_key.return_value = selectors.SelectorKey(
            1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
            (reader, writer))
        self.assertTrue(
            self.loop.remove_writer(1))

        self.assertFalse(self.loop._selector.unregister.called)
        self.assertEqual(
            (1, selectors.EVENT_READ, (reader, None)),
            self.loop._selector.modify.call_args[0])

    def test_remove_writer_unknown(self):
        self.loop._selector.get_key.side_effect = KeyError
        self.assertFalse(
            self.loop.remove_writer(1))

    def test_process_events_read(self):
720
        reader = mock.Mock()
721 722
        reader._cancelled = False

723
        self.loop._add_callback = mock.Mock()
724 725 726 727 728 729 730 731
        self.loop._process_events(
            [(selectors.SelectorKey(
                1, 1, selectors.EVENT_READ, (reader, None)),
              selectors.EVENT_READ)])
        self.assertTrue(self.loop._add_callback.called)
        self.loop._add_callback.assert_called_with(reader)

    def test_process_events_read_cancelled(self):
732
        reader = mock.Mock()
733 734
        reader.cancelled = True

735
        self.loop._remove_reader = mock.Mock()
736 737 738 739
        self.loop._process_events(
            [(selectors.SelectorKey(
                1, 1, selectors.EVENT_READ, (reader, None)),
             selectors.EVENT_READ)])
740
        self.loop._remove_reader.assert_called_with(1)
741 742

    def test_process_events_write(self):
743
        writer = mock.Mock()
744 745
        writer._cancelled = False

746
        self.loop._add_callback = mock.Mock()
747 748 749 750 751 752 753
        self.loop._process_events(
            [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
                                    (None, writer)),
              selectors.EVENT_WRITE)])
        self.loop._add_callback.assert_called_with(writer)

    def test_process_events_write_cancelled(self):
754
        writer = mock.Mock()
755
        writer.cancelled = True
756
        self.loop._remove_writer = mock.Mock()
757 758 759 760 761

        self.loop._process_events(
            [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
                                    (None, writer)),
              selectors.EVENT_WRITE)])
762
        self.loop._remove_writer.assert_called_with(1)
763

764 765 766 767 768 769 770 771 772 773 774 775 776 777
    def test_accept_connection_multiple(self):
        sock = mock.Mock()
        sock.accept.return_value = (mock.Mock(), mock.Mock())
        backlog = 100
        # Mock the coroutine generation for a connection to prevent
        # warnings related to un-awaited coroutines.
        mock_obj = mock.patch.object
        with mock_obj(self.loop, '_accept_connection2') as accept2_mock:
            accept2_mock.return_value = None
            with mock_obj(self.loop, 'create_task') as task_mock:
                task_mock.return_value = None
                self.loop._accept_connection(mock.Mock(), sock, backlog=backlog)
        self.assertEqual(sock.accept.call_count, backlog)

778

779
class SelectorTransportTests(test_utils.TestCase):
780 781

    def setUp(self):
782
        super().setUp()
783
        self.loop = self.new_test_loop()
784
        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
785
        self.sock = mock.Mock(socket.socket)
786 787
        self.sock.fileno.return_value = 7

788 789 790 791 792 793
    def create_transport(self):
        transport = _SelectorTransport(self.loop, self.sock, self.protocol,
                                       None)
        self.addCleanup(close_transport, transport)
        return transport

794
    def test_ctor(self):
795
        tr = self.create_transport()
796 797 798 799 800
        self.assertIs(tr._loop, self.loop)
        self.assertIs(tr._sock, self.sock)
        self.assertIs(tr._sock_fd, 7)

    def test_abort(self):
801
        tr = self.create_transport()
802
        tr._force_close = mock.Mock()
803 804 805 806 807

        tr.abort()
        tr._force_close.assert_called_with(None)

    def test_close(self):
808
        tr = self.create_transport()
809 810
        tr.close()

811
        self.assertTrue(tr.is_closing())
812 813 814 815 816 817 818 819 820
        self.assertEqual(1, self.loop.remove_reader_count[7])
        self.protocol.connection_lost(None)
        self.assertEqual(tr._conn_lost, 1)

        tr.close()
        self.assertEqual(tr._conn_lost, 1)
        self.assertEqual(1, self.loop.remove_reader_count[7])

    def test_close_write_buffer(self):
821
        tr = self.create_transport()
822
        tr._buffer.extend(b'data')
823 824 825 826 827 828 829
        tr.close()

        self.assertFalse(self.loop.readers)
        test_utils.run_briefly(self.loop)
        self.assertFalse(self.protocol.connection_lost.called)

    def test_force_close(self):
830
        tr = self.create_transport()
831
        tr._buffer.extend(b'1')
832 833
        self.loop._add_reader(7, mock.sentinel)
        self.loop._add_writer(7, mock.sentinel)
834 835
        tr._force_close(None)

836
        self.assertTrue(tr.is_closing())
837
        self.assertEqual(tr._buffer, list_to_buffer())
838 839 840 841 842 843 844 845
        self.assertFalse(self.loop.readers)
        self.assertFalse(self.loop.writers)

        # second close should not remove reader
        tr._force_close(None)
        self.assertFalse(self.loop.readers)
        self.assertEqual(1, self.loop.remove_reader_count[7])

846
    @mock.patch('asyncio.log.logger.error')
847 848
    def test_fatal_error(self, m_exc):
        exc = OSError()
849
        tr = self.create_transport()
850
        tr._force_close = mock.Mock()
851 852
        tr._fatal_error(exc)

853 854
        m_exc.assert_called_with(
            test_utils.MockPattern(
855
                'Fatal error on transport\nprotocol:.*\ntransport:.*'),
856 857
            exc_info=(OSError, MOCK_ANY, MOCK_ANY))

858 859 860 861
        tr._force_close.assert_called_with(exc)

    def test_connection_lost(self):
        exc = OSError()
862
        tr = self.create_transport()
863 864
        self.assertIsNotNone(tr._protocol)
        self.assertIsNotNone(tr._loop)
865 866 867 868 869 870 871 872 873 874
        tr._call_connection_lost(exc)

        self.protocol.connection_lost.assert_called_with(exc)
        self.sock.close.assert_called_with()
        self.assertIsNone(tr._sock)

        self.assertIsNone(tr._protocol)
        self.assertIsNone(tr._loop)


875
class SelectorSocketTransportTests(test_utils.TestCase):
876 877

    def setUp(self):
878
        super().setUp()
879
        self.loop = self.new_test_loop()
880
        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
881
        self.sock = mock.Mock(socket.socket)
882 883
        self.sock_fd = self.sock.fileno.return_value = 7

884 885 886 887 888 889
    def socket_transport(self, waiter=None):
        transport = _SelectorSocketTransport(self.loop, self.sock,
                                             self.protocol, waiter=waiter)
        self.addCleanup(close_transport, transport)
        return transport

890
    def test_ctor(self):
891 892 893 894
        waiter = asyncio.Future(loop=self.loop)
        tr = self.socket_transport(waiter=waiter)
        self.loop.run_until_complete(waiter)

895 896 897 898 899
        self.loop.assert_reader(7, tr._read_ready)
        test_utils.run_briefly(self.loop)
        self.protocol.connection_made.assert_called_with(tr)

    def test_ctor_with_waiter(self):
900 901 902
        waiter = asyncio.Future(loop=self.loop)
        self.socket_transport(waiter=waiter)
        self.loop.run_until_complete(waiter)
903

904
        self.assertIsNone(waiter.result())
905

906
    def test_pause_resume_reading(self):
907
        tr = self.socket_transport()
908
        test_utils.run_briefly(self.loop)
909
        self.assertFalse(tr._paused)
910
        self.assertTrue(tr.is_reading())
911
        self.loop.assert_reader(7, tr._read_ready)
912 913

        tr.pause_reading()
914
        tr.pause_reading()
915
        self.assertTrue(tr._paused)
916 917 918 919
        self.assertFalse(tr.is_reading())
        self.loop.assert_no_reader(7)

        tr.resume_reading()
920
        tr.resume_reading()
921
        self.assertFalse(tr._paused)
922
        self.assertTrue(tr.is_reading())
923
        self.loop.assert_reader(7, tr._read_ready)
924 925 926 927

        tr.close()
        self.assertFalse(tr.is_reading())
        self.loop.assert_no_reader(7)
928 929

    def test_read_ready(self):
930
        transport = self.socket_transport()
931 932 933 934 935 936 937

        self.sock.recv.return_value = b'data'
        transport._read_ready()

        self.protocol.data_received.assert_called_with(b'data')

    def test_read_ready_eof(self):
938
        transport = self.socket_transport()
939
        transport.close = mock.Mock()
940 941 942 943 944 945 946 947

        self.sock.recv.return_value = b''
        transport._read_ready()

        self.protocol.eof_received.assert_called_with()
        transport.close.assert_called_with()

    def test_read_ready_eof_keep_open(self):
948
        transport = self.socket_transport()
949
        transport.close = mock.Mock()
950 951 952 953 954 955 956 957

        self.sock.recv.return_value = b''
        self.protocol.eof_received.return_value = True
        transport._read_ready()

        self.protocol.eof_received.assert_called_with()
        self.assertFalse(transport.close.called)

958
    @mock.patch('logging.exception')
959 960 961
    def test_read_ready_tryagain(self, m_exc):
        self.sock.recv.side_effect = BlockingIOError

962
        transport = self.socket_transport()
963
        transport._fatal_error = mock.Mock()
964 965 966 967
        transport._read_ready()

        self.assertFalse(transport._fatal_error.called)

968
    @mock.patch('logging.exception')
969 970 971
    def test_read_ready_tryagain_interrupted(self, m_exc):
        self.sock.recv.side_effect = InterruptedError

972
        transport = self.socket_transport()
973
        transport._fatal_error = mock.Mock()
974 975 976 977
        transport._read_ready()

        self.assertFalse(transport._fatal_error.called)

978
    @mock.patch('logging.exception')
979 980 981
    def test_read_ready_conn_reset(self, m_exc):
        err = self.sock.recv.side_effect = ConnectionResetError()

982
        transport = self.socket_transport()
983
        transport._force_close = mock.Mock()
984 985
        with test_utils.disable_logger():
            transport._read_ready()
986 987
        transport._force_close.assert_called_with(err)

988
    @mock.patch('logging.exception')
989 990 991
    def test_read_ready_err(self, m_exc):
        err = self.sock.recv.side_effect = OSError()

992
        transport = self.socket_transport()
993
        transport._fatal_error = mock.Mock()
994 995
        transport._read_ready()

996 997 998
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal read error on socket transport')
999 1000 1001 1002 1003

    def test_write(self):
        data = b'data'
        self.sock.send.return_value = len(data)

1004
        transport = self.socket_transport()
1005 1006 1007
        transport.write(data)
        self.sock.send.assert_called_with(data)

1008 1009 1010 1011
    def test_write_bytearray(self):
        data = bytearray(b'data')
        self.sock.send.return_value = len(data)

1012
        transport = self.socket_transport()
1013 1014 1015 1016 1017 1018 1019 1020
        transport.write(data)
        self.sock.send.assert_called_with(data)
        self.assertEqual(data, bytearray(b'data'))  # Hasn't been mutated.

    def test_write_memoryview(self):
        data = memoryview(b'data')
        self.sock.send.return_value = len(data)

1021
        transport = self.socket_transport()
1022 1023 1024
        transport.write(data)
        self.sock.send.assert_called_with(data)

1025
    def test_write_no_data(self):
1026
        transport = self.socket_transport()
1027
        transport._buffer.extend(b'data')
1028 1029
        transport.write(b'')
        self.assertFalse(self.sock.send.called)
1030
        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
1031 1032

    def test_write_buffer(self):
1033
        transport = self.socket_transport()
1034
        transport._buffer.extend(b'data1')
1035 1036
        transport.write(b'data2')
        self.assertFalse(self.sock.send.called)
1037
        self.assertEqual(list_to_buffer([b'data1', b'data2']),
1038 1039 1040 1041 1042 1043
                         transport._buffer)

    def test_write_partial(self):
        data = b'data'
        self.sock.send.return_value = 2

1044
        transport = self.socket_transport()
1045 1046 1047
        transport.write(data)

        self.loop.assert_writer(7, transport._write_ready)
1048 1049 1050 1051 1052 1053
        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)

    def test_write_partial_bytearray(self):
        data = bytearray(b'data')
        self.sock.send.return_value = 2

1054
        transport = self.socket_transport()
1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
        transport.write(data)

        self.loop.assert_writer(7, transport._write_ready)
        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
        self.assertEqual(data, bytearray(b'data'))  # Hasn't been mutated.

    def test_write_partial_memoryview(self):
        data = memoryview(b'data')
        self.sock.send.return_value = 2

1065
        transport = self.socket_transport()
1066 1067 1068 1069
        transport.write(data)

        self.loop.assert_writer(7, transport._write_ready)
        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
1070 1071 1072 1073 1074 1075

    def test_write_partial_none(self):
        data = b'data'
        self.sock.send.return_value = 0
        self.sock.fileno.return_value = 7

1076
        transport = self.socket_transport()
1077 1078 1079
        transport.write(data)

        self.loop.assert_writer(7, transport._write_ready)
1080
        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
1081 1082 1083 1084 1085

    def test_write_tryagain(self):
        self.sock.send.side_effect = BlockingIOError

        data = b'data'
1086
        transport = self.socket_transport()
1087 1088 1089
        transport.write(data)

        self.loop.assert_writer(7, transport._write_ready)
1090
        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
1091

1092
    @mock.patch('asyncio.selector_events.logger')
1093 1094 1095 1096
    def test_write_exception(self, m_log):
        err = self.sock.send.side_effect = OSError()

        data = b'data'
1097
        transport = self.socket_transport()
1098
        transport._fatal_error = mock.Mock()
1099
        transport.write(data)
1100 1101 1102
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal write error on socket transport')
1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
        transport._conn_lost = 1

        self.sock.reset_mock()
        transport.write(data)
        self.assertFalse(self.sock.send.called)
        self.assertEqual(transport._conn_lost, 2)
        transport.write(data)
        transport.write(data)
        transport.write(data)
        transport.write(data)
        m_log.warning.assert_called_with('socket.send() raised exception.')

    def test_write_str(self):
1116
        transport = self.socket_transport()
1117
        self.assertRaises(TypeError, transport.write, 'str')
1118 1119

    def test_write_closing(self):
1120
        transport = self.socket_transport()
1121 1122 1123 1124 1125 1126 1127 1128 1129
        transport.close()
        self.assertEqual(transport._conn_lost, 1)
        transport.write(b'data')
        self.assertEqual(transport._conn_lost, 2)

    def test_write_ready(self):
        data = b'data'
        self.sock.send.return_value = len(data)

1130
        transport = self.socket_transport()
1131
        transport._buffer.extend(data)
1132
        self.loop._add_writer(7, transport._write_ready)
1133 1134 1135 1136 1137 1138 1139 1140
        transport._write_ready()
        self.assertTrue(self.sock.send.called)
        self.assertFalse(self.loop.writers)

    def test_write_ready_closing(self):
        data = b'data'
        self.sock.send.return_value = len(data)

1141
        transport = self.socket_transport()
1142
        transport._closing = True
1143
        transport._buffer.extend(data)
1144
        self.loop._add_writer(7, transport._write_ready)
1145
        transport._write_ready()
1146
        self.assertTrue(self.sock.send.called)
1147 1148 1149 1150 1151
        self.assertFalse(self.loop.writers)
        self.sock.close.assert_called_with()
        self.protocol.connection_lost.assert_called_with(None)

    def test_write_ready_no_data(self):
1152
        transport = self.socket_transport()
1153
        # This is an internal error.
1154 1155 1156 1157 1158 1159
        self.assertRaises(AssertionError, transport._write_ready)

    def test_write_ready_partial(self):
        data = b'data'
        self.sock.send.return_value = 2

1160
        transport = self.socket_transport()
1161
        transport._buffer.extend(data)
1162
        self.loop._add_writer(7, transport._write_ready)
1163 1164
        transport._write_ready()
        self.loop.assert_writer(7, transport._write_ready)
1165
        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
1166 1167 1168 1169 1170

    def test_write_ready_partial_none(self):
        data = b'data'
        self.sock.send.return_value = 0

1171
        transport = self.socket_transport()
1172
        transport._buffer.extend(data)
1173
        self.loop._add_writer(7, transport._write_ready)
1174 1175
        transport._write_ready()
        self.loop.assert_writer(7, transport._write_ready)
1176
        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
1177 1178 1179 1180

    def test_write_ready_tryagain(self):
        self.sock.send.side_effect = BlockingIOError

1181
        transport = self.socket_transport()
1182
        transport._buffer = list_to_buffer([b'data1', b'data2'])
1183
        self.loop._add_writer(7, transport._write_ready)
1184 1185 1186
        transport._write_ready()

        self.loop.assert_writer(7, transport._write_ready)
1187
        self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
1188 1189 1190 1191

    def test_write_ready_exception(self):
        err = self.sock.send.side_effect = OSError()

1192
        transport = self.socket_transport()
1193
        transport._fatal_error = mock.Mock()
1194
        transport._buffer.extend(b'data')
1195
        transport._write_ready()
1196 1197 1198
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal write error on socket transport')
1199 1200

    def test_write_eof(self):
1201
        tr = self.socket_transport()
1202 1203 1204 1205 1206 1207 1208 1209
        self.assertTrue(tr.can_write_eof())
        tr.write_eof()
        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
        tr.write_eof()
        self.assertEqual(self.sock.shutdown.call_count, 1)
        tr.close()

    def test_write_eof_buffer(self):
1210
        tr = self.socket_transport()
1211 1212 1213
        self.sock.send.side_effect = BlockingIOError
        tr.write(b'data')
        tr.write_eof()
1214
        self.assertEqual(tr._buffer, list_to_buffer([b'data']))
1215 1216 1217 1218
        self.assertTrue(tr._eof)
        self.assertFalse(self.sock.shutdown.called)
        self.sock.send.side_effect = lambda _: 4
        tr._write_ready()
1219
        self.assertTrue(self.sock.send.called)
1220 1221 1222
        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
        tr.close()

1223 1224
    @mock.patch('asyncio.base_events.logger')
    def test_transport_close_remove_writer(self, m_log):
1225
        remove_writer = self.loop._remove_writer = mock.Mock()
1226 1227 1228 1229 1230

        transport = self.socket_transport()
        transport.close()
        remove_writer.assert_called_with(self.sock_fd)

1231

1232
class SelectorDatagramTransportTests(test_utils.TestCase):
1233 1234

    def setUp(self):
1235
        super().setUp()
1236
        self.loop = self.new_test_loop()
1237
        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
1238
        self.sock = mock.Mock(spec_set=socket.socket)
1239 1240
        self.sock.fileno.return_value = 7

1241 1242 1243 1244 1245 1246 1247
    def datagram_transport(self, address=None):
        transport = _SelectorDatagramTransport(self.loop, self.sock,
                                               self.protocol,
                                               address=address)
        self.addCleanup(close_transport, transport)
        return transport

1248
    def test_read_ready(self):
1249
        transport = self.datagram_transport()
1250 1251 1252 1253 1254 1255 1256 1257

        self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
        transport._read_ready()

        self.protocol.datagram_received.assert_called_with(
            b'data', ('0.0.0.0', 1234))

    def test_read_ready_tryagain(self):
1258
        transport = self.datagram_transport()
1259 1260

        self.sock.recvfrom.side_effect = BlockingIOError
1261
        transport._fatal_error = mock.Mock()
1262 1263 1264 1265 1266
        transport._read_ready()

        self.assertFalse(transport._fatal_error.called)

    def test_read_ready_err(self):
1267
        transport = self.datagram_transport()
1268

1269
        err = self.sock.recvfrom.side_effect = RuntimeError()
1270
        transport._fatal_error = mock.Mock()
1271 1272
        transport._read_ready()

1273 1274 1275
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal read error on datagram transport')
1276

1277
    def test_read_ready_oserr(self):
1278
        transport = self.datagram_transport()
1279 1280

        err = self.sock.recvfrom.side_effect = OSError()
1281
        transport._fatal_error = mock.Mock()
1282 1283 1284 1285 1286
        transport._read_ready()

        self.assertFalse(transport._fatal_error.called)
        self.protocol.error_received.assert_called_with(err)

1287 1288
    def test_sendto(self):
        data = b'data'
1289
        transport = self.datagram_transport()
1290 1291 1292 1293 1294
        transport.sendto(data, ('0.0.0.0', 1234))
        self.assertTrue(self.sock.sendto.called)
        self.assertEqual(
            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))

1295 1296
    def test_sendto_bytearray(self):
        data = bytearray(b'data')
1297
        transport = self.datagram_transport()
1298 1299 1300 1301 1302 1303 1304
        transport.sendto(data, ('0.0.0.0', 1234))
        self.assertTrue(self.sock.sendto.called)
        self.assertEqual(
            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))

    def test_sendto_memoryview(self):
        data = memoryview(b'data')
1305
        transport = self.datagram_transport()
1306 1307 1308 1309 1310
        transport.sendto(data, ('0.0.0.0', 1234))
        self.assertTrue(self.sock.sendto.called)
        self.assertEqual(
            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))

1311
    def test_sendto_no_data(self):
1312
        transport = self.datagram_transport()
1313 1314 1315 1316 1317 1318 1319
        transport._buffer.append((b'data', ('0.0.0.0', 12345)))
        transport.sendto(b'', ())
        self.assertFalse(self.sock.sendto.called)
        self.assertEqual(
            [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))

    def test_sendto_buffer(self):
1320
        transport = self.datagram_transport()
1321 1322 1323 1324 1325 1326 1327 1328
        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
        transport.sendto(b'data2', ('0.0.0.0', 12345))
        self.assertFalse(self.sock.sendto.called)
        self.assertEqual(
            [(b'data1', ('0.0.0.0', 12345)),
             (b'data2', ('0.0.0.0', 12345))],
            list(transport._buffer))

1329 1330
    def test_sendto_buffer_bytearray(self):
        data2 = bytearray(b'data2')
1331
        transport = self.datagram_transport()
1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
        transport.sendto(data2, ('0.0.0.0', 12345))
        self.assertFalse(self.sock.sendto.called)
        self.assertEqual(
            [(b'data1', ('0.0.0.0', 12345)),
             (b'data2', ('0.0.0.0', 12345))],
            list(transport._buffer))
        self.assertIsInstance(transport._buffer[1][0], bytes)

    def test_sendto_buffer_memoryview(self):
        data2 = memoryview(b'data2')
1343
        transport = self.datagram_transport()
1344 1345 1346 1347 1348 1349 1350 1351 1352
        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
        transport.sendto(data2, ('0.0.0.0', 12345))
        self.assertFalse(self.sock.sendto.called)
        self.assertEqual(
            [(b'data1', ('0.0.0.0', 12345)),
             (b'data2', ('0.0.0.0', 12345))],
            list(transport._buffer))
        self.assertIsInstance(transport._buffer[1][0], bytes)

1353 1354 1355 1356 1357
    def test_sendto_tryagain(self):
        data = b'data'

        self.sock.sendto.side_effect = BlockingIOError

1358
        transport = self.datagram_transport()
1359 1360 1361 1362 1363 1364
        transport.sendto(data, ('0.0.0.0', 12345))

        self.loop.assert_writer(7, transport._sendto_ready)
        self.assertEqual(
            [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))

1365
    @mock.patch('asyncio.selector_events.logger')
1366 1367
    def test_sendto_exception(self, m_log):
        data = b'data'
1368
        err = self.sock.sendto.side_effect = RuntimeError()
1369

1370
        transport = self.datagram_transport()
1371
        transport._fatal_error = mock.Mock()
1372 1373 1374
        transport.sendto(data, ())

        self.assertTrue(transport._fatal_error.called)
1375 1376 1377
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal write error on datagram transport')
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387
        transport._conn_lost = 1

        transport._address = ('123',)
        transport.sendto(data)
        transport.sendto(data)
        transport.sendto(data)
        transport.sendto(data)
        transport.sendto(data)
        m_log.warning.assert_called_with('socket.send() raised exception.')

1388
    def test_sendto_error_received(self):
1389 1390 1391 1392
        data = b'data'

        self.sock.sendto.side_effect = ConnectionRefusedError

1393
        transport = self.datagram_transport()
1394
        transport._fatal_error = mock.Mock()
1395 1396 1397 1398 1399
        transport.sendto(data, ())

        self.assertEqual(transport._conn_lost, 0)
        self.assertFalse(transport._fatal_error.called)

1400
    def test_sendto_error_received_connected(self):
1401 1402 1403 1404
        data = b'data'

        self.sock.send.side_effect = ConnectionRefusedError

1405
        transport = self.datagram_transport(address=('0.0.0.0', 1))
1406
        transport._fatal_error = mock.Mock()
1407 1408
        transport.sendto(data)

1409 1410
        self.assertFalse(transport._fatal_error.called)
        self.assertTrue(self.protocol.error_received.called)
1411 1412

    def test_sendto_str(self):
1413
        transport = self.datagram_transport()
1414
        self.assertRaises(TypeError, transport.sendto, 'str', ())
1415 1416

    def test_sendto_connected_addr(self):
1417
        transport = self.datagram_transport(address=('0.0.0.0', 1))
1418
        self.assertRaises(
1419
            ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
1420 1421

    def test_sendto_closing(self):
1422
        transport = self.datagram_transport(address=(1,))
1423 1424 1425 1426 1427 1428 1429 1430 1431
        transport.close()
        self.assertEqual(transport._conn_lost, 1)
        transport.sendto(b'data', (1,))
        self.assertEqual(transport._conn_lost, 2)

    def test_sendto_ready(self):
        data = b'data'
        self.sock.sendto.return_value = len(data)

1432
        transport = self.datagram_transport()
1433
        transport._buffer.append((data, ('0.0.0.0', 12345)))
1434
        self.loop._add_writer(7, transport._sendto_ready)
1435 1436 1437 1438 1439 1440 1441 1442 1443 1444
        transport._sendto_ready()
        self.assertTrue(self.sock.sendto.called)
        self.assertEqual(
            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
        self.assertFalse(self.loop.writers)

    def test_sendto_ready_closing(self):
        data = b'data'
        self.sock.send.return_value = len(data)

1445
        transport = self.datagram_transport()
1446 1447
        transport._closing = True
        transport._buffer.append((data, ()))
1448
        self.loop._add_writer(7, transport._sendto_ready)
1449 1450 1451 1452 1453 1454 1455
        transport._sendto_ready()
        self.sock.sendto.assert_called_with(data, ())
        self.assertFalse(self.loop.writers)
        self.sock.close.assert_called_with()
        self.protocol.connection_lost.assert_called_with(None)

    def test_sendto_ready_no_data(self):
1456
        transport = self.datagram_transport()
1457
        self.loop._add_writer(7, transport._sendto_ready)
1458 1459 1460 1461 1462 1463 1464
        transport._sendto_ready()
        self.assertFalse(self.sock.sendto.called)
        self.assertFalse(self.loop.writers)

    def test_sendto_ready_tryagain(self):
        self.sock.sendto.side_effect = BlockingIOError

1465
        transport = self.datagram_transport()
1466
        transport._buffer.extend([(b'data1', ()), (b'data2', ())])
1467
        self.loop._add_writer(7, transport._sendto_ready)
1468 1469 1470 1471 1472 1473 1474 1475
        transport._sendto_ready()

        self.loop.assert_writer(7, transport._sendto_ready)
        self.assertEqual(
            [(b'data1', ()), (b'data2', ())],
            list(transport._buffer))

    def test_sendto_ready_exception(self):
1476
        err = self.sock.sendto.side_effect = RuntimeError()
1477

1478
        transport = self.datagram_transport()
1479
        transport._fatal_error = mock.Mock()
1480 1481 1482
        transport._buffer.append((b'data', ()))
        transport._sendto_ready()

1483 1484 1485
        transport._fatal_error.assert_called_with(
                                   err,
                                   'Fatal write error on datagram transport')
1486

1487
    def test_sendto_ready_error_received(self):
1488 1489
        self.sock.sendto.side_effect = ConnectionRefusedError

1490
        transport = self.datagram_transport()
1491
        transport._fatal_error = mock.Mock()
1492 1493 1494 1495 1496
        transport._buffer.append((b'data', ()))
        transport._sendto_ready()

        self.assertFalse(transport._fatal_error.called)

1497
    def test_sendto_ready_error_received_connection(self):
1498 1499
        self.sock.send.side_effect = ConnectionRefusedError

1500
        transport = self.datagram_transport(address=('0.0.0.0', 1))
1501
        transport._fatal_error = mock.Mock()
1502 1503 1504
        transport._buffer.append((b'data', ()))
        transport._sendto_ready()

1505 1506
        self.assertFalse(transport._fatal_error.called)
        self.assertTrue(self.protocol.error_received.called)
1507

1508
    @mock.patch('asyncio.base_events.logger.error')
1509
    def test_fatal_error_connected(self, m_exc):
1510
        transport = self.datagram_transport(address=('0.0.0.0', 1))
1511 1512
        err = ConnectionRefusedError()
        transport._fatal_error(err)
1513
        self.assertFalse(self.protocol.error_received.called)
1514 1515
        m_exc.assert_called_with(
            test_utils.MockPattern(
1516
                'Fatal error on transport\nprotocol:.*\ntransport:.*'),
1517
            exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
1518

1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544

class TestSelectorUtils(test_utils.TestCase):
    def check_set_nodelay(self, sock):
        opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
        self.assertFalse(opt)

        _set_nodelay(sock)

        opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
        self.assertTrue(opt)

    @unittest.skipUnless(hasattr(socket, 'TCP_NODELAY'),
                         'need socket.TCP_NODELAY')
    def test_set_nodelay(self):
        sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM,
                             proto=socket.IPPROTO_TCP)
        with sock:
            self.check_set_nodelay(sock)

        sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM,
                             proto=socket.IPPROTO_TCP)
        with sock:
            sock.setblocking(False)
            self.check_set_nodelay(sock)


1545 1546
if __name__ == '__main__':
    unittest.main()