Kaydet (Commit) 8028caa9 authored tarafından Patrick Maupin's avatar Patrick Maupin

Fixes and testcases for issue #27

üst 4043b32e
...@@ -126,13 +126,13 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -126,13 +126,13 @@ class SourceGenerator(ExplicitNodeVisitor):
loop_args(node.args, node.defaults) loop_args(node.args, node.defaults)
self.conditional_write(write_comma, '*', node.vararg) self.conditional_write(write_comma, '*', node.vararg)
self.conditional_write(write_comma, '**', node.kwarg)
kwonlyargs = getattr(node, 'kwonlyargs', None) kwonlyargs = getattr(node, 'kwonlyargs', None)
if kwonlyargs: if kwonlyargs:
if node.vararg is None: if node.vararg is None:
self.write(write_comma, '*') self.write(write_comma, '*')
loop_args(kwonlyargs, node.kw_defaults) loop_args(kwonlyargs, node.kw_defaults)
self.conditional_write(write_comma, '**', node.kwarg)
def statement(self, node, *params, **kw): def statement(self, node, *params, **kw):
self.newline(node) self.newline(node)
...@@ -168,7 +168,7 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -168,7 +168,7 @@ class SourceGenerator(ExplicitNodeVisitor):
self.statement(node, 'from ', node.level * '.', self.statement(node, 'from ', node.level * '.',
node.module, ' import ') node.module, ' import ')
else: else:
self.statement(node, 'from ', node.level * '. import ') self.statement(node, 'from ', node.level * '.', ' import ')
self.comma_list(node.names) self.comma_list(node.names)
def visit_Import(self, node): def visit_Import(self, node):
...@@ -307,10 +307,10 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -307,10 +307,10 @@ class SourceGenerator(ExplicitNodeVisitor):
self.body(node.body) self.body(node.body)
for handler in node.handlers: for handler in node.handlers:
self.visit(handler) self.visit(handler)
self.else_body(node.orelse)
if node.finalbody: if node.finalbody:
self.statement(node, 'finally:') self.statement(node, 'finally:')
self.body(node.finalbody) self.body(node.finalbody)
self.else_body(node.orelse)
def visit_ExceptHandler(self, node): def visit_ExceptHandler(self, node):
self.statement(node, 'except') self.statement(node, 'except')
...@@ -367,6 +367,9 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -367,6 +367,9 @@ class SourceGenerator(ExplicitNodeVisitor):
# Expressions # Expressions
def visit_Attribute(self, node): def visit_Attribute(self, node):
if isinstance(node.value, ast.Num):
self.write('(', node.value, ')', '.', node.attr)
else:
self.write(node.value, '.', node.attr) self.write(node.value, '.', node.attr)
def visit_Call(self, node): def visit_Call(self, node):
...@@ -473,14 +476,19 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -473,14 +476,19 @@ class SourceGenerator(ExplicitNodeVisitor):
def visit_ExtSlice(self, node): def visit_ExtSlice(self, node):
self.comma_list(node.dims, len(node.dims) == 1) self.comma_list(node.dims, len(node.dims) == 1)
def visit_Yield(self, node): def visit_Yield(self, node, dofrom=False):
is_stmt = self.new_lines
if not is_stmt:
self.write('(')
self.write('yield') self.write('yield')
self.conditional_write(' from' if dofrom else None)
self.conditional_write(' ', node.value) self.conditional_write(' ', node.value)
if not is_stmt:
self.write(')')
# new for Python 3.3 # new for Python 3.3
def visit_YieldFrom(self, node): def visit_YieldFrom(self, node):
self.write('yield from ') self.visit_Yield(node, True)
self.visit(node.value)
# new for Python 3.5 # new for Python 3.5
def visit_Await(self, node): def visit_Await(self, node):
......
...@@ -23,10 +23,12 @@ class CodegenTestCase(unittest.TestCase): ...@@ -23,10 +23,12 @@ class CodegenTestCase(unittest.TestCase):
def assertAstSourceEqual(self, source): def assertAstSourceEqual(self, source):
self.assertEqual(astor.to_source(ast.parse(source)), source) self.assertEqual(astor.to_source(ast.parse(source)), source)
def assertAstSourceEqualIfAtLeastVersion(self, source, version_tuple): def assertAstSourceEqualIfAtLeastVersion(self, source, version_tuple, version2=None):
if version2 is None:
version2 = version_tuple[0], version_tuple[1] - 1
if sys.version_info >= version_tuple: if sys.version_info >= version_tuple:
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
else: elif sys.version_info <= version2:
self.assertRaises(SyntaxError, ast.parse, source) self.assertRaises(SyntaxError, ast.parse, source)
def test_imports(self): def test_imports(self):
...@@ -36,6 +38,8 @@ class CodegenTestCase(unittest.TestCase): ...@@ -36,6 +38,8 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
source = "from math import floor" source = "from math import floor"
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
source = "from .. import foobar"
self.assertAstSourceEqual(source)
def test_dictionary_literals(self): def test_dictionary_literals(self):
source = "{'a': 1, 'b': 2}" source = "{'a': 1, 'b': 2}"
...@@ -58,6 +62,20 @@ class CodegenTestCase(unittest.TestCase): ...@@ -58,6 +62,20 @@ class CodegenTestCase(unittest.TestCase):
sys.stdout.write(exc)""") sys.stdout.write(exc)""")
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
source = textwrap.dedent("""\
try:
'spam'[10]
except IndexError as exc:
sys.stdout.write(exc)
else:
pass
finally:
pass""")
# This is interesting -- the latest 2.7 compiler seems to
# handle this OK, but creates an AST with nested try/finally
# and try/except, so the source code doesn't match.
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (1, 0))
def test_del_statement(self): def test_del_statement(self):
source = "del l[0]" source = "del l[0]"
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
...@@ -82,6 +100,11 @@ class CodegenTestCase(unittest.TestCase): ...@@ -82,6 +100,11 @@ class CodegenTestCase(unittest.TestCase):
if isinstance(n, ast.arguments)][0] if isinstance(n, ast.arguments)][0]
self.assertEqual(astor.to_source(arguments_node), self.assertEqual(astor.to_source(arguments_node),
"a1, a2, b1=j, b2='123', b3={}, b4=[]") "a1, a2, b1=j, b2='123', b3={}, b4=[]")
source = textwrap.dedent("""\
def call(*popenargs, timeout=None, **kwargs):
pass""")
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
def test_matrix_multiplication(self): def test_matrix_multiplication(self):
for source in ("(a @ b)", "a @= b"): for source in ("(a @ b)", "a @= b"):
...@@ -113,6 +136,19 @@ class CodegenTestCase(unittest.TestCase): ...@@ -113,6 +136,19 @@ class CodegenTestCase(unittest.TestCase):
pass""") pass""")
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 0)) self.assertAstSourceEqualIfAtLeastVersion(source, (3, 0))
def test_yield(self):
source = "yield"
self.assertAstSourceEqual(source)
source = textwrap.dedent("""\
def dummy():
yield""")
self.assertAstSourceEqual(source)
source = "foo((yield bar))"
self.assertAstSourceEqual(source)
source = "return (yield from sam())"
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment