Kaydet (Commit) 22e8ef9b authored tarafından Anthony Sottile's avatar Anthony Sottile Kaydeden (comit) GitHub

Merge pull request #2 from asottile/literals

Add support for adding trailing commas to literals
......@@ -102,9 +102,6 @@ If `--py35-plus` is passed (or python3.5+ syntax is automatically detected),
Note that this would cause a **`SyntaxError`** in earlier python versions.
## Planned features
### trailing commas for tuple / list / dict / set literals
```diff
......@@ -114,6 +111,9 @@ Note that this would cause a **`SyntaxError`** in earlier python versions.
]
```
## Planned features
### trailing commas for function definitions
```diff
......
......@@ -15,6 +15,8 @@ from tokenize_rt import UNIMPORTANT_WS
Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset'))
Node = collections.namedtuple('Node', ('node', 'star_args', 'arg_offsets'))
Literal = collections.namedtuple('Literal', ('node', 'braces', 'backtrack'))
Literal.__new__.__defaults__ = (False,)
NON_CODING_TOKENS = frozenset(('COMMENT', 'NL', UNIMPORTANT_WS))
......@@ -48,11 +50,45 @@ def _is_star_star_kwarg(node):
return isinstance(node, ast.keyword) and node.arg is None
class FindCalls(ast.NodeVisitor):
class FindNodes(ast.NodeVisitor):
def __init__(self):
self.calls = {}
self.literals = {}
self.has_new_syntax = False
def _visit_literal(self, node, key='elts', is_multiline=False, **kwargs):
orig = node.lineno
for elt in getattr(node, key):
if elt.lineno > orig:
is_multiline = True
if _is_star_arg(elt): # pragma: no cover (PY35+)
self.has_new_syntax = True
if is_multiline:
key = Offset(node.lineno, node.col_offset)
self.literals[key] = Literal(node, **kwargs)
self.generic_visit(node)
def visit_Set(self, node):
self._visit_literal(node, braces=('{', '}'))
def visit_Dict(self, node):
# unpackings are represented as a `None` key
if None in node.keys: # pragma: no cover (PY35+)
self.has_new_syntax = True
self._visit_literal(node, key='values', braces=('{', '}'))
def visit_List(self, node):
self._visit_literal(node, braces=('[', ']'))
def visit_Tuple(self, node):
# tuples lie about things, so we pretend they are all multiline
# and tell the later machinery to backtrack
self._visit_literal(
node, is_multiline=True, braces=('(', ')'), backtrack=True,
)
def visit_Call(self, node):
orig = node.lineno
......@@ -100,6 +136,37 @@ class FindCalls(ast.NodeVisitor):
self.generic_visit(node)
def _fix_inner(brace_start, brace_end, first_paren, tokens):
i = first_paren
brace_stack = [first_paren]
i += 1
for i in range(i, len(tokens)):
token = tokens[i]
if token.src == brace_start:
brace_stack.append(i)
elif token.src == brace_end:
brace_stack.pop()
if not brace_stack:
break
else:
raise AssertionError('Past end?')
# This was not actually a multi-line call, despite the ast telling us that
if tokens[first_paren].line == tokens[i].line:
return
# From there, we can walk backwards and decide whether a comma is needed
i -= 1
while tokens[i].name in NON_CODING_TOKENS:
i -= 1
# If we're not a hugging paren, we can insert a comma
if tokens[i].src != ',' and tokens[i + 1].src != brace_end:
tokens.insert(i + 1, Token('OP', ','))
def _fix_call(call, i, tokens):
# When we get a `call` object, the ast refers to it as this:
#
......@@ -112,12 +179,6 @@ def _fix_call(call, i, tokens):
#
# func_name(arg, arg, arg)
# ^ outer paren
#
# Once that is identified, walk until the paren stack is empty -- this will
# put us at the last paren
#
# func_name(arg, arg, arg)
# ^ paren stack is empty
first_paren = None
paren_stack = []
for i in range(i, len(tokens)):
......@@ -129,33 +190,36 @@ def _fix_call(call, i, tokens):
if (token.line, token.utf8_byte_offset) in call.arg_offsets:
first_paren = paren_stack[0]
if first_paren is not None and not paren_stack:
break
else:
raise AssertionError('Past end?')
# This was not actually a multi-line call, despite the ast telling us that
if tokens[first_paren].line == tokens[i].line:
return
_fix_inner('(', ')', first_paren, tokens)
# From there, we can walk backwards and decide whether a comma is needed
i -= 1
while tokens[i].name in NON_CODING_TOKENS:
def _fix_literal(literal, i, tokens):
brace_start, brace_end = literal.braces
# tuples are evil, we need to backtrack to find the opening paren
if literal.backtrack:
i -= 1
while tokens[i].name in NON_CODING_TOKENS:
i -= 1
# Sometimes tuples don't even have a paren!
# x = 1, 2, 3
if tokens[i].src != brace_start:
return
# If we're not a hugging paren, we can insert a comma
if tokens[i].src != ',' and tokens[i + 1].src != ')':
tokens.insert(i + 1, Token('OP', ','))
_fix_inner(brace_start, brace_end, i, tokens)
def _fix_calls(contents_text, py35_plus):
def _fix_commas(contents_text, py35_plus):
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
return contents_text
visitor = FindCalls()
visitor = FindNodes()
visitor.visit(ast_obj)
tokens = src_to_tokens(contents_text)
......@@ -166,6 +230,8 @@ def _fix_calls(contents_text, py35_plus):
# Only fix stararg calls if asked to
if not call.star_args or py35_plus or visitor.has_new_syntax:
_fix_call(call, i, tokens)
elif key in visitor.literals:
_fix_literal(visitor.literals[key], i, tokens)
return tokens_to_src(tokens)
......@@ -180,7 +246,7 @@ def fix_file(filename, args):
print('{} is non-utf-8 (not supported)'.format(filename))
return 1
contents_text = _fix_calls(contents_text, args.py35_plus)
contents_text = _fix_commas(contents_text, args.py35_plus)
if contents_text != contents_text_orig:
print('Rewriting {}'.format(filename))
......
......@@ -7,10 +7,13 @@ import sys
import pytest
from add_trailing_comma import _fix_calls
from add_trailing_comma import _fix_commas
from add_trailing_comma import main
xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+')
@pytest.mark.parametrize(
'src',
(
......@@ -42,7 +45,7 @@ from add_trailing_comma import main
),
)
def test_fix_calls_noops(src):
ret = _fix_calls(src, py35_plus=False)
ret = _fix_commas(src, py35_plus=False)
assert ret == src
......@@ -59,7 +62,7 @@ def test_ignores_invalid_ast_node():
' """\n'
')'
)
assert _fix_calls(src, py35_plus=False) == src
assert _fix_commas(src, py35_plus=False) == src
def test_py35_plus_rewrite():
......@@ -68,7 +71,7 @@ def test_py35_plus_rewrite():
' *args\n'
')'
)
ret = _fix_calls(src, py35_plus=True)
ret = _fix_commas(src, py35_plus=True)
assert ret == (
'x(\n'
' *args,\n'
......@@ -76,10 +79,14 @@ def test_py35_plus_rewrite():
)
@pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+ only feature')
@xfailif_lt_py35
@pytest.mark.parametrize(
'syntax',
(
'(1, 2, *a)\n',
'[1, 2, *a]\n',
'{1, 2, *a}\n',
'{1: 2, **k}\n',
'y(*args1, *args2)\n',
'y(**kwargs1, **kwargs2)\n',
),
......@@ -87,7 +94,159 @@ def test_py35_plus_rewrite():
def test_auto_detected_py35_plus_rewrite(syntax):
src = syntax + 'x(\n *args\n)'
expected = syntax + 'x(\n *args,\n)'
assert _fix_calls(src, py35_plus=False) == expected
assert _fix_commas(src, py35_plus=False) == expected
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'x(\n'
' 1\n'
')',
'x(\n'
' 1,\n'
')',
),
(
'x(\n'
' kwarg=5\n'
')',
'x(\n'
' kwarg=5,\n'
')',
),
(
'foo()(\n'
' 1\n'
')',
'foo()(\n'
' 1,\n'
')',
),
),
)
def test_fixes_calls(src, expected):
assert _fix_commas(src, py35_plus=False) == expected
@pytest.mark.parametrize(
'src',
(
'(1, 2, 3, 4)',
'[1, 2, 3, 4]',
'{1, 2, 3, 4}',
'{1: 2, 3: 4}',
),
)
def test_noop_one_line_literals(src):
assert _fix_commas(src, py35_plus=False) == src
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'x = [\n'
' 1\n'
']',
'x = [\n'
' 1,\n'
']',
),
(
'x = {\n'
' 1\n'
'}',
'x = {\n'
' 1,\n'
'}',
),
(
'x = {\n'
' 1: 2\n'
'}',
'x = {\n'
' 1: 2,\n'
'}',
),
(
'x = (\n'
' 1,\n'
' 2\n'
')',
'x = (\n'
' 1,\n'
' 2,\n'
')',
),
),
)
def test_fixes_literals(src, expected):
assert _fix_commas(src, py35_plus=False) == expected
@xfailif_lt_py35
@pytest.mark.parametrize(
('src', 'expected'),
(
(
'x = {\n'
' 1, *y\n'
'}',
'x = {\n'
' 1, *y,\n'
'}',
),
(
'x = [\n'
' 1, *y\n'
']',
'x = [\n'
' 1, *y,\n'
']',
),
(
'x = (\n'
' 1, *y\n'
')',
'x = (\n'
' 1, *y,\n'
')',
),
(
'x = {\n'
' 1: 2, **y\n'
'}',
'x = {\n'
' 1: 2, **y,\n'
'}',
),
),
)
def test_fixes_py35_plus_literals(src, expected):
assert _fix_commas(src, py35_plus=False) == expected
def test_noop_tuple_literal_without_braces():
src = (
'x = \\\n'
' 1, \\\n'
' 2, \\\n'
' 3'
)
assert _fix_commas(src, py35_plus=False) == src
def test_main_trivial():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment