Kaydet (Commit) 703d81e2 authored tarafından Patrick Maupin's avatar Patrick Maupin Kaydeden (comit) GitHub

Merge PR #60 -- clean up and bug fixes for new features

- Clean up code providing new functionality
- (Thanks  for the initial functionality, Ryan) 👍 
- run extended test cases via astor.rtrip
- add regression tests for new features based on rtrip failures
- fix failing regressions
......@@ -76,6 +76,7 @@ class Delimit(object):
"""
discard = False
coalesce = False
def __init__(self, tree, *args):
""" use write instead of using result directly
......@@ -106,11 +107,14 @@ class Delimit(object):
return self
def __exit__(self, *exc_info):
result = self.result
start = self.index - 1
if self.discard:
self.result[self.index - 1] = ''
result[start] = ''
else:
self.result.append(self.closing)
result.append(self.closing)
if self.coalesce:
result[start:] = [''.join(result[start:])]
class SourceGenerator(ExplicitNodeVisitor):
"""This visitor is able to transform a well formed syntax tree into Python
......@@ -248,15 +252,13 @@ class SourceGenerator(ExplicitNodeVisitor):
node.value)
def visit_AnnAssign(self, node):
set_precedence(node, node.value, node.target)
self.newline(node)
if not node.simple and isinstance(node.target, ast.Name):
self.write('(', node.target, ')')
else:
self.write(node.target)
self.write(': ', node.annotation)
if node.value is not None:
self.write(' = ', node.value)
set_precedence(node, node.target, node.annotation)
set_precedence(Precedence.Comma, node.value)
need_parens = isinstance(node.target, ast.Name) and not node.simple
begin = '(' if need_parens else ''
end = ')' if need_parens else ''
self.statement(node, begin, node.target, end, ': ', node.annotation)
self.conditional_write(' = ', node.value)
def visit_ImportFrom(self, node):
self.statement(node, 'from ', node.level * '.',
......@@ -497,47 +499,44 @@ class SourceGenerator(ExplicitNodeVisitor):
result.pop()
result.append(self.pretty_string(node.s, embedded, result))
def visit_JoinedStr(self, node, nested=False):
if not nested:
self.write("f'")
def visit_JoinedStr(self, node,
# constants
new=sys.version_info >= (3, 0)):
def write_string(s):
if sys.version_info >= (3, 0):
escaped = s.encode('unicode-escape').decode()
else:
escaped = s.encode('string-escape')
self.write(escaped.replace("'", "\\'"))
for value in node.values:
if isinstance(value, ast.Str):
write_string(value.s)
elif isinstance(value, ast.FormattedValue):
self.write('{')
self.visit(value.value)
if value.conversion != -1:
self.write('!%s' % chr(value.conversion))
if value.format_spec is not None:
self.write(':')
# Either a nested Str or JoinedStr can be here.
if isinstance(value.format_spec, ast.Str):
write_string(value.format_spec.s)
elif isinstance(value.format_spec, ast.JoinedStr):
self.visit_JoinedStr(value.format_spec, nested=True)
def recurse(node):
for value in node.values:
if isinstance(value, ast.Str):
if new:
encoded = value.s.encode('unicode-escape').decode()
else:
kind = type(value).__name__
assert False, 'Invalid node %s inside JoinedStr' % kind
self.write('}')
encoded = value.s.encode('string-escape')
self.write(encoded)
elif isinstance(value, ast.FormattedValue):
with self.delimit('{}'):
self.visit(value.value)
if value.conversion != -1:
self.write('!%s' % chr(value.conversion))
if value.format_spec is not None:
self.write(':')
recurse(value.format_spec)
else:
kind = type(value).__name__
assert False, 'Invalid node %s inside JoinedStr' % kind
with self.delimit(("f'", "'")) as delimiters:
recurse(node)
delimiters.coalesce = True
s = self.result.pop()
squotes = s.count("'") - 2
if squotes:
dquotes = s.count('"')
s = s[2:-1]
if dquotes < squotes:
s = 'f"%s"' % s.replace('"', r'\"')
else:
kind = type(value).__name__
assert False, 'Invalid node %s inside JoinedStr' % kind
if not nested:
self.write("'")
def visit_FormattedValue(self, node):
self.visit_JoinedStr(ast.JoinedStr(values=[node]))
s = "f'%s'" % s.replace("'", r"\'")
self.write(s)
def visit_Bytes(self, node):
self.write(repr(node.s))
......@@ -727,8 +726,7 @@ class SourceGenerator(ExplicitNodeVisitor):
def visit_comprehension(self, node):
set_precedence(node, node.iter, *node.ifs)
set_precedence(Precedence.comprehension_target, node.target)
if getattr(node, 'is_async', False):
self.write(' async')
self.write(' for ', node.target, ' in ', node.iter)
stmt = ' async for ' if self.get_is_async(node) else ' for '
self.write(stmt, node.target, ' in ', node.iter)
for if_ in node.ifs:
self.write(' if ', if_)
......@@ -370,14 +370,36 @@ class CodegenTestCase(unittest.TestCase):
f'{int(x)}'
f'a{b:c}d'
f'a{b!s:c{d}e}f'
f'""'
f'"\\''
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
source = """
a_really_long_line_will_probably_break_things = (
f'a{b!s:c{d}e}fghijka{b!s:c{d}e}a{b!s:c{d}e}a{b!s:c{d}e}')
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
source = """
return f"functools.{qualname}({', '.join(args)})"
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
def test_annassign(self):
source = """
a: int
(b): Tuple[int, str, ...]
(a): int
a.b: int
(a.b): int
b: Tuple[int, str, ...]
c.d[e].f: Any
q: 3 = (1, 2, 3)
t: Tuple[int, ...] = (1, 2, 3)
some_list: List[int] = []
(a): int = 0
a:int = 0
(a.b): int = 0
a.b: int = 0
"""
self.assertAstEqualIfAtLeastVersion(source, (3, 6))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment