test_collections.py 36.3 KB
Newer Older
1 2
"""Unit tests for collections.py."""

3
import unittest, doctest
4
import inspect
5
from test import support
6 7
from collections import namedtuple, Counter, OrderedDict
from test import mapping_tests
Georg Brandl's avatar
Georg Brandl committed
8
import pickle, copy
9
from random import randrange, shuffle
10
import operator
11 12
import keyword
import re
13 14 15 16 17
from collections import Hashable, Iterable, Iterator
from collections import Sized, Container, Callable
from collections import Set, MutableSet
from collections import Mapping, MutableMapping
from collections import Sequence, MutableSequence
18
from collections import ByteString
19

Georg Brandl's avatar
Georg Brandl committed
20
TestNT = namedtuple('TestNT', 'x y z')    # type used for pickle tests
21 22 23 24

class TestNamedTuple(unittest.TestCase):

    def test_factory(self):
25
        Point = namedtuple('Point', 'x y')
26 27 28 29 30
        self.assertEqual(Point.__name__, 'Point')
        self.assertEqual(Point.__doc__, 'Point(x, y)')
        self.assertEqual(Point.__slots__, ())
        self.assertEqual(Point.__module__, __name__)
        self.assertEqual(Point.__getitem__, tuple.__getitem__)
31
        self.assertEqual(Point._fields, ('x', 'y'))
32 33 34 35 36 37 38 39

        self.assertRaises(ValueError, namedtuple, 'abc%', 'efg ghi')       # type has non-alpha char
        self.assertRaises(ValueError, namedtuple, 'class', 'efg ghi')      # type has keyword
        self.assertRaises(ValueError, namedtuple, '9abc', 'efg ghi')       # type starts with digit

        self.assertRaises(ValueError, namedtuple, 'abc', 'efg g%hi')       # field with non-alpha char
        self.assertRaises(ValueError, namedtuple, 'abc', 'abc class')      # field has keyword
        self.assertRaises(ValueError, namedtuple, 'abc', '8efg 9ghi')      # field starts with digit
40
        self.assertRaises(ValueError, namedtuple, 'abc', '_efg ghi')       # field with leading underscore
41 42 43
        self.assertRaises(ValueError, namedtuple, 'abc', 'efg efg ghi')    # duplicate field

        namedtuple('Point0', 'x1 y2')   # Verify that numbers are allowed in names
44
        namedtuple('_', 'a b c')        # Test leading underscores in a typename
45

Benjamin Peterson's avatar
Benjamin Peterson committed
46
        nt = namedtuple('nt', 'the quick brown fox')                       # check unicode input
47
        self.assertTrue("u'" not in repr(nt._fields))
Benjamin Peterson's avatar
Benjamin Peterson committed
48
        nt = namedtuple('nt', ('the', 'quick'))                           # check unicode input
49
        self.assertTrue("u'" not in repr(nt._fields))
Benjamin Peterson's avatar
Benjamin Peterson committed
50

51 52 53
        self.assertRaises(TypeError, Point._make, [11])                     # catch too few args
        self.assertRaises(TypeError, Point._make, [11, 22, 33])             # catch too many args

54 55
    def test_name_fixer(self):
        for spec, renamed in [
56 57 58 59 60 61
            [('efg', 'g%hi'),  ('efg', '_1')],                              # field with non-alpha char
            [('abc', 'class'), ('abc', '_1')],                              # field has keyword
            [('8efg', '9ghi'), ('_0', '_1')],                               # field starts with digit
            [('abc', '_efg'), ('abc', '_1')],                               # field with leading underscore
            [('abc', 'efg', 'efg', 'ghi'), ('abc', 'efg', '_2', 'ghi')],    # duplicate field
            [('abc', '', 'x'), ('abc', '_1', 'x')],                         # fieldname is a space
62 63 64
        ]:
            self.assertEqual(namedtuple('NT', spec, rename=True)._fields, renamed)

65
    def test_instance(self):
66
        Point = namedtuple('Point', 'x y')
67 68 69 70 71 72 73 74 75 76 77
        p = Point(11, 22)
        self.assertEqual(p, Point(x=11, y=22))
        self.assertEqual(p, Point(11, y=22))
        self.assertEqual(p, Point(y=22, x=11))
        self.assertEqual(p, Point(*(11, 22)))
        self.assertEqual(p, Point(**dict(x=11, y=22)))
        self.assertRaises(TypeError, Point, 1)                              # too few args
        self.assertRaises(TypeError, Point, 1, 2, 3)                        # too many args
        self.assertRaises(TypeError, eval, 'Point(XXX=1, y=2)', locals())   # wrong keyword argument
        self.assertRaises(TypeError, eval, 'Point(x=1)', locals())          # missing keyword argument
        self.assertEqual(repr(p), 'Point(x=11, y=22)')
78 79
        self.assertTrue('__dict__' not in dir(p))                              # verify instance has no dict
        self.assertTrue('__weakref__' not in dir(p))
80
        self.assertEqual(p, Point._make([11, 22]))                          # test _make classmethod
81 82 83
        self.assertEqual(p._fields, ('x', 'y'))                             # test _fields attribute
        self.assertEqual(p._replace(x=1), (1, 22))                          # test _replace method
        self.assertEqual(p._asdict(), dict(x=11, y=22))                     # test _asdict method
84

85
        try:
86 87
            p._replace(x=1, error=2)
        except ValueError:
88 89
            pass
        else:
90
            self._fail('Did not detect an incorrect fieldname')
91

92
        # verify that field string can have commas
93 94 95 96 97 98
        Point = namedtuple('Point', 'x, y')
        p = Point(x=11, y=22)
        self.assertEqual(repr(p), 'Point(x=11, y=22)')

        # verify that fieldspec can be a non-string sequence
        Point = namedtuple('Point', ('x', 'y'))
99 100
        p = Point(x=11, y=22)
        self.assertEqual(repr(p), 'Point(x=11, y=22)')
101 102

    def test_tupleness(self):
103
        Point = namedtuple('Point', 'x y')
104 105
        p = Point(11, 22)

106
        self.assertTrue(isinstance(p, tuple))
107 108 109 110 111 112 113 114 115 116 117 118 119 120
        self.assertEqual(p, (11, 22))                                       # matches a real tuple
        self.assertEqual(tuple(p), (11, 22))                                # coercable to a real tuple
        self.assertEqual(list(p), [11, 22])                                 # coercable to a list
        self.assertEqual(max(p), 22)                                        # iterable
        self.assertEqual(max(*p), 22)                                       # star-able
        x, y = p
        self.assertEqual(p, (x, y))                                         # unpacks like a tuple
        self.assertEqual((p[0], p[1]), (11, 22))                            # indexable like a tuple
        self.assertRaises(IndexError, p.__getitem__, 3)

        self.assertEqual(p.x, x)
        self.assertEqual(p.y, y)
        self.assertRaises(AttributeError, eval, 'p.z', locals())

121
    def test_odd_sizes(self):
122
        Zero = namedtuple('Zero', '')
123
        self.assertEqual(Zero(), ())
124
        self.assertEqual(Zero._make([]), ())
125 126 127 128
        self.assertEqual(repr(Zero()), 'Zero()')
        self.assertEqual(Zero()._asdict(), {})
        self.assertEqual(Zero()._fields, ())

129
        Dot = namedtuple('Dot', 'd')
130
        self.assertEqual(Dot(1), (1,))
131
        self.assertEqual(Dot._make([1]), (1,))
132 133 134 135 136 137
        self.assertEqual(Dot(1).d, 1)
        self.assertEqual(repr(Dot(1)), 'Dot(d=1)')
        self.assertEqual(Dot(1)._asdict(), {'d':1})
        self.assertEqual(Dot(1)._replace(d=999), (999,))
        self.assertEqual(Dot(1)._fields, ('d',))

Georg Brandl's avatar
Georg Brandl committed
138
        # n = 5000
139 140
        n = 254 # SyntaxError: more than 255 arguments:
        import string, random
Georg Brandl's avatar
Georg Brandl committed
141 142 143
        names = list(set(''.join([random.choice(string.ascii_letters)
                                  for j in range(10)]) for i in range(n)))
        n = len(names)
144 145 146
        Big = namedtuple('Big', names)
        b = Big(*range(n))
        self.assertEqual(b, tuple(range(n)))
147
        self.assertEqual(Big._make(range(n)), tuple(range(n)))
148 149 150 151 152 153 154 155 156 157 158 159
        for pos, name in enumerate(names):
            self.assertEqual(getattr(b, name), pos)
        repr(b)                                 # make sure repr() doesn't blow-up
        d = b._asdict()
        d_expected = dict(zip(names, range(n)))
        self.assertEqual(d, d_expected)
        b2 = b._replace(**dict([(names[1], 999),(names[-5], 42)]))
        b2_expected = list(range(n))
        b2_expected[1] = 999
        b2_expected[-5] = 42
        self.assertEqual(b2, tuple(b2_expected))
        self.assertEqual(b._fields, tuple(names))
160

Georg Brandl's avatar
Georg Brandl committed
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def test_pickle(self):
        p = TestNT(x=10, y=20, z=30)
        for module in (pickle,):
            loads = getattr(module, 'loads')
            dumps = getattr(module, 'dumps')
            for protocol in -1, 0, 1, 2:
                q = loads(dumps(p, protocol))
                self.assertEqual(p, q)
                self.assertEqual(p._fields, q._fields)

    def test_copy(self):
        p = TestNT(x=10, y=20, z=30)
        for copier in copy.copy, copy.deepcopy:
            q = copier(p)
            self.assertEqual(p, q)
            self.assertEqual(p._fields, q._fields)

178 179 180 181 182 183 184 185 186
    def test_name_conflicts(self):
        # Some names like "self", "cls", "tuple", "itemgetter", and "property"
        # failed when used as field names.  Test to make sure these now work.
        T = namedtuple('T', 'itemgetter property self cls tuple')
        t = T(1, 2, 3, 4, 5)
        self.assertEqual(t, (1,2,3,4,5))
        newt = t._replace(itemgetter=10, property=20, self=30, cls=40, tuple=50)
        self.assertEqual(newt, (10,20,30,40,50))

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        # Broader test of all interesting names in a template
        with support.captured_stdout() as template:
            T = namedtuple('T', 'x', verbose=True)
        words = set(re.findall('[A-Za-z]+', template.getvalue()))
        words -= set(keyword.kwlist)
        T = namedtuple('T', words)
        # test __new__
        values = tuple(range(len(words)))
        t = T(*values)
        self.assertEqual(t, values)
        t = T(**dict(zip(T._fields, values)))
        self.assertEqual(t, values)
        # test _make
        t = T._make(values)
        self.assertEqual(t, values)
        # exercise __repr__
        repr(t)
        # test _asdict
        self.assertEqual(t._asdict(), dict(zip(T._fields, values)))
        # test _replace
        t = T._make(values)
        newvalues = tuple(v*10 for v in values)
        newt = t._replace(**dict(zip(T._fields, newvalues)))
        self.assertEqual(newt, newvalues)
        # test _fields
        self.assertEqual(T._fields, tuple(words))
        # test __getnewargs__
        self.assertEqual(t.__getnewargs__(), values)

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
class ABCTestCase(unittest.TestCase):

    def validate_abstract_methods(self, abc, *names):
        methodstubs = dict.fromkeys(names, lambda s, *args: 0)

        # everything should work will all required methods are present
        C = type('C', (abc,), methodstubs)
        C()

        # instantiation should fail if a required method is missing
        for name in names:
            stubs = methodstubs.copy()
            del stubs[name]
            C = type('C', (abc,), stubs)
            self.assertRaises(TypeError, C, name)


class TestOneTrickPonyABCs(ABCTestCase):
234 235 236

    def test_Hashable(self):
        # Check some non-hashables
237
        non_samples = [bytearray(), list(), set(), dict()]
238
        for x in non_samples:
239 240
            self.assertFalse(isinstance(x, Hashable), repr(x))
            self.assertFalse(issubclass(type(x), Hashable), repr(type(x)))
241 242 243
        # Check some hashables
        samples = [None,
                   int(), float(), complex(),
244
                   str(),
245
                   tuple(), frozenset(),
246
                   int, list, object, type, bytes()
247 248
                   ]
        for x in samples:
249 250
            self.assertTrue(isinstance(x, Hashable), repr(x))
            self.assertTrue(issubclass(type(x), Hashable), repr(type(x)))
251 252 253 254 255 256
        self.assertRaises(TypeError, Hashable)
        # Check direct subclassing
        class H(Hashable):
            def __hash__(self):
                return super().__hash__()
        self.assertEqual(hash(H()), 0)
257
        self.assertFalse(issubclass(int, H))
258
        self.validate_abstract_methods(Hashable, '__hash__')
259 260 261 262 263

    def test_Iterable(self):
        # Check some non-iterables
        non_samples = [None, 42, 3.14, 1j]
        for x in non_samples:
264 265
            self.assertFalse(isinstance(x, Iterable), repr(x))
            self.assertFalse(issubclass(type(x), Iterable), repr(type(x)))
266
        # Check some iterables
267
        samples = [bytes(), str(),
268 269 270 271 272 273
                   tuple(), list(), set(), frozenset(), dict(),
                   dict().keys(), dict().items(), dict().values(),
                   (lambda: (yield))(),
                   (x for x in []),
                   ]
        for x in samples:
274 275
            self.assertTrue(isinstance(x, Iterable), repr(x))
            self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
276 277 278 279 280
        # Check direct subclassing
        class I(Iterable):
            def __iter__(self):
                return super().__iter__()
        self.assertEqual(list(I()), [])
281
        self.assertFalse(issubclass(str, I))
282
        self.validate_abstract_methods(Iterable, '__iter__')
283 284

    def test_Iterator(self):
285
        non_samples = [None, 42, 3.14, 1j, b"", "", (), [], {}, set()]
286
        for x in non_samples:
287 288
            self.assertFalse(isinstance(x, Iterator), repr(x))
            self.assertFalse(issubclass(type(x), Iterator), repr(type(x)))
289
        samples = [iter(bytes()), iter(str()),
290 291 292 293 294 295 296 297
                   iter(tuple()), iter(list()), iter(dict()),
                   iter(set()), iter(frozenset()),
                   iter(dict().keys()), iter(dict().items()),
                   iter(dict().values()),
                   (lambda: (yield))(),
                   (x for x in []),
                   ]
        for x in samples:
298 299
            self.assertTrue(isinstance(x, Iterator), repr(x))
            self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
300
        self.validate_abstract_methods(Iterator, '__next__')
301 302 303 304 305 306 307

    def test_Sized(self):
        non_samples = [None, 42, 3.14, 1j,
                       (lambda: (yield))(),
                       (x for x in []),
                       ]
        for x in non_samples:
308 309
            self.assertFalse(isinstance(x, Sized), repr(x))
            self.assertFalse(issubclass(type(x), Sized), repr(type(x)))
310
        samples = [bytes(), str(),
311 312 313 314
                   tuple(), list(), set(), frozenset(), dict(),
                   dict().keys(), dict().items(), dict().values(),
                   ]
        for x in samples:
315 316
            self.assertTrue(isinstance(x, Sized), repr(x))
            self.assertTrue(issubclass(type(x), Sized), repr(type(x)))
317
        self.validate_abstract_methods(Sized, '__len__')
318 319 320 321 322 323 324

    def test_Container(self):
        non_samples = [None, 42, 3.14, 1j,
                       (lambda: (yield))(),
                       (x for x in []),
                       ]
        for x in non_samples:
325 326
            self.assertFalse(isinstance(x, Container), repr(x))
            self.assertFalse(issubclass(type(x), Container), repr(type(x)))
327
        samples = [bytes(), str(),
328 329 330 331
                   tuple(), list(), set(), frozenset(), dict(),
                   dict().keys(), dict().items(),
                   ]
        for x in samples:
332 333
            self.assertTrue(isinstance(x, Container), repr(x))
            self.assertTrue(issubclass(type(x), Container), repr(type(x)))
334
        self.validate_abstract_methods(Container, '__contains__')
335 336 337 338 339 340 341 342

    def test_Callable(self):
        non_samples = [None, 42, 3.14, 1j,
                       "", b"", (), [], {}, set(),
                       (lambda: (yield))(),
                       (x for x in []),
                       ]
        for x in non_samples:
343 344
            self.assertFalse(isinstance(x, Callable), repr(x))
            self.assertFalse(issubclass(type(x), Callable), repr(type(x)))
345 346 347 348 349 350
        samples = [lambda: None,
                   type, int, object,
                   len,
                   list.append, [].append,
                   ]
        for x in samples:
351 352
            self.assertTrue(isinstance(x, Callable), repr(x))
            self.assertTrue(issubclass(type(x), Callable), repr(type(x)))
353
        self.validate_abstract_methods(Callable, '__call__')
354 355 356 357 358

    def test_direct_subclassing(self):
        for B in Hashable, Iterable, Iterator, Sized, Container, Callable:
            class C(B):
                pass
359 360
            self.assertTrue(issubclass(C, B))
            self.assertFalse(issubclass(int, C))
361 362 363 364 365

    def test_registration(self):
        for B in Hashable, Iterable, Iterator, Sized, Container, Callable:
            class C:
                __hash__ = None  # Make sure it isn't hashable by default
366
            self.assertFalse(issubclass(C, B), B.__name__)
367
            B.register(C)
368
            self.assertTrue(issubclass(C, B))
369

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
class WithSet(MutableSet):

    def __init__(self, it=()):
        self.data = set(it)

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def __contains__(self, item):
        return item in self.data

    def add(self, item):
        self.data.add(item)

    def discard(self, item):
        self.data.discard(item)
389

390
class TestCollectionABCs(ABCTestCase):
391 392 393 394 395 396 397

    # XXX For now, we only test some virtual inheritance properties.
    # We should also test the proper behavior of the collection ABCs
    # as real base classes or mix-in classes.

    def test_Set(self):
        for sample in [set, frozenset]:
398 399
            self.assertTrue(isinstance(sample(), Set))
            self.assertTrue(issubclass(sample, Set))
400
        self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
401

Benjamin Peterson's avatar
Benjamin Peterson committed
402 403 404 405 406 407 408 409 410 411 412 413 414
    def test_hash_Set(self):
        class OneTwoThreeSet(Set):
            def __init__(self):
                self.contents = [1, 2, 3]
            def __contains__(self, x):
                return x in self.contents
            def __len__(self):
                return len(self.contents)
            def __iter__(self):
                return iter(self.contents)
            def __hash__(self):
                return self._hash()
        a, b = OneTwoThreeSet(), OneTwoThreeSet()
415
        self.assertTrue(hash(a) == hash(b))
Benjamin Peterson's avatar
Benjamin Peterson committed
416

417
    def test_MutableSet(self):
418 419 420 421
        self.assertTrue(isinstance(set(), MutableSet))
        self.assertTrue(issubclass(set, MutableSet))
        self.assertFalse(isinstance(frozenset(), MutableSet))
        self.assertFalse(issubclass(frozenset, MutableSet))
422 423 424
        self.validate_abstract_methods(MutableSet, '__contains__', '__iter__', '__len__',
            'add', 'discard')

425 426 427 428 429 430
    def test_issue_5647(self):
        # MutableSet.__iand__ mutated the set during iteration
        s = WithSet('abcd')
        s &= WithSet('cdef')            # This used to fail
        self.assertEqual(set(s), set('cd'))

431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
    def test_issue_4920(self):
        # MutableSet.pop() method did not work
        class MySet(collections.MutableSet):
            __slots__=['__s']
            def __init__(self,items=None):
                if items is None:
                    items=[]
                self.__s=set(items)
            def __contains__(self,v):
                return v in self.__s
            def __iter__(self):
                return iter(self.__s)
            def __len__(self):
                return len(self.__s)
            def add(self,v):
                result=v not in self.__s
                self.__s.add(v)
                return result
            def discard(self,v):
                result=v in self.__s
                self.__s.discard(v)
                return result
            def __repr__(self):
                return "MySet(%s)" % repr(list(self))
        s = MySet([5,43,2,1])
        self.assertEqual(s.pop(), 1)
457 458 459

    def test_Mapping(self):
        for sample in [dict]:
460 461
            self.assertTrue(isinstance(sample(), Mapping))
            self.assertTrue(issubclass(sample, Mapping))
462 463
        self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
            '__getitem__')
464 465 466

    def test_MutableMapping(self):
        for sample in [dict]:
467 468
            self.assertTrue(isinstance(sample(), MutableMapping))
            self.assertTrue(issubclass(sample, MutableMapping))
469 470
        self.validate_abstract_methods(MutableMapping, '__contains__', '__iter__', '__len__',
            '__getitem__', '__setitem__', '__delitem__')
471 472 473

    def test_Sequence(self):
        for sample in [tuple, list, bytes, str]:
474 475 476 477 478
            self.assertTrue(isinstance(sample(), Sequence))
            self.assertTrue(issubclass(sample, Sequence))
        self.assertTrue(isinstance(range(10), Sequence))
        self.assertTrue(issubclass(range, Sequence))
        self.assertTrue(issubclass(str, Sequence))
479 480
        self.validate_abstract_methods(Sequence, '__contains__', '__iter__', '__len__',
            '__getitem__')
481

482 483
    def test_ByteString(self):
        for sample in [bytes, bytearray]:
484 485
            self.assertTrue(isinstance(sample(), ByteString))
            self.assertTrue(issubclass(sample, ByteString))
486
        for sample in [str, list, tuple]:
487 488 489 490
            self.assertFalse(isinstance(sample(), ByteString))
            self.assertFalse(issubclass(sample, ByteString))
        self.assertFalse(isinstance(memoryview(b""), ByteString))
        self.assertFalse(issubclass(memoryview, ByteString))
491

492
    def test_MutableSequence(self):
493
        for sample in [tuple, str, bytes]:
494 495
            self.assertFalse(isinstance(sample(), MutableSequence))
            self.assertFalse(issubclass(sample, MutableSequence))
496
        for sample in [list, bytearray]:
497 498 499
            self.assertTrue(isinstance(sample(), MutableSequence))
            self.assertTrue(issubclass(sample, MutableSequence))
        self.assertFalse(issubclass(str, MutableSequence))
500 501
        self.validate_abstract_methods(MutableSequence, '__contains__', '__iter__',
            '__len__', '__getitem__', '__setitem__', '__delitem__', 'insert')
502

503 504 505 506
class TestCounter(unittest.TestCase):

    def test_basics(self):
        c = Counter('abcaba')
507 508
        self.assertEqual(c, Counter({'a':3 , 'b': 2, 'c': 1}))
        self.assertEqual(c, Counter(a=3, b=2, c=1))
509 510 511 512
        self.assertTrue(isinstance(c, dict))
        self.assertTrue(isinstance(c, Mapping))
        self.assertTrue(issubclass(Counter, dict))
        self.assertTrue(issubclass(Counter, Mapping))
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
        self.assertEqual(len(c), 3)
        self.assertEqual(sum(c.values()), 6)
        self.assertEqual(sorted(c.values()), [1, 2, 3])
        self.assertEqual(sorted(c.keys()), ['a', 'b', 'c'])
        self.assertEqual(sorted(c), ['a', 'b', 'c'])
        self.assertEqual(sorted(c.items()),
                         [('a', 3), ('b', 2), ('c', 1)])
        self.assertEqual(c['b'], 2)
        self.assertEqual(c['z'], 0)
        self.assertEqual(c.__contains__('c'), True)
        self.assertEqual(c.__contains__('z'), False)
        self.assertEqual(c.get('b', 10), 2)
        self.assertEqual(c.get('z', 10), 10)
        self.assertEqual(c, dict(a=3, b=2, c=1))
        self.assertEqual(repr(c), "Counter({'a': 3, 'b': 2, 'c': 1})")
        self.assertEqual(c.most_common(), [('a', 3), ('b', 2), ('c', 1)])
        for i in range(5):
            self.assertEqual(c.most_common(i),
                             [('a', 3), ('b', 2), ('c', 1)][:i])
        self.assertEqual(''.join(sorted(c.elements())), 'aaabbc')
        c['a'] += 1         # increment an existing value
        c['b'] -= 2         # sub existing value to zero
        del c['c']          # remove an entry
536
        del c['c']          # make sure that del doesn't raise KeyError
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
        c['d'] -= 2         # sub from a missing value
        c['e'] = -5         # directly assign a missing value
        c['f'] += 4         # add to a missing value
        self.assertEqual(c, dict(a=4, b=0, d=-2, e=-5, f=4))
        self.assertEqual(''.join(sorted(c.elements())), 'aaaaffff')
        self.assertEqual(c.pop('f'), 4)
        self.assertEqual('f' in c, False)
        for i in range(3):
            elem, cnt = c.popitem()
            self.assertEqual(elem in c, False)
        c.clear()
        self.assertEqual(c, {})
        self.assertEqual(repr(c), 'Counter()')
        self.assertRaises(NotImplementedError, Counter.fromkeys, 'abc')
        self.assertRaises(TypeError, hash, c)
552 553
        c.update(dict(a=5, b=3))
        c.update(c=1)
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
        c.update(Counter('a' * 50 + 'b' * 30))
        c.update()          # test case with no args
        c.__init__('a' * 500 + 'b' * 300)
        c.__init__('cdc')
        c.__init__()
        self.assertEqual(c, dict(a=555, b=333, c=3, d=1))
        self.assertEqual(c.setdefault('d', 5), 1)
        self.assertEqual(c['d'], 1)
        self.assertEqual(c.setdefault('e', 5), 5)
        self.assertEqual(c['e'], 5)

    def test_copying(self):
        # Check that counters are copyable, deepcopyable, picklable, and
        #have a repr/eval round-trip
        words = Counter('which witch had which witches wrist watch'.split())
        update_test = Counter()
        update_test.update(words)
        for i, dup in enumerate([
                    words.copy(),
                    copy.copy(words),
                    copy.deepcopy(words),
                    pickle.loads(pickle.dumps(words, 0)),
                    pickle.loads(pickle.dumps(words, 1)),
                    pickle.loads(pickle.dumps(words, 2)),
                    pickle.loads(pickle.dumps(words, -1)),
                    eval(repr(words)),
                    update_test,
                    Counter(words),
                    ]):
            msg = (i, dup, words)
584
            self.assertTrue(dup is not words)
585 586 587 588 589 590 591 592 593 594 595 596
            self.assertEquals(dup, words)
            self.assertEquals(len(dup), len(words))
            self.assertEquals(type(dup), type(words))

    def test_conversions(self):
        # Convert to: set, list, dict
        s = 'she sells sea shells by the sea shore'
        self.assertEqual(sorted(Counter(s).elements()), sorted(s))
        self.assertEqual(sorted(Counter(s)), sorted(set(s)))
        self.assertEqual(dict(Counter(s)), dict(Counter(s).items()))
        self.assertEqual(set(Counter(s)), set(s))

597 598 599
    def test_invariant_for_the_in_operator(self):
        c = Counter(a=10, b=-2, c=0)
        for elem in c:
600
            self.assertTrue(elem in c)
601

602 603 604 605 606 607 608 609 610
    def test_multiset_operations(self):
        # Verify that adding a zero counter will strip zeros and negatives
        c = Counter(a=10, b=-2, c=0) + Counter()
        self.assertEqual(dict(c), dict(a=10))

        elements = 'abcd'
        for i in range(1000):
            # test random pairs of multisets
            p = Counter(dict((elem, randrange(-2,4)) for elem in elements))
611
            p.update(e=1, f=-1, g=0)
612
            q = Counter(dict((elem, randrange(-2,4)) for elem in elements))
