Kaydet (Commit) 3343fe9b authored tarafından Anthony Sottile's avatar Anthony Sottile Kaydeden (comit) GitHub

Merge pull request #27 from asottile/commmas_for_all

Add trailing commas on function defs if --py36-plus is passed
...@@ -121,9 +121,19 @@ Note that this would cause a **`SyntaxError`** in earlier python versions. ...@@ -121,9 +121,19 @@ Note that this would cause a **`SyntaxError`** in earlier python versions.
): ):
``` ```
Note that functions with starargs (`*args`), kwargs (`**kwargs`), or python 3 ### trailing commas for function definitions with unpacking arguments
keyword-only arguments (`..., *, ...`) cannot have a trailing comma due to it
being a syntax error. If `--py36-plus` is passed, `add-trailing-comma` will also perform the
following change:
```diff
def f(
- *args
+ *args,
):
```
Note that this would cause a **`SyntaxError`** in earlier python versions.
### unhug trailing paren ### unhug trailing paren
......
...@@ -16,7 +16,7 @@ from tokenize_rt import UNIMPORTANT_WS ...@@ -16,7 +16,7 @@ from tokenize_rt import UNIMPORTANT_WS
Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset')) Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset'))
Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets')) Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets'))
Func = collections.namedtuple('Func', ('node', 'arg_offsets')) Func = collections.namedtuple('Func', ('node', 'star_args', 'arg_offsets'))
Literal = collections.namedtuple('Literal', ('node', 'backtrack')) Literal = collections.namedtuple('Literal', ('node', 'backtrack'))
Literal.__new__.__defaults__ = (False,) Literal.__new__.__defaults__ = (False,)
Fix = collections.namedtuple('Fix', ('braces', 'multi_arg', 'initial_indent')) Fix = collections.namedtuple('Fix', ('braces', 'multi_arg', 'initial_indent'))
...@@ -119,17 +119,27 @@ class FindNodes(ast.NodeVisitor): ...@@ -119,17 +119,27 @@ class FindNodes(ast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
has_starargs = ( has_starargs = False
node.args.vararg or node.args.kwarg or args = list(node.args.args)
# python 3 only
getattr(node.args, 'kwonlyargs', None) if node.args.vararg:
) if isinstance(node.args.vararg, ast.AST): # pragma: no cover (py3)
args.append(node.args.vararg)
has_starargs = True
if node.args.kwarg:
if isinstance(node.args.kwarg, ast.AST): # pragma: no cover (py3)
args.append(node.args.kwarg)
has_starargs = True
py3_kwonlyargs = getattr(node.args, 'kwonlyargs', None)
if py3_kwonlyargs: # pragma: no cover (py3)
args.extend(py3_kwonlyargs)
has_starargs = True
arg_offsets = {_to_offset(arg) for arg in node.args.args} arg_offsets = {_to_offset(arg) for arg in args}
if arg_offsets and not has_starargs: if arg_offsets:
key = Offset(node.lineno, node.col_offset) key = Offset(node.lineno, node.col_offset)
self.funcs[key] = Func(node, arg_offsets) self.funcs[key] = Func(node, has_starargs, arg_offsets)
self.generic_visit(node) self.generic_visit(node)
...@@ -304,7 +314,7 @@ def _changing_list(lst): ...@@ -304,7 +314,7 @@ def _changing_list(lst):
i += 1 i += 1
def _fix_src(contents_text, py35_plus): def _fix_src(contents_text, py35_plus, py36_plus):
try: try:
ast_obj = ast_parse(contents_text) ast_obj = ast_parse(contents_text)
except SyntaxError: except SyntaxError:
...@@ -324,8 +334,10 @@ def _fix_src(contents_text, py35_plus): ...@@ -324,8 +334,10 @@ def _fix_src(contents_text, py35_plus):
add_comma = not call.star_args or py35_plus add_comma = not call.star_args or py35_plus
fixes.append((add_comma, _find_call(call, i, tokens))) fixes.append((add_comma, _find_call(call, i, tokens)))
elif key in visitor.funcs: elif key in visitor.funcs:
func = visitor.funcs[key]
add_comma = not func.star_args or py36_plus
# functions can be treated as calls # functions can be treated as calls
fixes.append((True, _find_call(visitor.funcs[key], i, tokens))) fixes.append((add_comma, _find_call(func, i, tokens)))
# Handle parenthesized things # Handle parenthesized things
elif token.src == '(': elif token.src == '(':
fixes.append((False, _find_simple(i, tokens))) fixes.append((False, _find_simple(i, tokens)))
...@@ -355,7 +367,7 @@ def fix_file(filename, args): ...@@ -355,7 +367,7 @@ def fix_file(filename, args):
print('{} is non-utf-8 (not supported)'.format(filename)) print('{} is non-utf-8 (not supported)'.format(filename))
return 1 return 1
contents_text = _fix_src(contents_text, args.py35_plus) contents_text = _fix_src(contents_text, args.py35_plus, args.py36_plus)
if contents_text != contents_text_orig: if contents_text != contents_text_orig:
print('Rewriting {}'.format(filename)) print('Rewriting {}'.format(filename))
...@@ -366,10 +378,25 @@ def fix_file(filename, args): ...@@ -366,10 +378,25 @@ def fix_file(filename, args):
return 0 return 0
class StoreTrueImplies(argparse.Action):
def __init__(self, option_strings, dest, implies, **kwargs):
self.implies = implies
kwargs.update(const=True, default=False, nargs=0)
super(StoreTrueImplies, self).__init__(option_strings, dest, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
assert hasattr(namespace, self.implies), self.implies
setattr(namespace, self.dest, self.const)
setattr(namespace, self.implies, self.const)
def main(argv=None): def main(argv=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
parser.add_argument('--py35-plus', action='store_true') parser.add_argument('--py35-plus', action='store_true')
parser.add_argument(
'--py36-plus', action=StoreTrueImplies, implies='py35_plus',
)
args = parser.parse_args(argv) args = parser.parse_args(argv)
ret = 0 ret = 0
......
...@@ -50,7 +50,7 @@ xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+') ...@@ -50,7 +50,7 @@ xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+')
), ),
) )
def test_fix_calls_noops(src): def test_fix_calls_noops(src):
ret = _fix_src(src, py35_plus=False) ret = _fix_src(src, py35_plus=False, py36_plus=False)
assert ret == src assert ret == src
...@@ -67,7 +67,7 @@ def test_ignores_invalid_ast_node(): ...@@ -67,7 +67,7 @@ def test_ignores_invalid_ast_node():
' """\n' ' """\n'
')' ')'
) )
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
def test_py35_plus_rewrite(): def test_py35_plus_rewrite():
...@@ -76,7 +76,7 @@ def test_py35_plus_rewrite(): ...@@ -76,7 +76,7 @@ def test_py35_plus_rewrite():
' *args\n' ' *args\n'
')' ')'
) )
ret = _fix_src(src, py35_plus=True) ret = _fix_src(src, py35_plus=True, py36_plus=False)
assert ret == ( assert ret == (
'x(\n' 'x(\n'
' *args,\n' ' *args,\n'
...@@ -139,7 +139,7 @@ def test_py35_plus_rewrite(): ...@@ -139,7 +139,7 @@ def test_py35_plus_rewrite():
), ),
) )
def test_fixes_calls(src, expected): def test_fixes_calls(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -152,7 +152,7 @@ def test_fixes_calls(src, expected): ...@@ -152,7 +152,7 @@ def test_fixes_calls(src, expected):
), ),
) )
def test_noop_one_line_literals(src): def test_noop_one_line_literals(src):
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -199,7 +199,7 @@ def test_noop_one_line_literals(src): ...@@ -199,7 +199,7 @@ def test_noop_one_line_literals(src):
), ),
) )
def test_fixes_literals(src, expected): def test_fixes_literals(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_lt_py35 @xfailif_lt_py35
...@@ -245,7 +245,7 @@ def test_fixes_literals(src, expected): ...@@ -245,7 +245,7 @@ def test_fixes_literals(src, expected):
), ),
) )
def test_fixes_py35_plus_literals(src, expected): def test_fixes_py35_plus_literals(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
def test_noop_tuple_literal_without_braces(): def test_noop_tuple_literal_without_braces():
...@@ -255,7 +255,7 @@ def test_noop_tuple_literal_without_braces(): ...@@ -255,7 +255,7 @@ def test_noop_tuple_literal_without_braces():
' 2, \\\n' ' 2, \\\n'
' 3' ' 3'
) )
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -282,7 +282,7 @@ def test_noop_tuple_literal_without_braces(): ...@@ -282,7 +282,7 @@ def test_noop_tuple_literal_without_braces():
), ),
) )
def test_noop_function_defs(src): def test_noop_function_defs(src):
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -300,7 +300,44 @@ def test_noop_function_defs(src): ...@@ -300,7 +300,44 @@ def test_noop_function_defs(src):
), ),
) )
def test_fixes_defs(src, expected): def test_fixes_defs(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_py2
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'def f(\n'
' *args\n'
'): pass',
'def f(\n'
' *args,\n'
'): pass',
),
(
'def f(\n'
' **kwargs\n'
'): pass',
'def f(\n'
' **kwargs,\n'
'): pass',
),
(
'def f(\n'
' *, kw=1\n'
'): pass',
'def f(\n'
' *, kw=1,\n'
'): pass',
),
),
)
def test_fixes_defs_py36_plus(src, expected):
assert _fix_src(src, py35_plus=True, py36_plus=True) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -324,7 +361,7 @@ def test_fixes_defs(src, expected): ...@@ -324,7 +361,7 @@ def test_fixes_defs(src, expected):
), ),
) )
def test_noop_unhugs(src): def test_noop_unhugs(src):
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -527,7 +564,7 @@ def test_noop_unhugs(src): ...@@ -527,7 +564,7 @@ def test_noop_unhugs(src):
), ),
) )
def test_fix_unhugs(src, expected): def test_fix_unhugs(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_py2 @xfailif_py2
...@@ -546,7 +583,7 @@ def test_fix_unhugs(src, expected): ...@@ -546,7 +583,7 @@ def test_fix_unhugs(src, expected):
), ),
) )
def test_fix_unhugs_py3_only(src, expected): def test_fix_unhugs_py3_only(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -569,7 +606,7 @@ def test_fix_unhugs_py3_only(src, expected): ...@@ -569,7 +606,7 @@ def test_fix_unhugs_py3_only(src, expected):
), ),
) )
def test_noop_trailing_brace(src): def test_noop_trailing_brace(src):
assert _fix_src(src, py35_plus=False) == src assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -611,7 +648,7 @@ def test_noop_trailing_brace(src): ...@@ -611,7 +648,7 @@ def test_noop_trailing_brace(src):
), ),
) )
def test_fix_trailing_brace(src, expected): def test_fix_trailing_brace(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
def test_main_trivial(): def test_main_trivial():
...@@ -664,3 +701,24 @@ def test_main_py35_plus_argument_star_star_kwargs(tmpdir): ...@@ -664,3 +701,24 @@ def test_main_py35_plus_argument_star_star_kwargs(tmpdir):
assert f.read() == 'x(\n **args\n)\n' assert f.read() == 'x(\n **args\n)\n'
assert main((f.strpath, '--py35-plus')) == 1 assert main((f.strpath, '--py35-plus')) == 1
assert f.read() == 'x(\n **args,\n)\n' assert f.read() == 'x(\n **args,\n)\n'
def test_main_py36_plus_implies_py35_plus(tmpdir):
f = tmpdir.join('f.py')
f.write('x(\n **kwargs\n)\n')
assert main((f.strpath,)) == 0
assert f.read() == 'x(\n **kwargs\n)\n'
assert main((f.strpath, '--py36-plus')) == 1
assert f.read() == 'x(\n **kwargs,\n)\n'
@xfailif_py2
def test_main_py36_plus_function_trailing_commas(
tmpdir,
): # pragma: no cover (py3+)
f = tmpdir.join('f.py')
f.write('def f(\n **kwargs\n): pass\n')
assert main((f.strpath,)) == 0
assert f.read() == 'def f(\n **kwargs\n): pass\n'
assert main((f.strpath, '--py36-plus')) == 1
assert f.read() == 'def f(\n **kwargs,\n): pass\n'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment