Kaydet (Commit) 7e74384a authored tarafından Steve Purcell's avatar Steve Purcell

- Fixed loading of tests by name when name refers to unbound

  method (PyUnit issue 563882, thanks to Alexandre Fayolle)
- Ignore non-callable attributes of classes when searching for test
  method names (PyUnit issue 769338, thanks to Seth Falcon)
- New assertTrue and assertFalse aliases for comfort of JUnit users
- Automatically discover 'runTest()' test methods (PyUnit issue 469444,
  thanks to Roeland Rengelink)
- Dropped Python 1.5.2 compatibility, merged appropriate shortcuts from
  Python CVS; should work with Python >= 2.1.
- Removed all references to string module by using string methods instead
üst 1e803597
...@@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from ...@@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from
http://pyunit.sourceforge.net/ http://pyunit.sourceforge.net/
Copyright (c) 1999, 2000, 2001 Steve Purcell Copyright (c) 1999-2003 Steve Purcell
This module is free software, and you may redistribute it and/or modify This module is free software, and you may redistribute it and/or modify
it under the same terms as Python itself, so long as this copyright message it under the same terms as Python itself, so long as this copyright message
and disclaimer are retained in their original form. and disclaimer are retained in their original form.
...@@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. ...@@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
__author__ = "Steve Purcell" __author__ = "Steve Purcell"
__email__ = "stephen_purcell at yahoo dot com" __email__ = "stephen_purcell at yahoo dot com"
__version__ = "#Revision: 1.46 $"[11:-2] __version__ = "#Revision: 1.56 $"[11:-2]
import time import time
import sys import sys
import traceback import traceback
import string
import os import os
import types import types
...@@ -61,10 +60,26 @@ import types ...@@ -61,10 +60,26 @@ import types
__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader'] 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
# Expose obsolete functions for backwards compatability # Expose obsolete functions for backwards compatibility
__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
##############################################################################
# Backward compatibility
##############################################################################
if sys.version_info[:2] < (2, 2):
False, True = 0, 1
def isinstance(obj, clsinfo):
import __builtin__
if type(clsinfo) in (types.TupleType, types.ListType):
for cls in clsinfo:
if cls is type: cls = types.ClassType
if __builtin__.isinstance(obj, cls):
return 1
return 0
else: return __builtin__.isinstance(obj, clsinfo)
############################################################################## ##############################################################################
# Test framework core # Test framework core
############################################################################## ##############################################################################
...@@ -121,11 +136,11 @@ class TestResult: ...@@ -121,11 +136,11 @@ class TestResult:
def stop(self): def stop(self):
"Indicates that the tests should be aborted" "Indicates that the tests should be aborted"
self.shouldStop = 1 self.shouldStop = True
def _exc_info_to_string(self, err): def _exc_info_to_string(self, err):
"""Converts a sys.exc_info()-style tuple of values into a string.""" """Converts a sys.exc_info()-style tuple of values into a string."""
return string.join(traceback.format_exception(*err), '') return ''.join(traceback.format_exception(*err))
def __repr__(self): def __repr__(self):
return "<%s run=%i errors=%i failures=%i>" % \ return "<%s run=%i errors=%i failures=%i>" % \
...@@ -196,7 +211,7 @@ class TestCase: ...@@ -196,7 +211,7 @@ class TestCase:
the specified test method's docstring. the specified test method's docstring.
""" """
doc = self.__testMethodDoc doc = self.__testMethodDoc
return doc and string.strip(string.split(doc, "\n")[0]) or None return doc and doc.split("\n")[0].strip() or None
def id(self): def id(self):
return "%s.%s" % (_strclass(self.__class__), self.__testMethodName) return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
...@@ -209,9 +224,6 @@ class TestCase: ...@@ -209,9 +224,6 @@ class TestCase:
(_strclass(self.__class__), self.__testMethodName) (_strclass(self.__class__), self.__testMethodName)
def run(self, result=None): def run(self, result=None):
return self(result)
def __call__(self, result=None):
if result is None: result = self.defaultTestResult() if result is None: result = self.defaultTestResult()
result.startTest(self) result.startTest(self)
testMethod = getattr(self, self.__testMethodName) testMethod = getattr(self, self.__testMethodName)
...@@ -224,10 +236,10 @@ class TestCase: ...@@ -224,10 +236,10 @@ class TestCase:
result.addError(self, self.__exc_info()) result.addError(self, self.__exc_info())
return return
ok = 0 ok = False
try: try:
testMethod() testMethod()
ok = 1 ok = True
except self.failureException: except self.failureException:
result.addFailure(self, self.__exc_info()) result.addFailure(self, self.__exc_info())
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -241,11 +253,13 @@ class TestCase: ...@@ -241,11 +253,13 @@ class TestCase:
raise raise
except: except:
result.addError(self, self.__exc_info()) result.addError(self, self.__exc_info())
ok = 0 ok = False
if ok: result.addSuccess(self) if ok: result.addSuccess(self)
finally: finally:
result.stopTest(self) result.stopTest(self)
__call__ = run
def debug(self): def debug(self):
"""Run the test without collecting errors in a TestResult""" """Run the test without collecting errors in a TestResult"""
self.setUp() self.setUp()
...@@ -292,7 +306,7 @@ class TestCase: ...@@ -292,7 +306,7 @@ class TestCase:
else: else:
if hasattr(excClass,'__name__'): excName = excClass.__name__ if hasattr(excClass,'__name__'): excName = excClass.__name__
else: excName = str(excClass) else: excName = str(excClass)
raise self.failureException, excName raise self.failureException, "%s not raised" % excName
def failUnlessEqual(self, first, second, msg=None): def failUnlessEqual(self, first, second, msg=None):
"""Fail if the two objects are unequal as determined by the '==' """Fail if the two objects are unequal as determined by the '=='
...@@ -334,6 +348,8 @@ class TestCase: ...@@ -334,6 +348,8 @@ class TestCase:
raise self.failureException, \ raise self.failureException, \
(msg or '%s == %s within %s places' % (`first`, `second`, `places`)) (msg or '%s == %s within %s places' % (`first`, `second`, `places`))
# Synonyms for assertion methods
assertEqual = assertEquals = failUnlessEqual assertEqual = assertEquals = failUnlessEqual
assertNotEqual = assertNotEquals = failIfEqual assertNotEqual = assertNotEquals = failIfEqual
...@@ -344,7 +360,9 @@ class TestCase: ...@@ -344,7 +360,9 @@ class TestCase:
assertRaises = failUnlessRaises assertRaises = failUnlessRaises
assert_ = failUnless assert_ = assertTrue = failUnless
assertFalse = failIf
...@@ -369,7 +387,7 @@ class TestSuite: ...@@ -369,7 +387,7 @@ class TestSuite:
def countTestCases(self): def countTestCases(self):
cases = 0 cases = 0
for test in self._tests: for test in self._tests:
cases = cases + test.countTestCases() cases += test.countTestCases()
return cases return cases
def addTest(self, test): def addTest(self, test):
...@@ -434,7 +452,7 @@ class FunctionTestCase(TestCase): ...@@ -434,7 +452,7 @@ class FunctionTestCase(TestCase):
def shortDescription(self): def shortDescription(self):
if self.__description is not None: return self.__description if self.__description is not None: return self.__description
doc = self.__testFunc.__doc__ doc = self.__testFunc.__doc__
return doc and string.strip(string.split(doc, "\n")[0]) or None return doc and doc.split("\n")[0].strip() or None
...@@ -452,8 +470,10 @@ class TestLoader: ...@@ -452,8 +470,10 @@ class TestLoader:
def loadTestsFromTestCase(self, testCaseClass): def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass""" """Return a suite of all tests cases contained in testCaseClass"""
return self.suiteClass(map(testCaseClass, testCaseNames = self.getTestCaseNames(testCaseClass)
self.getTestCaseNames(testCaseClass))) if not testCaseNames and hasattr(testCaseClass, 'runTest'):
testCaseNames = ['runTest']
return self.suiteClass(map(testCaseClass, testCaseNames))
def loadTestsFromModule(self, module): def loadTestsFromModule(self, module):
"""Return a suite of all tests cases contained in the given module""" """Return a suite of all tests cases contained in the given module"""
...@@ -474,23 +494,20 @@ class TestLoader: ...@@ -474,23 +494,20 @@ class TestLoader:
The method optionally resolves the names relative to a given module. The method optionally resolves the names relative to a given module.
""" """
parts = string.split(name, '.') parts = name.split('.')
if module is None: if module is None:
if not parts: parts_copy = parts[:]
raise ValueError, "incomplete test name: %s" % name while parts_copy:
else: try:
parts_copy = parts[:] module = __import__('.'.join(parts_copy))
while parts_copy: break
try: except ImportError:
module = __import__(string.join(parts_copy,'.')) del parts_copy[-1]
break if not parts_copy: raise
except ImportError:
del parts_copy[-1]
if not parts_copy: raise
parts = parts[1:] parts = parts[1:]
obj = module obj = module
for part in parts: for part in parts:
obj = getattr(obj, part) parent, obj = obj, getattr(obj, part)
import unittest import unittest
if type(obj) == types.ModuleType: if type(obj) == types.ModuleType:
...@@ -499,11 +516,13 @@ class TestLoader: ...@@ -499,11 +516,13 @@ class TestLoader:
issubclass(obj, unittest.TestCase)): issubclass(obj, unittest.TestCase)):
return self.loadTestsFromTestCase(obj) return self.loadTestsFromTestCase(obj)
elif type(obj) == types.UnboundMethodType: elif type(obj) == types.UnboundMethodType:
return parent(obj.__name__)
return obj.im_class(obj.__name__) return obj.im_class(obj.__name__)
elif isinstance(obj, unittest.TestSuite):
return obj
elif callable(obj): elif callable(obj):
test = obj() test = obj()
if not isinstance(test, unittest.TestCase) and \ if not isinstance(test, (unittest.TestCase, unittest.TestSuite)):
not isinstance(test, unittest.TestSuite):
raise ValueError, \ raise ValueError, \
"calling %s returned %s, not a test" % (obj,test) "calling %s returned %s, not a test" % (obj,test)
return test return test
...@@ -514,16 +533,15 @@ class TestLoader: ...@@ -514,16 +533,15 @@ class TestLoader:
"""Return a suite of all tests cases found using the given sequence """Return a suite of all tests cases found using the given sequence
of string specifiers. See 'loadTestsFromName()'. of string specifiers. See 'loadTestsFromName()'.
""" """
suites = [] suites = [self.loadTestsFromName(name, module) for name in names]
for name in names:
suites.append(self.loadTestsFromName(name, module))
return self.suiteClass(suites) return self.suiteClass(suites)
def getTestCaseNames(self, testCaseClass): def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass """Return a sorted sequence of method names found within testCaseClass
""" """
testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p, def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
dir(testCaseClass)) return attrname[:len(prefix)] == prefix and callable(getattr(testCaseClass, attrname))
testFnNames = filter(isTestMethod, dir(testCaseClass))
for baseclass in testCaseClass.__bases__: for baseclass in testCaseClass.__bases__:
for testFnName in self.getTestCaseNames(baseclass): for testFnName in self.getTestCaseNames(baseclass):
if testFnName not in testFnNames: # handle overridden methods if testFnName not in testFnNames: # handle overridden methods
...@@ -706,7 +724,7 @@ Examples: ...@@ -706,7 +724,7 @@ Examples:
argv=None, testRunner=None, testLoader=defaultTestLoader): argv=None, testRunner=None, testLoader=defaultTestLoader):
if type(module) == type(''): if type(module) == type(''):
self.module = __import__(module) self.module = __import__(module)
for part in string.split(module,'.')[1:]: for part in module.split('.')[1:]:
self.module = getattr(self.module, part) self.module = getattr(self.module, part)
else: else:
self.module = module self.module = module
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment