refactor.py 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2006 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.

"""Refactoring framework.

Used as a main program, this can refactor any number of files and/or
recursively descend down directories.  Imported as a module, this
provides infrastructure to write your own refactoring tool.
"""

__author__ = "Guido van Rossum <guido@python.org>"


# Python imports
15
import io
16 17 18
import os
import sys
import logging
19
import operator
20
import collections
Christian Heimes's avatar
Christian Heimes committed
21
from itertools import chain
22 23

# Local imports
24
from .pgen2 import driver, tokenize, token
Benjamin Peterson's avatar
Benjamin Peterson committed
25
from .fixer_util import find_root
26
from . import pytree, pygram
Benjamin Peterson's avatar
Benjamin Peterson committed
27
from . import btm_matcher as bm
28 29


30 31 32 33
def get_all_fix_names(fixer_pkg, remove_prefix=True):
    """Return a sorted list of all available fix names in the given package."""
    pkg = __import__(fixer_pkg, [], [], ["*"])
    fixer_dir = os.path.dirname(pkg.__file__)
34
    fix_names = []
Benjamin Peterson's avatar
Benjamin Peterson committed
35
    for name in sorted(os.listdir(fixer_dir)):
36
        if name.startswith("fix_") and name.endswith(".py"):
37 38 39
            if remove_prefix:
                name = name[4:]
            fix_names.append(name[:-3])
40 41
    return fix_names

42 43 44 45 46 47

class _EveryNode(Exception):
    pass


def _get_head_types(pat):
Christian Heimes's avatar
Christian Heimes committed
48 49 50 51 52 53 54
    """ Accepts a pytree Pattern Node and returns a set
        of the pattern types which will match first. """

    if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
        # NodePatters must either have no type and no content
        #   or a type and content -- so they don't get any farther
        # Always return leafs
55 56
        if pat.type is None:
            raise _EveryNode
57
        return {pat.type}
Christian Heimes's avatar
Christian Heimes committed
58 59 60

    if isinstance(pat, pytree.NegatedPattern):
        if pat.content:
61 62
            return _get_head_types(pat.content)
        raise _EveryNode # Negated Patterns don't have a type
Christian Heimes's avatar
Christian Heimes committed
63 64 65 66 67 68

    if isinstance(pat, pytree.WildcardPattern):
        # Recurse on each node in content
        r = set()
        for p in pat.content:
            for x in p:
69
                r.update(_get_head_types(x))
Christian Heimes's avatar
Christian Heimes committed
70 71 72 73
        return r

    raise Exception("Oh no! I don't understand pattern %s" %(pat))

74 75

def _get_headnode_dict(fixer_list):
Christian Heimes's avatar
Christian Heimes committed
76 77
    """ Accepts a list of fixers and returns a dictionary
        of head node type --> fixer list.  """
78 79
    head_nodes = collections.defaultdict(list)
    every = []
Christian Heimes's avatar
Christian Heimes committed
80
    for fixer in fixer_list:
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        if fixer.pattern:
            try:
                heads = _get_head_types(fixer.pattern)
            except _EveryNode:
                every.append(fixer)
            else:
                for node_type in heads:
                    head_nodes[node_type].append(fixer)
        else:
            if fixer._accept_type is not None:
                head_nodes[fixer._accept_type].append(fixer)
            else:
                every.append(fixer)
    for node_type in chain(pygram.python_grammar.symbol2number.values(),
                           pygram.python_grammar.tokens):
        head_nodes[node_type].extend(every)
    return dict(head_nodes)

Christian Heimes's avatar
Christian Heimes committed
99

100 101 102 103 104 105 106
def get_fixers_from_package(pkg_name):
    """
    Return the fully qualified names for fixers in the package pkg_name.
    """
    return [pkg_name + "." + fix_name
            for fix_name in get_all_fix_names(pkg_name, False)]

107 108 109
def _identity(obj):
    return obj

110

111
def _detect_future_features(source):
112 113 114 115 116
    have_docstring = False
    gen = tokenize.generate_tokens(io.StringIO(source).readline)
    def advance():
        tok = next(gen)
        return tok[0], tok[1]
117
    ignore = frozenset({token.NEWLINE, tokenize.NL, token.COMMENT})
118
    features = set()
119 120 121 122 123 124 125 126 127
    try:
        while True:
            tp, value = advance()
            if tp in ignore:
                continue
            elif tp == token.STRING:
                if have_docstring:
                    break
                have_docstring = True
128 129
            elif tp == token.NAME and value == "from":
                tp, value = advance()
130
                if tp != token.NAME or value != "__future__":
131 132
                    break
                tp, value = advance()
133
                if tp != token.NAME or value != "import":
134 135 136
                    break
                tp, value = advance()
                if tp == token.OP and value == "(":
137
                    tp, value = advance()
138
                while tp == token.NAME:
139
                    features.add(value)
140
                    tp, value = advance()
141
                    if tp != token.OP or value != ",":
142 143 144 145 146 147
                        break
                    tp, value = advance()
            else:
                break
    except StopIteration:
        pass
148
    return frozenset(features)
149 150


151 152 153 154
class FixerError(Exception):
    """A fixer could not be loaded."""


155 156
class RefactoringTool(object):

157 158
    _default_options = {"print_function" : False,
                        "write_unchanged_files" : False}
159

160 161 162 163
    CLASS_PREFIX = "Fix" # The prefix for fixer classes
    FILE_PREFIX = "fix_" # The prefix for modules with a fixer within

    def __init__(self, fixer_names, options=None, explicit=None):
164 165
        """Initializer.

166
        Args:
167
            fixer_names: a list of fixers to import
168
            options: a dict with configuration.
169
            explicit: a list of fixers to run even if they are explicit.
170
        """
171
        self.fixers = fixer_names
172
        self.explicit = explicit or []
173 174 175
        self.options = self._default_options.copy()
        if options is not None:
            self.options.update(options)
176 177 178 179
        if self.options["print_function"]:
            self.grammar = pygram.python_grammar_no_print_statement
        else:
            self.grammar = pygram.python_grammar
180 181 182 183
        # When this is True, the refactor*() methods will call write_file() for
        # files processed even if they were not changed during refactoring. If
        # and only if the refactor method's write parameter was True.
        self.write_unchanged_files = self.options.get("write_unchanged_files")
184 185 186
        self.errors = []
        self.logger = logging.getLogger("RefactoringTool")
        self.fixer_log = []
187
        self.wrote = False
188
        self.driver = driver.Driver(self.grammar,
189 190 191
                                    convert=pytree.convert,
                                    logger=self.logger)
        self.pre_order, self.post_order = self.get_fixers()
Christian Heimes's avatar
Christian Heimes committed
192 193


194 195
        self.files = []  # List of files that were or should be modified

Benjamin Peterson's avatar
Benjamin Peterson committed
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        self.BM = bm.BottomMatcher()
        self.bmi_pre_order = [] # Bottom Matcher incompatible fixers
        self.bmi_post_order = []

        for fixer in chain(self.post_order, self.pre_order):
            if fixer.BM_compatible:
                self.BM.add_fixer(fixer)
                # remove fixers that will be handled by the bottom-up
                # matcher
            elif fixer in self.pre_order:
                self.bmi_pre_order.append(fixer)
            elif fixer in self.post_order:
                self.bmi_post_order.append(fixer)

        self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order)
        self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order)



215 216
    def get_fixers(self):
        """Inspects the options to load the requested patterns and handlers.
217

218 219 220 221 222 223 224
        Returns:
          (pre_order, post_order), where pre_order is the list of fixers that
          want a pre-order AST traversal, and post_order is the list that want
          post-order traversal.
        """
        pre_order_fixers = []
        post_order_fixers = []
225
        for fix_mod_path in self.fixers:
226
            mod = __import__(fix_mod_path, {}, {}, ["*"])
227
            fix_name = fix_mod_path.rsplit(".", 1)[-1]
228 229
            if fix_name.startswith(self.FILE_PREFIX):
                fix_name = fix_name[len(self.FILE_PREFIX):]
230
            parts = fix_name.split("_")
231
            class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
232 233 234
            try:
                fix_class = getattr(mod, class_name)
            except AttributeError:
235
                raise FixerError("Can't find %s.%s" % (fix_name, class_name)) from None
236
            fixer = fix_class(self.options, self.fixer_log)
237 238
            if fixer.explicit and self.explicit is not True and \
                    fix_mod_path not in self.explicit:
239
                self.log_message("Skipping optional fixer: %s", fix_name)
240 241
                continue

242
            self.log_debug("Adding transformation: %s", fix_name)
243 244 245 246 247
            if fixer.order == "pre":
                pre_order_fixers.append(fixer)
            elif fixer.order == "post":
                post_order_fixers.append(fixer)
            else:
248
                raise FixerError("Illegal fixer order: %r" % fixer.order)
249

250 251 252
        key_func = operator.attrgetter("run_order")
        pre_order_fixers.sort(key=key_func)
        post_order_fixers.sort(key=key_func)
253 254 255
        return (pre_order_fixers, post_order_fixers)

    def log_error(self, msg, *args, **kwds):
256 257
        """Called when an error occurs."""
        raise
258 259 260 261 262 263 264

    def log_message(self, msg, *args):
        """Hook to log a message."""
        if args:
            msg = msg % args
        self.logger.info(msg)

265 266 267 268 269
    def log_debug(self, msg, *args):
        if args:
            msg = msg % args
        self.logger.debug(msg)

270 271 272
    def print_output(self, old_text, new_text, filename, equal):
        """Called with the old version, new version, and filename of a
        refactored file."""
273 274
        pass

275 276
    def refactor(self, items, write=False, doctests_only=False):
        """Refactor a list of files and directories."""
Benjamin Peterson's avatar
Benjamin Peterson committed
277

278 279
        for dir_or_file in items:
            if os.path.isdir(dir_or_file):
280
                self.refactor_dir(dir_or_file, write, doctests_only)
281
            else:
282
                self.refactor_file(dir_or_file, write, doctests_only)
283

284
    def refactor_dir(self, dir_name, write=False, doctests_only=False):
285 286 287 288 289 290
        """Descends down a directory and refactor every Python file found.

        Python files are assumed to have a .py extension.

        Files and subdirectories starting with '.' are skipped.
        """
291
        py_ext = os.extsep + "py"
292 293
        for dirpath, dirnames, filenames in os.walk(dir_name):
            self.log_debug("Descending into %s", dirpath)
294 295 296
            dirnames.sort()
            filenames.sort()
            for name in filenames:
297 298
                if (not name.startswith(".") and
                    os.path.splitext(name)[1] == py_ext):
299
                    fullname = os.path.join(dirpath, name)
300
                    self.refactor_file(fullname, write, doctests_only)
301 302 303
            # Modify dirnames in-place to remove subdirs with leading dots
            dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]

304 305 306 307
    def _read_python_source(self, filename):
        """
        Do our best to decode a Python source file correctly.
        """
308
        try:
309
            f = open(filename, "rb")
310
        except OSError as err:
311
            self.log_error("Can't open %s: %s", filename, err)
312
            return None, None
313
        try:
314
            encoding = tokenize.detect_encoding(f.readline)[0]
315 316
        finally:
            f.close()
317
        with io.open(filename, "r", encoding=encoding, newline='') as f:
318
            return f.read(), encoding
319 320 321 322 323 324 325 326

    def refactor_file(self, filename, write=False, doctests_only=False):
        """Refactors a file."""
        input, encoding = self._read_python_source(filename)
        if input is None:
            # Reading the file failed.
            return
        input += "\n" # Silence certain parse errors
327 328
        if doctests_only:
            self.log_debug("Refactoring doctests in %s", filename)
329
            output = self.refactor_docstring(input, filename)
330
            if self.write_unchanged_files or output != input:
331
                self.processed_file(output, filename, input, write, encoding)
332 333
            else:
                self.log_debug("No doctest changes in %s", filename)
334 335
        else:
            tree = self.refactor_string(input, filename)
336
            if self.write_unchanged_files or (tree and tree.was_changed):
337
                # The [:-1] is to take off the \n we added earlier
338 339
                self.processed_file(str(tree)[:-1], filename,
                                    write=write, encoding=encoding)
340 341
            else:
                self.log_debug("No changes in %s", filename)
342 343 344

    def refactor_string(self, data, name):
        """Refactor a given input string.
345

346 347 348
        Args:
            data: a string holding the code to be refactored.
            name: a human-readable name for use in error/log messages.
349

350 351 352 353
        Returns:
            An AST corresponding to the refactored input stream; None if
            there were errors during the parse.
        """
354 355
        features = _detect_future_features(data)
        if "print_function" in features:
356
            self.driver.grammar = pygram.python_grammar_no_print_statement
357
        try:
Benjamin Peterson's avatar
Benjamin Peterson committed
358
            tree = self.driver.parse_string(data)
359
        except Exception as err:
360 361 362
            self.log_error("Can't parse %s: %s: %s",
                           name, err.__class__.__name__, err)
            return
363
        finally:
364
            self.driver.grammar = self.grammar
365
        tree.future_features = features
366
        self.log_debug("Refactoring %s", name)
367 368 369
        self.refactor_tree(tree, name)
        return tree

370
    def refactor_stdin(self, doctests_only=False):
371
        input = sys.stdin.read()
372 373
        if doctests_only:
            self.log_debug("Refactoring doctests in stdin")
374
            output = self.refactor_docstring(input, "<stdin>")
375
            if self.write_unchanged_files or output != input:
376 377 378
                self.processed_file(output, "<stdin>", input)
            else:
                self.log_debug("No doctest changes in stdin")
379 380
        else:
            tree = self.refactor_string(input, "<stdin>")
381
            if self.write_unchanged_files or (tree and tree.was_changed):
382 383 384
                self.processed_file(str(tree), "<stdin>", input)
            else:
                self.log_debug("No changes in stdin")
385 386 387

    def refactor_tree(self, tree, name):
        """Refactors a parse tree (modifying the tree in place).
388

Benjamin Peterson's avatar
Benjamin Peterson committed
389 390 391 392
        For compatible patterns the bottom matcher module is
        used. Otherwise the tree is traversed node-to-node for
        matches.

393 394 395 396
        Args:
            tree: a pytree.Node instance representing the root of the tree
                  to be refactored.
            name: a human-readable name for this tree.
397

398 399 400
        Returns:
            True if the tree was modified, False otherwise.
        """
Benjamin Peterson's avatar
Benjamin Peterson committed
401

402
        for fixer in chain(self.pre_order, self.post_order):
403 404
            fixer.start_tree(tree, name)

Benjamin Peterson's avatar
Benjamin Peterson committed
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
        #use traditional matching for the incompatible fixers
        self.traverse_by(self.bmi_pre_order_heads, tree.pre_order())
        self.traverse_by(self.bmi_post_order_heads, tree.post_order())

        # obtain a set of candidate nodes
        match_set = self.BM.run(tree.leaves())

        while any(match_set.values()):
            for fixer in self.BM.fixers:
                if fixer in match_set and match_set[fixer]:
                    #sort by depth; apply fixers from bottom(of the AST) to top
                    match_set[fixer].sort(key=pytree.Base.depth, reverse=True)

                    if fixer.keep_line_order:
                        #some fixers(eg fix_imports) must be applied
                        #with the original file's line order
                        match_set[fixer].sort(key=pytree.Base.get_lineno)

                    for node in list(match_set[fixer]):
                        if node in match_set[fixer]:
                            match_set[fixer].remove(node)

                        try:
                            find_root(node)
429
                        except ValueError:
Benjamin Peterson's avatar
Benjamin Peterson committed
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
                            # this node has been cut off from a
                            # previous transformation ; skip
                            continue

                        if node.fixers_applied and fixer in node.fixers_applied:
                            # do not apply the same fixer again
                            continue

                        results = fixer.match(node)

                        if results:
                            new = fixer.transform(node, results)
                            if new is not None:
                                node.replace(new)
                                #new.fixers_applied.append(fixer)
                                for node in new.post_order():
                                    # do not apply the fixer again to
                                    # this or any subnode
                                    if not node.fixers_applied:
                                        node.fixers_applied = []
                                    node.fixers_applied.append(fixer)

                                # update the original match set for
                                # the added code
                                new_matches = self.BM.run(new.leaves())
                                for fxr in new_matches:
                                    if not fxr in match_set:
                                        match_set[fxr]=[]

                                    match_set[fxr].extend(new_matches[fxr])
460

461
        for fixer in chain(self.pre_order, self.post_order):
462 463 464 465 466
            fixer.finish_tree(tree, name)
        return tree.was_changed

    def traverse_by(self, fixers, traversal):
        """Traverse an AST, applying a set of fixers to each node.
467

468
        This is a helper method for refactor_tree().
469

470 471 472
        Args:
            fixers: a list of fixer instances.
            traversal: a generator that yields AST nodes.
473

474 475 476 477 478 479
        Returns:
            None
        """
        if not fixers:
            return
        for node in traversal:
480
            for fixer in fixers[node.type]:
481 482 483
                results = fixer.match(node)
                if results:
                    new = fixer.transform(node, results)
484
                    if new is not None:
485 486 487
                        node.replace(new)
                        node = new

488 489
    def processed_file(self, new_text, filename, old_text=None, write=False,
                       encoding=None):
490
        """
491
        Called when a file has been refactored and there may be changes.
492 493 494
        """
        self.files.append(filename)
        if old_text is None:
495 496
            old_text = self._read_python_source(filename)[0]
            if old_text is None:
497
                return
498 499 500
        equal = old_text == new_text
        self.print_output(old_text, new_text, filename, equal)
        if equal:
501
            self.log_debug("No changes to %s", filename)
502 503
            if not self.write_unchanged_files:
                return
504
        if write:
505
            self.write_file(new_text, filename, old_text, encoding)
506 507
        else:
            self.log_debug("Not writing changes to %s", filename)
508

509
    def write_file(self, new_text, filename, old_text, encoding=None):
510 511 512 513 514 515
        """Writes a string to a file.

        It first shows a unified diff between the old text and the new text, and
        then rewrites the file; the latter is only done if the write option is
        set.
        """
516
        try:
517
            fp = io.open(filename, "w", encoding=encoding, newline='')
518
        except OSError as err:
519 520
            self.log_error("Can't create %s: %s", filename, err)
            return
521 522 523 524 525 526

        with fp:
            try:
                fp.write(new_text)
            except OSError as err:
                self.log_error("Can't write %s: %s", filename, err)
527 528
        self.log_debug("Wrote changes to %s", filename)
        self.wrote = True
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549

    PS1 = ">>> "
    PS2 = "... "

    def refactor_docstring(self, input, filename):
        """Refactors a docstring, looking for doctests.

        This returns a modified version of the input string.  It looks
        for doctests, which start with a ">>>" prompt, and may be
        continued with "..." prompts, as long as the "..." is indented
        the same as the ">>>".

        (Unfortunately we can't use the doctest module's parser,
        since, like most parsers, it is not geared towards preserving
        the original source.)
        """
        result = []
        block = None
        block_lineno = None
        indent = None
        lineno = 0
550
        for line in input.splitlines(keepends=True):
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
            lineno += 1
            if line.lstrip().startswith(self.PS1):
                if block is not None:
                    result.extend(self.refactor_doctest(block, block_lineno,
                                                        indent, filename))
                block_lineno = lineno
                block = [line]
                i = line.find(self.PS1)
                indent = line[:i]
            elif (indent is not None and
                  (line.startswith(indent + self.PS2) or
                   line == indent + self.PS2.rstrip() + "\n")):
                block.append(line)
            else:
                if block is not None:
                    result.extend(self.refactor_doctest(block, block_lineno,
                                                        indent, filename))
                block = None
                indent = None
                result.append(line)
        if block is not None:
            result.extend(self.refactor_doctest(block, block_lineno,
                                                indent, filename))
        return "".join(result)

    def refactor_doctest(self, block, lineno, indent, filename):
        """Refactors one doctest.

        A doctest is given as a block of lines, the first of which starts
        with ">>>" (possibly indented), while the remaining lines start
        with "..." (identically indented).

        """
        try:
            tree = self.parse_block(block, lineno, indent)
586
        except Exception as err:
Benjamin Peterson's avatar
Benjamin Peterson committed
587
            if self.logger.isEnabledFor(logging.DEBUG):
588
                for line in block:
589
                    self.log_debug("Source: %s", line.rstrip("\n"))
590 591 592 593
            self.log_error("Can't parse docstring in %s line %s: %s: %s",
                           filename, lineno, err.__class__.__name__, err)
            return block
        if self.refactor_tree(tree, filename):
594
            new = str(tree).splitlines(keepends=True)
595 596 597 598 599 600 601 602 603 604 605
            # Undo the adjustment of the line numbers in wrap_toks() below.
            clipped, new = new[:lineno-1], new[lineno-1:]
            assert clipped == ["\n"] * (lineno-1), clipped
            if not new[-1].endswith("\n"):
                new[-1] += "\n"
            block = [indent + self.PS1 + new.pop(0)]
            if new:
                block += [indent + self.PS2 + line for line in new]
        return block

    def summarize(self):
606
        if self.wrote:
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
            were = "were"
        else:
            were = "need to be"
        if not self.files:
            self.log_message("No files %s modified.", were)
        else:
            self.log_message("Files that %s modified:", were)
            for file in self.files:
                self.log_message(file)
        if self.fixer_log:
            self.log_message("Warnings/messages while refactoring:")
            for message in self.fixer_log:
                self.log_message(message)
        if self.errors:
            if len(self.errors) == 1:
                self.log_message("There was 1 error:")
            else:
                self.log_message("There were %d errors:", len(self.errors))
            for msg, args, kwds in self.errors:
                self.log_message(msg, *args, **kwds)

    def parse_block(self, block, lineno, indent):
        """Parses a block into a tree.

        This is necessary to get correct line number / offset information
        in the parser diagnostics and embedded into the parse tree.
        """
634 635 636
        tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
        tree.future_features = frozenset()
        return tree
637 638 639

    def wrap_toks(self, block, lineno, indent):
        """Wraps a tokenize stream to systematically modify start/end."""
640
        tokens = tokenize.generate_tokens(self.gen_lines(block, indent).__next__)
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
        for type, value, (line0, col0), (line1, col1), line_text in tokens:
            line0 += lineno - 1
            line1 += lineno - 1
            # Don't bother updating the columns; this is too complicated
            # since line_text would also have to be updated and it would
            # still break for tokens spanning lines.  Let the user guess
            # that the column numbers for doctests are relative to the
            # end of the prompt string (PS1 or PS2).
            yield type, value, (line0, col0), (line1, col1), line_text


    def gen_lines(self, block, indent):
        """Generates lines as expected by tokenize from a list of lines.

        This strips the first len(indent + self.PS1) characters off each line.
        """
        prefix1 = indent + self.PS1
        prefix2 = indent + self.PS2
        prefix = prefix1
        for line in block:
            if line.startswith(prefix):
                yield line[len(prefix):]
            elif line == prefix.rstrip() + "\n":
                yield "\n"
            else:
                raise AssertionError("line=%r, prefix=%r" % (line, prefix))
            prefix = prefix2
        while True:
            yield ""


672 673 674 675 676 677 678 679 680
class MultiprocessingUnsupported(Exception):
    pass


class MultiprocessRefactoringTool(RefactoringTool):

    def __init__(self, *args, **kwargs):
        super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs)
        self.queue = None
681
        self.output_lock = None
682 683 684 685 686 687 688 689

    def refactor(self, items, write=False, doctests_only=False,
                 num_processes=1):
        if num_processes == 1:
            return super(MultiprocessRefactoringTool, self).refactor(
                items, write, doctests_only)
        try:
            import multiprocessing
690
        except ImportError:
691 692 693 694
            raise MultiprocessingUnsupported
        if self.queue is not None:
            raise RuntimeError("already doing multiple processes")
        self.queue = multiprocessing.JoinableQueue()
695
        self.output_lock = multiprocessing.Lock()
696
        processes = [multiprocessing.Process(target=self._child)
Benjamin Peterson's avatar
Benjamin Peterson committed
697
                     for i in range(num_processes)]
698 699 700 701 702 703 704
        try:
            for p in processes:
                p.start()
            super(MultiprocessRefactoringTool, self).refactor(items, write,
                                                              doctests_only)
        finally:
            self.queue.join()
Benjamin Peterson's avatar
Benjamin Peterson committed
705
            for i in range(num_processes):
706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
                self.queue.put(None)
            for p in processes:
                if p.is_alive():
                    p.join()
            self.queue = None

    def _child(self):
        task = self.queue.get()
        while task is not None:
            args, kwargs = task
            try:
                super(MultiprocessRefactoringTool, self).refactor_file(
                    *args, **kwargs)
            finally:
                self.queue.task_done()
            task = self.queue.get()

    def refactor_file(self, *args, **kwargs):
        if self.queue is not None:
            self.queue.put((args, kwargs))
        else:
            return super(MultiprocessRefactoringTool, self).refactor_file(
                *args, **kwargs)