fix_import.py 2.88 KB
Newer Older
1 2 3 4 5 6 7 8 9
"""Fixer for import statements.
If spam is being imported from the local directory, this import:
    from spam import eggs
Becomes:
    from .spam import eggs

And this import:
    import spam
Becomes:
10
    from . import spam
11 12 13
"""

# Local imports
14
from .. import fixer_base
15
from os.path import dirname, join, exists, pathsep
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
from ..fixer_util import FromImport, syms, token


def traverse_imports(names):
    """
    Walks over all the names imported in a dotted_as_names node.
    """
    pending = [names]
    while pending:
        node = pending.pop()
        if node.type == token.NAME:
            yield node.value
        elif node.type == syms.dotted_name:
            yield "".join([ch.value for ch in node.children])
        elif node.type == syms.dotted_as_name:
            pending.append(node.children[0])
        elif node.type == syms.dotted_as_names:
            pending.extend(node.children[::-2])
        else:
            raise AssertionError("unkown node type")

37

38
class FixImport(fixer_base.BaseFix):
39 40

    PATTERN = """
41
    import_from< 'from' imp=any 'import' ['('] any [')'] >
42
    |
43
    import_name< 'import' imp=any >
44 45 46 47 48
    """

    def transform(self, node, results):
        imp = results['imp']

49
        if node.type == syms.import_from:
50 51 52 53 54 55
            # Some imps are top-level (eg: 'import ham')
            # some are first level (eg: 'import ham.eggs')
            # some are third level (eg: 'import ham.eggs as spam')
            # Hence, the loop
            while not hasattr(imp, 'value'):
                imp = imp.children[0]
56 57 58 59
            if self.probably_a_local_import(imp.value):
                imp.value = "." + imp.value
                imp.changed()
                return node
60
        else:
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
            have_local = False
            have_absolute = False
            for mod_name in traverse_imports(imp):
                if self.probably_a_local_import(mod_name):
                    have_local = True
                else:
                    have_absolute = True
            if have_absolute:
                if have_local:
                    # We won't handle both sibling and absolute imports in the
                    # same statement at the moment.
                    self.warning(node, "absolute and local imports together")
                return

            new = FromImport('.', [imp])
76
            new.set_prefix(node.get_prefix())
77
            return new
78

79 80 81 82 83 84 85 86 87 88 89
    def probably_a_local_import(self, imp_name):
        imp_name = imp_name.split('.', 1)[0]
        base_path = dirname(self.filename)
        base_path = join(base_path, imp_name)
        # If there is no __init__.py next to the file its not in a package
        # so can't be a relative import.
        if not exists(join(dirname(base_path), '__init__.py')):
            return False
        for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']:
            if exists(base_path + ext):
                return True
90
        return False