Kaydet (Commit) 484c9419 authored tarafından Anthony Sottile's avatar Anthony Sottile

Add trailing commas for class definitions

üst bc57cb44
...@@ -17,6 +17,7 @@ from tokenize_rt import UNIMPORTANT_WS ...@@ -17,6 +17,7 @@ from tokenize_rt import UNIMPORTANT_WS
Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset')) Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset'))
Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets')) Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets'))
Func = collections.namedtuple('Func', ('node', 'star_args', 'arg_offsets')) Func = collections.namedtuple('Func', ('node', 'star_args', 'arg_offsets'))
Class = collections.namedtuple('Class', ('node', 'star_args', 'arg_offsets'))
Literal = collections.namedtuple('Literal', ('node', 'backtrack')) Literal = collections.namedtuple('Literal', ('node', 'backtrack'))
Literal.__new__.__defaults__ = (False,) Literal.__new__.__defaults__ = (False,)
Fix = collections.namedtuple('Fix', ('braces', 'multi_arg', 'initial_indent')) Fix = collections.namedtuple('Fix', ('braces', 'multi_arg', 'initial_indent'))
...@@ -65,6 +66,7 @@ class FindNodes(ast.NodeVisitor): ...@@ -65,6 +66,7 @@ class FindNodes(ast.NodeVisitor):
self.literals = {} self.literals = {}
self.tuples = {} self.tuples = {}
self.imports = set() self.imports = set()
self.classes = {}
def _visit_literal(self, node, key='elts'): def _visit_literal(self, node, key='elts'):
if getattr(node, key): if getattr(node, key):
...@@ -148,6 +150,21 @@ class FindNodes(ast.NodeVisitor): ...@@ -148,6 +150,21 @@ class FindNodes(ast.NodeVisitor):
self.imports.add(Offset(node.lineno, node.col_offset)) self.imports.add(Offset(node.lineno, node.col_offset))
self.generic_visit(node) self.generic_visit(node)
def visit_ClassDef(self, node):
# starargs are allowed in py3 class definitions, py35+ allows trailing
# commas. py34 does not, but adding an option for this very obscure
# case seems not worth it.
has_starargs = False
args = list(node.bases)
args.extend(getattr(node, 'keywords', ())) # py3 only
arg_offsets = {_to_offset(arg) for arg in args}
if arg_offsets:
key = Offset(node.lineno, node.col_offset)
self.classes[key] = Class(node, has_starargs, arg_offsets)
self.generic_visit(node)
def _find_simple(first_brace, tokens): def _find_simple(first_brace, tokens):
brace_stack = [first_brace] brace_stack = [first_brace]
...@@ -358,16 +375,19 @@ def _fix_src(contents_text, py35_plus, py36_plus): ...@@ -358,16 +375,19 @@ def _fix_src(contents_text, py35_plus, py36_plus):
add_comma = not func.star_args or py36_plus add_comma = not func.star_args or py36_plus
# functions can be treated as calls # functions can be treated as calls
fixes.append((add_comma, _find_call(func, i, tokens))) fixes.append((add_comma, _find_call(func, i, tokens)))
elif key in visitor.classes:
# classes can be treated as calls
fixes.append((True, _find_call(visitor.classes[key], i, tokens)))
elif key in visitor.literals: elif key in visitor.literals:
fixes.append((True, _find_simple(i, tokens))) fixes.append((True, _find_simple(i, tokens)))
# Handle parenthesized things, unhug of tuples, and comprehensions
elif token.src in START_BRACES:
fixes.append((False, _find_simple(i, tokens)))
elif key in visitor.imports: elif key in visitor.imports:
# some imports do not have parens # some imports do not have parens
fix = _find_import(i, tokens) fix = _find_import(i, tokens)
if fix: if fix:
fixes.append((True, fix)) fixes.append((True, fix))
# Handle parenthesized things, unhug of tuples, and comprehensions
elif token.src in START_BRACES:
fixes.append((False, _find_simple(i, tokens)))
for add_comma, fix_data in fixes: for add_comma, fix_data in fixes:
if fix_data is not None: if fix_data is not None:
......
...@@ -750,6 +750,76 @@ def test_fix_from_import(src, expected): ...@@ -750,6 +750,76 @@ def test_fix_from_import(src, expected):
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@pytest.mark.parametrize(
'src',
(
'class C: pass',
'class C(): pass',
'class C(object): pass',
'class C(\n'
' object,\n'
'): pass',
),
)
def test_fix_classes_noop(src):
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'class C(\n'
' object\n'
'): pass',
'class C(\n'
' object,\n'
'): pass',
),
),
)
def test_fix_classes(src, expected):
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_py2
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'bases = (object,)\n'
'class C(\n'
' *bases\n'
'): pass',
'bases = (object,)\n'
'class C(\n'
' *bases,\n'
'): pass',
),
(
'kws = {"metaclass": type}\n'
'class C(\n'
' **kws\n'
'): pass',
'kws = {"metaclass": type}\n'
'class C(\n'
' **kws,\n'
'): pass',
),
(
'class C(\n'
' metaclass=type\n'
'): pass',
'class C(\n'
' metaclass=type,\n'
'): pass',
),
),
)
def test_fix_classes_py3_only_syntax(src, expected):
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
def test_main_trivial(): def test_main_trivial():
assert main(()) == 0 assert main(()) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment