_weakrefset.py 6.12 KB
Newer Older
1 2 3 4 5 6 7 8
# Access WeakSet through the weakref module.
# This code is separated-out because it is needed
# by abc.py to load everything else at startup.

from _weakref import ref

__all__ = ['WeakSet']

9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34

class _IterationGuard:
    # This context manager registers itself in the current iterators of the
    # weak container, such as to delay all removals until the context manager
    # exits.
    # This technique should be relatively thread-safe (since sets are).

    def __init__(self, weakcontainer):
        # Don't create cycles
        self.weakcontainer = ref(weakcontainer)

    def __enter__(self):
        w = self.weakcontainer()
        if w is not None:
            w._iterating.add(self)
        return self

    def __exit__(self, e, t, b):
        w = self.weakcontainer()
        if w is not None:
            s = w._iterating
            s.remove(self)
            if not s:
                w._commit_removals()


35 36 37 38 39 40
class WeakSet:
    def __init__(self, data=None):
        self.data = set()
        def _remove(item, selfref=ref(self)):
            self = selfref()
            if self is not None:
41 42 43 44
                if self._iterating:
                    self._pending_removals.append(item)
                else:
                    self.data.discard(item)
45
        self._remove = _remove
46 47 48
        # A list of keys to be removed
        self._pending_removals = []
        self._iterating = set()
49 50 51
        if data is not None:
            self.update(data)

52 53 54 55 56 57
    def _commit_removals(self):
        l = self._pending_removals
        discard = self.data.discard
        while l:
            discard(l.pop())

58
    def __iter__(self):
59 60 61 62 63
        with _IterationGuard(self):
            for itemref in self.data:
                item = itemref()
                if item is not None:
                    yield item
64

65 66 67
    def __len__(self):
        return sum(x() is not None for x in self.data)

68 69 70 71 72 73 74 75
    def __contains__(self, item):
        return ref(item) in self.data

    def __reduce__(self):
        return (self.__class__, (list(self),),
                getattr(self, '__dict__', None))

    def add(self, item):
76 77
        if self._pending_removals:
            self._commit_removals()
78 79 80
        self.data.add(ref(item, self._remove))

    def clear(self):
81 82
        if self._pending_removals:
            self._commit_removals()
83 84 85 86 87 88
        self.data.clear()

    def copy(self):
        return self.__class__(self)

    def pop(self):
89 90
        if self._pending_removals:
            self._commit_removals()
91
        while True:
92 93 94 95
            try:
                itemref = self.data.pop()
            except KeyError:
                raise KeyError('pop from empty WeakSet')
96 97 98 99 100
            item = itemref()
            if item is not None:
                return item

    def remove(self, item):
101 102
        if self._pending_removals:
            self._commit_removals()
103 104 105
        self.data.remove(ref(item))

    def discard(self, item):
106 107
        if self._pending_removals:
            self._commit_removals()
108 109 110
        self.data.discard(ref(item))

    def update(self, other):
111 112
        if self._pending_removals:
            self._commit_removals()
113 114 115 116 117
        if isinstance(other, self.__class__):
            self.data.update(other.data)
        else:
            for element in other:
                self.add(element)
118

119 120 121
    def __ior__(self, other):
        self.update(other)
        return self
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

    # Helper functions for simple delegating methods.
    def _apply(self, other, method):
        if not isinstance(other, self.__class__):
            other = self.__class__(other)
        newdata = method(other.data)
        newset = self.__class__()
        newset.data = newdata
        return newset

    def difference(self, other):
        return self._apply(other, self.data.difference)
    __sub__ = difference

    def difference_update(self, other):
137 138
        if self._pending_removals:
            self._commit_removals()
139 140 141 142 143
        if self is other:
            self.data.clear()
        else:
            self.data.difference_update(ref(item) for item in other)
    def __isub__(self, other):
144 145
        if self._pending_removals:
            self._commit_removals()
146 147 148 149 150
        if self is other:
            self.data.clear()
        else:
            self.data.difference_update(ref(item) for item in other)
        return self
151 152 153 154 155 156

    def intersection(self, other):
        return self._apply(other, self.data.intersection)
    __and__ = intersection

    def intersection_update(self, other):
157 158
        if self._pending_removals:
            self._commit_removals()
159 160
        self.data.intersection_update(ref(item) for item in other)
    def __iand__(self, other):
161 162
        if self._pending_removals:
            self._commit_removals()
163 164
        self.data.intersection_update(ref(item) for item in other)
        return self
165 166 167 168 169

    def issubset(self, other):
        return self.data.issubset(ref(item) for item in other)
    __lt__ = issubset

170 171 172
    def __le__(self, other):
        return self.data <= set(ref(item) for item in other)

173 174 175 176
    def issuperset(self, other):
        return self.data.issuperset(ref(item) for item in other)
    __gt__ = issuperset

177 178 179 180
    def __ge__(self, other):
        return self.data >= set(ref(item) for item in other)

    def __eq__(self, other):
181 182
        if not isinstance(other, self.__class__):
            return NotImplemented
183 184
        return self.data == set(ref(item) for item in other)

185 186 187 188 189
    def symmetric_difference(self, other):
        return self._apply(other, self.data.symmetric_difference)
    __xor__ = symmetric_difference

    def symmetric_difference_update(self, other):
190 191
        if self._pending_removals:
            self._commit_removals()
192 193 194 195 196
        if self is other:
            self.data.clear()
        else:
            self.data.symmetric_difference_update(ref(item) for item in other)
    def __ixor__(self, other):
197 198
        if self._pending_removals:
            self._commit_removals()
199 200 201 202 203
        if self is other:
            self.data.clear()
        else:
            self.data.symmetric_difference_update(ref(item) for item in other)
        return self
204 205

    def union(self, other):
206
        return self._apply(other, self.data.union)
207
    __or__ = union
208 209 210

    def isdisjoint(self, other):
        return len(self.intersection(other)) == 0