weakref.py 20 KB
Newer Older
1 2 3 4
"""Weak reference support for Python.

This module is an implementation of PEP 205:

Christian Heimes's avatar
Christian Heimes committed
5
http://www.python.org/dev/peps/pep-0205/
6 7
"""

8 9 10 11
# Naming convention: Variables named "wr" are weak reference objects;
# they are called this instead of "ref" to avoid name collisions with
# the module-global ref() function imported from _weakref.

12 13 14 15 16 17 18
from _weakref import (
     getweakrefcount,
     getweakrefs,
     ref,
     proxy,
     CallableProxyType,
     ProxyType,
19 20
     ReferenceType,
     _remove_dead_weakref)
21

22
from _weakrefset import WeakSet, _IterationGuard
23

24
import collections  # Import after _weakref to avoid circular import.
25 26
import sys
import itertools
27

28 29
ProxyTypes = (ProxyType, CallableProxyType)

30
__all__ = ["ref", "proxy", "getweakrefcount", "getweakrefs",
31
           "WeakKeyDictionary", "ReferenceType", "ProxyType",
32
           "CallableProxyType", "ProxyTypes", "WeakValueDictionary",
33
           "WeakSet", "WeakMethod", "finalize"]
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87


class WeakMethod(ref):
    """
    A custom `weakref.ref` subclass which simulates a weak reference to
    a bound method, working around the lifetime problem of bound methods.
    """

    __slots__ = "_func_ref", "_meth_type", "_alive", "__weakref__"

    def __new__(cls, meth, callback=None):
        try:
            obj = meth.__self__
            func = meth.__func__
        except AttributeError:
            raise TypeError("argument should be a bound method, not {}"
                            .format(type(meth))) from None
        def _cb(arg):
            # The self-weakref trick is needed to avoid creating a reference
            # cycle.
            self = self_wr()
            if self._alive:
                self._alive = False
                if callback is not None:
                    callback(self)
        self = ref.__new__(cls, obj, _cb)
        self._func_ref = ref(func, _cb)
        self._meth_type = type(meth)
        self._alive = True
        self_wr = ref(self)
        return self

    def __call__(self):
        obj = super().__call__()
        func = self._func_ref()
        if obj is None or func is None:
            return None
        return self._meth_type(func, obj)

    def __eq__(self, other):
        if isinstance(other, WeakMethod):
            if not self._alive or not other._alive:
                return self is other
            return ref.__eq__(self, other) and self._func_ref == other._func_ref
        return False

    def __ne__(self, other):
        if isinstance(other, WeakMethod):
            if not self._alive or not other._alive:
                return self is not other
            return ref.__ne__(self, other) or self._func_ref != other._func_ref
        return True

    __hash__ = ref.__hash__
88

89

90
class WeakValueDictionary(collections.MutableMapping):
91
    """Mapping class that references values weakly.
92

93 94 95
    Entries in the dictionary will be discarded when no strong
    reference to the value exists anymore
    """
96 97
    # We inherit the constructor without worrying about the input
    # dictionary; since it uses our .update() method, we get the right
98 99 100
    # checks (if the other dictionary is a WeakValueDictionary,
    # objects are unwrapped on the way out, and we always wrap on the
    # way in).
101

102 103 104 105 106 107 108
    def __init__(*args, **kw):
        if not args:
            raise TypeError("descriptor '__init__' of 'WeakValueDictionary' "
                            "object needs an argument")
        self, *args = args
        if len(args) > 1:
            raise TypeError('expected at most 1 arguments, got %d' % len(args))
109 110 111
        def remove(wr, selfref=ref(self)):
            self = selfref()
            if self is not None:
112 113 114
                if self._iterating:
                    self._pending_removals.append(wr.key)
                else:
115 116 117
                    # Atomic removal is necessary since this function
                    # can be called asynchronously by the GC
                    _remove_dead_weakref(d, wr.key)
118
        self._remove = remove
119 120 121
        # A list of keys to be removed
        self._pending_removals = []
        self._iterating = set()
122
        self.data = d = {}
123
        self.update(*args, **kw)
124

125 126 127 128 129 130
    def _commit_removals(self):
        l = self._pending_removals
        d = self.data
        # We shouldn't encounter any KeyError, because this method should
        # always be called *before* mutating the dict.
        while l:
131 132
            key = l.pop()
            _remove_dead_weakref(d, key)
133

134
    def __getitem__(self, key):
135 136
        if self._pending_removals:
            self._commit_removals()
137
        o = self.data[key]()
138
        if o is None:
139
            raise KeyError(key)
140 141 142
        else:
            return o

143
    def __delitem__(self, key):
144 145
        if self._pending_removals:
            self._commit_removals()
146 147 148
        del self.data[key]

    def __len__(self):
