pickletester.py 28.9 KB
Newer Older
1
import unittest
2
import pickle
3
import cPickle
4
import pickletools
5
import copy_reg
6

7
from test.test_support import TestFailed, have_unicode, TESTFN
8

9 10
# Tests that try a number of pickle protocols should have a
#     for proto in protocols:
11 12 13
# kind of outer loop.
assert pickle.HIGHEST_PROTOCOL == cPickle.HIGHEST_PROTOCOL == 2
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
14

15 16 17 18 19 20 21 22

# Return True if opcode code appears in the pickle, else False.
def opcode_in_pickle(code, pickle):
    for op, dummy, dummy in pickletools.genops(pickle):
        if op.code == code:
            return True
    return False

23 24 25 26 27 28 29 30
# Return the number of times opcode code appears in pickle.
def count_opcode(code, pickle):
    n = 0
    for op, dummy, dummy in pickletools.genops(pickle):
        if op.code == code:
            n += 1
    return n

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
# We can't very well test the extension registry without putting known stuff
# in it, but we have to be careful to restore its original state.  Code
# should do this:
#
#     e = ExtensionSaver(extension_code)
#     try:
#         fiddle w/ the extension registry's stuff for extension_code
#     finally:
#         e.restore()

class ExtensionSaver:
    # Remember current registration for code (if any), and remove it (if
    # there is one).
    def __init__(self, code):
        self.code = code
        if code in copy_reg._inverted_registry:
            self.pair = copy_reg._inverted_registry[code]
            copy_reg.remove_extension(self.pair[0], self.pair[1], code)
        else:
            self.pair = None

    # Restore previous registration for code.
    def restore(self):
        code = self.code
        curpair = copy_reg._inverted_registry.get(code)
        if curpair is not None:
            copy_reg.remove_extension(curpair[0], curpair[1], code)
        pair = self.pair
        if pair is not None:
            copy_reg.add_extension(pair[0], pair[1], code)

62 63 64 65 66 67 68 69 70 71 72 73 74
class C:
    def __cmp__(self, other):
        return cmp(self.__dict__, other.__dict__)

import __main__
__main__.C = C
C.__module__ = "__main__"

class myint(int):
    def __init__(self, x):
        self.str = str(x)

class initarg(C):
75

76 77 78 79 80 81 82
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __getinitargs__(self):
        return self.a, self.b

83 84 85 86 87 88
class metaclass(type):
    pass

class use_metaclass(object):
    __metaclass__ = metaclass

89 90
# DATA0 .. DATA2 are the pickles we expect under the various protocols, for
# the object returned by create_data().
91

92
# break into multiple strings to avoid confusing font-lock-mode
93
DATA0 = """(lp1
94 95
I0
aL1L
96
aF2
97 98
ac__builtin__
complex
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
p2
""" + \
"""(F3
F0
tRp3
aI1
aI-1
aI255
aI-255
aI-256
aI65535
aI-65535
aI-65536
aI2147483647
aI-2147483647
aI-2147483648
a""" + \
"""(S'abc'
117 118
p4
g4
119
""" + \
120
"""(i__main__
121 122
C
p5
123
""" + \
124 125 126 127 128 129 130 131 132 133 134 135 136 137
"""(dp6
S'foo'
p7
I1
sS'bar'
p8
I2
sbg5
tp9
ag9
aI5
a.
"""

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
# Disassembly of DATA0.
DATA0_DIS = """\
    0: (    MARK
    1: l        LIST       (MARK at 0)
    2: p    PUT        1
    5: I    INT        0
    8: a    APPEND
    9: L    LONG       1L
   13: a    APPEND
   14: F    FLOAT      2.0
   17: a    APPEND
   18: c    GLOBAL     '__builtin__ complex'
   39: p    PUT        2
   42: (    MARK
   43: F        FLOAT      3.0
   46: F        FLOAT      0.0
   49: t        TUPLE      (MARK at 42)
   50: R    REDUCE
   51: p    PUT        3
   54: a    APPEND
   55: I    INT        1
   58: a    APPEND
   59: I    INT        -1
   63: a    APPEND
   64: I    INT        255
   69: a    APPEND
   70: I    INT        -255
   76: a    APPEND
   77: I    INT        -256
   83: a    APPEND
   84: I    INT        65535
   91: a    APPEND
   92: I    INT        -65535
  100: a    APPEND
  101: I    INT        -65536
  109: a    APPEND
  110: I    INT        2147483647
  122: a    APPEND
  123: I    INT        -2147483647
  136: a    APPEND
  137: I    INT        -2147483648
  150: a    APPEND
  151: (    MARK
  152: S        STRING     'abc'
  159: p        PUT        4
  162: g        GET        4
  165: (        MARK
  166: i            INST       '__main__ C' (MARK at 165)
  178: p        PUT        5
  181: (        MARK
  182: d            DICT       (MARK at 181)
  183: p        PUT        6
  186: S        STRING     'foo'
  193: p        PUT        7
  196: I        INT        1
  199: s        SETITEM
  200: S        STRING     'bar'
  207: p        PUT        8
  210: I        INT        2
  213: s        SETITEM
  214: b        BUILD
  215: g        GET        5
  218: t        TUPLE      (MARK at 151)
  219: p    PUT        9
  222: a    APPEND
  223: g    GET        9
  226: a    APPEND
  227: I    INT        5
  230: a    APPEND
  231: .    STOP
highest protocol among opcodes = 0
"""

DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00'
         'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00'
         '\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff'
         '\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff'
         'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00'
         '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n'
         'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh'
         '\x06tq\nh\nK\x05e.'
        )

# Disassembly of DATA1.
DATA1_DIS = """\
    0: ]    EMPTY_LIST
    1: q    BINPUT     1
    3: (    MARK
    4: K        BININT1    0
    6: L        LONG       1L
   10: G        BINFLOAT   2.0
   19: c        GLOBAL     '__builtin__ complex'
   40: q        BINPUT     2
   42: (        MARK
   43: G            BINFLOAT   3.0
   52: G            BINFLOAT   0.0
   61: t            TUPLE      (MARK at 42)
   62: R        REDUCE
   63: q        BINPUT     3
   65: K        BININT1    1
   67: J        BININT     -1
   72: K        BININT1    255
   74: J        BININT     -255
   79: J        BININT     -256
   84: M        BININT2    65535
   87: J        BININT     -65535
   92: J        BININT     -65536
   97: J        BININT     2147483647
  102: J        BININT     -2147483647
  107: J        BININT     -2147483648
  112: (        MARK
  113: U            SHORT_BINSTRING 'abc'
  118: q            BINPUT     4
  120: h            BINGET     4
  122: (            MARK
  123: c                GLOBAL     '__main__ C'
  135: q                BINPUT     5
  137: o                OBJ        (MARK at 122)
  138: q            BINPUT     6
  140: }            EMPTY_DICT
  141: q            BINPUT     7
  143: (            MARK
  144: U                SHORT_BINSTRING 'foo'
  149: q                BINPUT     8
  151: K                BININT1    1
  153: U                SHORT_BINSTRING 'bar'
  158: q                BINPUT     9
  160: K                BININT1    2
  162: u                SETITEMS   (MARK at 143)
  163: b            BUILD
  164: h            BINGET     6
  166: t            TUPLE      (MARK at 112)
  167: q        BINPUT     10
  169: h        BINGET     10
  171: K        BININT1    5
  173: e        APPENDS    (MARK at 3)
  174: .    STOP
highest protocol among opcodes = 1
"""
Tim Peters's avatar
Tim Peters committed
277

278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00'
         'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00'
         '\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK'
         '\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff'
         'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00'
         '\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo'
         'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.')

# Disassembly of DATA2.
DATA2_DIS = """\
    0: \x80 PROTO      2
    2: ]    EMPTY_LIST
    3: q    BINPUT     1
    5: (    MARK
    6: K        BININT1    0
    8: \x8a     LONG1      1L
   11: G        BINFLOAT   2.0
   20: c        GLOBAL     '__builtin__ complex'
   41: q        BINPUT     2
   43: G        BINFLOAT   3.0
   52: G        BINFLOAT   0.0
   61: \x86     TUPLE2
   62: R        REDUCE
   63: q        BINPUT     3
   65: K        BININT1    1
   67: J        BININT     -1
   72: K        BININT1    255
   74: J        BININT     -255
   79: J        BININT     -256
   84: M        BININT2    65535
   87: J        BININT     -65535
   92: J        BININT     -65536
   97: J        BININT     2147483647
  102: J        BININT     -2147483647
  107: J        BININT     -2147483648
  112: (        MARK
  113: U            SHORT_BINSTRING 'abc'
  118: q            BINPUT     4
  120: h            BINGET     4
  122: (            MARK
  123: c                GLOBAL     '__main__ C'
  135: q                BINPUT     5
  137: o                OBJ        (MARK at 122)
  138: q            BINPUT     6
  140: }            EMPTY_DICT
  141: q            BINPUT     7
  143: (            MARK
  144: U                SHORT_BINSTRING 'foo'
  149: q                BINPUT     8
  151: K                BININT1    1
  153: U                SHORT_BINSTRING 'bar'
  158: q                BINPUT     9
  160: K                BININT1    2
  162: u                SETITEMS   (MARK at 143)
  163: b            BUILD
  164: h            BINGET     6
  166: t            TUPLE      (MARK at 112)
  167: q        BINPUT     10
  169: h        BINGET     10
  171: K        BININT1    5
  173: e        APPENDS    (MARK at 5)
  174: .    STOP
highest protocol among opcodes = 2
"""

343
def create_data():
344 345 346 347
    c = C()
    c.foo = 1
    c.bar = 2
    x = [0, 1L, 2.0, 3.0+0j]
348 349 350 351 352 353 354 355 356
    # Append some integer test cases at cPickle.c's internal size
    # cutoffs.
    uint1max = 0xff
    uint2max = 0xffff
    int4max = 0x7fffffff
    x.extend([1, -1,
              uint1max, -uint1max, -uint1max-1,
              uint2max, -uint2max, -uint2max-1,
               int4max,  -int4max,  -int4max-1])
357 358 359 360
    y = ('abc', 'abc', c, c)
    x.append(y)
    x.append(y)
    x.append(5)
361 362 363
    return x

class AbstractPickleTests(unittest.TestCase):
364
    # Subclass must define self.dumps, self.loads, self.error.
365 366 367 368

    _testdata = create_data()

    def setUp(self):
369
        pass
370 371 372

    def test_misc(self):
        # test various datatypes not tested by testdata
373 374 375 376 377
        for proto in protocols:
            x = myint(4)
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
378

379 380 381 382
            x = (1, ())
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
383

384 385 386 387
            x = initarg(1, x)
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
388 389 390

        # XXX test __reduce__ protocol?

391 392 393 394 395 396 397 398 399
    def test_roundtrip_equality(self):
        expected = self._testdata
        for proto in protocols:
            s = self.dumps(expected, proto)
            got = self.loads(s)
            self.assertEqual(expected, got)

    def test_load_from_canned_string(self):
        expected = self._testdata
400
        for canned in DATA0, DATA1, DATA2:
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
            got = self.loads(canned)
            self.assertEqual(expected, got)

    # There are gratuitous differences between pickles produced by
    # pickle and cPickle, largely because cPickle starts PUT indices at
    # 1 and pickle starts them at 0.  See XXX comment in cPickle's put2() --
    # there's a comment with an exclamation point there whose meaning
    # is a mystery.  cPickle also suppresses PUT for objects with a refcount
    # of 1.
    def dont_test_disassembly(self):
        from cStringIO import StringIO
        from pickletools import dis

        for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS):
            s = self.dumps(self._testdata, proto)
            filelike = StringIO()
            dis(s, out=filelike)
            got = filelike.getvalue()
            self.assertEqual(expected, got)
420 421 422 423

    def test_recursive_list(self):
        l = []
        l.append(l)
424 425 426 427 428 429
        for proto in protocols:
            s = self.dumps(l, proto)
            x = self.loads(s)
            self.assertEqual(x, l)
            self.assertEqual(x, x[0])
            self.assertEqual(id(x), id(x[0]))
430 431 432 433

    def test_recursive_dict(self):
        d = {}
        d[1] = d
434 435 436 437 438 439
        for proto in protocols:
            s = self.dumps(d, proto)
            x = self.loads(s)
            self.assertEqual(x, d)
            self.assertEqual(x[1], x)
            self.assertEqual(id(x[1]), id(x))
440 441 442 443

    def test_recursive_inst(self):
        i = C()
        i.attr = i
444 445 446 447 448 449
        for proto in protocols:
            s = self.dumps(i, 2)
            x = self.loads(s)
            self.assertEqual(x, i)
            self.assertEqual(x.attr, x)
            self.assertEqual(id(x.attr), id(x))
450 451 452 453 454 455 456

    def test_recursive_multi(self):
        l = []
        d = {1:l}
        i = C()
        i.attr = d
        l.append(i)
457 458 459 460 461 462 463 464 465
        for proto in protocols:
            s = self.dumps(l, proto)
            x = self.loads(s)
            self.assertEqual(x, l)
            self.assertEqual(x[0], i)
            self.assertEqual(x[0].attr, d)
            self.assertEqual(x[0].attr[1], x)
            self.assertEqual(x[0].attr[1][0], i)
            self.assertEqual(x[0].attr[1][0].attr, d)
466 467 468 469 470 471

    def test_garyp(self):
        self.assertRaises(self.error, self.loads, 'garyp')

    def test_insecure_strings(self):
        insecure = ["abc", "2 + 2", # not quoted
472
                    #"'abc' + 'def'", # not a single quoted string
473 474 475
                    "'abc", # quote is not closed
                    "'abc\"", # open quote and close quote don't match
                    "'abc'   ?", # junk after close quote
476
                    "'\\'", # trailing backslash
477
                    # some tests of the quoting rules
478 479
                    #"'abc\"\''",
                    #"'\\\\a\'\'\'\\\'\\\\\''",
480 481 482 483 484
                    ]
        for s in insecure:
            buf = "S" + s + "\012p0\012."
            self.assertRaises(ValueError, self.loads, buf)

485
    if have_unicode:
486 487 488
        def test_unicode(self):
            endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
                        unicode('<\n>'),  unicode('<\\>')]
489 490 491 492 493
            for proto in protocols:
                for u in endcases:
                    p = self.dumps(u, proto)
                    u2 = self.loads(p)
                    self.assertEqual(u2, u)
494 495 496

    def test_ints(self):
        import sys
497 498 499 500 501 502 503 504
        for proto in protocols:
            n = sys.maxint
            while n:
                for expected in (-n, n):
                    s = self.dumps(expected, proto)
                    n2 = self.loads(s)
                    self.assertEqual(expected, n2)
                n = n >> 1
505 506 507 508 509 510 511 512 513 514 515

    def test_maxint64(self):
        maxint64 = (1L << 63) - 1
        data = 'I' + str(maxint64) + '\n.'
        got = self.loads(data)
        self.assertEqual(got, maxint64)

        # Try too with a bogus literal.
        data = 'I' + str(maxint64) + 'JUNK\n.'
        self.assertRaises(ValueError, self.loads, data)

516 517
    def test_long(self):
        for proto in protocols:
518
            # 256 bytes is where LONG4 begins.
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
            for nbits in 1, 8, 8*254, 8*255, 8*256, 8*257:
                nbase = 1L << nbits
                for npos in nbase-1, nbase, nbase+1:
                    for n in npos, -npos:
                        pickle = self.dumps(n, proto)
                        got = self.loads(pickle)
                        self.assertEqual(n, got)
        # Try a monster.  This is quadratic-time in protos 0 & 1, so don't
        # bother with those.
        nbase = long("deadbeeffeedface", 16)
        nbase += nbase << 1000000
        for n in nbase, -nbase:
            p = self.dumps(n, 2)
            got = self.loads(p)
            self.assertEqual(n, got)

535
    def test_reduce(self):
536
        pass
537 538 539 540

    def test_getinitargs(self):
        pass

541 542
    def test_metaclass(self):
        a = use_metaclass()
543 544 545 546
        for proto in protocols:
            s = self.dumps(a, proto)
            b = self.loads(s)
            self.assertEqual(a.__class__, b.__class__)
547

Michael W. Hudson's avatar
Michael W. Hudson committed
548 549
    def test_structseq(self):
        import time
550
        import os
551 552 553 554

        t = time.localtime()
        for proto in protocols:
            s = self.dumps(t, proto)
555 556
            u = self.loads(s)
            self.assertEqual(t, u)
557 558 559 560 561 562 563 564 565 566
            if hasattr(os, "stat"):
                t = os.stat(os.curdir)
                s = self.dumps(t, proto)
                u = self.loads(s)
                self.assertEqual(t, u)
            if hasattr(os, "statvfs"):
                t = os.statvfs(os.curdir)
                s = self.dumps(t, proto)
                u = self.loads(s)
                self.assertEqual(t, u)
Michael W. Hudson's avatar
Michael W. Hudson committed
567

568 569
    # Tests for protocol 2

570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
    def test_proto(self):
        build_none = pickle.NONE + pickle.STOP
        for proto in protocols:
            expected = build_none
            if proto >= 2:
                expected = pickle.PROTO + chr(proto) + expected
            p = self.dumps(None, proto)
            self.assertEqual(p, expected)

        oob = protocols[-1] + 1     # a future protocol
        badpickle = pickle.PROTO + chr(oob) + build_none
        try:
            self.loads(badpickle)
        except ValueError, detail:
            self.failUnless(str(detail).startswith(
                                            "unsupported pickle protocol"))
        else:
            self.fail("expected bad protocol number to raise ValueError")

589 590
    def test_long1(self):
        x = 12345678910111213141516178920L
591 592 593 594
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
595
            self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
596 597 598

    def test_long4(self):
        x = 12345678910111213141516178920L << (256*8)
599 600 601 602
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
603
            self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
604

605
    def test_short_tuples(self):
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
        # Map (proto, len(tuple)) to expected opcode.
        expected_opcode = {(0, 0): pickle.TUPLE,
                           (0, 1): pickle.TUPLE,
                           (0, 2): pickle.TUPLE,
                           (0, 3): pickle.TUPLE,
                           (0, 4): pickle.TUPLE,

                           (1, 0): pickle.EMPTY_TUPLE,
                           (1, 1): pickle.TUPLE,
                           (1, 2): pickle.TUPLE,
                           (1, 3): pickle.TUPLE,
                           (1, 4): pickle.TUPLE,

                           (2, 0): pickle.EMPTY_TUPLE,
                           (2, 1): pickle.TUPLE1,
                           (2, 2): pickle.TUPLE2,
                           (2, 3): pickle.TUPLE3,
                           (2, 4): pickle.TUPLE,
                          }
625
        a = ()
626 627 628 629
        b = (1,)
        c = (1, 2)
        d = (1, 2, 3)
        e = (1, 2, 3, 4)
630
        for proto in protocols:
631 632 633 634
            for x in a, b, c, d, e:
                s = self.dumps(x, proto)
                y = self.loads(s)
                self.assertEqual(x, y, (proto, x, s, y))
635
                expected = expected_opcode[proto, len(x)]
636
                self.assertEqual(opcode_in_pickle(expected, s), True)
637

638
    def test_singletons(self):
639 640 641 642 643 644 645 646 647 648 649 650 651
        # Map (proto, singleton) to expected opcode.
        expected_opcode = {(0, None): pickle.NONE,
                           (1, None): pickle.NONE,
                           (2, None): pickle.NONE,

                           (0, True): pickle.INT,
                           (1, True): pickle.INT,
                           (2, True): pickle.NEWTRUE,

                           (0, False): pickle.INT,
                           (1, False): pickle.INT,
                           (2, False): pickle.NEWFALSE,
                          }
652
        for proto in protocols:
653 654 655 656
            for x in None, False, True:
                s = self.dumps(x, proto)
                y = self.loads(s)
                self.assert_(x is y, (proto, x, s, y))
657
                expected = expected_opcode[proto, x]
658
                self.assertEqual(opcode_in_pickle(expected, s), True)
659

660
    def test_newobj_tuple(self):
661 662 663
        x = MyTuple([1, 2, 3])
        x.foo = 42
        x.bar = "hello"
664 665 666 667 668
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(tuple(x), tuple(y))
            self.assertEqual(x.__dict__, y.__dict__)
669 670

    def test_newobj_list(self):
671 672 673
        x = MyList([1, 2, 3])
        x.foo = 42
        x.bar = "hello"
674 675 676 677 678
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(list(x), list(y))
            self.assertEqual(x.__dict__, y.__dict__)
679

680
    def test_newobj_generic(self):
681
        for proto in protocols:
682 683 684 685 686 687 688 689 690 691
            for C in myclasses:
                B = C.__base__
                x = C(C.sample)
                x.foo = 42
                s = self.dumps(x, proto)
                y = self.loads(s)
                detail = (proto, C, B, x, y, type(y))
                self.assertEqual(B(x), B(y), detail)
                self.assertEqual(x.__dict__, y.__dict__, detail)

692 693 694
    # Register a type with copy_reg, with extension code extcode.  Pickle
    # an object of that type.  Check that the resulting pickle uses opcode
    # (EXT[124]) under proto 2, and not in proto 1.
695

696
    def produce_global_ext(self, extcode, opcode):
697
        e = ExtensionSaver(extcode)
698
        try:
699
            copy_reg.add_extension(__name__, "MyList", extcode)
700 701 702 703
            x = MyList([1, 2, 3])
            x.foo = 42
            x.bar = "hello"

704
            # Dump using protocol 1 for comparison.
705
            s1 = self.dumps(x, 1)
706 707 708 709
            self.assert_(__name__ in s1)
            self.assert_("MyList" in s1)
            self.assertEqual(opcode_in_pickle(opcode, s1), False)

710 711 712 713
            y = self.loads(s1)
            self.assertEqual(list(x), list(y))
            self.assertEqual(x.__dict__, y.__dict__)

714
            # Dump using protocol 2 for test.
715
            s2 = self.dumps(x, 2)
716 717 718 719
            self.assert_(__name__ not in s2)
            self.assert_("MyList" not in s2)
            self.assertEqual(opcode_in_pickle(opcode, s2), True)

720 721 722 723 724
            y = self.loads(s2)
            self.assertEqual(list(x), list(y))
            self.assertEqual(x.__dict__, y.__dict__)

        finally:
725
            e.restore()
726 727

    def test_global_ext1(self):
728 729
        self.produce_global_ext(0x00000001, pickle.EXT1)  # smallest EXT1 code
        self.produce_global_ext(0x000000ff, pickle.EXT1)  # largest EXT1 code
730 731

    def test_global_ext2(self):
732 733 734
        self.produce_global_ext(0x00000100, pickle.EXT2)  # smallest EXT2 code
        self.produce_global_ext(0x0000ffff, pickle.EXT2)  # largest EXT2 code
        self.produce_global_ext(0x0000abcd, pickle.EXT2)  # check endianness
735 736

    def test_global_ext4(self):
737 738 739 740
        self.produce_global_ext(0x00010000, pickle.EXT4)  # smallest EXT4 code
        self.produce_global_ext(0x7fffffff, pickle.EXT4)  # largest EXT4 code
        self.produce_global_ext(0x12abcdef, pickle.EXT4)  # check endianness

741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
    def test_list_chunking(self):
        n = 10  # too small to chunk
        x = range(n)
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
            num_appends = count_opcode(pickle.APPENDS, s)
            self.assertEqual(num_appends, proto > 0)

        n = 2500  # expect at least two chunks when proto > 0
        x = range(n)
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
            num_appends = count_opcode(pickle.APPENDS, s)
            if proto == 0:
                self.assertEqual(num_appends, 0)
            else:
                self.failUnless(num_appends >= 2)

    def test_dict_chunking(self):
        n = 10  # too small to chunk
        x = dict.fromkeys(range(n))
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
            num_setitems = count_opcode(pickle.SETITEMS, s)
            self.assertEqual(num_setitems, proto > 0)

        n = 2500  # expect at least two chunks when proto > 0
        x = dict.fromkeys(range(n))
        for proto in protocols:
            s = self.dumps(x, proto)
            y = self.loads(s)
            self.assertEqual(x, y)
            num_setitems = count_opcode(pickle.SETITEMS, s)
            if proto == 0:
                self.assertEqual(num_setitems, 0)
            else:
                self.failUnless(num_setitems >= 2)
784

785 786 787 788 789 790 791 792 793 794
    def test_simple_newobj(self):
        x = object.__new__(SimpleNewObj)  # avoid __init__
        x.abc = 666
        for proto in protocols:
            s = self.dumps(x, proto)
            self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
            y = self.loads(s)   # will raise TypeError if __init__ called
            self.assertEqual(y.abc, 666)
            self.assertEqual(x.__dict__, y.__dict__)

795 796 797 798 799 800 801 802 803 804 805
    def test_newobj_list_slots(self):
        x = SlotList([1, 2, 3])
        x.foo = 42
        x.bar = "hello"
        s = self.dumps(x, 2)
        y = self.loads(s)
        self.assertEqual(list(x), list(y))
        self.assertEqual(x.__dict__, y.__dict__)
        self.assertEqual(x.foo, y.foo)
        self.assertEqual(x.bar, y.bar)

806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
    def test_reduce_overrides_default_reduce_ex(self):
        for proto in 0, 1, 2:
            x = REX_one()
            self.assertEqual(x._reduce_called, 0)
            s = self.dumps(x, proto)
            self.assertEqual(x._reduce_called, 1)
            y = self.loads(s)
            self.assertEqual(y._reduce_called, 0)

    def test_reduce_ex_called(self):
        for proto in 0, 1, 2:
            x = REX_two()
            self.assertEqual(x._proto, None)
            s = self.dumps(x, proto)
            self.assertEqual(x._proto, proto)
            y = self.loads(s)
            self.assertEqual(y._proto, None)

    def test_reduce_ex_overrides_reduce(self):
        for proto in 0, 1, 2:
            x = REX_three()
            self.assertEqual(x._proto, None)
            s = self.dumps(x, proto)
            self.assertEqual(x._proto, proto)
            y = self.loads(s)
            self.assertEqual(y._proto, None)

# Test classes for reduce_ex

class REX_one(object):
    _reduce_called = 0
    def __reduce__(self):
        self._reduce_called = 1
        return REX_one, ()
    # No __reduce_ex__ here, but inheriting it from object

class REX_two(object):
    _proto = None
    def __reduce_ex__(self, proto):
        self._proto = proto
        return REX_two, ()
    # No __reduce__ here, but inheriting it from object

class REX_three(object):
    _proto = None
    def __reduce_ex__(self, proto):
        self._proto = proto
        return REX_two, ()
    def __reduce__(self):
        raise TestFailed, "This __reduce__ shouldn't be called"

# Test classes for newobj
858

859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876
class MyInt(int):
    sample = 1

class MyLong(long):
    sample = 1L

class MyFloat(float):
    sample = 1.0

class MyComplex(complex):
    sample = 1.0 + 0.0j

class MyStr(str):
    sample = "hello"

class MyUnicode(unicode):
    sample = u"hello \u1234"

877
class MyTuple(tuple):
878
    sample = (1, 2, 3)
879 880

class MyList(list):
881 882 883 884 885 886
    sample = [1, 2, 3]

class MyDict(dict):
    sample = {"a": 1, "b": 2}

myclasses = [MyInt, MyLong, MyFloat,
Guido van Rossum's avatar
Guido van Rossum committed
887
             MyComplex,
888 889 890
             MyStr, MyUnicode,
             MyTuple, MyList, MyDict]

891

892 893 894
class SlotList(MyList):
    __slots__ = ["foo"]

895 896 897 898 899
class SimpleNewObj(object):
    def __init__(self, a, b, c):
        # raise an error, to make sure this isn't called
        raise TypeError("SimpleNewObj.__init__() didn't expect to get called")

900 901 902
class AbstractPickleModuleTests(unittest.TestCase):

    def test_dump_closed_file(self):
903 904 905 906 907 908 909
        import os
        f = open(TESTFN, "w")
        try:
            f.close()
            self.assertRaises(ValueError, self.module.dump, 123, f)
        finally:
            os.remove(TESTFN)
910 911

    def test_load_closed_file(self):
912 913 914 915 916 917 918
        import os
        f = open(TESTFN, "w")
        try:
            f.close()
            self.assertRaises(ValueError, self.module.dump, 123, f)
        finally:
            os.remove(TESTFN)
919

920 921 922 923 924
    def test_highest_protocol(self):
        # Of course this needs to be changed when HIGHEST_PROTOCOL changes.
        self.assertEqual(self.module.HIGHEST_PROTOCOL, 2)


925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
class AbstractPersistentPicklerTests(unittest.TestCase):

    # This class defines persistent_id() and persistent_load()
    # functions that should be used by the pickler.  All even integers
    # are pickled using persistent ids.

    def persistent_id(self, object):
        if isinstance(object, int) and object % 2 == 0:
            self.id_count += 1
            return str(object)
        else:
            return None

    def persistent_load(self, oid):
        self.load_count += 1
        object = int(oid)
        assert object % 2 == 0
        return object

    def test_persistence(self):
        self.id_count = 0
        self.load_count = 0
        L = range(10)
        self.assertEqual(self.loads(self.dumps(L)), L)
        self.assertEqual(self.id_count, 5)
        self.assertEqual(self.load_count, 5)

    def test_bin_persistence(self):
        self.id_count = 0
        self.load_count = 0
        L = range(10)
        self.assertEqual(self.loads(self.dumps(L, 1)), L)
        self.assertEqual(self.id_count, 5)
        self.assertEqual(self.load_count, 5)