unittest.py 30.4 KB
Newer Older
1
#!/usr/bin/env python
2
'''
3 4 5 6 7 8
Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's
Smalltalk testing framework.

This module contains the core framework classes that form the basis of
specific test cases and suites (TestCase, TestSuite etc.), and also a
text-based utility class for running the tests and reporting the results
Jeremy Hylton's avatar
Jeremy Hylton committed
9
 (TextTestRunner).
10

11 12 13 14 15 16 17 18
Simple usage:

    import unittest

    class IntegerArithmenticTestCase(unittest.TestCase):
        def testAdd(self):  ## test method names begin 'test*'
            self.assertEquals((1 + 2), 3)
            self.assertEquals(0 + 1, 1)
19
        def testMultiply(self):
20 21 22 23 24 25 26 27
            self.assertEquals((0 * 10), 0)
            self.assertEquals((5 * 8), 40)

    if __name__ == '__main__':
        unittest.main()

Further information is available in the bundled documentation, and from

28
  http://docs.python.org/lib/module-unittest.html
29

30
Copyright (c) 1999-2003 Steve Purcell
31 32 33 34 35 36 37 38 39 40 41 42 43 44
This module is free software, and you may redistribute it and/or modify
it under the same terms as Python itself, so long as this copyright message
and disclaimer are retained in their original form.

IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.

THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE.  THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
45
'''
46

47 48
__author__ = "Steve Purcell"
__email__ = "stephen_purcell at yahoo dot com"
49
__version__ = "#Revision: 1.63 $"[11:-2]
50 51 52 53 54

import time
import sys
import traceback
import os
55
import types
56

57 58 59 60 61 62
##############################################################################
# Exported classes and functions
##############################################################################
__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
           'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']

63
# Expose obsolete functions for backwards compatibility
64 65 66
__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])


67 68 69 70 71 72
##############################################################################
# Backward compatibility
##############################################################################
if sys.version_info[:2] < (2, 2):
    def isinstance(obj, clsinfo):
        import __builtin__
73
        if type(clsinfo) in (tuple, list):
74 75 76 77 78 79 80
            for cls in clsinfo:
                if cls is type: cls = types.ClassType
                if __builtin__.isinstance(obj, cls):
                    return 1
            return 0
        else: return __builtin__.isinstance(obj, clsinfo)

81 82 83 84 85 86 87 88
def _CmpToKey(mycmp):
    'Convert a cmp= function into a key= function'
    class K(object):
        def __init__(self, obj):
            self.obj = obj
        def __lt__(self, other):
            return mycmp(self.obj, other.obj) == -1
    return K
89

90 91 92 93
##############################################################################
# Test framework core
##############################################################################

94 95 96
# All classes defined herein are 'new-style' classes, allowing use of 'super()'
__metaclass__ = type

97 98 99
def _strclass(cls):
    return "%s.%s" % (cls.__module__, cls.__name__)

100 101
__unittest = 1

102 103 104 105 106 107 108 109
class TestResult:
    """Holder for test result information.

    Test results are automatically managed by the TestCase and TestSuite
    classes, and do not need to be explicitly manipulated by writers of tests.

    Each instance holds the total number of tests run, and collections of
    failures and errors that occurred among those test runs. The collections
110
    contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
111
    formatted traceback of the error that occurred.
112 113 114 115 116
    """
    def __init__(self):
        self.failures = []
        self.errors = []
        self.testsRun = 0
117
        self.shouldStop = False
118 119 120 121 122 123 124 125 126 127

    def startTest(self, test):
        "Called when the given test is about to be run"
        self.testsRun = self.testsRun + 1

    def stopTest(self, test):
        "Called when the given test has been run"
        pass

    def addError(self, test, err):
128 129 130
        """Called when an error has occurred. 'err' is a tuple of values as
        returned by sys.exc_info().
        """
131
        self.errors.append((test, self._exc_info_to_string(err, test)))
132 133

    def addFailure(self, test, err):
134 135
        """Called when an error has occurred. 'err' is a tuple of values as
        returned by sys.exc_info()."""
136
        self.failures.append((test, self._exc_info_to_string(err, test)))
137

138 139 140 141
    def addSuccess(self, test):
        "Called when a test has completed successfully"
        pass

142 143 144 145 146 147
    def wasSuccessful(self):
        "Tells whether or not this result was a success"
        return len(self.failures) == len(self.errors) == 0

    def stop(self):
        "Indicates that the tests should be aborted"
148
        self.shouldStop = True
Tim Peters's avatar
Tim Peters committed
149

150
    def _exc_info_to_string(self, err, test):
151
        """Converts a sys.exc_info()-style tuple of values into a string."""
152 153 154 155 156 157 158 159 160 161 162
        exctype, value, tb = err
        # Skip test runner traceback levels
        while tb and self._is_relevant_tb_level(tb):
            tb = tb.tb_next
        if exctype is test.failureException:
            # Skip assert*() traceback levels
            length = self._count_relevant_tb_levels(tb)
            return ''.join(traceback.format_exception(exctype, value, tb, length))
        return ''.join(traceback.format_exception(exctype, value, tb))

    def _is_relevant_tb_level(self, tb):
163
        return '__unittest' in tb.tb_frame.f_globals
164 165 166 167 168 169 170

    def _count_relevant_tb_levels(self, tb):
        length = 0
        while tb and not self._is_relevant_tb_level(tb):
            length += 1
            tb = tb.tb_next
        return length
171

172 173
    def __repr__(self):
        return "<%s run=%i errors=%i failures=%i>" % \
174
               (_strclass(self.__class__), self.testsRun, len(self.errors),
175 176 177 178 179 180 181
                len(self.failures))

class TestCase:
    """A class whose instances are single test cases.

    By default, the test code itself should be placed in a method named
    'runTest'.
182

Tim Peters's avatar
Tim Peters committed
183
    If the fixture may be used for many test cases, create as
184 185 186
    many test methods as are needed. When instantiating such a TestCase
    subclass, specify in the constructor arguments the name of the test method
    that the instance is to execute.
187

Tim Peters's avatar
Tim Peters committed
188
    Test authors should subclass TestCase for their own tests. Construction
189 190 191 192 193 194 195 196
    and deconstruction of the test's environment ('fixture') can be
    implemented by overriding the 'setUp' and 'tearDown' methods respectively.

    If it is necessary to override the __init__ method, the base class
    __init__ method must always be called. It is important that subclasses
    should not change the signature of their __init__ method, since instances
    of the classes are instantiated automatically by parts of the framework
    in order to be run.
197
    """
198 199 200 201 202 203 204

    # This attribute determines which exception will be raised when
    # the instance's assertion methods fail; test methods raising this
    # exception will be deemed to have 'failed' rather than 'errored'

    failureException = AssertionError

205 206 207 208 209 210
    def __init__(self, methodName='runTest'):
        """Create an instance of the class that will use the named test
           method when executed. Raises a ValueError if the instance does
           not have a method with the specified name.
        """
        try:
211
            self._testMethodName = methodName
212
            testMethod = getattr(self, methodName)
213
            self._testMethodDoc = testMethod.__doc__
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
        except AttributeError:
            raise ValueError, "no such test method in %s: %s" % \
                  (self.__class__, methodName)

    def setUp(self):
        "Hook method for setting up the test fixture before exercising it."
        pass

    def tearDown(self):
        "Hook method for deconstructing the test fixture after testing it."
        pass

    def countTestCases(self):
        return 1

    def defaultTestResult(self):
        return TestResult()

    def shortDescription(self):
        """Returns a one-line description of the test, or None if no
        description has been provided.

        The default implementation of this method returns the first line of
        the specified test method's docstring.
        """
239
        doc = self._testMethodDoc
240
        return doc and doc.split("\n")[0].strip() or None
241 242

    def id(self):
243
        return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
244

245 246 247 248 249 250 251 252 253 254
    def __eq__(self, other):
        if type(self) is not type(other):
            return False

        return self._testMethodName == other._testMethodName

    def __ne__(self, other):
        return not self == other

    def __hash__(self):
255
        return hash((type(self), self._testMethodName))
256

257
    def __str__(self):
258
        return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
259 260 261

    def __repr__(self):
        return "<%s testMethod=%s>" % \
262
               (_strclass(self.__class__), self._testMethodName)
263 264 265 266

    def run(self, result=None):
        if result is None: result = self.defaultTestResult()
        result.startTest(self)
267
        testMethod = getattr(self, self._testMethodName)
268 269 270
        try:
            try:
                self.setUp()
271 272
            except KeyboardInterrupt:
                raise
273
            except:
274
                result.addError(self, self._exc_info())
275 276
                return

277
            ok = False
278
            try:
279
                testMethod()
280
                ok = True
281
            except self.failureException:
282
                result.addFailure(self, self._exc_info())
283 284
            except KeyboardInterrupt:
                raise
285
            except:
286
                result.addError(self, self._exc_info())
287 288 289

            try:
                self.tearDown()
290 291
            except KeyboardInterrupt:
                raise
292
            except:
293
                result.addError(self, self._exc_info())
294
                ok = False
295
            if ok: result.addSuccess(self)
296 297 298
        finally:
            result.stopTest(self)

299 300
    def __call__(self, *args, **kwds):
        return self.run(*args, **kwds)
301

302
    def debug(self):
303
        """Run the test without collecting errors in a TestResult"""
304
        self.setUp()
305
        getattr(self, self._testMethodName)()
306 307
        self.tearDown()

308
    def _exc_info(self):
309 310 311
        """Return a version of sys.exc_info() with the traceback frame
           minimised; usually the top level of the traceback frame is not
           needed.
312
        """
313
        return sys.exc_info()
314

315 316 317
    def fail(self, msg=None):
        """Fail immediately, with the given message."""
        raise self.failureException, msg
318 319 320

    def failIf(self, expr, msg=None):
        "Fail the test if the expression is true."
321 322 323 324 325
        if expr: raise self.failureException, msg

    def failUnless(self, expr, msg=None):
        """Fail the test unless the expression is true."""
        if not expr: raise self.failureException, msg
326

327 328
    def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
        """Fail unless an exception of class excClass is thrown
329 330 331 332 333 334 335
           by callableObj when invoked with arguments args and keyword
           arguments kwargs. If a different type of exception is
           thrown, it will not be caught, and the test case will be
           deemed to have suffered an error, exactly as for an
           unexpected exception.
        """
        try:
336
            callableObj(*args, **kwargs)
337 338 339 340 341
        except excClass:
            return
        else:
            if hasattr(excClass,'__name__'): excName = excClass.__name__
            else: excName = str(excClass)
342
            raise self.failureException, "%s not raised" % excName
343

344
    def failUnlessEqual(self, first, second, msg=None):
345
        """Fail if the two objects are unequal as determined by the '=='
346 347
           operator.
        """
348
        if not first == second:
Steve Purcell's avatar
Steve Purcell committed
349
            raise self.failureException, \
350
                  (msg or '%r != %r' % (first, second))
351

352 353
    def failIfEqual(self, first, second, msg=None):
        """Fail if the two objects are equal as determined by the '=='
354 355
           operator.
        """
356
        if first == second:
Steve Purcell's avatar
Steve Purcell committed
357
            raise self.failureException, \
358
                  (msg or '%r == %r' % (first, second))
359

360 361 362 363 364
    def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
        """Fail if the two objects are unequal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero.

365
           Note that decimal places (from zero) are usually not the same
366 367
           as significant digits (measured from the most signficant digit).
        """
368
        if round(abs(second-first), places) != 0:
369
            raise self.failureException, \
370
                  (msg or '%r != %r within %r places' % (first, second, places))
371 372 373 374 375 376

    def failIfAlmostEqual(self, first, second, places=7, msg=None):
        """Fail if the two objects are equal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero.

377
           Note that decimal places (from zero) are usually not the same
378 379
           as significant digits (measured from the most signficant digit).
        """
380
        if round(abs(second-first), places) == 0:
381
            raise self.failureException, \
382
                  (msg or '%r == %r within %r places' % (first, second, places))
383

384 385
    # Synonyms for assertion methods

386
    assertEqual = assertEquals = failUnlessEqual
387

388 389
    assertNotEqual = assertNotEquals = failIfEqual

390 391 392 393
    assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual

    assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual

394 395
    assertRaises = failUnlessRaises

396 397 398
    assert_ = assertTrue = failUnless

    assertFalse = failIf
399 400


401 402 403 404 405 406 407 408 409 410 411 412 413 414 415

class TestSuite:
    """A test suite is a composite test consisting of a number of TestCases.

    For use, create an instance of TestSuite, then add test case instances.
    When all tests have been added, the suite can be passed to a test
    runner, such as TextTestRunner. It will run the individual test cases
    in the order in which they were added, aggregating the results. When
    subclassing, do not forget to call the base class constructor.
    """
    def __init__(self, tests=()):
        self._tests = []
        self.addTests(tests)

    def __repr__(self):
416
        return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
417 418 419

    __str__ = __repr__

420 421 422 423 424 425 426 427
    def __eq__(self, other):
        if type(self) is not type(other):
            return False
        return self._tests == other._tests

    def __ne__(self, other):
        return not self == other

428 429 430
    # Can't guarantee hash invariant, so flag as unhashable
    __hash__ = None

431 432 433
    def __iter__(self):
        return iter(self._tests)

434 435 436
    def countTestCases(self):
        cases = 0
        for test in self._tests:
437
            cases += test.countTestCases()
438 439 440
        return cases

    def addTest(self, test):
441
        # sanity checks
442
        if not hasattr(test, '__call__'):
443 444 445 446 447
            raise TypeError("the test to add must be callable")
        if (isinstance(test, (type, types.ClassType)) and
            issubclass(test, (TestCase, TestSuite))):
            raise TypeError("TestCases and TestSuites must be instantiated "
                            "before passing them to addTest()")
448 449 450
        self._tests.append(test)

    def addTests(self, tests):
451 452
        if isinstance(tests, basestring):
            raise TypeError("tests must be an iterable of tests, not a string")
453 454 455 456 457 458 459 460 461 462
        for test in tests:
            self.addTest(test)

    def run(self, result):
        for test in self._tests:
            if result.shouldStop:
                break
            test(result)
        return result

463 464 465
    def __call__(self, *args, **kwds):
        return self.run(*args, **kwds)

466
    def debug(self):
467
        """Run the tests without collecting errors in a TestResult"""
468 469 470 471 472 473 474
        for test in self._tests: test.debug()


class FunctionTestCase(TestCase):
    """A test case that wraps a test function.

    This is useful for slipping pre-existing test functions into the
475
    unittest framework. Optionally, set-up and tidy-up functions can be
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
    supplied. As with TestCase, the tidy-up ('tearDown') function will
    always be called if the set-up ('setUp') function ran successfully.
    """

    def __init__(self, testFunc, setUp=None, tearDown=None,
                 description=None):
        TestCase.__init__(self)
        self.__setUpFunc = setUp
        self.__tearDownFunc = tearDown
        self.__testFunc = testFunc
        self.__description = description

    def setUp(self):
        if self.__setUpFunc is not None:
            self.__setUpFunc()

    def tearDown(self):
        if self.__tearDownFunc is not None:
            self.__tearDownFunc()

    def runTest(self):
        self.__testFunc()

    def id(self):
        return self.__testFunc.__name__

502 503 504 505 506 507 508 509 510 511 512 513 514
    def __eq__(self, other):
        if type(self) is not type(other):
            return False

        return self.__setUpFunc == other.__setUpFunc and \
               self.__tearDownFunc == other.__tearDownFunc and \
               self.__testFunc == other.__testFunc and \
               self.__description == other.__description

    def __ne__(self, other):
        return not self == other

    def __hash__(self):
515 516
        return hash((type(self), self.__setUpFunc, self.__tearDownFunc,
                                           self.__testFunc, self.__description))
517

518
    def __str__(self):
519
        return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
520 521

    def __repr__(self):
522
        return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
523 524 525 526

    def shortDescription(self):
        if self.__description is not None: return self.__description
        doc = self.__testFunc.__doc__
527
        return doc and doc.split("\n")[0].strip() or None
528 529 530 531



##############################################################################
532
# Locating and loading tests
533 534
##############################################################################

535 536
class TestLoader:
    """This class is responsible for loading tests according to various
537
    criteria and returning them wrapped in a TestSuite
538
    """
539 540 541
    testMethodPrefix = 'test'
    sortTestMethodsUsing = cmp
    suiteClass = TestSuite
542

543
    def loadTestsFromTestCase(self, testCaseClass):
544
        """Return a suite of all tests cases contained in testCaseClass"""
545 546
        if issubclass(testCaseClass, TestSuite):
            raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?")
547 548 549 550
        testCaseNames = self.getTestCaseNames(testCaseClass)
        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
            testCaseNames = ['runTest']
        return self.suiteClass(map(testCaseClass, testCaseNames))
551

552
    def loadTestsFromModule(self, module):
553
        """Return a suite of all tests cases contained in the given module"""
554 555 556
        tests = []
        for name in dir(module):
            obj = getattr(module, name)
557 558
            if (isinstance(obj, (type, types.ClassType)) and
                issubclass(obj, TestCase)):
559 560 561 562
                tests.append(self.loadTestsFromTestCase(obj))
        return self.suiteClass(tests)

    def loadTestsFromName(self, name, module=None):
563 564 565 566 567
        """Return a suite of all tests cases given a string specifier.

        The name may resolve either to a module, a test case class, a
        test method within a test case class, or a callable object which
        returns a TestCase or TestSuite instance.
Tim Peters's avatar
Tim Peters committed
568

569 570
        The method optionally resolves the names relative to a given module.
        """
571
        parts = name.split('.')
572
        if module is None:
573 574 575 576 577 578 579 580
            parts_copy = parts[:]
            while parts_copy:
                try:
                    module = __import__('.'.join(parts_copy))
                    break
                except ImportError:
                    del parts_copy[-1]
                    if not parts_copy: raise
581
            parts = parts[1:]
582 583
        obj = module
        for part in parts:
584
            parent, obj = obj, getattr(obj, part)
585 586 587

        if type(obj) == types.ModuleType:
            return self.loadTestsFromModule(obj)
588
        elif (isinstance(obj, (type, types.ClassType)) and
589
              issubclass(obj, TestCase)):
590
            return self.loadTestsFromTestCase(obj)
591 592 593 594
        elif (type(obj) == types.UnboundMethodType and
              isinstance(parent, (type, types.ClassType)) and
              issubclass(parent, TestCase)):
            return TestSuite([parent(obj.__name__)])
595
        elif isinstance(obj, TestSuite):
596
            return obj
597
        elif hasattr(obj, '__call__'):
598
            test = obj()
599 600 601 602 603 604 605
            if isinstance(test, TestSuite):
                return test
            elif isinstance(test, TestCase):
                return TestSuite([test])
            else:
                raise TypeError("calling %s returned %s, not a test" %
                                (obj, test))
606
        else:
607
            raise TypeError("don't know how to make test from: %s" % obj)
608

609
    def loadTestsFromNames(self, names, module=None):
610 611 612
        """Return a suite of all tests cases found using the given sequence
        of string specifiers. See 'loadTestsFromName()'.
        """
613
        suites = [self.loadTestsFromName(name, module) for name in names]
614
        return self.suiteClass(suites)
615

616
    def getTestCaseNames(self, testCaseClass):
617 618
        """Return a sorted sequence of method names found within testCaseClass
        """
619
        def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
620
            return attrname.startswith(prefix) and hasattr(getattr(testCaseClass, attrname), '__call__')
621
        testFnNames = filter(isTestMethod, dir(testCaseClass))
622
        if self.sortTestMethodsUsing:
623
            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
        return testFnNames



defaultTestLoader = TestLoader()


##############################################################################
# Patches for old functions: these functions should be considered obsolete
##############################################################################

def _makeLoader(prefix, sortUsing, suiteClass=None):
    loader = TestLoader()
    loader.sortTestMethodsUsing = sortUsing
    loader.testMethodPrefix = prefix
    if suiteClass: loader.suiteClass = suiteClass
    return loader

def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)

def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)

def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
650 651 652 653 654 655 656 657 658 659 660 661 662 663


##############################################################################
# Text UI
##############################################################################

class _WritelnDecorator:
    """Used to decorate file-like objects with a handy 'writeln' method"""
    def __init__(self,stream):
        self.stream = stream

    def __getattr__(self, attr):
        return getattr(self.stream,attr)

664 665
    def writeln(self, arg=None):
        if arg: self.write(arg)
666
        self.write('\n') # text-mode streams translate to \r\n if needed
Tim Peters's avatar
Tim Peters committed
667

668

669
class _TextTestResult(TestResult):
670 671
    """A test result class that can print formatted text results to a stream.

672
    Used by TextTestRunner.
673
    """
674 675 676 677
    separator1 = '=' * 70
    separator2 = '-' * 70

    def __init__(self, stream, descriptions, verbosity):
678 679
        TestResult.__init__(self)
        self.stream = stream
680 681
        self.showAll = verbosity > 1
        self.dots = verbosity == 1
682
        self.descriptions = descriptions
683 684

    def getDescription(self, test):
685
        if self.descriptions:
686
            return test.shortDescription() or str(test)
687
        else:
688
            return str(test)
689

690 691 692 693 694
    def startTest(self, test):
        TestResult.startTest(self, test)
        if self.showAll:
            self.stream.write(self.getDescription(test))
            self.stream.write(" ... ")
695
            self.stream.flush()
696 697 698 699

    def addSuccess(self, test):
        TestResult.addSuccess(self, test)
        if self.showAll:
700
            self.stream.writeln("ok")
701 702
        elif self.dots:
            self.stream.write('.')
703
            self.stream.flush()
704 705 706

    def addError(self, test, err):
        TestResult.addError(self, test, err)
707 708 709 710
        if self.showAll:
            self.stream.writeln("ERROR")
        elif self.dots:
            self.stream.write('E')
711
            self.stream.flush()
712 713 714

    def addFailure(self, test, err):
        TestResult.addFailure(self, test, err)
715 716 717 718
        if self.showAll:
            self.stream.writeln("FAIL")
        elif self.dots:
            self.stream.write('F')
719
            self.stream.flush()
720 721 722

    def printErrors(self):
        if self.dots or self.showAll:
723
            self.stream.writeln()
724 725 726 727 728 729 730 731
        self.printErrorList('ERROR', self.errors)
        self.printErrorList('FAIL', self.failures)

    def printErrorList(self, flavour, errors):
        for test, err in errors:
            self.stream.writeln(self.separator1)
            self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
            self.stream.writeln(self.separator2)
732
            self.stream.writeln("%s" % err)
733 734


735
class TextTestRunner:
736
    """A test runner class that displays results in textual form.
Tim Peters's avatar
Tim Peters committed
737

738 739 740
    It prints out the names of tests as they are run, errors as they
    occur, and a summary of the results at the end of the test run.
    """
741
    def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
742 743
        self.stream = _WritelnDecorator(stream)
        self.descriptions = descriptions
744 745 746 747
        self.verbosity = verbosity

    def _makeResult(self):
        return _TextTestResult(self.stream, self.descriptions, self.verbosity)
748 749 750

    def run(self, test):
        "Run the given test case or test suite."
751
        result = self._makeResult()
752 753 754
        startTime = time.time()
        test(result)
        stopTime = time.time()
755
        timeTaken = stopTime - startTime
756 757
        result.printErrors()
        self.stream.writeln(result.separator2)
758 759
        run = result.testsRun
        self.stream.writeln("Ran %d test%s in %.3fs" %
760
                            (run, run != 1 and "s" or "", timeTaken))
761 762 763 764 765 766 767 768 769 770 771 772 773
        self.stream.writeln()
        if not result.wasSuccessful():
            self.stream.write("FAILED (")
            failed, errored = map(len, (result.failures, result.errors))
            if failed:
                self.stream.write("failures=%d" % failed)
            if errored:
                if failed: self.stream.write(", ")
                self.stream.write("errors=%d" % errored)
            self.stream.writeln(")")
        else:
            self.stream.writeln("OK")
        return result
Tim Peters's avatar
Tim Peters committed
774

775 776 777 778 779 780 781 782 783 784 785


##############################################################################
# Facilities for running tests from the command line
##############################################################################

class TestProgram:
    """A command-line program that runs a set of tests; this is primarily
       for making test modules conveniently executable.
    """
    USAGE = """\
786
Usage: %(progName)s [options] [test] [...]
787 788 789 790 791

Options:
  -h, --help       Show this message
  -v, --verbose    Verbose output
  -q, --quiet      Minimal output
792 793 794 795

Examples:
  %(progName)s                               - run default set of tests
  %(progName)s MyTestSuite                   - run suite 'MyTestSuite'
796 797
  %(progName)s MyTestCase.testSomething      - run MyTestCase.testSomething
  %(progName)s MyTestCase                    - run all 'test*' test methods
798 799 800
                                               in MyTestCase
"""
    def __init__(self, module='__main__', defaultTest=None,
801 802
                 argv=None, testRunner=TextTestRunner,
                 testLoader=defaultTestLoader):
803 804
        if type(module) == type(''):
            self.module = __import__(module)
805
            for part in module.split('.')[1:]:
806 807 808 809 810
                self.module = getattr(self.module, part)
        else:
            self.module = module
        if argv is None:
            argv = sys.argv
811
        self.verbosity = 1
812 813
        self.defaultTest = defaultTest
        self.testRunner = testRunner
814
        self.testLoader = testLoader
815 816 817 818 819 820 821 822 823 824 825 826
        self.progName = os.path.basename(argv[0])
        self.parseArgs(argv)
        self.runTests()

    def usageExit(self, msg=None):
        if msg: print msg
        print self.USAGE % self.__dict__
        sys.exit(2)

    def parseArgs(self, argv):
        import getopt
        try:
827 828
            options, args = getopt.getopt(argv[1:], 'hHvq',
                                          ['help','verbose','quiet'])
829 830 831
            for opt, value in options:
                if opt in ('-h','-H','--help'):
                    self.usageExit()
832 833 834 835
                if opt in ('-q','--quiet'):
                    self.verbosity = 0
                if opt in ('-v','--verbose'):
                    self.verbosity = 2
836
            if len(args) == 0 and self.defaultTest is None:
837 838
                self.test = self.testLoader.loadTestsFromModule(self.module)
                return
839 840 841 842
            if len(args) > 0:
                self.testNames = args
            else:
                self.testNames = (self.defaultTest,)
843
            self.createTests()
844 845 846 847
        except getopt.error, msg:
            self.usageExit(msg)

    def createTests(self):
848 849
        self.test = self.testLoader.loadTestsFromNames(self.testNames,
                                                       self.module)
850 851

    def runTests(self):
852 853 854 855 856 857 858 859 860 861
        if isinstance(self.testRunner, (type, types.ClassType)):
            try:
                testRunner = self.testRunner(verbosity=self.verbosity)
            except TypeError:
                # didn't accept the verbosity argument
                testRunner = self.testRunner()
        else:
            # it is assumed to be a TestRunner instance
            testRunner = self.testRunner
        result = testRunner.run(self.test)
Tim Peters's avatar
Tim Peters committed
862
        sys.exit(not result.wasSuccessful())
863 864 865 866 867 868 869 870 871 872

main = TestProgram


##############################################################################
# Executing this module from the command line
##############################################################################

if __name__ == "__main__":
    main(module=None)