Kaydet (Commit) bd8c629e authored tarafından Serhiy Storchaka's avatar Serhiy Storchaka

Issue #23799: Added test.test_support.start_threads() for running and

cleaning up multiple threads.
üst 2baaba8a
from test import test_support from test import test_support as support
from test.test_support import TESTFN, _4G, bigmemtest, import_module, findfile from test.test_support import TESTFN, _4G, bigmemtest, import_module, findfile
import unittest import unittest
...@@ -306,10 +306,8 @@ class BZ2FileTest(BaseTest): ...@@ -306,10 +306,8 @@ class BZ2FileTest(BaseTest):
for i in range(5): for i in range(5):
f.write(data) f.write(data)
threads = [threading.Thread(target=comp) for i in range(nthreads)] threads = [threading.Thread(target=comp) for i in range(nthreads)]
for t in threads: with support.start_threads(threads):
t.start() pass
for t in threads:
t.join()
def testMixedIterationReads(self): def testMixedIterationReads(self):
# Issue #8397: mixed iteration and reads should be forbidden. # Issue #8397: mixed iteration and reads should be forbidden.
...@@ -482,13 +480,13 @@ class FuncTest(BaseTest): ...@@ -482,13 +480,13 @@ class FuncTest(BaseTest):
self.assertEqual(text.strip("a"), "") self.assertEqual(text.strip("a"), "")
def test_main(): def test_main():
test_support.run_unittest( support.run_unittest(
BZ2FileTest, BZ2FileTest,
BZ2CompressorTest, BZ2CompressorTest,
BZ2DecompressorTest, BZ2DecompressorTest,
FuncTest FuncTest
) )
test_support.reap_children() support.reap_children()
if __name__ == '__main__': if __name__ == '__main__':
test_main() test_main()
......
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import time import time
import random import random
import unittest import unittest
from test import test_support from test import test_support as support
try: try:
import thread import thread
import threading import threading
...@@ -14,7 +14,7 @@ except ImportError: ...@@ -14,7 +14,7 @@ except ImportError:
thread = None thread = None
threading = None threading = None
# Skip this test if the _testcapi module isn't available. # Skip this test if the _testcapi module isn't available.
_testcapi = test_support.import_module('_testcapi') _testcapi = support.import_module('_testcapi')
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
...@@ -42,7 +42,7 @@ class TestPendingCalls(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestPendingCalls(unittest.TestCase):
#this busy loop is where we expect to be interrupted to #this busy loop is where we expect to be interrupted to
#run our callbacks. Note that callbacks are only run on the #run our callbacks. Note that callbacks are only run on the
#main thread #main thread
if False and test_support.verbose: if False and support.verbose:
print "(%i)"%(len(l),), print "(%i)"%(len(l),),
for i in xrange(1000): for i in xrange(1000):
a = i*i a = i*i
...@@ -51,7 +51,7 @@ class TestPendingCalls(unittest.TestCase): ...@@ -51,7 +51,7 @@ class TestPendingCalls(unittest.TestCase):
count += 1 count += 1
self.assertTrue(count < 10000, self.assertTrue(count < 10000,
"timeout waiting for %i callbacks, got %i"%(n, len(l))) "timeout waiting for %i callbacks, got %i"%(n, len(l)))
if False and test_support.verbose: if False and support.verbose:
print "(%i)"%(len(l),) print "(%i)"%(len(l),)
def test_pendingcalls_threaded(self): def test_pendingcalls_threaded(self):
...@@ -67,15 +67,11 @@ class TestPendingCalls(unittest.TestCase): ...@@ -67,15 +67,11 @@ class TestPendingCalls(unittest.TestCase):
context.lock = threading.Lock() context.lock = threading.Lock()
context.event = threading.Event() context.event = threading.Event()
for i in range(context.nThreads): threads = [threading.Thread(target=self.pendingcalls_thread,
t = threading.Thread(target=self.pendingcalls_thread, args = (context,)) args=(context,))
t.start() for i in range(context.nThreads)]
threads.append(t) with support.start_threads(threads):
self.pendingcalls_wait(context.l, n, context)
self.pendingcalls_wait(context.l, n, context)
for t in threads:
t.join()
def pendingcalls_thread(self, context): def pendingcalls_thread(self, context):
try: try:
...@@ -84,7 +80,7 @@ class TestPendingCalls(unittest.TestCase): ...@@ -84,7 +80,7 @@ class TestPendingCalls(unittest.TestCase):
with context.lock: with context.lock:
context.nFinished += 1 context.nFinished += 1
nFinished = context.nFinished nFinished = context.nFinished
if False and test_support.verbose: if False and support.verbose:
print "finished threads: ", nFinished print "finished threads: ", nFinished
if nFinished == context.nThreads: if nFinished == context.nThreads:
context.event.set() context.event.set()
...@@ -103,7 +99,7 @@ class TestPendingCalls(unittest.TestCase): ...@@ -103,7 +99,7 @@ class TestPendingCalls(unittest.TestCase):
@unittest.skipUnless(threading and thread, 'Threading required for this test.') @unittest.skipUnless(threading and thread, 'Threading required for this test.')
class TestThreadState(unittest.TestCase): class TestThreadState(unittest.TestCase):
@test_support.reap_threads @support.reap_threads
def test_thread_state(self): def test_thread_state(self):
# some extra thread-state tests driven via _testcapi # some extra thread-state tests driven via _testcapi
def target(): def target():
...@@ -129,14 +125,14 @@ def test_main(): ...@@ -129,14 +125,14 @@ def test_main():
for name in dir(_testcapi): for name in dir(_testcapi):
if name.startswith('test_'): if name.startswith('test_'):
test = getattr(_testcapi, name) test = getattr(_testcapi, name)
if test_support.verbose: if support.verbose:
print "internal", name print "internal", name
try: try:
test() test()
except _testcapi.error: except _testcapi.error:
raise test_support.TestFailed, sys.exc_info()[1] raise support.TestFailed, sys.exc_info()[1]
test_support.run_unittest(TestPendingCalls, TestThreadState) support.run_unittest(TestPendingCalls, TestThreadState)
if __name__ == "__main__": if __name__ == "__main__":
test_main() test_main()
import unittest import unittest
from test.test_support import verbose, run_unittest from test.test_support import verbose, run_unittest, start_threads
import sys import sys
import time import time
import gc import gc
...@@ -352,19 +352,13 @@ class GCTests(unittest.TestCase): ...@@ -352,19 +352,13 @@ class GCTests(unittest.TestCase):
old_checkinterval = sys.getcheckinterval() old_checkinterval = sys.getcheckinterval()
sys.setcheckinterval(3) sys.setcheckinterval(3)
try: try:
exit = False exit = []
threads = [] threads = []
for i in range(N_THREADS): for i in range(N_THREADS):
t = threading.Thread(target=run_thread) t = threading.Thread(target=run_thread)
threads.append(t) threads.append(t)
try: with start_threads(threads, lambda: exit.append(1)):
for t in threads:
t.start()
finally:
time.sleep(1.0) time.sleep(1.0)
exit = True
for t in threads:
t.join()
finally: finally:
sys.setcheckinterval(old_checkinterval) sys.setcheckinterval(old_checkinterval)
gc.collect() gc.collect()
......
...@@ -985,11 +985,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): ...@@ -985,11 +985,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
errors.append(e) errors.append(e)
raise raise
threads = [threading.Thread(target=f) for x in range(20)] threads = [threading.Thread(target=f) for x in range(20)]
for t in threads: with support.start_threads(threads):
t.start() time.sleep(0.02) # yield
time.sleep(0.02) # yield
for t in threads:
t.join()
self.assertFalse(errors, self.assertFalse(errors,
"the following exceptions were caught: %r" % errors) "the following exceptions were caught: %r" % errors)
s = b''.join(results) s = b''.join(results)
...@@ -1299,11 +1296,8 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): ...@@ -1299,11 +1296,8 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
errors.append(e) errors.append(e)
raise raise
threads = [threading.Thread(target=f) for x in range(20)] threads = [threading.Thread(target=f) for x in range(20)]
for t in threads: with support.start_threads(threads):
t.start() time.sleep(0.02) # yield
time.sleep(0.02) # yield
for t in threads:
t.join()
self.assertFalse(errors, self.assertFalse(errors,
"the following exceptions were caught: %r" % errors) "the following exceptions were caught: %r" % errors)
bufio.close() bufio.close()
...@@ -2555,14 +2549,10 @@ class TextIOWrapperTest(unittest.TestCase): ...@@ -2555,14 +2549,10 @@ class TextIOWrapperTest(unittest.TestCase):
text = "Thread%03d\n" % n text = "Thread%03d\n" % n
event.wait() event.wait()
f.write(text) f.write(text)
threads = [threading.Thread(target=lambda n=x: run(n)) threads = [threading.Thread(target=run, args=(x,))
for x in range(20)] for x in range(20)]
for t in threads: with support.start_threads(threads, event.set):
t.start() time.sleep(0.02)
time.sleep(0.02)
event.set()
for t in threads:
t.join()
with self.open(support.TESTFN) as f: with self.open(support.TESTFN) as f:
content = f.read() content = f.read()
for n in range(20): for n in range(20):
...@@ -3042,9 +3032,11 @@ class SignalsTest(unittest.TestCase): ...@@ -3042,9 +3032,11 @@ class SignalsTest(unittest.TestCase):
# return with a successful (partial) result rather than an EINTR. # return with a successful (partial) result rather than an EINTR.
# The buffered IO layer must check for pending signal # The buffered IO layer must check for pending signal
# handlers, which in this case will invoke alarm_interrupt(). # handlers, which in this case will invoke alarm_interrupt().
self.assertRaises(ZeroDivisionError, try:
wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1)) with self.assertRaises(ZeroDivisionError):
t.join() wio.write(item * (support.PIPE_MAX_SIZE // len(item) + 1))
finally:
t.join()
# We got one byte, get another one and check that it isn't a # We got one byte, get another one and check that it isn't a
# repeat of the first one. # repeat of the first one.
read_results.append(os.read(r, 1)) read_results.append(os.read(r, 1))
......
...@@ -37,7 +37,7 @@ __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module", ...@@ -37,7 +37,7 @@ __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module",
"captured_stdout", "TransientResource", "transient_internet", "captured_stdout", "TransientResource", "transient_internet",
"run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest", "run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
"BasicTestRunner", "run_unittest", "run_doctest", "threading_setup", "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
"threading_cleanup", "reap_children", "cpython_only", "threading_cleanup", "reap_threads", "start_threads", "cpython_only",
"check_impl_detail", "get_attribute", "py3k_bytes", "check_impl_detail", "get_attribute", "py3k_bytes",
"import_fresh_module", "threading_cleanup", "reap_children", "import_fresh_module", "threading_cleanup", "reap_children",
"strip_python_stderr", "IPV6_ENABLED"] "strip_python_stderr", "IPV6_ENABLED"]
...@@ -1508,6 +1508,39 @@ def reap_children(): ...@@ -1508,6 +1508,39 @@ def reap_children():
except: except:
break break
@contextlib.contextmanager
def start_threads(threads, unlock=None):
threads = list(threads)
started = []
try:
try:
for t in threads:
t.start()
started.append(t)
except:
if verbose:
print("Can't start %d threads, only %d threads started" %
(len(threads), len(started)))
raise
yield
finally:
if unlock:
unlock()
endtime = starttime = time.time()
for timeout in range(1, 16):
endtime += 60
for t in started:
t.join(max(endtime - time.time(), 0.01))
started = [t for t in started if t.isAlive()]
if not started:
break
if verbose:
print('Unable to join %d threads during a period of '
'%d minutes' % (len(started), timeout))
started = [t for t in started if t.isAlive()]
if started:
raise AssertionError('Unable to join %d threads' % len(started))
@contextlib.contextmanager @contextlib.contextmanager
def swap_attr(obj, attr, new_val): def swap_attr(obj, attr, new_val):
"""Temporary swap out an attribute with a new object. """Temporary swap out an attribute with a new object.
......
...@@ -18,7 +18,7 @@ FILES_PER_THREAD = 50 ...@@ -18,7 +18,7 @@ FILES_PER_THREAD = 50
import tempfile import tempfile
from test.test_support import threading_setup, threading_cleanup, run_unittest, import_module from test.test_support import start_threads, run_unittest, import_module
threading = import_module('threading') threading = import_module('threading')
import unittest import unittest
import StringIO import StringIO
...@@ -46,25 +46,12 @@ class TempFileGreedy(threading.Thread): ...@@ -46,25 +46,12 @@ class TempFileGreedy(threading.Thread):
class ThreadedTempFileTest(unittest.TestCase): class ThreadedTempFileTest(unittest.TestCase):
def test_main(self): def test_main(self):
threads = [] threads = [TempFileGreedy() for i in range(NUM_THREADS)]
thread_info = threading_setup() with start_threads(threads, startEvent.set):
pass
for i in range(NUM_THREADS): ok = sum(t.ok_count for t in threads)
t = TempFileGreedy() errors = [str(t.getName()) + str(t.errors.getvalue())
threads.append(t) for t in threads if t.error_count]
t.start()
startEvent.set()
ok = 0
errors = []
for t in threads:
t.join()
ok += t.ok_count
if t.error_count:
errors.append(str(t.getName()) + str(t.errors.getvalue()))
threading_cleanup(*thread_info)
msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok, msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok,
'\n'.join(errors)) '\n'.join(errors))
......
import unittest import unittest
from doctest import DocTestSuite from doctest import DocTestSuite
from test import test_support from test import test_support as support
import weakref import weakref
import gc import gc
# Modules under test # Modules under test
_thread = test_support.import_module('thread') _thread = support.import_module('thread')
threading = test_support.import_module('threading') threading = support.import_module('threading')
import _threading_local import _threading_local
...@@ -63,14 +63,9 @@ class BaseLocalTest: ...@@ -63,14 +63,9 @@ class BaseLocalTest:
# Simply check that the variable is correctly set # Simply check that the variable is correctly set
self.assertEqual(local.x, i) self.assertEqual(local.x, i)
threads= [] with support.start_threads(threading.Thread(target=f, args=(i,))
for i in range(10): for i in range(10)):
t = threading.Thread(target=f, args=(i,)) pass
t.start()
threads.append(t)
for t in threads:
t.join()
def test_derived_cycle_dealloc(self): def test_derived_cycle_dealloc(self):
# http://bugs.python.org/issue6990 # http://bugs.python.org/issue6990
...@@ -228,7 +223,7 @@ def test_main(): ...@@ -228,7 +223,7 @@ def test_main():
setUp=setUp, tearDown=tearDown) setUp=setUp, tearDown=tearDown)
) )
test_support.run_unittest(suite) support.run_unittest(suite)
if __name__ == '__main__': if __name__ == '__main__':
test_main() test_main()
...@@ -182,6 +182,9 @@ Tools/Demos ...@@ -182,6 +182,9 @@ Tools/Demos
Tests Tests
----- -----
- Issue #23799: Added test.test_support.start_threads() for running and
cleaning up multiple threads.
- Issue #22390: test.regrtest now emits a warning if temporary files or - Issue #22390: test.regrtest now emits a warning if temporary files or
directories are left after running a test. directories are left after running a test.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment