test_threading_local.py 4.12 KB
Newer Older
1 2
import unittest
from doctest import DocTestSuite
3
from test import support
4
threading = support.import_module('threading')
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
import weakref
import gc

class Weak(object):
    pass

def target(local, weaklist):
    weak = Weak()
    local.weak = weak
    weaklist.append(weakref.ref(weak))

class ThreadingLocalTest(unittest.TestCase):

    def test_local_refs(self):
        self._local_refs(20)
        self._local_refs(50)
        self._local_refs(100)

    def _local_refs(self, n):
        local = threading.local()
        weaklist = []
        for i in range(n):
            t = threading.Thread(target=target, args=(local, weaklist))
            t.start()
            t.join()
        del t

        gc.collect()
        self.assertEqual(len(weaklist), n)

        # XXX threading.local keeps the local of the last stopped thread alive.
        deadlist = [weak for weak in weaklist if weak() is None]
        self.assertEqual(len(deadlist), n-1)

        # Assignment to the same thread local frees it sometimes (!)
        local.someothervar = None
        gc.collect()
        deadlist = [weak for weak in weaklist if weak() is None]
43
        self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
44

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def test_derived(self):
        # Issue 3088: if there is a threads switch inside the __init__
        # of a threading.local derived class, the per-thread dictionary
        # is created but not correctly set on the object.
        # The first member set may be bogus.
        import time
        class Local(threading.local):
            def __init__(self):
                time.sleep(0.01)
        local = Local()

        def f(i):
            local.x = i
            # Simply check that the variable is correctly set
            self.assertEqual(local.x, i)

        threads= []
        for i in range(10):
            t = threading.Thread(target=f, args=(i,))
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    def test_derived_cycle_dealloc(self):
        # http://bugs.python.org/issue6990
        class Local(threading.local):
            pass
        locals = None
        passed = False
        e1 = threading.Event()
        e2 = threading.Event()

        def f():
            nonlocal passed
            # 1) Involve Local in a cycle
            cycle = [Local()]
            cycle.append(cycle)
            cycle[0].foo = 'bar'

            # 2) GC the cycle (triggers threadmodule.c::local_clear
            # before local_dealloc)
            del cycle
            gc.collect()
            e1.set()
            e2.wait()

            # 4) New Locals should be empty
            passed = all(not hasattr(local, 'foo') for local in locals)

        t = threading.Thread(target=f)
        t.start()
        e1.wait()

        # 3) New Locals should recycle the original's address. Creating
        # them in the thread overwrites the thread state and avoids the
        # bug
        locals = [Local() for i in range(10)]
        e2.set()
        t.join()

        self.assertTrue(passed)

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    def test_arguments(self):
        # Issue 1522237
        from _thread import _local as local
        from _threading_local import local as py_local

        for cls in (local, py_local):
            class MyLocal(cls):
                def __init__(self, *args, **kwargs):
                    pass

            MyLocal(a=1)
            MyLocal(1)
            self.assertRaises(TypeError, cls, a=1)
            self.assertRaises(TypeError, cls, 1)

124

125
def test_main():
126 127 128
    suite = unittest.TestSuite()
    suite.addTest(DocTestSuite('_threading_local'))
    suite.addTest(unittest.makeSuite(ThreadingLocalTest))
129 130 131 132 133 134 135 136

    try:
        from thread import _local
    except ImportError:
        pass
    else:
        import _threading_local
        local_orig = _threading_local.local
137
        def setUp(test):
138
            _threading_local.local = _local
139
        def tearDown(test):
140 141 142 143 144
            _threading_local.local = local_orig
        suite.addTest(DocTestSuite('_threading_local',
                                   setUp=setUp, tearDown=tearDown)
                      )

145
    support.run_unittest(suite)
146 147 148

if __name__ == '__main__':
    test_main()