test_defaultdict.py 5.86 KB
Newer Older
Guido van Rossum's avatar
Guido van Rossum committed
1 2 3 4
"""Unit tests for collections.defaultdict."""

import os
import copy
5
import pickle
Guido van Rossum's avatar
Guido van Rossum committed
6 7
import tempfile
import unittest
8
from test import support
Guido van Rossum's avatar
Guido van Rossum committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

from collections import defaultdict

def foobar():
    return list

class TestDefaultDict(unittest.TestCase):

    def test_basic(self):
        d1 = defaultdict()
        self.assertEqual(d1.default_factory, None)
        d1.default_factory = list
        d1[12].append(42)
        self.assertEqual(d1, {12: [42]})
        d1[12].append(24)
        self.assertEqual(d1, {12: [42, 24]})
        d1[13]
        d1[14]
        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
28
        self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum's avatar
Guido van Rossum committed
29 30 31 32 33 34
        d2 = defaultdict(list, foo=1, bar=2)
        self.assertEqual(d2.default_factory, list)
        self.assertEqual(d2, {"foo": 1, "bar": 2})
        self.assertEqual(d2["foo"], 1)
        self.assertEqual(d2["bar"], 2)
        self.assertEqual(d2[42], [])
35 36 37 38 39 40 41 42
        self.assertIn("foo", d2)
        self.assertIn("foo", d2.keys())
        self.assertIn("bar", d2)
        self.assertIn("bar", d2.keys())
        self.assertIn(42, d2)
        self.assertIn(42, d2.keys())
        self.assertNotIn(12, d2)
        self.assertNotIn(12, d2.keys())
Guido van Rossum's avatar
Guido van Rossum committed
43 44 45 46
        d2.default_factory = None
        self.assertEqual(d2.default_factory, None)
        try:
            d2[15]
47
        except KeyError as err:
Guido van Rossum's avatar
Guido van Rossum committed
48 49 50
            self.assertEqual(err.args, (15,))
        else:
            self.fail("d2[15] didn't raise KeyError")
51
        self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum's avatar
Guido van Rossum committed
52 53 54 55 56 57 58 59 60 61 62

    def test_missing(self):
        d1 = defaultdict()
        self.assertRaises(KeyError, d1.__missing__, 42)
        d1.default_factory = list
        self.assertEqual(d1.__missing__(42), [])

    def test_repr(self):
        d1 = defaultdict()
        self.assertEqual(d1.default_factory, None)
        self.assertEqual(repr(d1), "defaultdict(None, {})")
63
        self.assertEqual(eval(repr(d1)), d1)
Guido van Rossum's avatar
Guido van Rossum committed
64 65
        d1[11] = 41
        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
66 67
        d2 = defaultdict(int)
        self.assertEqual(d2.default_factory, int)
Guido van Rossum's avatar
Guido van Rossum committed
68
        d2[12] = 42
69
        self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
Guido van Rossum's avatar
Guido van Rossum committed
70 71
        def foo(): return 43
        d3 = defaultdict(foo)
72
        self.assertTrue(d3.default_factory is foo)
Guido van Rossum's avatar
Guido van Rossum committed
73 74 75 76 77 78 79 80 81 82 83 84 85 86
        d3[13]
        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))

    def test_print(self):
        d1 = defaultdict()
        def foo(): return 42
        d2 = defaultdict(foo, {1: 2})
        # NOTE: We can't use tempfile.[Named]TemporaryFile since this
        # code must exercise the tp_print C code, which only gets
        # invoked for *real* files.
        tfn = tempfile.mktemp()
        try:
            f = open(tfn, "w+")
            try:
87 88
                print(d1, file=f)
                print(d2, file=f)
Guido van Rossum's avatar
Guido van Rossum committed
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
                f.seek(0)
                self.assertEqual(f.readline(), repr(d1) + "\n")
                self.assertEqual(f.readline(), repr(d2) + "\n")
            finally:
                f.close()
        finally:
            os.remove(tfn)

    def test_copy(self):
        d1 = defaultdict()
        d2 = d1.copy()
        self.assertEqual(type(d2), defaultdict)
        self.assertEqual(d2.default_factory, None)
        self.assertEqual(d2, {})
        d1.default_factory = list
        d3 = d1.copy()
        self.assertEqual(type(d3), defaultdict)
        self.assertEqual(d3.default_factory, list)
        self.assertEqual(d3, {})
        d1[42]
        d4 = d1.copy()
        self.assertEqual(type(d4), defaultdict)
        self.assertEqual(d4.default_factory, list)
        self.assertEqual(d4, {42: []})
        d4[12]
        self.assertEqual(d4, {42: [], 12: []})

116 117 118 119 120 121
        # Issue 6637: Copy fails for empty default dict
        d = defaultdict()
        d['a'] = 42
        e = d.copy()
        self.assertEqual(e['a'], 42)

Guido van Rossum's avatar
Guido van Rossum committed
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    def test_shallow_copy(self):
        d1 = defaultdict(foobar, {1: 1})
        d2 = copy.copy(d1)
        self.assertEqual(d2.default_factory, foobar)
        self.assertEqual(d2, d1)
        d1.default_factory = list
        d2 = copy.copy(d1)
        self.assertEqual(d2.default_factory, list)
        self.assertEqual(d2, d1)

    def test_deep_copy(self):
        d1 = defaultdict(foobar, {1: [1]})
        d2 = copy.deepcopy(d1)
        self.assertEqual(d2.default_factory, foobar)
        self.assertEqual(d2, d1)
137
        self.assertTrue(d1[1] is not d2[1])
Guido van Rossum's avatar
Guido van Rossum committed
138 139 140 141 142
        d1.default_factory = list
        d2 = copy.deepcopy(d1)
        self.assertEqual(d2.default_factory, list)
        self.assertEqual(d2, d1)

143 144 145 146 147
    def test_keyerror_without_factory(self):
        d1 = defaultdict()
        try:
            d1[(1,)]
        except KeyError as err:
148
            self.assertEqual(err.args[0], (1,))
149 150 151
        else:
            self.fail("expected KeyError")

Christian Heimes's avatar
Christian Heimes committed
152 153 154 155 156 157 158 159
    def test_recursive_repr(self):
        # Issue2045: stack overflow when default_factory is a bound method
        class sub(defaultdict):
            def __init__(self):
                self.default_factory = self._factory
            def _factory(self):
                return []
        d = sub()
160
        self.assertTrue(repr(d).startswith(
Christian Heimes's avatar
Christian Heimes committed
161 162 163 164 165 166 167 168 169 170 171 172 173 174
            "defaultdict(<bound method sub._factory of defaultdict(..."))

        # NOTE: printing a subclass of a builtin type does not call its
        # tp_print slot. So this part is essentially the same test as above.
        tfn = tempfile.mktemp()
        try:
            f = open(tfn, "w+")
            try:
                print(d, file=f)
            finally:
                f.close()
        finally:
            os.remove(tfn)

175 176 177 178 179 180 181
    def test_pickleing(self):
        d = defaultdict(int)
        d[1]
        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
            s = pickle.dumps(d, proto)
            o = pickle.loads(s)
            self.assertEqual(d, o)
Guido van Rossum's avatar
Guido van Rossum committed
182

183
def test_main():
184
    support.run_unittest(TestDefaultDict)
185

Guido van Rossum's avatar
Guido van Rossum committed
186
if __name__ == "__main__":
187
    test_main()