loader.py 21.7 KB
Newer Older
1 2 3
"""Loading unittests."""

import os
Benjamin Peterson's avatar
Benjamin Peterson committed
4
import re
5
import sys
Benjamin Peterson's avatar
Benjamin Peterson committed
6
import traceback
7
import types
8
import functools
9
import warnings
10 11 12 13 14

from fnmatch import fnmatch

from . import case, suite, util

15
__unittest = True
16

17
# what about .pyc (etc)
Benjamin Peterson's avatar
Benjamin Peterson committed
18
# we would need to avoid loading the same tests multiple times
19
# from '.py', *and* '.pyc'
Benjamin Peterson's avatar
Benjamin Peterson committed
20 21 22
VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)


23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
class _FailedTest(case.TestCase):
    _testMethodName = None

    def __init__(self, method_name, exception):
        self._exception = exception
        super(_FailedTest, self).__init__(method_name)

    def __getattr__(self, name):
        if name != self._testMethodName:
            return super(_FailedTest, self).__getattr__(name)
        def testFailure():
            raise self._exception
        return testFailure


Benjamin Peterson's avatar
Benjamin Peterson committed
38
def _make_failed_import_test(name, suiteClass):
39 40
    message = 'Failed to import test module: %s\n%s' % (
        name, traceback.format_exc())
41
    return _make_failed_test(name, ImportError(message), suiteClass, message)
Benjamin Peterson's avatar
Benjamin Peterson committed
42

Benjamin Peterson's avatar
Benjamin Peterson committed
43
def _make_failed_load_tests(name, exception, suiteClass):
44 45
    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
    return _make_failed_test(
46
        name, exception, suiteClass, message)
Benjamin Peterson's avatar
Benjamin Peterson committed
47

48
def _make_failed_test(methodname, exception, suiteClass, message):
49
    test = _FailedTest(methodname, exception)
50
    return suiteClass((test,)), message
Benjamin Peterson's avatar
Benjamin Peterson committed
51

52 53 54 55 56 57 58 59
def _make_skipped_test(methodname, exception, suiteClass):
    @case.skip(str(exception))
    def testSkipped(self):
        pass
    attrs = {methodname: testSkipped}
    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
    return suiteClass((TestClass(methodname),))

60 61 62 63 64
def _jython_aware_splitext(path):
    if path.lower().endswith('$py.class'):
        return path[:-9]
    return os.path.splitext(path)[0]

Benjamin Peterson's avatar
Benjamin Peterson committed
65

66 67 68 69 70 71 72 73 74 75
class TestLoader(object):
    """
    This class is responsible for loading tests according to various criteria
    and returning them wrapped in a TestSuite
    """
    testMethodPrefix = 'test'
    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
    suiteClass = suite.TestSuite
    _top_level_dir = None

76 77 78
    def __init__(self):
        super(TestLoader, self).__init__()
        self.errors = []
79 80 81
        # Tracks packages which we have called into via load_tests, to
        # avoid infinite re-entrancy.
        self._loading_packages = set()
82

83 84 85
    def loadTestsFromTestCase(self, testCaseClass):
        """Return a suite of all tests cases contained in testCaseClass"""
        if issubclass(testCaseClass, suite.TestSuite):
86 87 88
            raise TypeError("Test cases should not be derived from "
                            "TestSuite. Maybe you meant to derive from "
                            "TestCase?")
89 90 91 92 93 94
        testCaseNames = self.getTestCaseNames(testCaseClass)
        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
            testCaseNames = ['runTest']
        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
        return loaded_suite

95 96 97
    # XXX After Python 3.5, remove backward compatibility hacks for
    # use_load_tests deprecation via *args and **kws.  See issue 16662.
    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
98
        """Return a suite of all tests cases contained in the given module"""
99 100 101 102
        # This method used to take an undocumented and unofficial
        # use_load_tests argument.  For backward compatibility, we still
        # accept the argument (which can also be the first position) but we
        # ignore it and issue a deprecation warning if it's present.
103
        if len(args) > 0 or 'use_load_tests' in kws:
104 105 106 107
            warnings.warn('use_load_tests is deprecated and ignored',
                          DeprecationWarning)
            kws.pop('use_load_tests', None)
        if len(args) > 1:
108 109 110 111
            # Complain about the number of arguments, but don't forget the
            # required `module` argument.
            complaint = len(args) + 1
            raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
112 113 114 115 116 117 118
        if len(kws) != 0:
            # Since the keyword arguments are unsorted (see PEP 468), just
            # pick the alphabetically sorted first argument to complain about,
            # if multiple were given.  At least the error message will be
            # predictable.
            complaint = sorted(kws)[0]
            raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
119 120 121 122 123 124 125
        tests = []
        for name in dir(module):
            obj = getattr(module, name)
            if isinstance(obj, type) and issubclass(obj, case.TestCase):
                tests.append(self.loadTestsFromTestCase(obj))

        load_tests = getattr(module, 'load_tests', None)
126
        tests = self.suiteClass(tests)
127
        if load_tests is not None:
Benjamin Peterson's avatar
Benjamin Peterson committed
128
            try:
129
                return load_tests(self, tests, pattern)
Benjamin Peterson's avatar
Benjamin Peterson committed
130
            except Exception as e:
131 132 133 134
                error_case, error_message = _make_failed_load_tests(
                    module.__name__, e, self.suiteClass)
                self.errors.append(error_message)
                return error_case
135
        return tests
136 137 138 139 140 141 142 143 144 145 146

    def loadTestsFromName(self, name, module=None):
        """Return a suite of all tests cases given a string specifier.

        The name may resolve either to a module, a test case class, a
        test method within a test case class, or a callable object which
        returns a TestCase or TestSuite instance.

        The method optionally resolves the names relative to a given module.
        """
        parts = name.split('.')
147
        error_case, error_message = None, None
148 149 150 151
        if module is None:
            parts_copy = parts[:]
            while parts_copy:
                try:
152 153
                    module_name = '.'.join(parts_copy)
                    module = __import__(module_name)
154 155
                    break
                except ImportError:
156 157 158 159
                    next_attribute = parts_copy.pop()
                    # Last error so we can give it to the user if needed.
                    error_case, error_message = _make_failed_import_test(
                        next_attribute, self.suiteClass)
160
                    if not parts_copy:
161 162 163
                        # Even the top level import failed: report that error.
                        self.errors.append(error_message)
                        return error_case
164 165 166
            parts = parts[1:]
        obj = module
        for part in parts:
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
            try:
                parent, obj = obj, getattr(obj, part)
            except AttributeError as e:
                # We can't traverse some part of the name.
                if (getattr(obj, '__path__', None) is not None
                    and error_case is not None):
                    # This is a package (no __path__ per importlib docs), and we
                    # encountered an error importing something. We cannot tell
                    # the difference between package.WrongNameTestClass and
                    # package.wrong_module_name so we just report the
                    # ImportError - it is more informative.
                    self.errors.append(error_message)
                    return error_case
                else:
                    # Otherwise, we signal that an AttributeError has occurred.
                    error_case, error_message = _make_failed_test(
183
                        part, e, self.suiteClass,
184 185 186 187
                        'Failed to access attribute:\n%s' % (
                            traceback.format_exc(),))
                    self.errors.append(error_message)
                    return error_case
188 189 190 191 192 193 194 195

        if isinstance(obj, types.ModuleType):
            return self.loadTestsFromModule(obj)
        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
            return self.loadTestsFromTestCase(obj)
        elif (isinstance(obj, types.FunctionType) and
              isinstance(parent, type) and
              issubclass(parent, case.TestCase)):
196
            name = parts[-1]
197 198 199
            inst = parent(name)
            # static methods follow a different path
            if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson's avatar
Benjamin Peterson committed
200
                return self.suiteClass([inst])
201 202
        elif isinstance(obj, suite.TestSuite):
            return obj
203
        if callable(obj):
204 205 206 207
            test = obj()
            if isinstance(test, suite.TestSuite):
                return test
            elif isinstance(test, case.TestCase):
Benjamin Peterson's avatar
Benjamin Peterson committed
208
                return self.suiteClass([test])
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
            else:
                raise TypeError("calling %s returned %s, not a test" %
                                (obj, test))
        else:
            raise TypeError("don't know how to make test from: %s" % obj)

    def loadTestsFromNames(self, names, module=None):
        """Return a suite of all tests cases found using the given sequence
        of string specifiers. See 'loadTestsFromName()'.
        """
        suites = [self.loadTestsFromName(name, module) for name in names]
        return self.suiteClass(suites)

    def getTestCaseNames(self, testCaseClass):
        """Return a sorted sequence of method names found within testCaseClass
        """
        def isTestMethod(attrname, testCaseClass=testCaseClass,
                         prefix=self.testMethodPrefix):
            return attrname.startswith(prefix) and \
228
                callable(getattr(testCaseClass, attrname))
229
        testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
230
        if self.sortTestMethodsUsing:
231
            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
