test_socket.py 19.7 KB
Newer Older
1
#!/usr/bin/env python
2

3
import unittest
4
from test import test_support
5 6

import socket
7
import select
8
import time
9 10
import thread, threading
import Queue
11
import sys
12

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
PORT = 50007
HOST = 'localhost'
MSG = 'Michael Gilfix was here\n'

class SocketTCPTest(unittest.TestCase):

    def setUp(self):
        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.serv.bind((HOST, PORT))
        self.serv.listen(1)

    def tearDown(self):
        self.serv.close()
        self.serv = None

class SocketUDPTest(unittest.TestCase):

    def setUp(self):
        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.serv.bind((HOST, PORT))

    def tearDown(self):
        self.serv.close()
        self.serv = None

class ThreadableTest:
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    """Threadable Test class

    The ThreadableTest class makes it easy to create a threaded
    client/server pair from an existing unit test. To create a
    new threaded class from an existing unit test, use multiple
    inheritance:

        class NewClass (OldClass, ThreadableTest):
            pass

    This class defines two new fixture functions with obvious
    purposes for overriding:

        clientSetUp ()
        clientTearDown ()

    Any new test functions within the class must then define
    tests in pairs, where the test name is preceeded with a
    '_' to indicate the client portion of the test. Ex:

        def testFoo(self):
            # Server portion

        def _testFoo(self):
            # Client portion

    Any exceptions raised by the clients during their tests
    are caught and transferred to the main thread to alert
    the testing framework.

    Note, the server setup function cannot call any blocking
    functions that rely on the client thread during setup,
    unless serverExplicityReady() is called just before
    the blocking call (such as in setting up a client/server
    connection and performing the accept() in setUp().
    """
77 78 79 80 81 82 83 84

    def __init__(self):
        # Swap the true setup function
        self.__setUp = self.setUp
        self.__tearDown = self.tearDown
        self.setUp = self._setUp
        self.tearDown = self._tearDown

85 86 87 88 89 90 91
    def serverExplicitReady(self):
        """This method allows the server to explicitly indicate that
        it wants the client thread to proceed. This is useful if the
        server is about to execute a blocking routine that is
        dependent upon the client thread during its setup routine."""
        self.server_ready.set()

92
    def _setUp(self):
93 94
        self.server_ready = threading.Event()
        self.client_ready = threading.Event()
95 96 97 98
        self.done = threading.Event()
        self.queue = Queue.Queue(1)

        # Do some munging to start the client test.
99 100 101 102
        methodname = self.id()
        i = methodname.rfind('.')
        methodname = methodname[i+1:]
        test_method = getattr(self, '_' + methodname)
103 104
        self.client_thread = thread.start_new_thread(
            self.clientRun, (test_method,))
105 106

        self.__setUp()
107 108 109
        if not self.server_ready.isSet():
            self.server_ready.set()
        self.client_ready.wait()
110 111 112 113 114 115 116 117 118 119

    def _tearDown(self):
        self.__tearDown()
        self.done.wait()

        if not self.queue.empty():
            msg = self.queue.get()
            self.fail(msg)

    def clientRun(self, test_func):
120 121
        self.server_ready.wait()
        self.client_ready.set()
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
        self.clientSetUp()
        if not callable(test_func):
            raise TypeError, "test_func must be a callable function"
        try:
            test_func()
        except Exception, strerror:
            self.queue.put(strerror)
        self.clientTearDown()

    def clientSetUp(self):
        raise NotImplementedError, "clientSetUp must be implemented."

    def clientTearDown(self):
        self.done.set()
        thread.exit()

class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):

    def __init__(self, methodName='runTest'):
        SocketTCPTest.__init__(self, methodName=methodName)
        ThreadableTest.__init__(self)

    def clientSetUp(self):
        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    def clientTearDown(self):
        self.cli.close()
        self.cli = None
        ThreadableTest.clientTearDown(self)
151

152
class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
153

154 155 156
    def __init__(self, methodName='runTest'):
        SocketUDPTest.__init__(self, methodName=methodName)
        ThreadableTest.__init__(self)
