test_iterlen.py 7.57 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
""" Test Iterator Length Transparency

Some functions or methods which accept general iterable arguments have
optional, more efficient code paths if they know how many items to expect.
For instance, map(func, iterable), will pre-allocate the exact amount of
space required whenever the iterable can report its length.

The desired invariant is:  len(it)==len(list(it)).

A complication is that an iterable and iterator can be the same object. To
maintain the invariant, an iterator needs to dynamically update its length.
12 13
For instance, an iterable such as range(10) always reports its length as ten,
but it=iter(range(10)) starts at ten, and then goes to nine after next(it).
14 15 16 17 18
Having this capability means that map() can ignore the distinction between
map(func, iterable) and map(func, iter(iterable)).

When the iterable is immutable, the implementation can straight-forwardly
report the original length minus the cumulative number of calls to next().
19
This is the case for tuples, range objects, and itertools.repeat().
20 21 22

Some containers become temporarily immutable during iteration.  This includes
dicts, sets, and collections.deque.  Their implementation is equally simple
23
though they need to permanently set their length to zero whenever there is
24 25 26
an attempt to iterate after a length mutation.

The situation slightly more involved whenever an object allows length mutation
27
during iteration.  Lists and sequence iterators are dynamically updatable.
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
So, if a list is extended during iteration, the iterator will continue through
the new items.  If it shrinks to a point before the most recent iteration,
then no further items are available and the length is reported at zero.

Reversed objects can also be wrapped around mutable objects; however, any
appends after the current position are ignored.  Any other approach leads
to confusion and possibly returning the same item more than once.

The iterators not listed above, such as enumerate and the other itertools,
are not length transparent because they have no way to distinguish between
iterables that report static length and iterators whose length changes with
each call (i.e. the difference between enumerate('abc') and
enumerate(iter('abc')).

"""

import unittest
45
from test import support
46
from itertools import repeat
47
from collections import deque
48
from builtins import len as _len
49 50 51

n = 10

52 53 54 55 56
def len(obj):
    try:
        return _len(obj)
    except TypeError:
        try:
57 58 59
            # note: this is an internal undocumented API,
            # don't rely on it in your own programs
            return obj.__length_hint__()
60 61 62
        except AttributeError:
            raise TypeError

63 64 65
class TestInvariantWithoutMutations(unittest.TestCase):

    def test_invariant(self):
Tim Peters's avatar
Tim Peters committed
66
        it = self.it
67
        for i in reversed(range(1, n+1)):
Tim Peters's avatar
Tim Peters committed
68
            self.assertEqual(len(it), i)
69
            next(it)
Tim Peters's avatar
Tim Peters committed
70
        self.assertEqual(len(it), 0)
71
        self.assertRaises(StopIteration, next, it)
Tim Peters's avatar
Tim Peters committed
72
        self.assertEqual(len(it), 0)
73 74 75 76 77 78 79 80 81

class TestTemporarilyImmutable(TestInvariantWithoutMutations):

    def test_immutable_during_iteration(self):
        # objects such as deques, sets, and dictionaries enforce
        # length immutability  during iteration

        it = self.it
        self.assertEqual(len(it), n)
82
        next(it)
83 84
        self.assertEqual(len(it), n-1)
        self.mutate()
85
        self.assertRaises(RuntimeError, next, it)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        self.assertEqual(len(it), 0)

## ------- Concrete Type Tests -------

class TestRepeat(TestInvariantWithoutMutations):

    def setUp(self):
        self.it = repeat(None, n)

    def test_no_len_for_infinite_repeat(self):
        # The repeat() object can also be infinite
        self.assertRaises(TypeError, len, repeat(None))

class TestXrange(TestInvariantWithoutMutations):

    def setUp(self):
102
        self.it = iter(range(n))
103 104 105 106

class TestXrangeCustomReversed(TestInvariantWithoutMutations):

    def setUp(self):
107
        self.it = reversed(range(n))
108 109 110 111

class TestTuple(TestInvariantWithoutMutations):

    def setUp(self):
112
        self.it = iter(tuple(range(n)))
113 114 115 116 117 118

## ------- Types that should not be mutated during iteration -------

class TestDeque(TestTemporarilyImmutable):

    def setUp(self):
119
        d = deque(range(n))
120 121 122 123 124 125
        self.it = iter(d)
        self.mutate = d.pop

class TestDequeReversed(TestTemporarilyImmutable):

    def setUp(self):
126
        d = deque(range(n))
127 128 129 130 131 132
        self.it = reversed(d)
        self.mutate = d.pop

class TestDictKeys(TestTemporarilyImmutable):

    def setUp(self):
133
        d = dict.fromkeys(range(n))
134 135 136 137 138 139
        self.it = iter(d)
        self.mutate = d.popitem

class TestDictItems(TestTemporarilyImmutable):

    def setUp(self):
140
        d = dict.fromkeys(range(n))
141
        self.it = iter(d.items())
142 143 144 145 146
        self.mutate = d.popitem

class TestDictValues(TestTemporarilyImmutable):

    def setUp(self):
147
        d = dict.fromkeys(range(n))
148
        self.it = iter(d.values())
149 150 151 152 153
        self.mutate = d.popitem

class TestSet(TestTemporarilyImmutable):

    def setUp(self):
154
        d = set(range(n))
155 156 157 158 159 160 161 162 163 164 165
        self.it = iter(d)
        self.mutate = d.pop

## ------- Types that can mutate during iteration -------

class TestList(TestInvariantWithoutMutations):

    def setUp(self):
        self.it = iter(range(n))

    def test_mutation(self):
166
        d = list(range(n))
167
        it = iter(d)
168 169
        next(it)
        next(it)
170 171 172 173 174 175
        self.assertEqual(len(it), n-2)
        d.append(n)
        self.assertEqual(len(it), n-1)  # grow with append
        d[1:] = []
        self.assertEqual(len(it), 0)
        self.assertEqual(list(it), [])
176
        d.extend(range(20))
177 178 179 180 181 182 183 184
        self.assertEqual(len(it), 0)

class TestListReversed(TestInvariantWithoutMutations):

    def setUp(self):
        self.it = reversed(range(n))

    def test_mutation(self):
185
        d = list(range(n))
186
        it = reversed(d)
187 188
        next(it)
        next(it)
189 190 191 192 193 194
        self.assertEqual(len(it), n-2)
        d.append(n)
        self.assertEqual(len(it), n-2)  # ignore append
        d[1:] = []
        self.assertEqual(len(it), 0)
        self.assertEqual(list(it), [])  # confirm invariant
195
        d.extend(range(20))
196 197
        self.assertEqual(len(it), 0)

198 199 200 201 202 203 204 205 206 207 208 209 210
## -- Check to make sure exceptions are not suppressed by __length_hint__()


class BadLen(object):
    def __iter__(self): return iter(range(10))
    def __len__(self):
        raise RuntimeError('hello')

class BadLengthHint(object):
    def __iter__(self): return iter(range(10))
    def __length_hint__(self):
        raise RuntimeError('hello')

211 212 213 214 215
class NoneLengthHint(object):
    def __iter__(self): return iter(range(10))
    def __length_hint__(self):
        return None

216 217 218 219 220 221 222 223 224 225
class TestLengthHintExceptions(unittest.TestCase):

    def test_issue1242657(self):
        self.assertRaises(RuntimeError, list, BadLen())
        self.assertRaises(RuntimeError, list, BadLengthHint())
        self.assertRaises(RuntimeError, [].extend, BadLen())
        self.assertRaises(RuntimeError, [].extend, BadLengthHint())
        b = bytearray(range(10))
        self.assertRaises(RuntimeError, b.extend, BadLen())
        self.assertRaises(RuntimeError, b.extend, BadLengthHint())
226

227 228 229 230 231
    def test_invalid_hint(self):
        # Make sure an invalid result doesn't muck-up the works
        self.assertEqual(list(NoneLengthHint()), list(range(10)))


232
def test_main():
233 234 235 236 237 238 239 240 241 242 243 244 245
    unittests = [
        TestRepeat,
        TestXrange,
        TestXrangeCustomReversed,
        TestTuple,
        TestDeque,
        TestDequeReversed,
        TestDictKeys,
        TestDictItems,
        TestDictValues,
        TestSet,
        TestList,
        TestListReversed,
246
        TestLengthHintExceptions,
247
    ]
248
    support.run_unittest(*unittests)
249 250 251

if __name__ == "__main__":
    test_main()