232 233 234 235
        return testFnNames

    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
        """Find and return all test modules from the specified start
236 237 238
        directory, recursing into subdirectories to find them and return all
        tests found within them. Only test files that match the pattern will
        be loaded. (Using shell style pattern matching.)
239 240 241 242 243 244 245

        All test modules must be importable from the top level of the project.
        If the start directory is not the top level directory then the top
        level directory must be specified separately.

        If a test package name (directory with '__init__.py') matches the
        pattern then the package will be checked for a 'load_tests' function. If
246 247 248 249 250
        this exists then it will be called with (loader, tests, pattern) unless
        the package has already had load_tests called from the same discovery
        invocation, in which case the package module object is not scanned for
        tests - this ensures that when a package uses discover to further
        discover child tests that infinite recursion does not happen.
251

252
        If load_tests exists then discovery does *not* recurse into the package,
253 254 255 256 257
        load_tests is responsible for loading all tests in the package.

        The pattern is deliberately not stored as a loader attribute so that
        packages can continue discovery themselves. top_level_dir is stored so
        load_tests does not need to pass this argument in to loader.discover().
258 259 260

        Paths are sorted before being imported to ensure reproducible execution
        order even on filesystems with non-alphabetical ordering like ext3/4.
261
        """
Benjamin Peterson's avatar
Benjamin Peterson committed
262
        set_implicit_top = False
263 264 265 266
        if top_level_dir is None and self._top_level_dir is not None:
            # make top_level_dir optional if called from load_tests in a package
            top_level_dir = self._top_level_dir
        elif top_level_dir is None:
Benjamin Peterson's avatar
Benjamin Peterson committed
267
            set_implicit_top = True
268 269
            top_level_dir = start_dir

Benjamin Peterson's avatar
Benjamin Peterson committed
270
        top_level_dir = os.path.abspath(top_level_dir)
271 272 273

        if not top_level_dir in sys.path:
            # all test modules must be importable from the top level directory
274 275 276 277
            # should we *unconditionally* put the start directory in first
            # in sys.path to minimise likelihood of conflicts between installed
            # modules and development versions?
            sys.path.insert(0, top_level_dir)
278 279
        self._top_level_dir = top_level_dir

Benjamin Peterson's avatar
Benjamin Peterson committed
280
        is_not_importable = False
281 282
        is_namespace = False
        tests = []
Benjamin Peterson's avatar
Benjamin Peterson committed
283 284 285 286 287 288 289 290 291 292 293 294 295
        if os.path.isdir(os.path.abspath(start_dir)):
            start_dir = os.path.abspath(start_dir)
            if start_dir != top_level_dir:
                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
        else:
            # support for discovery from dotted module names
            try:
                __import__(start_dir)
            except ImportError:
                is_not_importable = True
            else:
                the_module = sys.modules[start_dir]
                top_part = start_dir.split('.')[0]
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
                try:
                    start_dir = os.path.abspath(
                       os.path.dirname((the_module.__file__)))
                except AttributeError:
                    # look for namespace packages
                    try:
                        spec = the_module.__spec__
                    except AttributeError:
                        spec = None

                    if spec and spec.loader is None:
                        if spec.submodule_search_locations is not None:
                            is_namespace = True

                            for path in the_module.__path__:
                                if (not set_implicit_top and
                                    not path.startswith(top_level_dir)):
                                    continue
                                self._top_level_dir = \
                                    (path.split(the_module.__name__
                                         .replace(".", os.path.sep))[0])
                                tests.extend(self._find_tests(path,
                                                              pattern,
                                                              namespace=True))
                    elif the_module.__name__ in sys.builtin_module_names:
                        # builtin module
                        raise TypeError('Can not use builtin modules '
                                        'as dotted module names') from None
                    else:
                        raise TypeError(
                            'don\'t know how to discover from {!r}'
                            .format(the_module)) from None

Benjamin Peterson's avatar
Benjamin Peterson committed
329
                if set_implicit_top:
330 331 332 333 334 335
                    if not is_namespace:
                        self._top_level_dir = \
                           self._get_directory_containing_module(top_part)
                        sys.path.remove(top_level_dir)
                    else:
                        sys.path.remove(top_level_dir)
Benjamin Peterson's avatar
Benjamin Peterson committed
336 337

        if is_not_importable:
338 339
            raise ImportError('Start directory is not importable: %r' % start_dir)

340 341
        if not is_namespace:
            tests = list(self._find_tests(start_dir, pattern))
342 343
        return self.suiteClass(tests)

Benjamin Peterson's avatar
Benjamin Peterson committed
344 345 346 347 348 349 350 351 352 353 354 355
    def _get_directory_containing_module(self, module_name):
        module = sys.modules[module_name]
        full_path = os.path.abspath(module.__file__)

        if os.path.basename(full_path).lower().startswith('__init__.py'):
            return os.path.dirname(os.path.dirname(full_path))
        else:
            # here we have been given a module rather than a package - so
            # all we can do is search the *same* directory the module is in
            # should an exception be raised instead
            return os.path.dirname(full_path)

Benjamin Peterson's avatar
Benjamin Peterson committed
356
    def _get_name_from_path(self, path):
357 358
        if path == self._top_level_dir:
            return '.'
359
        path = _jython_aware_splitext(os.path.normpath(path))
360