149 150 151
        if self._pending_removals:
            self._commit_removals()
        return len(self.data)
152

153
    def __contains__(self, key):
154 155
        if self._pending_removals:
            self._commit_removals()
156 157 158 159 160 161
        try:
            o = self.data[key]()
        except KeyError:
            return False
        return o is not None

162
    def __repr__(self):
163
        return "<%s at %#x>" % (self.__class__.__name__, id(self))
164 165

    def __setitem__(self, key, value):
166 167
        if self._pending_removals:
            self._commit_removals()
168
        self.data[key] = KeyedRef(value, self._remove, key)
169 170

    def copy(self):
171 172
        if self._pending_removals:
            self._commit_removals()
173
        new = WeakValueDictionary()
174 175
        for key, wr in self.data.items():
            o = wr()
176 177
            if o is not None:
                new[key] = o
178
        return new
179

180 181 182 183
    __copy__ = copy

    def __deepcopy__(self, memo):
        from copy import deepcopy
184 185
        if self._pending_removals:
            self._commit_removals()
186 187 188 189 190 191 192
        new = self.__class__()
        for key, wr in self.data.items():
            o = wr()
            if o is not None:
                new[deepcopy(key, memo)] = o
        return new

193
    def get(self, key, default=None):
194 195
        if self._pending_removals:
            self._commit_removals()
196
        try:
197
            wr = self.data[key]
198 199 200
        except KeyError:
            return default
        else:
201
            o = wr()
202 203 204 205 206 207 208
            if o is None:
                # This should only happen
                return default
            else:
                return o

    def items(self):
209 210
        if self._pending_removals:
            self._commit_removals()
211 212 213 214 215
        with _IterationGuard(self):
            for k, wr in self.data.items():
                v = wr()
                if v is not None:
                    yield k, v
216

217
    def keys(self):
218 219
        if self._pending_removals:
            self._commit_removals()
220 221 222 223
        with _IterationGuard(self):
            for k, wr in self.data.items():
                if wr() is not None:
                    yield k
224

225
    __iter__ = keys
226

227 228 229 230 231 232 233 234 235 236
    def itervaluerefs(self):
        """Return an iterator that yields the weak references to the values.

        The references are not guaranteed to be 'live' at the time
        they are used, so the result of calling the references needs
        to be checked before being used.  This can be used to avoid
        creating references that will cause the garbage collector to
        keep the values around longer than needed.

        """
237 238
        if self._pending_removals:
            self._commit_removals()
239
        with _IterationGuard(self):
Philip Jenvey's avatar
Philip Jenvey committed
240
            yield from self.data.values()
241

242
    def values(self):
243 244
        if self._pending_removals:
            self._commit_removals()
245 246 247 248 249
        with _IterationGuard(self):
            for wr in self.data.values():
                obj = wr()
                if obj is not None:
                    yield obj
250

251
    def popitem(self):
252 253
        if self._pending_removals:
            self._commit_removals()
Georg Brandl's avatar
Georg Brandl committed
254
        while True:
255 256
            key, wr = self.data.popitem()
            o = wr()
257 258 259
            if o is not None:
                return key, o

260
    def pop(self, key, *args):
261 262
        if self._pending_removals:
            self._commit_removals()
263 264 265
        try:
            o = self.data.pop(key)()
        except KeyError:
266 267
            o = None
        if o is None:
268 269
            if args:
                return args[0]
270 271
            else:
                raise KeyError(key)
272 273 274
        else:
            return o

275
    def setdefault(self, key, default=None):
276
        try:
277
            o = self.data[key]()
278
        except KeyError:
279 280
            o = None
        if o is None:
281 282
            if self._pending_removals:
                self._commit_removals()
283
            self.data[key] = KeyedRef(default, self._remove, key)
284 285
            return default
        else:
286
            return o
287

288 289 290 291 292 293 294 295
    def update(*args, **kwargs):
        if not args:
            raise TypeError("descriptor 'update' of 'WeakValueDictionary' "
                            "object needs an argument")
        self, *args = args
        if len(args) > 1:
            raise TypeError('expected at most 1 arguments, got %d' % len(args))
        dict = args[0] if args else None
296 297
        if self._pending_removals:
            self._commit_removals()
298
        d = self.data
299 300 301 302
        if dict is not None:
            if not hasattr(dict, "items"):
                dict = type({})(dict)
            for key, o in dict.items():
303
                d[key] = KeyedRef(o, self._remove, key)
304 305
        if len(kwargs):
            self.update(kwargs)
306

307 308 309 310 311 312 313 314 315 316
    def valuerefs(self):
        """Return a list of weak references to the values.

        The references are not guaranteed to be 'live' at the time
        they are used, so the result of calling the references needs
        to be checked before being used.  This can be used to avoid
        creating references that will cause the garbage collector to
        keep the values around longer than needed.

        """
317 318
        if self._pending_removals:
            self._commit_removals()
319
        return list(self.data.values())
320

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

class KeyedRef(ref):
    """Specialized reference that includes a key corresponding to the value.

    This is used in the WeakValueDictionary to avoid having to create
    a function object for each key stored in the mapping.  A shared
    callback object can use the 'key' attribute of a KeyedRef instead
    of getting a reference to the key from an enclosing scope.

    """

    __slots__ = "key",

    def __new__(type, ob, callback, key):
        self = ref.__new__(type, ob, callback)
        self.key = key
        return self

    def __init__(self, ob, callback, key):
340
        super().__init__(ob, callback)
341

342

343
class WeakKeyDictionary(collections.MutableMapping):
344 345 346 347 348 349 350 351 352
    """ Mapping class that references keys weakly.

    Entries in the dictionary will be discarded when there is no
    longer a strong reference to the key. This can be used to
    associate additional data with an object owned by other parts of
    an application without adding attributes to those objects. This
    can be especially useful with objects that override attribute
    accesses.
    """
353 354 355

    def __init__(self, dict=None):
        self.data = {}
356 357 358
        def remove(k, selfref=ref(self)):
            self = selfref()
            if self is not None:
359 360 361 362
                if self._iterating:
                    self._pending_removals.append(k)
                else:
                    del self.data[k]
363
        self._remove = remove
364 365 366
        # A list of dead weakrefs (keys to be removed)
        self._pending_removals = []
        self._iterating = set()
367
        self._dirty_len = False
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
        if dict is not None:
            self.update(dict)

    def _commit_removals(self):
        # NOTE: We don't need to call this method before mutating the dict,
        # because a dead weakref never compares equal to a live weakref,
        # even if they happened to refer to equal objects.
        # However, it means keys may already have been removed.
        l = self._pending_removals
        d = self.data
        while l:
            try:
                del d[l.pop()]
            except KeyError:
                pass
383

384 385 386 387 388
    def _scrub_removals(self):
        d = self.data
        self._pending_removals = [k for k in self._pending_removals if k in d]
        self._dirty_len = False

389
    def __delitem__(self, key):
390
        self._dirty_len = True
391
        del self.data[ref(key)]
392

393 394 395
    def __getitem__(self, key):
        return self.data[ref(key)]

396
    def __len__(self):
397 398 399 400
        if self._dirty_len and self._pending_removals:
            # self._pending_removals may still contain keys which were
            # explicitly removed, we have to scrub them (see issue #21173).
            self._scrub_removals()
401
        return len(self.data) - len(self._pending_removals)
402

403
    def __repr__(self):
404
        return "<%s at %#x>" % (self.__class__.__name__, id(self))
405 406 407 408 409 410 411 412 413 414

    def __setitem__(self, key, value):
        self.data[ref(key, self._remove)] = value

    def copy(self):
        new = WeakKeyDictionary()
        for key, value in self.data.items():
            o = key()
            if o is not None:
                new[o] = value
415
        return new
416

417 418 419 420 421 422 423 424 425 426 427
    __copy__ = copy

    def __deepcopy__(self, memo):
        from copy import deepcopy
        new = self.__class__()
        for key, value in self.data.items():
            o = key()
            if o is not None:
                new[o] = deepcopy(value, memo)
        return new

428
    def get(self, key, default=None):
429 430
        return self.data.get(ref(key),default)

431 432 433 434
    def __contains__(self, key):
        try:
            wr = ref(key)
        except TypeError:
Georg Brandl's avatar
Georg Brandl committed
435
            return False
436
        return wr in self.data
Tim Peters's avatar
Tim Peters committed
437

438
    def items(self):
439 440 441 442 443
        with _IterationGuard(self):
            for wr, value in self.data.items():
                key = wr()
                if key is not None:
                    yield key, value
444

445
    def keys(self):
446 447 448 449 450
        with _IterationGuard(self):
            for wr in self.data:
                obj = wr()
                if obj is not None:
                    yield obj
451

452
    __iter__ = keys
453

454
    def values(self):
455 456 457 458
        with _IterationGuard(self):
            for wr, value in self.data.items():
                if wr() is not None:
                    yield value
459

460 461 462 463 464 465 466 467 468 469
    def keyrefs(self):
        """Return a list of weak references to the keys.

        The references are not guaranteed to be 'live' at the time
        they are used, so the result of calling the references needs
        to be checked before being used.  This can be used to avoid
        creating references that will cause the garbage collector to
        keep the keys around longer than needed.

        """
470
        return list(self.data)
471

472
    def popitem(self):
473
        self._dirty_len = True
Georg Brandl's avatar
Georg Brandl committed
474
        while True:
475 476 477 478 479
            key, value = self.data.popitem()
            o = key()
            if o is not None:
                return o, value

480
    def pop(self, key, *args):
481
        self._dirty_len = True
482 483
        return self.data.pop(ref(key), *args)

484
    def setdefault(self, key, default=None):
485 486
        return self.data.setdefault(ref(key, self._remove),default)

487
    def update(self, dict=None, **kwargs):
488
        d = self.data
489 490 491 492 493 494 495
        if dict is not None:
            if not hasattr(dict, "items"):
                dict = type({})(dict)
            for key, value in dict.items():
                d[ref(key, self._remove)] = value
        if len(kwargs):
            self.update(kwargs)
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520


class finalize:
    """Class for finalization of weakrefable objects

    finalize(obj, func, *args, **kwargs) returns a callable finalizer
    object which will be called when obj is garbage collected. The
    first time the finalizer is called it evaluates func(*arg, **kwargs)
    and returns the result. After this the finalizer is dead, and
    calling it just returns None.

    When the program exits any remaining finalizers for which the
    atexit attribute is true will be run in reverse order of creation.
    By default atexit is true.
    """

    # Finalizer objects don't have any state of their own.  They are
    # just used as keys to lookup _Info objects in the registry.  This
    # ensures that they cannot be part of a ref-cycle.

    __slots__ = ()
    _registry = {}
    _shutdown = False
    _index_iter = itertools.count()
    _dirty = False
521
    _registered_with_atexit = False
522 523 524 525 526

    class _Info:
        __slots__ = ("weakref", "func", "args", "kwargs", "atexit", "index")

    def __init__(self, obj, func, *args, **kwargs):
527 528 529 530 531 532
        if not self._registered_with_atexit:
            # We may register the exit function more than once because
            # of a thread race, but that is harmless
            import atexit
            atexit.register(self._exitfunc)
            finalize._registered_with_atexit = True
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
        info = self._Info()
        info.weakref = ref(obj, self)
        info.func = func
        info.args = args
        info.kwargs = kwargs or None
        info.atexit = True
        info.index = next(self._index_iter)
        self._registry[self] = info
        finalize._dirty = True

    def __call__(self, _=None):
        """If alive then mark as dead and return func(*args, **kwargs);
        otherwise return None"""
        info = self._registry.pop(self, None)
        if info and not self._shutdown:
            return info.func(*info.args, **(info.kwargs or {}))

    def detach(self):
        """If alive then mark as dead and return (obj, func, args, kwargs);
        otherwise return None"""
        info = self._registry.get(self)
        obj = info and info.weakref()
        if obj is not None and self._registry.pop(self, None):
            return (obj, info.func, info.args, info.kwargs or {})

    def peek(self):
        """If alive then return (obj, func, args, kwargs);
        otherwise return None"""
        info = self._registry.get(self)
        obj = info and info.weakref()
        if obj is not None:
            return (obj, info.func, info.args, info.kwargs or {})

    @property
    def alive(self):
        """Whether finalizer is alive"""
        return self in self._registry

    @property
    def atexit(self):
        """Whether finalizer should be called at exit"""
        info = self._registry.get(self)
        return bool(info) and info.atexit

    @atexit.setter
    def atexit(self, value):
        info = self._registry.get(self)
        if info:
            info.atexit = bool(value)

    def __repr__(self):
        info = self._registry.get(self)
        obj = info and info.weakref()
        if obj is None:
            return '<%s object at %#x; dead>' % (type(self).__name__, id(self))
        else:
            return '<%s object at %#x; for %r at %#x>' % \
                (type(self).__name__, id(self), type(obj).__name__, id(obj))

    @classmethod
    def _select_for_exit(cls):
        # Return live finalizers marked for exit, oldest first
        L = [(f,i) for (f,i) in cls._registry.items() if i.atexit]
        L.sort(key=lambda item:item[1].index)
        return [f for (f,i) in L]

    @classmethod
    def _exitfunc(cls):
        # At shutdown invoke finalizers for which atexit is true.
        # This is called once all other non-daemonic threads have been
        # joined.
        reenable_gc = False
        try:
            if cls._registry:
                import gc
                if gc.isenabled():
                    reenable_gc = True
                    gc.disable()
                pending = None
                while True:
                    if pending is None or finalize._dirty:
                        pending = cls._select_for_exit()
                        finalize._dirty = False
                    if not pending:
                        break
                    f = pending.pop()
                    try:
                        # gc is disabled, so (assuming no daemonic
                        # threads) the following is the only line in
                        # this function which might trigger creation
                        # of a new finalizer
                        f()
                    except Exception:
                        sys.excepthook(*sys.exc_info())
                    assert f not in cls._registry
        finally:
            # prevent any more finalizers from executing during shutdown
            finalize._shutdown = True
            if reenable_gc:
                gc.enable()