157

158 159
    def clientSetUp(self):
        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
160

161 162 163 164 165 166 167
class SocketConnectedTest(ThreadedTCPSocketTest):

    def __init__(self, methodName='runTest'):
        ThreadedTCPSocketTest.__init__(self, methodName=methodName)

    def setUp(self):
        ThreadedTCPSocketTest.setUp(self)
168 169 170
        # Indicate explicitly we're ready for the client thread to
        # proceed and then perform the blocking call to accept
        self.serverExplicitReady()
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        conn, addr = self.serv.accept()
        self.cli_conn = conn

    def tearDown(self):
        self.cli_conn.close()
        self.cli_conn = None
        ThreadedTCPSocketTest.tearDown(self)

    def clientSetUp(self):
        ThreadedTCPSocketTest.clientSetUp(self)
        self.cli.connect((HOST, PORT))
        self.serv_conn = self.cli

    def clientTearDown(self):
        self.serv_conn.close()
        self.serv_conn = None
        ThreadedTCPSocketTest.clientTearDown(self)

#######################################################################
## Begin Tests

class GeneralModuleTests(unittest.TestCase):

    def testSocketError(self):
195
        # Testing socket module exceptions
196 197 198 199 200 201 202 203 204 205 206 207 208 209
        def raise_error(*args, **kwargs):
            raise socket.error
        def raise_herror(*args, **kwargs):
            raise socket.herror
        def raise_gaierror(*args, **kwargs):
            raise socket.gaierror
        self.failUnlessRaises(socket.error, raise_error,
                              "Error raising socket exception.")
        self.failUnlessRaises(socket.error, raise_herror,
                              "Error raising socket exception.")
        self.failUnlessRaises(socket.error, raise_gaierror,
                              "Error raising socket exception.")

    def testCrucialConstants(self):
210
        # Testing for mission critical constants
211 212 213 214 215 216 217 218 219
        socket.AF_INET
        socket.SOCK_STREAM
        socket.SOCK_DGRAM
        socket.SOCK_RAW
        socket.SOCK_RDM
        socket.SOCK_SEQPACKET
        socket.SOL_SOCKET
        socket.SO_REUSEADDR

220
    def testHostnameRes(self):
221
        # Testing hostname resolution mechanisms
222 223 224 225 226 227 228 229
        hostname = socket.gethostname()
        ip = socket.gethostbyname(hostname)
        self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
        hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
        all_host_names = [hname] + aliases
        fqhn = socket.getfqdn()
        if not fqhn in all_host_names:
            self.fail("Error testing host resolution mechanisms.")
230

231
    def testRefCountGetNameInfo(self):
232
        # Testing reference count for getnameinfo
233
        import sys
234
        if hasattr(sys, "getrefcount"):
235 236 237 238 239 240 241 242 243
            try:
                # On some versions, this loses a reference
                orig = sys.getrefcount(__name__)
                socket.getnameinfo(__name__,0)
            except SystemError:
                if sys.getrefcount(__name__) <> orig:
                    self.fail("socket.getnameinfo loses a reference")

    def testInterpreterCrash(self):
244
        # Making sure getnameinfo doesn't crash the interpreter
245 246 247 248 249 250
        try:
            # On some versions, this crashes the interpreter.
            socket.getnameinfo(('x', 0, 0, 0), 0)
        except socket.error:
            pass

251
    def testNtoH(self):
252 253 254 255 256 257 258 259 260 261 262 263
        # This just checks that htons etc. are their own inverse,
        # when looking at the lower 16 or 32 bits.
        sizes = {socket.htonl: 32, socket.ntohl: 32,
                 socket.htons: 16, socket.ntohs: 16}
        for func, size in sizes.items():
            mask = (1L<<size) - 1
            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
                self.assertEqual(i & mask, func(func(i&mask)) & mask)

            swapped = func(mask)
            self.assertEqual(swapped & mask, mask)
            self.assertRaises(OverflowError, func, 1L<<34)
264

265
    def testGetServByName(self):
266
        # Testing getservbyname()
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        # try a few protocols - not everyone has telnet enabled
        found = 0
        for proto in ("telnet", "ssh", "www", "ftp"):
            try:
                socket.getservbyname(proto, 'tcp')
                found = 1
                break
            except socket.error:
                pass
            try:
                socket.getservbyname(proto, 'udp')
                found = 1
                break
            except socket.error:
                pass
            if not found:
                raise socket.error
284

285
    def testDefaultTimeout(self):
286
        # Testing default timeout
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
        # The default timeout should initially be None
        self.assertEqual(socket.getdefaulttimeout(), None)
        s = socket.socket()
        self.assertEqual(s.gettimeout(), None)
        s.close()

        # Set the default timeout to 10, and see if it propagates
        socket.setdefaulttimeout(10)
        self.assertEqual(socket.getdefaulttimeout(), 10)
        s = socket.socket()
        self.assertEqual(s.gettimeout(), 10)
        s.close()

        # Reset the default timeout to None, and see if it propagates
        socket.setdefaulttimeout(None)
        self.assertEqual(socket.getdefaulttimeout(), None)
        s = socket.socket()
        self.assertEqual(s.gettimeout(), None)
        s.close()

        # Check that setting it to an invalid value raises ValueError
        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)

        # Check that setting it to an invalid type raises TypeError
        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")

313
    # XXX The following don't test module-level functionality...
314

315
    def testSockName(self):
316
        # Testing getsockname()
317
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
318
        sock.bind(("0.0.0.0", PORT+1))
319
        name = sock.getsockname()
320
        self.assertEqual(name, ("0.0.0.0", PORT+1))
321 322

    def testGetSockOpt(self):
323
        # Testing getsockopt()
324 325 326
        # We know a socket should start without reuse==0
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
327
        self.failIf(reuse != 0, "initial mode is reuse")
328 329

    def testSetSockOpt(self):
330
        # Testing setsockopt()
331 332 333
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
334
        self.failIf(reuse == 0, "failed to set reuse mode")
335

336
    def testSendAfterClose(self):
337
        # testing send() after close() with timeout
338 339 340 341 342
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(1)
        sock.close()
        self.assertRaises(socket.error, sock.send, "spam")

343 344 345 346 347 348
class BasicTCPTest(SocketConnectedTest):

    def __init__(self, methodName='runTest'):
        SocketConnectedTest.__init__(self, methodName=methodName)

    def testRecv(self):
349
        # Testing large receive over TCP
350
        msg = self.cli_conn.recv(1024)
351
        self.assertEqual(msg, MSG)
352 353 354 355 356

    def _testRecv(self):
        self.serv_conn.send(MSG)

    def testOverFlowRecv(self):
357
        # Testing receive in chunks over TCP
358 359
        seg1 = self.cli_conn.recv(len(MSG) - 3)
        seg2 = self.cli_conn.recv(1024)
360
        msg = seg1 + seg2
361
        self.assertEqual(msg, MSG)
362 363 364 365 366

    def _testOverFlowRecv(self):
        self.serv_conn.send(MSG)

    def testRecvFrom(self):
367
        # Testing large recvfrom() over TCP
368
        msg, addr = self.cli_conn.recvfrom(1024)
369
        self.assertEqual(msg, MSG)
370 371 372 373 374

    def _testRecvFrom(self):
        self.serv_conn.send(MSG)

    def testOverFlowRecvFrom(self):
375
        # Testing recvfrom() in chunks over TCP
376 377
        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
        seg2, addr = self.cli_conn.recvfrom(1024)
378
        msg = seg1 + seg2
379
        self.assertEqual(msg, MSG)
380 381 382 383 384

    def _testOverFlowRecvFrom(self):
        self.serv_conn.send(MSG)

    def testSendAll(self):
385
        # Testing sendall() with a 2048 byte string over TCP
386
        msg = ''
387 388 389 390
        while 1:
            read = self.cli_conn.recv(1024)
            if not read:
                break
391 392
            msg += read
        self.assertEqual(msg, 'f' * 2048)
393 394

    def _testSendAll(self):
395
        big_chunk = 'f' * 2048
396 397 398
        self.serv_conn.sendall(big_chunk)

    def testFromFd(self):
399
        # Testing fromfd()
Guido van Rossum's avatar
Guido van Rossum committed
400
        if not hasattr(socket, "fromfd"):
401
            return # On Windows, this doesn't exist
402 403 404
        fd = self.cli_conn.fileno()
        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
        msg = sock.recv(1024)
405
        self.assertEqual(msg, MSG)
406 407 408 409 410

    def _testFromFd(self):
        self.serv_conn.send(MSG)

    def testShutdown(self):
411
        # Testing shutdown()
412
        msg = self.cli_conn.recv(1024)
413
        self.assertEqual(msg, MSG)
414 415 416 417 418 419 420 421 422 423 424

    def _testShutdown(self):
        self.serv_conn.send(MSG)
        self.serv_conn.shutdown(2)

class BasicUDPTest(ThreadedUDPSocketTest):

    def __init__(self, methodName='runTest'):
        ThreadedUDPSocketTest.__init__(self, methodName=methodName)

    def testSendtoAndRecv(self):
425
        # Testing sendto() and Recv() over UDP
426
        msg = self.serv.recv(len(MSG))
427
        self.assertEqual(msg, MSG)
428 429 430 431

    def _testSendtoAndRecv(self):
        self.cli.sendto(MSG, 0, (HOST, PORT))

432
    def testRecvFrom(self):
433
        # Testing recvfrom() over UDP
434
        msg, addr = self.serv.recvfrom(len(MSG))
435
        self.assertEqual(msg, MSG)
436

437
    def _testRecvFrom(self):
438 439 440 441 442 443 444 445
        self.cli.sendto(MSG, 0, (HOST, PORT))

class NonBlockingTCPTests(ThreadedTCPSocketTest):

    def __init__(self, methodName='runTest'):
        ThreadedTCPSocketTest.__init__(self, methodName=methodName)

    def testSetBlocking(self):
446
        # Testing whether set blocking works
447 448 449 450 451 452 453 454 455 456
        self.serv.setblocking(0)
        start = time.time()
        try:
            self.serv.accept()
        except socket.error:
            pass
        end = time.time()
        self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")

    def _testSetBlocking(self):
457
        pass
458

459
    def testAccept(self):
460
        # Testing non-blocking accept
461 462 463 464 465 466 467 468 469 470 471 472 473 474
        self.serv.setblocking(0)
        try:
            conn, addr = self.serv.accept()
        except socket.error:
            pass
        else:
            self.fail("Error trying to do non-blocking accept.")
        read, write, err = select.select([self.serv], [], [])
        if self.serv in read:
            conn, addr = self.serv.accept()
        else:
            self.fail("Error trying to do accept after select.")

    def _testAccept(self):
475
        time.sleep(0.1)
476 477 478
        self.cli.connect((HOST, PORT))

    def testConnect(self):
479
        # Testing non-blocking connect
480 481 482
        conn, addr = self.serv.accept()

    def _testConnect(self):
483 484
        self.cli.settimeout(10)
        self.cli.connect((HOST, PORT))
485 486

    def testRecv(self):
487
        # Testing non-blocking recv
488 489 490 491 492 493 494 495 496 497 498
        conn, addr = self.serv.accept()
        conn.setblocking(0)
        try:
            msg = conn.recv(len(MSG))
        except socket.error:
            pass
        else:
            self.fail("Error trying to do non-blocking recv.")
        read, write, err = select.select([conn], [], [])
        if conn in read:
            msg = conn.recv(len(MSG))
499
            self.assertEqual(msg, MSG)
500 501 502 503 504
        else:
            self.fail("Error during select call to non-blocking socket.")

    def _testRecv(self):
        self.cli.connect((HOST, PORT))
505
        time.sleep(0.1)
506 507 508 509
        self.cli.send(MSG)

class FileObjectClassTestCase(SocketConnectedTest):

510 511
    bufsize = -1 # Use default buffer size

512 513 514 515 516
    def __init__(self, methodName='runTest'):
        SocketConnectedTest.__init__(self, methodName=methodName)

    def setUp(self):
        SocketConnectedTest.setUp(self)
517
        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
518 519 520 521 522 523 524 525

    def tearDown(self):
        self.serv_file.close()
        self.serv_file = None
        SocketConnectedTest.tearDown(self)

    def clientSetUp(self):
        SocketConnectedTest.clientSetUp(self)
526
        self.cli_file = self.serv_conn.makefile('wb')
527 528 529 530 531 532 533

    def clientTearDown(self):
        self.cli_file.close()
        self.cli_file = None
        SocketConnectedTest.clientTearDown(self)

    def testSmallRead(self):
534
        # Performing small file read test
535 536
        first_seg = self.serv_file.read(len(MSG)-3)
        second_seg = self.serv_file.read(3)
537
        msg = first_seg + second_seg
538
        self.assertEqual(msg, MSG)
539 540 541 542 543

    def _testSmallRead(self):
        self.cli_file.write(MSG)
        self.cli_file.flush()

544 545 546 547 548 549 550 551 552
    def testFullRead(self):
        # read until EOF
        msg = self.serv_file.read()
        self.assertEqual(msg, MSG)

    def _testFullRead(self):
        self.cli_file.write(MSG)
        self.cli_file.close()

553
    def testUnbufferedRead(self):
554
        # Performing unbuffered file read test
555 556 557
        buf = ''
        while 1:
            char = self.serv_file.read(1)
558
            if not char:
559
                break
560 561
            buf += char
        self.assertEqual(buf, MSG)
562 563 564 565 566 567

    def _testUnbufferedRead(self):
        self.cli_file.write(MSG)
        self.cli_file.flush()

    def testReadline(self):
568
        # Performing file readline test
569
        line = self.serv_file.readline()
570
        self.assertEqual(line, MSG)
571 572 573 574 575

    def _testReadline(self):
        self.cli_file.write(MSG)
        self.cli_file.flush()

576 577 578
class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):

    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
Tim Peters's avatar
Tim Peters committed
579

580 581 582 583 584 585 586 587 588
    In this case (and in this case only), it should be possible to
    create a file object, read a line from it, create another file
    object, read another line from it, without loss of data in the
    first file object's buffer.  Note that httplib relies on this
    when reading multiple requests from the same socket."""

    bufsize = 0 # Use unbuffered mode

    def testUnbufferedReadline(self):
589
        # Read a line, create a new file object, read another line with it
590
        line = self.serv_file.readline() # first line
591
        self.assertEqual(line, "A. " + MSG) # first line
592 593
        self.serv_file = self.cli_conn.makefile('rb', 0)
        line = self.serv_file.readline() # second line
594
        self.assertEqual(line, "B. " + MSG) # second line
595 596

    def _testUnbufferedReadline(self):
597 598
        self.cli_file.write("A. " + MSG)
        self.cli_file.write("B. " + MSG)
599 600
        self.cli_file.flush()

601 602 603 604 605 606 607 608
class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):

    bufsize = 1 # Default-buffered for reading; line-buffered for writing


class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):

    bufsize = 2 # Exercise the buffering code
609

610
def test_main():
611
    suite = unittest.TestSuite()
612 613
    suite.addTest(unittest.makeSuite(GeneralModuleTests))
    suite.addTest(unittest.makeSuite(BasicTCPTest))
614 615
    if sys.platform != 'mac':
        suite.addTest(unittest.makeSuite(BasicUDPTest))
616
    suite.addTest(unittest.makeSuite(NonBlockingTCPTests))
617
    suite.addTest(unittest.makeSuite(FileObjectClassTestCase))
618
    suite.addTest(unittest.makeSuite(UnbufferedFileObjectClassTestCase))
619 620
    suite.addTest(unittest.makeSuite(LineBufferedFileObjectClassTestCase))
    suite.addTest(unittest.makeSuite(SmallBufferedFileObjectClassTestCase))
621 622 623
    test_support.run_suite(suite)

if __name__ == "__main__":
624
    test_main()