613 614 615 616 617 618
            q.update(h=1, i=-1, j=0)
            for counterop, numberop in [
                (Counter.__add__, lambda x, y: max(0, x+y)),
                (Counter.__sub__, lambda x, y: max(0, x-y)),
                (Counter.__or__, lambda x, y: max(0,x,y)),
                (Counter.__and__, lambda x, y: max(0, min(x,y))),
619 620 621
            ]:
                result = counterop(p, q)
                for x in elements:
622 623
                    self.assertEqual(numberop(p[x], q[x]), result[x],
                                     (counterop, x, p, q))
624
                # verify that results exclude non-positive counts
625
                self.assertTrue(x>0 for x in result.values())
626 627 628 629 630 631 632 633 634 635 636 637 638 639

        elements = 'abcdef'
        for i in range(100):
            # verify that random multisets with no repeats are exactly like sets
            p = Counter(dict((elem, randrange(0, 2)) for elem in elements))
            q = Counter(dict((elem, randrange(0, 2)) for elem in elements))
            for counterop, setop in [
                (Counter.__sub__, set.__sub__),
                (Counter.__or__, set.__or__),
                (Counter.__and__, set.__and__),
            ]:
                counter_result = counterop(p, q)
                set_result = setop(set(p.elements()), set(q.elements()))
                self.assertEqual(counter_result, dict.fromkeys(set_result, 1))
640

641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699

class TestOrderedDict(unittest.TestCase):

    def test_init(self):
        with self.assertRaises(TypeError):
            OrderedDict([('a', 1), ('b', 2)], None)                                 # too many args
        pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)]
        self.assertEqual(sorted(OrderedDict(dict(pairs)).items()), pairs)           # dict input
        self.assertEqual(sorted(OrderedDict(**dict(pairs)).items()), pairs)         # kwds input
        self.assertEqual(list(OrderedDict(pairs).items()), pairs)                   # pairs input
        self.assertEqual(list(OrderedDict([('a', 1), ('b', 2), ('c', 9), ('d', 4)],
                                          c=3, e=5).items()), pairs)                # mixed input

        # make sure no positional args conflict with possible kwdargs
        self.assertEqual(inspect.getargspec(OrderedDict.__dict__['__init__']).args,
                         ['self'])

        # Make sure that direct calls to __init__ do not clear previous contents
        d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)])
        d.__init__([('e', 5), ('f', 6)], g=7, d=4)
        self.assertEqual(list(d.items()),
            [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)])

    def test_update(self):
        with self.assertRaises(TypeError):
            OrderedDict().update([('a', 1), ('b', 2)], None)                        # too many args
        pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)]
        od = OrderedDict()
        od.update(dict(pairs))
        self.assertEqual(sorted(od.items()), pairs)                                 # dict input
        od = OrderedDict()
        od.update(**dict(pairs))
        self.assertEqual(sorted(od.items()), pairs)                                 # kwds input
        od = OrderedDict()
        od.update(pairs)
        self.assertEqual(list(od.items()), pairs)                                   # pairs input
        od = OrderedDict()
        od.update([('a', 1), ('b', 2), ('c', 9), ('d', 4)], c=3, e=5)
        self.assertEqual(list(od.items()), pairs)                                   # mixed input

        # Make sure that direct calls to update do not clear previous contents
        # add that updates items are not moved to the end
        d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)])
        d.update([('e', 5), ('f', 6)], g=7, d=4)
        self.assertEqual(list(d.items()),
            [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)])

    def test_clear(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od = OrderedDict(pairs)
        self.assertEqual(len(od), len(pairs))
        od.clear()
        self.assertEqual(len(od), 0)

    def test_delitem(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        od = OrderedDict(pairs)
        del od['a']
700
        self.assertTrue('a' not in od)
701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
        with self.assertRaises(KeyError):
            del od['a']
        self.assertEqual(list(od.items()), pairs[:2] + pairs[3:])

    def test_setitem(self):
        od = OrderedDict([('d', 1), ('b', 2), ('c', 3), ('a', 4), ('e', 5)])
        od['c'] = 10           # existing element
        od['f'] = 20           # new element
        self.assertEqual(list(od.items()),
                         [('d', 1), ('b', 2), ('c', 10), ('a', 4), ('e', 5), ('f', 20)])

    def test_iterators(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od = OrderedDict(pairs)
        self.assertEqual(list(od), [t[0] for t in pairs])
        self.assertEqual(list(od.keys()), [t[0] for t in pairs])
        self.assertEqual(list(od.values()), [t[1] for t in pairs])
        self.assertEqual(list(od.items()), pairs)
        self.assertEqual(list(reversed(od)),
                         [t[0] for t in reversed(pairs)])

    def test_popitem(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od = OrderedDict(pairs)
        while pairs:
            self.assertEqual(od.popitem(), pairs.pop())
        with self.assertRaises(KeyError):
            od.popitem()
        self.assertEqual(len(od), 0)

    def test_pop(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od = OrderedDict(pairs)
        shuffle(pairs)
        while pairs:
            k, v = pairs.pop()
            self.assertEqual(od.pop(k), v)
        with self.assertRaises(KeyError):
            od.pop('xyz')
        self.assertEqual(len(od), 0)
        self.assertEqual(od.pop(k, 12345), 12345)

    def test_equality(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od1 = OrderedDict(pairs)
        od2 = OrderedDict(pairs)
        self.assertEqual(od1, od2)          # same order implies equality
        pairs = pairs[2:] + pairs[:2]
        od2 = OrderedDict(pairs)
        self.assertNotEqual(od1, od2)       # different order implies inequality
        # comparison to regular dict is not order sensitive
        self.assertEqual(od1, dict(od2))
        self.assertEqual(dict(od2), od1)
        # different length implied inequality
        self.assertNotEqual(od1, OrderedDict(pairs[:-1]))

    def test_copying(self):
        # Check that ordered dicts are copyable, deepcopyable, picklable,
        # and have a repr/eval round-trip
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        od = OrderedDict(pairs)
        update_test = OrderedDict()
        update_test.update(od)
        for i, dup in enumerate([
                    od.copy(),
                    copy.copy(od),
                    copy.deepcopy(od),
                    pickle.loads(pickle.dumps(od, 0)),
                    pickle.loads(pickle.dumps(od, 1)),
                    pickle.loads(pickle.dumps(od, 2)),
                    pickle.loads(pickle.dumps(od, 3)),
                    pickle.loads(pickle.dumps(od, -1)),
                    eval(repr(od)),
                    update_test,
                    OrderedDict(od),
                    ]):
781
            self.assertTrue(dup is not od)
782 783 784 785 786
            self.assertEquals(dup, od)
            self.assertEquals(list(dup.items()), list(od.items()))
            self.assertEquals(len(dup), len(od))
            self.assertEquals(type(dup), type(od))

Raymond Hettinger's avatar
Raymond Hettinger committed
787 788 789 790 791 792 793
    def test_yaml_linkage(self):
        # Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature.
        # In yaml, lists are native but tuples are not.
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        od = OrderedDict(pairs)
        # yaml.dump(od) -->
        # '!!python/object/apply:__main__.OrderedDict\n- - [a, 1]\n  - [b, 2]\n'
794
        self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1]))
Raymond Hettinger's avatar
Raymond Hettinger committed
795

796 797 798 799 800 801 802 803
    def test_reduce_not_too_fat(self):
        # do not save instance dictionary if not needed
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        od = OrderedDict(pairs)
        self.assertEqual(len(od.__reduce__()), 2)
        od.x = 10
        self.assertEqual(len(od.__reduce__()), 3)

804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
    def test_repr(self):
        od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])
        self.assertEqual(repr(od),
            "OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])")
        self.assertEqual(eval(repr(od)), od)
        self.assertEqual(repr(OrderedDict()), "OrderedDict()")

    def test_setdefault(self):
        pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
        shuffle(pairs)
        od = OrderedDict(pairs)
        pair_order = list(od.items())
        self.assertEqual(od.setdefault('a', 10), 3)
        # make sure order didn't change
        self.assertEqual(list(od.items()), pair_order)
        self.assertEqual(od.setdefault('x', 10), 10)
        # make sure 'x' is added to the end
        self.assertEqual(list(od.items())[-1], ('x', 10))

    def test_reinsert(self):
        # Given insert a, insert b, delete a, re-insert a,
        # verify that a is now later than b.
        od = OrderedDict()
        od['a'] = 1
        od['b'] = 2
        del od['a']
        od['a'] = 1
        self.assertEqual(list(od.items()), [('b', 2), ('a', 1)])



class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
    type2test = OrderedDict

838 839 840 841
    def test_popitem(self):
        d = self._empty_mapping()
        self.assertRaises(KeyError, d.popitem)

842 843 844 845 846 847
class MyOrderedDict(OrderedDict):
    pass

class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
    type2test = MyOrderedDict

848 849 850
    def test_popitem(self):
        d = self._empty_mapping()
        self.assertRaises(KeyError, d.popitem)
851 852


853
import doctest, collections
854

855
def test_main(verbose=None):
856
    NamedTupleDocs = doctest.DocTestSuite(module=collections)
857
    test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs,
858 859
                    TestCollectionABCs, TestCounter,
                    TestOrderedDict, GeneralMappingTests, SubclassMappingTests]
860 861
    support.run_unittest(*test_classes)
    support.run_doctest(collections, verbose)
862

863

864 865
if __name__ == "__main__":
    test_main(verbose=True)