Benjamin Peterson's avatar
Benjamin Peterson committed
361 362 363 364 365 366
        _relpath = os.path.relpath(path, self._top_level_dir)
        assert not os.path.isabs(_relpath), "Path must be within the project"
        assert not _relpath.startswith('..'), "Path must be within the project"

        name = _relpath.replace(os.path.sep, '.')
        return name
367

Benjamin Peterson's avatar
Benjamin Peterson committed
368
    def _get_module_from_name(self, name):
369 370 371
        __import__(name)
        return sys.modules[name]

372 373 374 375
    def _match_path(self, path, full_path, pattern):
        # override this method to use alternative matching strategy
        return fnmatch(path, pattern)

376
    def _find_tests(self, start_dir, pattern, namespace=False):
377
        """Used by discovery. Yields test suites it loads."""
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
        # Handle the __init__ in this package
        name = self._get_name_from_path(start_dir)
        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
        # definition not a package).
        if name != '.' and name not in self._loading_packages:
            # name is in self._loading_packages while we have called into
            # loadTestsFromModule with name.
            tests, should_recurse = self._find_test_path(
                start_dir, pattern, namespace)
            if tests is not None:
                yield tests
            if not should_recurse:
                # Either an error occured, or load_tests was used by the
                # package.
                return
        # Handle the contents.
394
        paths = sorted(os.listdir(start_dir))
395 396
        for path in paths:
            full_path = os.path.join(start_dir, path)
397 398 399 400 401 402
            tests, should_recurse = self._find_test_path(
                full_path, pattern, namespace)
            if tests is not None:
                yield tests
            if should_recurse:
                # we found a package that didn't use load_tests.
403
                name = self._get_name_from_path(full_path)
404
                self._loading_packages.add(name)
405
                try:
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
                    yield from self._find_tests(full_path, pattern, namespace)
                finally:
                    self._loading_packages.discard(name)

    def _find_test_path(self, full_path, pattern, namespace=False):
        """Used by discovery.

        Loads tests from a single file, or a directories' __init__.py when
        passed the directory.

        Returns a tuple (None_or_tests_from_file, should_recurse).
        """
        basename = os.path.basename(full_path)
        if os.path.isfile(full_path):
            if not VALID_MODULE_NAME.match(basename):
                # valid Python identifiers only
                return None, False
            if not self._match_path(basename, full_path, pattern):
                return None, False
            # if the test file matches, load it
            name = self._get_name_from_path(full_path)
            try:
                module = self._get_module_from_name(name)
            except case.SkipTest as e:
                return _make_skipped_test(name, e, self.suiteClass), False
            except:
                error_case, error_message = \
                    _make_failed_import_test(name, self.suiteClass)
                self.errors.append(error_message)
                return error_case, False
            else:
                mod_file = os.path.abspath(
                    getattr(module, '__file__', full_path))
                realpath = _jython_aware_splitext(
                    os.path.realpath(mod_file))
                fullpath_noext = _jython_aware_splitext(
                    os.path.realpath(full_path))
                if realpath.lower() != fullpath_noext.lower():
                    module_dir = os.path.dirname(realpath)
                    mod_name = _jython_aware_splitext(
                        os.path.basename(full_path))
                    expected_dir = os.path.dirname(full_path)
                    msg = ("%r module incorrectly imported from %r. Expected "
                           "%r. Is this module globally installed?")
                    raise ImportError(
                        msg % (mod_name, module_dir, expected_dir))
                return self.loadTestsFromModule(module, pattern=pattern), False
        elif os.path.isdir(full_path):
            if (not namespace and
                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
                return None, False

            load_tests = None
            tests = None
            name = self._get_name_from_path(full_path)
            try:
                package = self._get_module_from_name(name)
            except case.SkipTest as e:
                return _make_skipped_test(name, e, self.suiteClass), False
            except:
                error_case, error_message = \
                    _make_failed_import_test(name, self.suiteClass)
                self.errors.append(error_message)
                return error_case, False
            else:
                load_tests = getattr(package, 'load_tests', None)
                # Mark this package as being in load_tests (possibly ;))
                self._loading_packages.add(name)
474 475 476
                try:
                    tests = self.loadTestsFromModule(package, pattern=pattern)
                    if load_tests is not None:
477 478 479 480 481
                        # loadTestsFromModule(package) has loaded tests for us.
                        return tests, False
                    return tests, True
                finally:
                    self._loading_packages.discard(name)
482

483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506

defaultTestLoader = TestLoader()


def _makeLoader(prefix, sortUsing, suiteClass=None):
    loader = TestLoader()
    loader.sortTestMethodsUsing = sortUsing
    loader.testMethodPrefix = prefix
    if suiteClass:
        loader.suiteClass = suiteClass
    return loader

def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)

def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
              suiteClass=suite.TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
        testCaseClass)

def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
                  suiteClass=suite.TestSuite):
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
        module)