Kaydet (Commit) 3343fe9b authored tarafından Anthony Sottile's avatar Anthony Sottile Kaydeden (comit) GitHub

Merge pull request #27 from asottile/commmas_for_all

Add trailing commas on function defs if --py36-plus is passed
......@@ -121,9 +121,19 @@ Note that this would cause a **`SyntaxError`** in earlier python versions.
):
```
Note that functions with starargs (`*args`), kwargs (`**kwargs`), or python 3
keyword-only arguments (`..., *, ...`) cannot have a trailing comma due to it
being a syntax error.
### trailing commas for function definitions with unpacking arguments
If `--py36-plus` is passed, `add-trailing-comma` will also perform the
following change:
```diff
def f(
- *args
+ *args,
):
```
Note that this would cause a **`SyntaxError`** in earlier python versions.
### unhug trailing paren
......
......@@ -16,7 +16,7 @@ from tokenize_rt import UNIMPORTANT_WS
Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset'))
Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets'))
Func = collections.namedtuple('Func', ('node', 'arg_offsets'))
Func = collections.namedtuple('Func', ('node', 'star_args', 'arg_offsets'))
Literal = collections.namedtuple('Literal', ('node', 'backtrack'))
Literal.__new__.__defaults__ = (False,)
Fix = collections.namedtuple('Fix', ('braces', 'multi_arg', 'initial_indent'))
......@@ -119,17 +119,27 @@ class FindNodes(ast.NodeVisitor):
self.generic_visit(node)
def visit_FunctionDef(self, node):
has_starargs = (
node.args.vararg or node.args.kwarg or
# python 3 only
getattr(node.args, 'kwonlyargs', None)
)
arg_offsets = {_to_offset(arg) for arg in node.args.args}
if arg_offsets and not has_starargs:
has_starargs = False
args = list(node.args.args)
if node.args.vararg:
if isinstance(node.args.vararg, ast.AST): # pragma: no cover (py3)
args.append(node.args.vararg)
has_starargs = True
if node.args.kwarg:
if isinstance(node.args.kwarg, ast.AST): # pragma: no cover (py3)
args.append(node.args.kwarg)
has_starargs = True
py3_kwonlyargs = getattr(node.args, 'kwonlyargs', None)
if py3_kwonlyargs: # pragma: no cover (py3)
args.extend(py3_kwonlyargs)
has_starargs = True
arg_offsets = {_to_offset(arg) for arg in args}
if arg_offsets:
key = Offset(node.lineno, node.col_offset)
self.funcs[key] = Func(node, arg_offsets)
self.funcs[key] = Func(node, has_starargs, arg_offsets)
self.generic_visit(node)
......@@ -304,7 +314,7 @@ def _changing_list(lst):
i += 1
def _fix_src(contents_text, py35_plus):
def _fix_src(contents_text, py35_plus, py36_plus):
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
......@@ -324,8 +334,10 @@ def _fix_src(contents_text, py35_plus):
add_comma = not call.star_args or py35_plus
fixes.append((add_comma, _find_call(call, i, tokens)))
elif key in visitor.funcs:
func = visitor.funcs[key]
add_comma = not func.star_args or py36_plus
# functions can be treated as calls
fixes.append((True, _find_call(visitor.funcs[key], i, tokens)))
fixes.append((add_comma, _find_call(func, i, tokens)))
# Handle parenthesized things
elif token.src == '(':
fixes.append((False, _find_simple(i, tokens)))
......@@ -355,7 +367,7 @@ def fix_file(filename, args):
print('{} is non-utf-8 (not supported)'.format(filename))
return 1
contents_text = _fix_src(contents_text, args.py35_plus)
contents_text = _fix_src(contents_text, args.py35_plus, args.py36_plus)
if contents_text != contents_text_orig:
print('Rewriting {}'.format(filename))
......@@ -366,10 +378,25 @@ def fix_file(filename, args):
return 0
class StoreTrueImplies(argparse.Action):
def __init__(self, option_strings, dest, implies, **kwargs):
self.implies = implies
kwargs.update(const=True, default=False, nargs=0)
super(StoreTrueImplies, self).__init__(option_strings, dest, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
assert hasattr(namespace, self.implies), self.implies
setattr(namespace, self.dest, self.const)
setattr(namespace, self.implies, self.const)
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--py35-plus', action='store_true')
parser.add_argument(
'--py36-plus', action=StoreTrueImplies, implies='py35_plus',
)
args = parser.parse_args(argv)
ret = 0
......
......@@ -50,7 +50,7 @@ xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+')
),
)
def test_fix_calls_noops(src):
ret = _fix_src(src, py35_plus=False)
ret = _fix_src(src, py35_plus=False, py36_plus=False)
assert ret == src
......@@ -67,7 +67,7 @@ def test_ignores_invalid_ast_node():
' """\n'
')'
)
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
def test_py35_plus_rewrite():
......@@ -76,7 +76,7 @@ def test_py35_plus_rewrite():
' *args\n'
')'
)
ret = _fix_src(src, py35_plus=True)
ret = _fix_src(src, py35_plus=True, py36_plus=False)
assert ret == (
'x(\n'
' *args,\n'
......@@ -139,7 +139,7 @@ def test_py35_plus_rewrite():
),
)
def test_fixes_calls(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@pytest.mark.parametrize(
......@@ -152,7 +152,7 @@ def test_fixes_calls(src, expected):
),
)
def test_noop_one_line_literals(src):
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
......@@ -199,7 +199,7 @@ def test_noop_one_line_literals(src):
),
)
def test_fixes_literals(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_lt_py35
......@@ -245,7 +245,7 @@ def test_fixes_literals(src, expected):
),
)
def test_fixes_py35_plus_literals(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
def test_noop_tuple_literal_without_braces():
......@@ -255,7 +255,7 @@ def test_noop_tuple_literal_without_braces():
' 2, \\\n'
' 3'
)
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
......@@ -282,7 +282,7 @@ def test_noop_tuple_literal_without_braces():
),
)
def test_noop_function_defs(src):
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
......@@ -300,7 +300,44 @@ def test_noop_function_defs(src):
),
)
def test_fixes_defs(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_py2
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'def f(\n'
' *args\n'
'): pass',
'def f(\n'
' *args,\n'
'): pass',
),
(
'def f(\n'
' **kwargs\n'
'): pass',
'def f(\n'
' **kwargs,\n'
'): pass',
),
(
'def f(\n'
' *, kw=1\n'
'): pass',
'def f(\n'
' *, kw=1,\n'
'): pass',
),
),
)
def test_fixes_defs_py36_plus(src, expected):
assert _fix_src(src, py35_plus=True, py36_plus=True) == expected
@pytest.mark.parametrize(
......@@ -324,7 +361,7 @@ def test_fixes_defs(src, expected):
),
)
def test_noop_unhugs(src):
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
......@@ -527,7 +564,7 @@ def test_noop_unhugs(src):
),
)
def test_fix_unhugs(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@xfailif_py2
......@@ -546,7 +583,7 @@ def test_fix_unhugs(src, expected):
),
)
def test_fix_unhugs_py3_only(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
@pytest.mark.parametrize(
......@@ -569,7 +606,7 @@ def test_fix_unhugs_py3_only(src, expected):
),
)
def test_noop_trailing_brace(src):
assert _fix_src(src, py35_plus=False) == src
assert _fix_src(src, py35_plus=False, py36_plus=False) == src
@pytest.mark.parametrize(
......@@ -611,7 +648,7 @@ def test_noop_trailing_brace(src):
),
)
def test_fix_trailing_brace(src, expected):
assert _fix_src(src, py35_plus=False) == expected
assert _fix_src(src, py35_plus=False, py36_plus=False) == expected
def test_main_trivial():
......@@ -664,3 +701,24 @@ def test_main_py35_plus_argument_star_star_kwargs(tmpdir):
assert f.read() == 'x(\n **args\n)\n'
assert main((f.strpath, '--py35-plus')) == 1
assert f.read() == 'x(\n **args,\n)\n'
def test_main_py36_plus_implies_py35_plus(tmpdir):
f = tmpdir.join('f.py')
f.write('x(\n **kwargs\n)\n')
assert main((f.strpath,)) == 0
assert f.read() == 'x(\n **kwargs\n)\n'
assert main((f.strpath, '--py36-plus')) == 1
assert f.read() == 'x(\n **kwargs,\n)\n'
@xfailif_py2
def test_main_py36_plus_function_trailing_commas(
tmpdir,
): # pragma: no cover (py3+)
f = tmpdir.join('f.py')
f.write('def f(\n **kwargs\n): pass\n')
assert main((f.strpath,)) == 0
assert f.read() == 'def f(\n **kwargs\n): pass\n'
assert main((f.strpath, '--py36-plus')) == 1
assert f.read() == 'def f(\n **kwargs,\n): pass\n'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment