collections.py 39.3 KB
Newer Older
1
__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList',
2
            'UserString', 'Counter', 'OrderedDict']
3 4 5 6 7 8
# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py.
# They should however be considered an integral part of collections.py.
from _abcoll import *
import _abcoll
__all__ += _abcoll.__all__

9 10 11 12
from _collections import deque, defaultdict
from operator import itemgetter as _itemgetter
from keyword import iskeyword as _iskeyword
import sys as _sys
13
import heapq as _heapq
14
from weakref import proxy as _proxy
15
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
16
from reprlib import recursive_repr as _recursive_repr
17 18 19 20 21

################################################################################
### OrderedDict
################################################################################

22 23 24
class _Link(object):
    __slots__ = 'prev', 'next', 'key', '__weakref__'

25
class OrderedDict(dict):
26
    'Dictionary that remembers insertion order'
27
    # An inherited dict maps keys to values.
28 29
    # The inherited dict provides __getitem__, __len__, __contains__, and get.
    # The remaining methods are order-aware.
30 31 32 33 34
    # Big-O running times for all methods are the same as for regular dictionaries.

    # The internal self.__map dictionary maps keys to links in a doubly linked list.
    # The circular doubly linked list starts and ends with a sentinel element.
    # The sentinel element never gets deleted (this simplifies the algorithm).
35
    # The sentinel is stored in self.__hardroot with a weakref proxy in self.__root.
36 37 38
    # The prev/next links are weakref proxies (to prevent circular references).
    # Individual links are kept alive by the hard reference in self.__map.
    # Those hard references disappear when a key is deleted from an OrderedDict.
39 40

    def __init__(self, *args, **kwds):
Raymond Hettinger's avatar
Raymond Hettinger committed
41 42 43 44 45
        '''Initialize an ordered dictionary.  Signature is the same as for
        regular dictionaries, but keyword arguments are not recommended
        because their insertion order is arbitrary.

        '''
46 47
        if len(args) > 1:
            raise TypeError('expected at most 1 arguments, got %d' % len(args))
48
        try:
49
            self.__root
50
        except AttributeError:
51 52
            self.__hardroot = _Link()
            self.__root = root = _proxy(self.__hardroot)
53
            root.prev = root.next = root
54
            self.__map = {}
55
        self.__update(*args, **kwds)
56

57 58
    def __setitem__(self, key, value,
                    dict_setitem=dict.__setitem__, proxy=_proxy, Link=_Link):
Raymond Hettinger's avatar
Raymond Hettinger committed
59
        'od.__setitem__(i, y) <==> od[i]=y'
60 61
        # Setting a new item creates a new link which goes at the end of the linked
        # list, and the inherited dictionary is updated with the new key/value pair.
62
        if key not in self:
63
            self.__map[key] = link = Link()
64
            root = self.__root
65 66
            last = root.prev
            link.prev, link.next, link.key = last, root, key
67 68 69
            last.next = link
            root.prev = proxy(link)
        dict_setitem(self, key, value)
70

71
    def __delitem__(self, key, dict_delitem=dict.__delitem__):
Raymond Hettinger's avatar
Raymond Hettinger committed
72
        'od.__delitem__(y) <==> del od[y]'
73 74
        # Deleting an existing item uses self.__map to find the link which is
        # then removed by updating the links in the predecessor and successor nodes.
75
        dict_delitem(self, key)
76
        link = self.__map.pop(key)
77 78 79 80
        link_prev = link.prev
        link_next = link.next
        link_prev.next = link_next
        link_next.prev = link_prev
81

82
    def __iter__(self):
Raymond Hettinger's avatar
Raymond Hettinger committed
83
        'od.__iter__() <==> iter(od)'
84 85
        # Traverse the linked list in order.
        root = self.__root
86
        curr = root.next
87
        while curr is not root:
88 89
            yield curr.key
            curr = curr.next
90

91
    def __reversed__(self):
Raymond Hettinger's avatar
Raymond Hettinger committed
92
        'od.__reversed__() <==> reversed(od)'
93 94
        # Traverse the linked list in reverse order.
        root = self.__root
95
        curr = root.prev
96
        while curr is not root:
97 98
            yield curr.key
            curr = curr.prev
99

100 101
    def clear(self):
        'od.clear() -> None.  Remove all items from od.'
102 103 104
        root = self.__root
        root.prev = root.next = root
        self.__map.clear()
105 106
        dict.clear(self)

107
    def popitem(self, last=True):
Raymond Hettinger's avatar
Raymond Hettinger committed
108 109 110 111
        '''od.popitem() -> (k, v), return and remove a (key, value) pair.
        Pairs are returned in LIFO order if last is true or FIFO order if false.

        '''
112 113
        if not self:
            raise KeyError('dictionary is empty')
114
        root = self.__root
115 116 117 118 119 120 121 122 123 124 125
        if last:
            link = root.prev
            link_prev = link.prev
            link_prev.next = root
            root.prev = link_prev
        else:
            link = root.next
            link_next = link.next
            root.next = link_next
            link_next.prev = root
        key = link.key
126
        del self.__map[key]
127
        value = dict.pop(self, key)
128 129
        return key, value

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    def move_to_end(self, key, last=True):
        '''Move an existing element to the end (or beginning if last==False).

        Raises KeyError if the element does not exist.
        When last=True, acts like a fast version of self[key]=self.pop(key).

        '''
        link = self.__map[key]
        link_prev = link.prev
        link_next = link.next
        link_prev.next = link_next
        link_next.prev = link_prev
        root = self.__root
        if last:
            last = root.prev
            link.prev = last
            link.next = root
            last.next = root.prev = link
        else:
            first = root.next
            link.prev = root
            link.next = first
            root.next = first.prev = link

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    def __reduce__(self):
        'Return state information for pickling'
        items = [[k, self[k]] for k in self]
        tmp = self.__map, self.__root, self.__hardroot
        del self.__map, self.__root, self.__hardroot
        inst_dict = vars(self).copy()
        self.__map, self.__root, self.__hardroot = tmp
        if inst_dict:
            return (self.__class__, (items,), inst_dict)
        return self.__class__, (items,)

    def __sizeof__(self):
        sizeof = _sys.getsizeof
        n = len(self) + 1                       # number of links including root
        size = sizeof(self.__dict__)            # instance dictionary
        size += sizeof(self.__map) * 2          # internal dict and inherited dict
        size += sizeof(self.__hardroot) * n     # link objects
        size += sizeof(self.__root) * n         # proxy objects
        return size

174
    update = __update = MutableMapping.update
175 176 177 178 179
    keys = MutableMapping.keys
    values = MutableMapping.values
    items = MutableMapping.items
    __ne__ = MutableMapping.__ne__

180 181 182 183 184 185 186 187 188 189 190
    __marker = object()

    def pop(self, key, default=__marker):
        if key in self:
            result = self[key]
            del self[key]
            return result
        if default is self.__marker:
            raise KeyError(key)
        return default

191 192 193 194 195 196 197
    def setdefault(self, key, default=None):
        'OD.setdefault(k[,d]) -> OD.get(k,d), also set OD[k]=d if k not in OD'
        if key in self:
            return self[key]
        self[key] = default
        return default

198
    @_recursive_repr()
199
    def __repr__(self):
Raymond Hettinger's avatar
Raymond Hettinger committed
200
        'od.__repr__() <==> repr(od)'
201 202
        if not self:
            return '%s()' % (self.__class__.__name__,)
203
        return '%s(%r)' % (self.__class__.__name__, list(self.items()))
204 205

    def copy(self):
Raymond Hettinger's avatar
Raymond Hettinger committed
206
        'od.copy() -> a shallow copy of od'
207 208 209 210
        return self.__class__(self)

    @classmethod
    def fromkeys(cls, iterable, value=None):
Raymond Hettinger's avatar
Raymond Hettinger committed
211 212 213 214
        '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
        and values equal to v (which defaults to None).

        '''
215 216 217 218 219 220
        d = cls()
        for key in iterable:
            d[key] = value
        return d

    def __eq__(self, other):
Raymond Hettinger's avatar
Raymond Hettinger committed
221 222 223 224
        '''od.__eq__(y) <==> od==y.  Comparison to another OD is order-sensitive
        while comparison to a regular mapping is order-insensitive.

        '''
225
        if isinstance(other, OrderedDict):
226 227
            return len(self)==len(other) and \
                   all(p==q for p, q in zip(self.items(), other.items()))
228 229
        return dict.__eq__(self, other)

230

231 232 233 234
################################################################################
### namedtuple
################################################################################

235
def namedtuple(typename, field_names, verbose=False, rename=False):
236 237
    """Returns a new subclass of tuple with named fields.

238
    >>> Point = namedtuple('Point', 'x y')
239
    >>> Point.__doc__                   # docstring for the new class
240
    'Point(x, y)'
241
    >>> p = Point(11, y=22)             # instantiate with positional args or keywords
242
    >>> p[0] + p[1]                     # indexable like a plain tuple
243
    33
244
    >>> x, y = p                        # unpack like a regular tuple
245 246
    >>> x, y
    (11, 22)
247
    >>> p.x + p.y                       # fields also accessable by name
248
    33
249
    >>> d = p._asdict()                 # convert to a dictionary
250 251 252
    >>> d['x']
    11
    >>> Point(**d)                      # convert from a dictionary
253
    Point(x=11, y=22)
254
    >>> p._replace(x=100)               # _replace() is like str.replace() but targets named fields
255
    Point(x=100, y=22)
256 257 258

    """

259 260
    # Parse and validate the field names.  Validation serves two purposes,
    # generating informative error messages and preventing template injection attacks.
261 262
    if isinstance(field_names, str):
        field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
Benjamin Peterson's avatar
Benjamin Peterson committed
263
    field_names = tuple(map(str, field_names))
264 265 266 267 268 269 270
    if rename:
        names = list(field_names)
        seen = set()
        for i, name in enumerate(names):
            if (not all(c.isalnum() or c=='_' for c in name) or _iskeyword(name)
                or not name or name[0].isdigit() or name.startswith('_')
                or name in seen):
271
                names[i] = '_%d' % i
272 273
            seen.add(name)
        field_names = tuple(names)
274
    for name in (typename,) + field_names:
275
        if not all(c.isalnum() or c=='_' for c in name):
276 277 278 279 280 281 282
            raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
        if _iskeyword(name):
            raise ValueError('Type names and field names cannot be a keyword: %r' % name)
        if name[0].isdigit():
            raise ValueError('Type names and field names cannot start with a number: %r' % name)
    seen_names = set()
    for name in field_names:
283
        if name.startswith('_') and not rename:
284
            raise ValueError('Field names cannot start with an underscore: %r' % name)
285 286 287 288 289
        if name in seen_names:
            raise ValueError('Encountered duplicate field name: %r' % name)
        seen_names.add(name)

    # Create and fill-in the class template
290
    numfields = len(field_names)
291
    argtxt = repr(field_names).replace("'", "")[1:-1]   # tuple repr without parens or quotes
292 293
    reprtxt = ', '.join('%s=%%r' % name for name in field_names)
    template = '''class %(typename)s(tuple):
294 295
        '%(typename)s(%(argtxt)s)' \n
        __slots__ = () \n
296
        _fields = %(field_names)r \n
297
        def __new__(_cls, %(argtxt)s):
298
            'Create new instance of %(typename)s(%(argtxt)s)'
299
            return _tuple.__new__(_cls, (%(argtxt)s)) \n
300
        @classmethod
301
        def _make(cls, iterable, new=tuple.__new__, len=len):
302
            'Make a new %(typename)s object from a sequence or iterable'
303
            result = new(cls, iterable)
304 305 306
            if len(result) != %(numfields)d:
                raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
            return result \n
307
        def __repr__(self):
308
            'Return a nicely formatted representation string'
309
            return self.__class__.__name__ + '(%(reprtxt)s)' %% self \n
310 311 312
        def _asdict(self):
            'Return a new OrderedDict which maps field names to their values'
            return OrderedDict(zip(self._fields, self)) \n
313
        def _replace(_self, **kwds):
314
            'Return a new %(typename)s object replacing specified fields with new values'
315
            result = _self._make(map(kwds.pop, %(field_names)r, _self))
316 317
            if kwds:
                raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
Georg Brandl's avatar
Georg Brandl committed
318 319
            return result \n
        def __getnewargs__(self):
320
            'Return self as a plain tuple.  Used by copy and pickle.'
Georg Brandl's avatar
Georg Brandl committed
321
            return tuple(self) \n\n''' % locals()
322
    for i, name in enumerate(field_names):
323
        template += "        %s = _property(_itemgetter(%d), doc='Alias for field number %d')\n" % (name, i, i)
324 325
    if verbose:
        print(template)
326

Georg Brandl's avatar
Georg Brandl committed
327 328
    # Execute the template string in a temporary namespace and
    # support tracing utilities by setting a value for frame.f_globals['__name__']
329 330
    namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
                     OrderedDict=OrderedDict, _property=property, _tuple=tuple)
331 332 333
    try:
        exec(template, namespace)
    except SyntaxError as e:
334
        raise SyntaxError(e.msg + ':\n\n' + template)
335 336 337 338
    result = namespace[typename]

    # For pickling to work, the __module__ variable needs to be set to the frame
    # where the named tuple is created.  Bypass this step in enviroments where
Benjamin Peterson's avatar
Benjamin Peterson committed
339 340 341
    # sys._getframe is not defined (Jython for example) or sys._getframe is not
    # defined for arguments greater than 0 (IronPython).
    try:
342
        result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
Benjamin Peterson's avatar
Benjamin Peterson committed
343 344
    except (AttributeError, ValueError):
        pass
345

346 347 348
    return result


349 350 351 352
########################################################################
###  Counter
########################################################################

353 354 355 356 357 358 359 360 361 362 363
def _count_elements(mapping, iterable):
    'Tally elements from the iterable.'
    mapping_get = mapping.get
    for elem in iterable:
        mapping[elem] = mapping_get(elem, 0) + 1

try:                                    # Load C helper function if available
    from _collections import _count_elements
except ImportError:
    pass

364 365 366 367 368
class Counter(dict):
    '''Dict subclass for counting hashable items.  Sometimes called a bag
    or multiset.  Elements are stored as dictionary keys and their counts
    are stored as dictionary values.

369
    >>> c = Counter('abcdeabcdabcaba')  # count elements from a string
370 371

    >>> c.most_common(3)                # three most common elements
372
    [('a', 5), ('b', 4), ('c', 3)]
373
    >>> sorted(c)                       # list all unique elements
374
    ['a', 'b', 'c', 'd', 'e']
375
    >>> ''.join(sorted(c.elements()))   # list elements with repetitions
376
    'aaaaabbbbcccdde'
377
    >>> sum(c.values())                 # total of all counts
378
    15
379 380 381 382 383 384 385

    >>> c['a']                          # count of letter 'a'
    5
    >>> for elem in 'shazam':           # update counts from an iterable
    ...     c[elem] += 1                # by adding 1 to each element's count
    >>> c['a']                          # now there are seven 'a'
    7
386 387
    >>> del c['b']                      # remove all 'b'
    >>> c['b']                          # now there are zero 'b'
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
    0

    >>> d = Counter('simsalabim')       # make another counter
    >>> c.update(d)                     # add in the second counter
    >>> c['a']                          # now there are nine 'a'
    9

    >>> c.clear()                       # empty the counter
    >>> c
    Counter()

    Note:  If a count is set to zero or reduced to zero, it will remain
    in the counter until the entry is deleted or the counter is cleared:

    >>> c = Counter('aaabbc')
    >>> c['b'] -= 2                     # reduce the count of 'b' by two
    >>> c.most_common()                 # 'b' is still in, but its count is zero
    [('a', 3), ('c', 1), ('b', 0)]

    '''
    # References:
    #   http://en.wikipedia.org/wiki/Multiset
    #   http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
    #   http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
    #   http://code.activestate.com/recipes/259174/
    #   Knuth, TAOCP Vol. II section 4.6.3

415
    def __init__(self, iterable=None, **kwds):
416 417 418 419 420 421 422
        '''Create a new, empty Counter object.  And if given, count elements
        from an input iterable.  Or, initialize the count from another mapping
        of elements to their counts.

        >>> c = Counter()                           # a new, empty counter
        >>> c = Counter('gallahad')                 # a new counter from an iterable
        >>> c = Counter({'a': 4, 'b': 2})           # a new counter from a mapping
423
        >>> c = Counter(a=4, b=2)                   # a new counter from keyword args
424 425

        '''
426
        super().__init__()
427
        self.update(iterable, **kwds)
428 429 430 431 432 433 434 435 436 437

    def __missing__(self, key):
        'The count of elements not in the Counter is zero.'
        # Needed so that self[missing_item] does not raise KeyError
        return 0

    def most_common(self, n=None):
        '''List the n most common elements and their counts from the most
        common to the least.  If n is None, then list all element counts.

438 439
        >>> Counter('abcdeabcdabcaba').most_common(3)
        [('a', 5), ('b', 4), ('c', 3)]
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477

        '''
        # Emulate Bag.sortedByCount from Smalltalk
        if n is None:
            return sorted(self.items(), key=_itemgetter(1), reverse=True)
        return _heapq.nlargest(n, self.items(), key=_itemgetter(1))

    def elements(self):
        '''Iterator over elements repeating each as many times as its count.

        >>> c = Counter('ABCABC')
        >>> sorted(c.elements())
        ['A', 'A', 'B', 'B', 'C', 'C']

        # Knuth's example for prime factors of 1836:  2**2 * 3**3 * 17**1
        >>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
        >>> product = 1
        >>> for factor in prime_factors.elements():     # loop over factors
        ...     product *= factor                       # and multiply them
        >>> product
        1836

        Note, if an element's count has been set to zero or is a negative
        number, elements() will ignore it.

        '''
        # Emulate Bag.do from Smalltalk and Multiset.begin from C++.
        return _chain.from_iterable(_starmap(_repeat, self.items()))

    # Override dict methods where necessary

    @classmethod
    def fromkeys(cls, iterable, v=None):
        # There is no equivalent method for counters because setting v=1
        # means that no element can have a count greater than one.
        raise NotImplementedError(
            'Counter.fromkeys() is undefined.  Use Counter(iterable) instead.')

478
    def update(self, iterable=None, **kwds):
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
        '''Like dict.update() but add counts instead of replacing them.

        Source can be an iterable, a dictionary, or another Counter instance.

        >>> c = Counter('which')
        >>> c.update('witch')           # add elements from another iterable
        >>> d = Counter('watch')
        >>> c.update(d)                 # add elements from another counter
        >>> c['h']                      # four 'h' in which, witch, and watch
        4

        '''
        # The regular dict.update() operation makes no sense here because the
        # replace behavior results in the some of original untouched counts
        # being mixed-in with all of the other counts for a mismash that
        # doesn't have a straight-forward interpretation in most counting
495 496
        # contexts.  Instead, we implement straight-addition.  Both the inputs
        # and outputs are allowed to contain zero and negative counts.
497 498 499

        if iterable is not None:
            if isinstance(iterable, Mapping):
500
                if self:
501
                    self_get = self.get
502
                    for elem, count in iterable.items():
503
                        self[elem] = count + self_get(elem, 0)
504
                else:
505
                    super().update(iterable) # fast path when counter is empty
506
            else:
507
                _count_elements(self, iterable)
508 509
        if kwds:
            self.update(kwds)
510

511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
    def subtract(self, iterable=None, **kwds):
        '''Like dict.update() but subtracts counts instead of replacing them.
        Counts can be reduced below zero.  Both the inputs and outputs are
        allowed to contain zero and negative counts.

        Source can be an iterable, a dictionary, or another Counter instance.

        >>> c = Counter('which')
        >>> c.subtract('witch')             # subtract elements from another iterable
        >>> c.subtract(Counter('watch'))    # subtract elements from another counter
        >>> c['h']                          # 2 in which, minus 1 in witch, minus 1 in watch
        0
        >>> c['w']                          # 1 in which, minus 1 in witch, minus 1 in watch
        -1

        '''
        if iterable is not None:
Raymond Hettinger's avatar
Raymond Hettinger committed
528
            self_get = self.get
529 530 531 532 533 534 535 536 537
            if isinstance(iterable, Mapping):
                for elem, count in iterable.items():
                    self[elem] = self_get(elem, 0) - count
            else:
                for elem in iterable:
                    self[elem] = self_get(elem, 0) - 1
        if kwds:
            self.subtract(kwds)

538 539 540 541
    def copy(self):
        'Like dict.copy() but returns a Counter instance instead of a dict.'
        return Counter(self)

542 543 544
    def __reduce__(self):
        return self.__class__, (dict(self),)

545 546 547
    def __delitem__(self, elem):
        'Like dict.__delitem__() but does not raise KeyError for missing values.'
        if elem in self:
548
            super().__delitem__(elem)
549

550 551 552 553 554 555
    def __repr__(self):
        if not self:
            return '%s()' % self.__class__.__name__
        items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
        return '%s({%s})' % (self.__class__.__name__, items)

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
    # Multiset-style mathematical operations discussed in:
    #       Knuth TAOCP Volume II section 4.6.3 exercise 19
    #       and at http://en.wikipedia.org/wiki/Multiset
    #
    # Outputs guaranteed to only include positive counts.
    #
    # To strip negative and zero counts, add-in an empty counter:
    #       c += Counter()

    def __add__(self, other):
        '''Add counts from two counters.

        >>> Counter('abbb') + Counter('bcc')
        Counter({'b': 4, 'c': 2, 'a': 1})

        '''
        if not isinstance(other, Counter):
            return NotImplemented
        result = Counter()
        for elem in set(self) | set(other):
            newcount = self[elem] + other[elem]
            if newcount > 0:
                result[elem] = newcount
        return result

    def __sub__(self, other):
        ''' Subtract count, but keep only results with positive counts.

        >>> Counter('abbbc') - Counter('bccd')
        Counter({'b': 2, 'a': 1})

        '''
        if not isinstance(other, Counter):
            return NotImplemented
        result = Counter()
591 592
        for elem in set(self) | set(other):
            newcount = self[elem] - other[elem]
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
            if newcount > 0:
                result[elem] = newcount
        return result

    def __or__(self, other):
        '''Union is the maximum of value in either of the input counters.

        >>> Counter('abbb') | Counter('bcc')
        Counter({'b': 3, 'c': 2, 'a': 1})

        '''
        if not isinstance(other, Counter):
            return NotImplemented
        result = Counter()
        for elem in set(self) | set(other):
608 609
            p, q = self[elem], other[elem]
            newcount = q if p < q else p
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
            if newcount > 0:
                result[elem] = newcount
        return result

    def __and__(self, other):
        ''' Intersection is the minimum of corresponding counts.

        >>> Counter('abbb') & Counter('bcc')
        Counter({'b': 1})

        '''
        if not isinstance(other, Counter):
            return NotImplemented
        result = Counter()
        if len(self) < len(other):
            self, other = other, self
        for elem in filter(self.__contains__, other):
627 628
            p, q = self[elem], other[elem]
            newcount = p if p < q else q
629 630 631 632
            if newcount > 0:
                result[elem] = newcount
        return result

633

634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
########################################################################
###  ChainMap (helper for configparser)
########################################################################

class _ChainMap(MutableMapping):
    ''' A ChainMap groups multiple dicts (or other mappings) together
    to create a single, updateable view.

    The underlying mappings are stored in a list.  That list is public and can
    accessed or updated using the *maps* attribute.  There is no other state.

    Lookups search the underlying mappings successively until a key is found.
    In contrast, writes, updates, and deletions only operate on the first
    mapping.

    '''

    def __init__(self, *maps):
        '''Initialize a ChainMap by setting *maps* to the given mappings.
        If no mappings are provided, a single empty dictionary is used.

        '''
        self.maps = list(maps) or [{}]          # always at least one map

    def __missing__(self, key):
        raise KeyError(key)

    def __getitem__(self, key):
        for mapping in self.maps:
            try:
                return mapping[key]             # can't use 'key in mapping' with defaultdict
            except KeyError:
                pass
        return self.__missing__(key)            # support subclasses that define __missing__

    def get(self, key, default=None):
        return self[key] if key in self else default

    def __len__(self):
        return len(set().union(*self.maps))     # reuses stored hash values if possible

    def __iter__(self):
        return iter(set().union(*self.maps))

    def __contains__(self, key):
        return any(key in m for m in self.maps)

    @_recursive_repr()
    def __repr__(self):
        return '{0.__class__.__name__}({1})'.format(
            self, ', '.join(map(repr, self.maps)))

    @classmethod
    def fromkeys(cls, iterable, *args):
        'Create a ChainMap with a single dict created from the iterable.'
        return cls(dict.fromkeys(iterable, *args))

    def copy(self):
        'New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]'
        return self.__class__(self.maps[0].copy(), *self.maps[1:])

    __copy__ = copy

697 698 699 700 701 702 703 704 705
    def new_child(self):                        # like Django's Context.push()
        'New ChainMap with a new dict followed by all previous maps.'
        return self.__class__({}, *self.maps)

    @property
    def parents(self):                          # like Django's Context.pop()
        'New ChainMap from maps[1:].'
        return self.__class__(*self.maps[1:])

706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
    def __setitem__(self, key, value):
        self.maps[0][key] = value

    def __delitem__(self, key):
        try:
            del self.maps[0][key]
        except KeyError:
            raise KeyError('Key not found in the first mapping: {!r}'.format(key))

    def popitem(self):
        'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
        try:
            return self.maps[0].popitem()
        except KeyError:
            raise KeyError('No keys found in the first mapping.')

    def pop(self, key, *args):
        'Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0].'
        try:
            return self.maps[0].pop(key, *args)
        except KeyError:
            raise KeyError('Key not found in the first mapping: {!r}'.format(key))

    def clear(self):
        'Clear maps[0], leaving maps[1:] intact.'
        self.maps[0].clear()


734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
################################################################################
### UserDict
################################################################################

class UserDict(MutableMapping):

    # Start by filling-out the abstract methods
    def __init__(self, dict=None, **kwargs):
        self.data = {}
        if dict is not None:
            self.update(dict)
        if len(kwargs):
            self.update(kwargs)
    def __len__(self): return len(self.data)
    def __getitem__(self, key):
        if key in self.data:
            return self.data[key]
        if hasattr(self.__class__, "__missing__"):
            return self.__class__.__missing__(self, key)
        raise KeyError(key)
    def __setitem__(self, key, item): self.data[key] = item
    def __delitem__(self, key): del self.data[key]
    def __iter__(self):
        return iter(self.data)

759 760 761
    # Modify __contains__ to work correctly when __missing__ is present
    def __contains__(self, key):
        return key in self.data
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784

    # Now, add the methods in dicts but not in MutableMapping
    def __repr__(self): return repr(self.data)
    def copy(self):
        if self.__class__ is UserDict:
            return UserDict(self.data.copy())
        import copy
        data = self.data
        try:
            self.data = {}
            c = copy.copy(self)
        finally:
            self.data = data
        c.update(self)
        return c
    @classmethod
    def fromkeys(cls, iterable, value=None):
        d = cls()
        for key in iterable:
            d[key] = value
        return d


785

786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
################################################################################
### UserList
################################################################################

class UserList(MutableSequence):
    """A more or less complete user-defined wrapper around list objects."""
    def __init__(self, initlist=None):
        self.data = []
        if initlist is not None:
            # XXX should this accept an arbitrary sequence?
            if type(initlist) == type(self.data):
                self.data[:] = initlist
            elif isinstance(initlist, UserList):
                self.data[:] = initlist.data[:]
            else:
                self.data = list(initlist)
    def __repr__(self): return repr(self.data)
    def __lt__(self, other): return self.data <  self.__cast(other)
    def __le__(self, other): return self.data <= self.__cast(other)
    def __eq__(self, other): return self.data == self.__cast(other)
    def __ne__(self, other): return self.data != self.__cast(other)
    def __gt__(self, other): return self.data >  self.__cast(other)
    def __ge__(self, other): return self.data >= self.__cast(other)
    def __cast(self, other):
        return other.data if isinstance(other, UserList) else other
    def __contains__(self, item): return item in self.data
    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i]
    def __setitem__(self, i, item): self.data[i] = item
    def __delitem__(self, i): del self.data[i]
    def __add__(self, other):
        if isinstance(other, UserList):
            return self.__class__(self.data + other.data)
        elif isinstance(other, type(self.data)):
            return self.__class__(self.data + other)
        return self.__class__(self.data + list(other))
    def __radd__(self, other):
        if isinstance(other, UserList):
            return self.__class__(other.data + self.data)
        elif isinstance(other, type(self.data)):
            return self.__class__(other + self.data)
        return self.__class__(list(other) + self.data)
    def __iadd__(self, other):
        if isinstance(other, UserList):
            self.data += other.data
        elif isinstance(other, type(self.data)):
            self.data += other
        else:
            self.data += list(other)
        return self
    def __mul__(self, n):
        return self.__class__(self.data*n)
    __rmul__ = __mul__
    def __imul__(self, n):
        self.data *= n
        return self
    def append(self, item): self.data.append(item)
    def insert(self, i, item): self.data.insert(i, item)
    def pop(self, i=-1): return self.data.pop(i)
    def remove(self, item): self.data.remove(item)
    def count(self, item): return self.data.count(item)
    def index(self, item, *args): return self.data.index(item, *args)
    def reverse(self): self.data.reverse()
    def sort(self, *args, **kwds): self.data.sort(*args, **kwds)
    def extend(self, other):
        if isinstance(other, UserList):
            self.data.extend(other.data)
        else:
            self.data.extend(other)



858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
################################################################################
### UserString
################################################################################

class UserString(Sequence):
    def __init__(self, seq):
        if isinstance(seq, str):
            self.data = seq
        elif isinstance(seq, UserString):
            self.data = seq.data[:]
        else:
            self.data = str(seq)
    def __str__(self): return str(self.data)
    def __repr__(self): return repr(self.data)
    def __int__(self): return int(self.data)
    def __float__(self): return float(self.data)
    def __complex__(self): return complex(self.data)
    def __hash__(self): return hash(self.data)

    def __eq__(self, string):
        if isinstance(string, UserString):
            return self.data == string.data
        return self.data == string
    def __ne__(self, string):
        if isinstance(string, UserString):
            return self.data != string.data
        return self.data != string
    def __lt__(self, string):
        if isinstance(string, UserString):
            return self.data < string.data
        return self.data < string
    def __le__(self, string):
        if isinstance(string, UserString):
            return self.data <= string.data
        return self.data <= string
    def __gt__(self, string):
        if isinstance(string, UserString):
            return self.data > string.data
        return self.data > string
    def __ge__(self, string):
        if isinstance(string, UserString):
            return self.data >= string.data
        return self.data >= string

    def __contains__(self, char):
        if isinstance(char, UserString):
            char = char.data
        return char in self.data

    def __len__(self): return len(self.data)
    def __getitem__(self, index): return self.__class__(self.data[index])
    def __add__(self, other):
        if isinstance(other, UserString):
            return self.__class__(self.data + other.data)
        elif isinstance(other, str):
            return self.__class__(self.data + other)
        return self.__class__(self.data + str(other))
    def __radd__(self, other):
        if isinstance(other, str):
            return self.__class__(other + self.data)
        return self.__class__(str(other) + self.data)
    def __mul__(self, n):
        return self.__class__(self.data*n)
    __rmul__ = __mul__
    def __mod__(self, args):
        return self.__class__(self.data % args)

    # the following methods are defined in alphabetical order:
    def capitalize(self): return self.__class__(self.data.capitalize())
    def center(self, width, *args):
        return self.__class__(self.data.center(width, *args))
    def count(self, sub, start=0, end=_sys.maxsize):
        if isinstance(sub, UserString):
            sub = sub.data
        return self.data.count(sub, start, end)
    def encode(self, encoding=None, errors=None): # XXX improve this?
        if encoding:
            if errors:
                return self.__class__(self.data.encode(encoding, errors))
            return self.__class__(self.data.encode(encoding))
        return self.__class__(self.data.encode())
    def endswith(self, suffix, start=0, end=_sys.maxsize):
        return self.data.endswith(suffix, start, end)
    def expandtabs(self, tabsize=8):
        return self.__class__(self.data.expandtabs(tabsize))
    def find(self, sub, start=0, end=_sys.maxsize):
        if isinstance(sub, UserString):
            sub = sub.data
        return self.data.find(sub, start, end)
    def format(self, *args, **kwds):
        return self.data.format(*args, **kwds)
    def index(self, sub, start=0, end=_sys.maxsize):
        return self.data.index(sub, start, end)
    def isalpha(self): return self.data.isalpha()
    def isalnum(self): return self.data.isalnum()
    def isdecimal(self): return self.data.isdecimal()
    def isdigit(self): return self.data.isdigit()
    def isidentifier(self): return self.data.isidentifier()
    def islower(self): return self.data.islower()
    def isnumeric(self): return self.data.isnumeric()
    def isspace(self): return self.data.isspace()
    def istitle(self): return self.data.istitle()
    def isupper(self): return self.data.isupper()
    def join(self, seq): return self.data.join(seq)
    def ljust(self, width, *args):
        return self.__class__(self.data.ljust(width, *args))
    def lower(self): return self.__class__(self.data.lower())
    def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars))
    def partition(self, sep):
        return self.data.partition(sep)
    def replace(self, old, new, maxsplit=-1):
        if isinstance(old, UserString):
            old = old.data
        if isinstance(new, UserString):
            new = new.data
        return self.__class__(self.data.replace(old, new, maxsplit))
    def rfind(self, sub, start=0, end=_sys.maxsize):
975 976
        if isinstance(sub, UserString):
            sub = sub.data
977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002
        return self.data.rfind(sub, start, end)
    def rindex(self, sub, start=0, end=_sys.maxsize):
        return self.data.rindex(sub, start, end)
    def rjust(self, width, *args):
        return self.__class__(self.data.rjust(width, *args))
    def rpartition(self, sep):
        return self.data.rpartition(sep)
    def rstrip(self, chars=None):
        return self.__class__(self.data.rstrip(chars))
    def split(self, sep=None, maxsplit=-1):
        return self.data.split(sep, maxsplit)
    def rsplit(self, sep=None, maxsplit=-1):
        return self.data.rsplit(sep, maxsplit)
    def splitlines(self, keepends=0): return self.data.splitlines(keepends)
    def startswith(self, prefix, start=0, end=_sys.maxsize):
        return self.data.startswith(prefix, start, end)
    def strip(self, chars=None): return self.__class__(self.data.strip(chars))
    def swapcase(self): return self.__class__(self.data.swapcase())
    def title(self): return self.__class__(self.data.title())
    def translate(self, *args):
        return self.__class__(self.data.translate(*args))
    def upper(self): return self.__class__(self.data.upper())
    def zfill(self, width): return self.__class__(self.data.zfill(width))



1003 1004 1005
################################################################################
### Simple tests
################################################################################
1006 1007

if __name__ == '__main__':
1008
    # verify that instances can be pickled
1009
    from pickle import loads, dumps
1010
    Point = namedtuple('Point', 'x, y', True)
1011 1012 1013
    p = Point(x=10, y=20)
    assert p == loads(dumps(p))

1014
    # test and demonstrate ability to override methods
1015
    class Point(namedtuple('Point', 'x y')):
1016
        __slots__ = ()
1017 1018 1019
        @property
        def hypot(self):
            return (self.x ** 2 + self.y ** 2) ** 0.5
1020
        def __str__(self):
1021
            return 'Point: x=%6.3f  y=%6.3f  hypot=%6.3f' % (self.x, self.y, self.hypot)
1022

1023
    for p in Point(3, 4), Point(14, 5/7.):
1024
        print (p)
1025 1026 1027

    class Point(namedtuple('Point', 'x y')):
        'Point class with optimized _make() and _replace() without error-checking'
1028
        __slots__ = ()
1029 1030
        _make = classmethod(tuple.__new__)
        def _replace(self, _map=map, **kwds):
1031
            return self._make(_map(kwds.get, ('x', 'y'), self))
1032 1033

    print(Point(11, 22)._replace(x=100))
1034

1035 1036 1037
    Point3D = namedtuple('Point3D', Point._fields + ('z',))
    print(Point3D.__doc__)

1038
    import doctest
1039
    TestResults = namedtuple('TestResults', 'failed attempted')
1040
    print(TestResults(*doctest.testmod()))