Kaydet (Commit) 07422afc authored tarafından Anthony Sottile's avatar Anthony Sottile Kaydeden (comit) GitHub

Merge pull request #8 from asottile/unhug_starargs

Also unhug stararg calls / functions
...@@ -15,9 +15,10 @@ from tokenize_rt import UNIMPORTANT_WS ...@@ -15,9 +15,10 @@ from tokenize_rt import UNIMPORTANT_WS
Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset')) Offset = collections.namedtuple('Offset', ('line', 'utf8_byte_offset'))
Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets')) Call = collections.namedtuple('Call', ('node', 'star_args', 'arg_offsets'))
Func = collections.namedtuple('Func', ('node', 'arg_offsets')) Func = collections.namedtuple('Func', ('node', 'star_args', 'arg_offsets'))
Literal = collections.namedtuple('Literal', ('node', 'braces', 'backtrack')) Literal = collections.namedtuple('Literal', ('node', 'braces', 'backtrack'))
Literal.__new__.__defaults__ = (False,) Literal.__new__.__defaults__ = (False,)
Fix = collections.namedtuple('Fix', ('braces', 'initial_indent'))
NEWLINES = frozenset(('NEWLINE', 'NL')) NEWLINES = frozenset(('NEWLINE', 'NL'))
NON_CODING_TOKENS = frozenset(('COMMENT', 'NL', UNIMPORTANT_WS)) NON_CODING_TOKENS = frozenset(('COMMENT', 'NL', UNIMPORTANT_WS))
...@@ -141,28 +142,39 @@ class FindNodes(ast.NodeVisitor): ...@@ -141,28 +142,39 @@ class FindNodes(ast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
has_starargs = ( has_starargs = False
node.args.vararg or node.args.kwarg or args = list(node.args.args)
# python 3 only
getattr(node.args, 'kwonlyargs', None) if node.args.vararg:
) if isinstance(node.args.vararg, ast.AST): # pragma: no cover (py3)
args.append(node.args.vararg)
has_starargs = True
if node.args.kwarg:
if isinstance(node.args.kwarg, ast.AST): # pragma: no cover (py3)
args.append(node.args.kwarg)
has_starargs = True
py3_kwonlyargs = getattr(node.args, 'kwonlyargs', None)
if py3_kwonlyargs: # pragma: no cover (py3)
args.extend(py3_kwonlyargs)
has_starargs = True
orig = node.lineno orig = node.lineno
is_multiline = False is_multiline = False
offsets = set() offsets = set()
for argnode in node.args.args: for argnode in args:
offset = _to_offset(argnode) offset = _to_offset(argnode)
if offset.line > orig: if offset.line > orig:
is_multiline = True is_multiline = True
offsets.add(offset) offsets.add(offset)
if is_multiline and not has_starargs: if is_multiline:
key = Offset(node.lineno, node.col_offset) key = Offset(node.lineno, node.col_offset)
self.funcs[key] = Func(node, offsets) self.funcs[key] = Func(node, has_starargs, offsets)
self.generic_visit(node) self.generic_visit(node)
def _fix_inner(brace_start, brace_end, first_brace, tokens): def _find_simple(brace_start, brace_end, first_brace, tokens):
brace_stack = [first_brace] brace_stack = [first_brace]
for i in range(first_brace + 1, len(tokens)): for i in range(first_brace + 1, len(tokens)):
...@@ -183,12 +195,6 @@ def _fix_inner(brace_start, brace_end, first_brace, tokens): ...@@ -183,12 +195,6 @@ def _fix_inner(brace_start, brace_end, first_brace, tokens):
if tokens[first_brace].line == tokens[last_brace].line: if tokens[first_brace].line == tokens[last_brace].line:
return return
# Figure out if either of the braces are "hugging"
hug_open = tokens[first_brace + 1].name not in NON_CODING_TOKENS
hug_close = tokens[last_brace - 1].name not in NON_CODING_TOKENS
if hug_open and tokens[last_brace - 1].src in END_BRACES:
hug_open = hug_close = False
# determine the initial indentation # determine the initial indentation
i = first_brace i = first_brace
while i >= 0 and tokens[i].name not in NEWLINES: while i >= 0 and tokens[i].name not in NEWLINES:
...@@ -199,51 +205,10 @@ def _fix_inner(brace_start, brace_end, first_brace, tokens): ...@@ -199,51 +205,10 @@ def _fix_inner(brace_start, brace_end, first_brace, tokens):
else: else:
initial_indent = 0 initial_indent = 0
# fix open hugging return Fix(braces=(first_brace, last_brace), initial_indent=initial_indent)
if hug_open:
new_indent = initial_indent + 4
tokens[first_brace + 1:first_brace + 1] = [
Token('NL', '\n'), Token(UNIMPORTANT_WS, ' ' * new_indent),
]
last_brace += 2
# Adust indentation for the rest of the things
min_indent = None
indents = []
for i in range(first_brace + 3, last_brace):
if tokens[i - 1].name == 'NL' and tokens[i].name == UNIMPORTANT_WS:
if min_indent is None:
min_indent = len(tokens[i].src)
elif len(tokens[i].src) < min_indent:
min_indent = len(tokens[i].src)
indents.append(i)
for i in indents:
oldlen = len(tokens[i].src)
newlen = oldlen - min_indent + new_indent
tokens[i] = tokens[i]._replace(src=' ' * newlen)
# fix close hugging def _find_call(call, i, tokens):
if hug_close:
tokens[last_brace:last_brace] = [
Token('NL', '\n'),
Token(UNIMPORTANT_WS, ' ' * initial_indent),
]
last_brace += 2
# From there, we can walk backwards and decide whether a comma is needed
i = last_brace - 1
while tokens[i].name in NON_CODING_TOKENS:
i -= 1
# If we're not a hugging paren, we can insert a comma
if tokens[i].src != ',' and i + 1 != last_brace:
tokens.insert(i + 1, Token('OP', ','))
def _fix_call(call, i, tokens):
# When we get a `call` object, the ast refers to it as this: # When we get a `call` object, the ast refers to it as this:
# #
# func_name(arg, arg, arg) # func_name(arg, arg, arg)
...@@ -273,10 +238,10 @@ def _fix_call(call, i, tokens): ...@@ -273,10 +238,10 @@ def _fix_call(call, i, tokens):
else: else:
raise AssertionError('Past end?') raise AssertionError('Past end?')
_fix_inner(brace_start, brace_end, first_brace, tokens) return _find_simple(brace_start, brace_end, first_brace, tokens)
def _fix_literal(literal, i, tokens): def _find_literal(literal, i, tokens):
brace_start, brace_end = literal.braces brace_start, brace_end = literal.braces
# tuples are evil, we need to backtrack to find the opening paren # tuples are evil, we need to backtrack to find the opening paren
...@@ -289,7 +254,60 @@ def _fix_literal(literal, i, tokens): ...@@ -289,7 +254,60 @@ def _fix_literal(literal, i, tokens):
if tokens[i].src != brace_start: if tokens[i].src != brace_start:
return return
_fix_inner(brace_start, brace_end, i, tokens) return _find_simple(brace_start, brace_end, i, tokens)
def _fix_comma_and_unhug(fix_data, add_comma, tokens):
first_brace, last_brace = fix_data.braces
# Figure out if either of the braces are "hugging"
hug_open = tokens[first_brace + 1].name not in NON_CODING_TOKENS
hug_close = tokens[last_brace - 1].name not in NON_CODING_TOKENS
if hug_open and tokens[last_brace - 1].src in END_BRACES:
hug_open = hug_close = False
# fix open hugging
if hug_open:
new_indent = fix_data.initial_indent + 4
tokens[first_brace + 1:first_brace + 1] = [
Token('NL', '\n'), Token(UNIMPORTANT_WS, ' ' * new_indent),
]
last_brace += 2
# Adust indentation for the rest of the things
min_indent = None
indents = []
for i in range(first_brace + 3, last_brace):
if tokens[i - 1].name == 'NL' and tokens[i].name == UNIMPORTANT_WS:
if min_indent is None:
min_indent = len(tokens[i].src)
elif len(tokens[i].src) < min_indent:
min_indent = len(tokens[i].src)
indents.append(i)
for i in indents:
oldlen = len(tokens[i].src)
newlen = oldlen - min_indent + new_indent
tokens[i] = tokens[i]._replace(src=' ' * newlen)
# fix close hugging
if hug_close:
tokens[last_brace:last_brace] = [
Token('NL', '\n'),
Token(UNIMPORTANT_WS, ' ' * fix_data.initial_indent),
]
last_brace += 2
# From there, we can walk backwards and decide whether a comma is needed
i = last_brace - 1
while tokens[i].name in NON_CODING_TOKENS:
i -= 1
# If we're not a hugging paren, we can insert a comma
if add_comma and tokens[i].src != ',' and i + 1 != last_brace:
tokens.insert(i + 1, Token('OP', ','))
def _fix_src(contents_text, py35_plus): def _fix_src(contents_text, py35_plus):
...@@ -305,16 +323,25 @@ def _fix_src(contents_text, py35_plus): ...@@ -305,16 +323,25 @@ def _fix_src(contents_text, py35_plus):
tokens = src_to_tokens(contents_text) tokens = src_to_tokens(contents_text)
for i, token in reversed(tuple(enumerate(tokens))): for i, token in reversed(tuple(enumerate(tokens))):
key = Offset(token.line, token.utf8_byte_offset) key = Offset(token.line, token.utf8_byte_offset)
add_comma = True
fix_data = None
if key in visitor.calls: if key in visitor.calls:
call = visitor.calls[key] call = visitor.calls[key]
# Only fix stararg calls if asked to # Only fix stararg calls if asked to
if not call.star_args or py35_plus: add_comma = not call.star_args or py35_plus
_fix_call(call, i, tokens) fix_data = _find_call(call, i, tokens)
elif key in visitor.literals:
_fix_literal(visitor.literals[key], i, tokens)
elif key in visitor.funcs: elif key in visitor.funcs:
func = visitor.funcs[key]
# any amount of starargs excludes adding a comma for defs
add_comma = not func.star_args
# functions can be treated as calls # functions can be treated as calls
_fix_call(visitor.funcs[key], i, tokens) fix_data = _find_call(func, i, tokens)
elif key in visitor.literals:
fix_data = _find_literal(visitor.literals[key], i, tokens)
if fix_data is not None:
_fix_comma_and_unhug(fix_data, add_comma, tokens)
return tokens_to_src(tokens) return tokens_to_src(tokens)
......
...@@ -11,6 +11,7 @@ from add_trailing_comma import _fix_src ...@@ -11,6 +11,7 @@ from add_trailing_comma import _fix_src
from add_trailing_comma import main from add_trailing_comma import main
xfailif_py2 = pytest.mark.xfail(sys.version_info < (3,), reason='py3+')
xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+') xfailif_lt_py35 = pytest.mark.xfail(sys.version_info < (3, 5), reason='py35+')
...@@ -264,7 +265,7 @@ def test_noop_tuple_literal_without_braces(): ...@@ -264,7 +265,7 @@ def test_noop_tuple_literal_without_braces():
# *args forbid trailing commas # *args forbid trailing commas
'def f(\n' 'def f(\n'
' *args\n' ' *args\n'
'): pass' '): pass',
# **kwargs forbid trailing commas # **kwargs forbid trailing commas
'def f(\n' 'def f(\n'
' **kwargs\n' ' **kwargs\n'
...@@ -415,12 +416,56 @@ def test_noop_unhugs(src): ...@@ -415,12 +416,56 @@ def test_noop_unhugs(src):
' 1,\n' ' 1,\n'
')', ')',
), ),
(
'f(\n'
' *args)',
'f(\n'
' *args\n'
')',
),
), ),
) )
def test_fix_unhugs(src, expected): def test_fix_unhugs(src, expected):
assert _fix_src(src, py35_plus=False) == expected assert _fix_src(src, py35_plus=False) == expected
@xfailif_py2
@pytest.mark.parametrize(
('src', 'expected'),
(
# python 2 doesn't give offset information for starargs
(
'def f(\n'
' *args): pass',
'def f(\n'
' *args\n'
'): pass',
),
(
'def f(\n'
' **kwargs): pass',
'def f(\n'
' **kwargs\n'
'): pass',
),
# python 2 doesn't kwonlyargs
(
'def f(\n'
' *, kw=1, kw2=2): pass',
'def f(\n'
' *, kw=1, kw2=2\n'
'): pass',
),
),
)
def test_fix_unhugs_py3_only(src, expected):
assert _fix_src(src, py35_plus=False) == expected
def test_main_trivial(): def test_main_trivial():
assert main(()) == 0 assert main(()) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment