loader.py 22.1 KB
Newer Older
1 2 3
"""Loading unittests."""

import os
Benjamin Peterson's avatar
Benjamin Peterson committed
4
import re
5
import sys
Benjamin Peterson's avatar
Benjamin Peterson committed
6
import traceback
7
import types
8
import functools
9
import warnings
10

11
from fnmatch import fnmatch, fnmatchcase
12 13 14

from . import case, suite, util

15
__unittest = True
16

17
# what about .pyc (etc)
Benjamin Peterson's avatar
Benjamin Peterson committed
18
# we would need to avoid loading the same tests multiple times
19
# from '.py', *and* '.pyc'
Benjamin Peterson's avatar
Benjamin Peterson committed
20 21 22
VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)


23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
class _FailedTest(case.TestCase):
    _testMethodName = None

    def __init__(self, method_name, exception):
        self._exception = exception
        super(_FailedTest, self).__init__(method_name)

    def __getattr__(self, name):
        if name != self._testMethodName:
            return super(_FailedTest, self).__getattr__(name)
        def testFailure():
            raise self._exception
        return testFailure


Benjamin Peterson's avatar
Benjamin Peterson committed
38
def _make_failed_import_test(name, suiteClass):
39 40
    message = 'Failed to import test module: %s\n%s' % (
        name, traceback.format_exc())
41
    return _make_failed_test(name, ImportError(message), suiteClass, message)
Benjamin Peterson's avatar
Benjamin Peterson committed
42

Benjamin Peterson's avatar
Benjamin Peterson committed
43
def _make_failed_load_tests(name, exception, suiteClass):
44 45
    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
    return _make_failed_test(
46
        name, exception, suiteClass, message)
Benjamin Peterson's avatar
Benjamin Peterson committed
47

48
def _make_failed_test(methodname, exception, suiteClass, message):
49
    test = _FailedTest(methodname, exception)
50
    return suiteClass((test,)), message
Benjamin Peterson's avatar
Benjamin Peterson committed
51

52 53 54 55 56 57 58 59
def _make_skipped_test(methodname, exception, suiteClass):
    @case.skip(str(exception))
    def testSkipped(self):
        pass
    attrs = {methodname: testSkipped}
    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
    return suiteClass((TestClass(methodname),))

60 61 62 63 64
def _jython_aware_splitext(path):
    if path.lower().endswith('$py.class'):
        return path[:-9]
    return os.path.splitext(path)[0]

Benjamin Peterson's avatar
Benjamin Peterson committed
65

66 67 68 69 70 71 72
class TestLoader(object):
    """
    This class is responsible for loading tests according to various criteria
    and returning them wrapped in a TestSuite
    """
    testMethodPrefix = 'test'
    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
73
    testNamePatterns = None
74 75 76
    suiteClass = suite.TestSuite
    _top_level_dir = None

77 78 79
    def __init__(self):
        super(TestLoader, self).__init__()
        self.errors = []
80 81 82
        # Tracks packages which we have called into via load_tests, to
        # avoid infinite re-entrancy.
        self._loading_packages = set()
83

84
    def loadTestsFromTestCase(self, testCaseClass):
85
        """Return a suite of all test cases contained in testCaseClass"""
86
        if issubclass(testCaseClass, suite.TestSuite):
87 88 89
            raise TypeError("Test cases should not be derived from "
                            "TestSuite. Maybe you meant to derive from "
                            "TestCase?")
90 91 92 93 94 95
        testCaseNames = self.getTestCaseNames(testCaseClass)
        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
            testCaseNames = ['runTest']
        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
        return loaded_suite

96 97 98
    # XXX After Python 3.5, remove backward compatibility hacks for
    # use_load_tests deprecation via *args and **kws.  See issue 16662.
    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
99
        """Return a suite of all test cases contained in the given module"""
100 101 102 103
        # This method used to take an undocumented and unofficial
        # use_load_tests argument.  For backward compatibility, we still
        # accept the argument (which can also be the first position) but we
        # ignore it and issue a deprecation warning if it's present.
104
        if len(args) > 0 or 'use_load_tests' in kws:
105 106 107 108
            warnings.warn('use_load_tests is deprecated and ignored',
                          DeprecationWarning)
            kws.pop('use_load_tests', None)
        if len(args) > 1:
109 110 111 112
            # Complain about the number of arguments, but don't forget the
            # required `module` argument.
            complaint = len(args) + 1
            raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
113 114 115 116 117 118 119
        if len(kws) != 0:
            # Since the keyword arguments are unsorted (see PEP 468), just
            # pick the alphabetically sorted first argument to complain about,
            # if multiple were given.  At least the error message will be
            # predictable.
            complaint = sorted(kws)[0]
            raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
120 121 122 123 124 125 126
        tests = []
        for name in dir(module):
            obj = getattr(module, name)
            if isinstance(obj, type) and issubclass(obj, case.TestCase):
                tests.append(self.loadTestsFromTestCase(obj))

        load_tests = getattr(module, 'load_tests', None)
127
        tests = self.suiteClass(tests)
128
        if load_tests is not None:
Benjamin Peterson's avatar
Benjamin Peterson committed
129
            try:
130
                return load_tests(self, tests, pattern)
Benjamin Peterson's avatar
Benjamin Peterson committed
131
            except Exception as e:
132 133 134 135
                error_case, error_message = _make_failed_load_tests(
                    module.__name__, e, self.suiteClass)
                self.errors.append(error_message)
                return error_case
136
        return tests
137 138

    def loadTestsFromName(self, name, module=None):
139
        """Return a suite of all test cases given a string specifier.
140 141 142 143 144 145 146 147

        The name may resolve either to a module, a test case class, a
        test method within a test case class, or a callable object which
        returns a TestCase or TestSuite instance.

        The method optionally resolves the names relative to a given module.
        """
        parts = name.split('.')
148
        error_case, error_message = None, None
149 150 151 152
        if module is None:
            parts_copy = parts[:]
            while parts_copy:
                try:
153 154
                    module_name = '.'.join(parts_copy)
                    module = __import__(module_name)
155 156
                    break
                except ImportError:
157 158 159 160
                    next_attribute = parts_copy.pop()
                    # Last error so we can give it to the user if needed.
                    error_case, error_message = _make_failed_import_test(
                        next_attribute, self.suiteClass)
161
                    if not parts_copy:
162 163 164
                        # Even the top level import failed: report that error.
                        self.errors.append(error_message)
                        return error_case
165 166 167
            parts = parts[1:]
        obj = module
        for part in parts:
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
            try:
                parent, obj = obj, getattr(obj, part)
            except AttributeError as e:
                # We can't traverse some part of the name.
                if (getattr(obj, '__path__', None) is not None
                    and error_case is not None):
                    # This is a package (no __path__ per importlib docs), and we
                    # encountered an error importing something. We cannot tell
                    # the difference between package.WrongNameTestClass and
                    # package.wrong_module_name so we just report the
                    # ImportError - it is more informative.
                    self.errors.append(error_message)
                    return error_case
                else:
                    # Otherwise, we signal that an AttributeError has occurred.
                    error_case, error_message = _make_failed_test(
184
                        part, e, self.suiteClass,
185 186 187 188
                        'Failed to access attribute:\n%s' % (
                            traceback.format_exc(),))
                    self.errors.append(error_message)
                    return error_case
189 190 191 192 193 194 195 196

        if isinstance(obj, types.ModuleType):
            return self.loadTestsFromModule(obj)
        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
            return self.loadTestsFromTestCase(obj)
        elif (isinstance(obj, types.FunctionType) and
              isinstance(parent, type) and
              issubclass(parent, case.TestCase)):
197
            name = parts[-1]
198 199 200
            inst = parent(name)
            # static methods follow a different path
            if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson's avatar
Benjamin Peterson committed
201
                return self.suiteClass([inst])
202 203
        elif isinstance(obj, suite.TestSuite):
            return obj
204
        if callable(obj):
205 206 207 208
            test = obj()
            if isinstance(test, suite.TestSuite):
                return test
            elif isinstance(test, case.TestCase):
Benjamin Peterson's avatar
Benjamin Peterson committed
209
                return self.suiteClass([test])
210 211 212 213 214 215 216
            else:
                raise TypeError("calling %s returned %s, not a test" %
                                (obj, test))
        else:
            raise TypeError("don't know how to make test from: %s" % obj)

    def loadTestsFromNames(self, names, module=None):
217
        """Return a suite of all test cases found using the given sequence
218 219 220 221 222 223 224 225
        of string specifiers. See 'loadTestsFromName()'.
        """
        suites = [self.loadTestsFromName(name, module) for name in names]
        return self.suiteClass(suites)

    def getTestCaseNames(self, testCaseClass):
        """Return a sorted sequence of method names found within testCaseClass
        """
226
        def shouldIncludeMethod(attrname):
227 228
            if not attrname.startswith(self.testMethodPrefix):
                return False
229
            testFunc = getattr(testCaseClass, attrname)
230
            if not callable(testFunc):
231 232 233 234 235
                return False
            fullName = '%s.%s' % (testCaseClass.__module__, testFunc.__qualname__)
            return self.testNamePatterns is None or \
                any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
        testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
236
        if self.sortTestMethodsUsing:
237
            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
238 239 240 241
        return testFnNames

    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
        """Find and return all test modules from the specified start
242 243 244
        directory, recursing into subdirectories to find them and return all
        tests found within them. Only test files that match the pattern will
        be loaded. (Using shell style pattern matching.)
245 246 247 248 249 250 251

        All test modules must be importable from the top level of the project.
        If the start directory is not the top level directory then the top
        level directory must be specified separately.

        If a test package name (directory with '__init__.py') matches the
        pattern then the package will be checked for a 'load_tests' function. If
252 253 254 255 256
        this exists then it will be called with (loader, tests, pattern) unless
        the package has already had load_tests called from the same discovery
        invocation, in which case the package module object is not scanned for
        tests - this ensures that when a package uses discover to further
        discover child tests that infinite recursion does not happen.
257

258
        If load_tests exists then discovery does *not* recurse into the package,
259 260 261 262 263
        load_tests is responsible for loading all tests in the package.

        The pattern is deliberately not stored as a loader attribute so that
        packages can continue discovery themselves. top_level_dir is stored so
        load_tests does not need to pass this argument in to loader.discover().
264 265 266

        Paths are sorted before being imported to ensure reproducible execution
        order even on filesystems with non-alphabetical ordering like ext3/4.
267
        """
Benjamin Peterson's avatar
Benjamin Peterson committed
268
        set_implicit_top = False
269 270 271 272
        if top_level_dir is None and self._top_level_dir is not None:
            # make top_level_dir optional if called from load_tests in a package
            top_level_dir = self._top_level_dir
        elif top_level_dir is None:
Benjamin Peterson's avatar
Benjamin Peterson committed
273
            set_implicit_top = True
274 275
            top_level_dir = start_dir

Benjamin Peterson's avatar
Benjamin Peterson committed
276
        top_level_dir = os.path.abspath(top_level_dir)
277 278 279

        if not top_level_dir in sys.path:
            # all test modules must be importable from the top level directory
280 281 282 283
            # should we *unconditionally* put the start directory in first
            # in sys.path to minimise likelihood of conflicts between installed
            # modules and development versions?
            sys.path.insert(0, top_level_dir)
284 285
        self._top_level_dir = top_level_dir

Benjamin Peterson's avatar
Benjamin Peterson committed
286
        is_not_importable = False
287 288
        is_namespace = False
        tests = []
Benjamin Peterson's avatar
Benjamin Peterson committed
289 290 291 292 293 294 295 296 297 298 299 300 301
        if os.path.isdir(os.path.abspath(start_dir)):
            start_dir = os.path.abspath(start_dir)
            if start_dir != top_level_dir:
                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
        else:
            # support for discovery from dotted module names
            try:
                __import__(start_dir)
            except ImportError:
                is_not_importable = True
            else:
                the_module = sys.modules[start_dir]
                top_part = start_dir.split('.')[0]
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                try:
                    start_dir = os.path.abspath(
                       os.path.dirname((the_module.__file__)))
                except AttributeError:
                    # look for namespace packages
                    try:
                        spec = the_module.__spec__
                    except AttributeError:
                        spec = None

                    if spec and spec.loader is None:
                        if spec.submodule_search_locations is not None:
                            is_namespace = True

                            for path in the_module.__path__:
                                if (not set_implicit_top and
                                    not path.startswith(top_level_dir)):
                                    continue
                                self._top_level_dir = \
                                    (path.split(the_module.__name__
                                         .replace(".", os.path.sep))[0])
                                tests.extend(self._find_tests(path,
                                                              pattern,
                                                              namespace=True))
                    elif the_module.__name__ in sys.builtin_module_names:
                        # builtin module
                        raise TypeError('Can not use builtin modules '
                                        'as dotted module names') from None
                    else:
                        raise TypeError(
                            'don\'t know how to discover from {!r}'
                            .format(the_module)) from None

Benjamin Peterson's avatar
Benjamin Peterson committed
335
                if set_implicit_top:
336 337 338 339 340 341
                    if not is_namespace:
                        self._top_level_dir = \
                           self._get_directory_containing_module(top_part)
                        sys.path.remove(top_level_dir)
                    else:
                        sys.path.remove(top_level_dir)
Benjamin Peterson's avatar
Benjamin Peterson committed
342 343

        if is_not_importable:
344 345
            raise ImportError('Start directory is not importable: %r' % start_dir)

346 347
        if not is_namespace:
            tests = list(self._find_tests(start_dir, pattern))
348 349
        return self.suiteClass(tests)

Benjamin Peterson's avatar
Benjamin Peterson committed
350 351 352 353 354 355 356 357 358 359 360 361
    def _get_directory_containing_module(self, module_name):
        module = sys.modules[module_name]
        full_path = os.path.abspath(module.__file__)

        if os.path.basename(full_path).lower().startswith('__init__.py'):
            return os.path.dirname(os.path.dirname(full_path))
        else:
            # here we have been given a module rather than a package - so
            # all we can do is search the *same* directory the module is in
            # should an exception be raised instead
            return os.path.dirname(full_path)

Benjamin Peterson's avatar
Benjamin Peterson committed
362
    def _get_name_from_path(self, path):
363 364
        if path == self._top_level_dir:
            return '.'
365
        path = _jython_aware_splitext(os.path.normpath(path))
366

Benjamin Peterson's avatar
Benjamin Peterson committed
367 368 369 370 371 372
        _relpath = os.path.relpath(path, self._top_level_dir)
        assert not os.path.isabs(_relpath), "Path must be within the project"
        assert not _relpath.startswith('..'), "Path must be within the project"

        name = _relpath.replace(os.path.sep, '.')
        return name
373

Benjamin Peterson's avatar
Benjamin Peterson committed
374
    def _get_module_from_name(self, name):
375 376 377
        __import__(name)
        return sys.modules[name]

378 379 380 381
    def _match_path(self, path, full_path, pattern):
        # override this method to use alternative matching strategy
        return fnmatch(path, pattern)

382
    def _find_tests(self, start_dir, pattern, namespace=False):
383
        """Used by discovery. Yields test suites it loads."""
384 385 386 387 388 389 390 391 392 393 394 395
        # Handle the __init__ in this package
        name = self._get_name_from_path(start_dir)
        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
        # definition not a package).
        if name != '.' and name not in self._loading_packages:
            # name is in self._loading_packages while we have called into
            # loadTestsFromModule with name.
            tests, should_recurse = self._find_test_path(
                start_dir, pattern, namespace)
            if tests is not None:
                yield tests
            if not should_recurse:
396
                # Either an error occurred, or load_tests was used by the
397 398 399
                # package.
                return
        # Handle the contents.
400
        paths = sorted(os.listdir(start_dir))
401 402
        for path in paths:
            full_path = os.path.join(start_dir, path)
403 404 405 406 407 408
            tests, should_recurse = self._find_test_path(
                full_path, pattern, namespace)
            if tests is not None:
                yield tests
            if should_recurse:
                # we found a package that didn't use load_tests.
409
                name = self._get_name_from_path(full_path)
410
                self._loading_packages.add(name)
411
                try:
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
                    yield from self._find_tests(full_path, pattern, namespace)
                finally:
                    self._loading_packages.discard(name)

    def _find_test_path(self, full_path, pattern, namespace=False):
        """Used by discovery.

        Loads tests from a single file, or a directories' __init__.py when
        passed the directory.

        Returns a tuple (None_or_tests_from_file, should_recurse).
        """
        basename = os.path.basename(full_path)
        if os.path.isfile(full_path):
            if not VALID_MODULE_NAME.match(basename):
                # valid Python identifiers only
                return None, False
            if not self._match_path(basename, full_path, pattern):
                return None, False
            # if the test file matches, load it
            name = self._get_name_from_path(full_path)
            try:
                module = self._get_module_from_name(name)
            except case.SkipTest as e:
                return _make_skipped_test(name, e, self.suiteClass), False
            except:
                error_case, error_message = \
                    _make_failed_import_test(name, self.suiteClass)
                self.errors.append(error_message)
                return error_case, False
            else:
                mod_file = os.path.abspath(
                    getattr(module, '__file__', full_path))
                realpath = _jython_aware_splitext(
                    os.path.realpath(mod_file))
                fullpath_noext = _jython_aware_splitext(
                    os.path.realpath(full_path))
                if realpath.lower() != fullpath_noext.lower():
                    module_dir = os.path.dirname(realpath)
                    mod_name = _jython_aware_splitext(
                        os.path.basename(full_path))
                    expected_dir = os.path.dirname(full_path)
                    msg = ("%r module incorrectly imported from %r. Expected "
                           "%r. Is this module globally installed?")
                    raise ImportError(
                        msg % (mod_name, module_dir, expected_dir))
                return self.loadTestsFromModule(module, pattern=pattern), False
        elif os.path.isdir(full_path):
            if (not namespace and
                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
                return None, False

            load_tests = None
            tests = None
            name = self._get_name_from_path(full_path)
            try:
                package = self._get_module_from_name(name)
            except case.SkipTest as e:
                return _make_skipped_test(name, e, self.suiteClass), False
            except:
                error_case, error_message = \
                    _make_failed_import_test(name, self.suiteClass)
                self.errors.append(error_message)
                return error_case, False
            else:
                load_tests = getattr(package, 'load_tests', None)
                # Mark this package as being in load_tests (possibly ;))
                self._loading_packages.add(name)
480 481 482
                try:
                    tests = self.loadTestsFromModule(package, pattern=pattern)
                    if load_tests is not None:
483 484 485 486 487
                        # loadTestsFromModule(package) has loaded tests for us.
                        return tests, False
                    return tests, True
                finally:
                    self._loading_packages.discard(name)
488 489
        else:
            return None, False
490

491 492 493 494

defaultTestLoader = TestLoader()


495
def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
496 497 498
    loader = TestLoader()
    loader.sortTestMethodsUsing = sortUsing
    loader.testMethodPrefix = prefix
499
    loader.testNamePatterns = testNamePatterns
500 501 502 503
    if suiteClass:
        loader.suiteClass = suiteClass
    return loader

504 505
def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
    return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
506 507 508 509 510 511 512 513 514 515

def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
              suiteClass=suite.TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
        testCaseClass)

def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
                  suiteClass=suite.TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
        module)