Kaydet (Commit) 8028caa9 authored tarafından Patrick Maupin's avatar Patrick Maupin

Fixes and testcases for issue #27

üst 4043b32e
......@@ -126,13 +126,13 @@ class SourceGenerator(ExplicitNodeVisitor):
loop_args(node.args, node.defaults)
self.conditional_write(write_comma, '*', node.vararg)
self.conditional_write(write_comma, '**', node.kwarg)
kwonlyargs = getattr(node, 'kwonlyargs', None)
if kwonlyargs:
if node.vararg is None:
self.write(write_comma, '*')
loop_args(kwonlyargs, node.kw_defaults)
self.conditional_write(write_comma, '**', node.kwarg)
def statement(self, node, *params, **kw):
self.newline(node)
......@@ -168,7 +168,7 @@ class SourceGenerator(ExplicitNodeVisitor):
self.statement(node, 'from ', node.level * '.',
node.module, ' import ')
else:
self.statement(node, 'from ', node.level * '. import ')
self.statement(node, 'from ', node.level * '.', ' import ')
self.comma_list(node.names)
def visit_Import(self, node):
......@@ -307,10 +307,10 @@ class SourceGenerator(ExplicitNodeVisitor):
self.body(node.body)
for handler in node.handlers:
self.visit(handler)
self.else_body(node.orelse)
if node.finalbody:
self.statement(node, 'finally:')
self.body(node.finalbody)
self.else_body(node.orelse)
def visit_ExceptHandler(self, node):
self.statement(node, 'except')
......@@ -367,7 +367,10 @@ class SourceGenerator(ExplicitNodeVisitor):
# Expressions
def visit_Attribute(self, node):
self.write(node.value, '.', node.attr)
if isinstance(node.value, ast.Num):
self.write('(', node.value, ')', '.', node.attr)
else:
self.write(node.value, '.', node.attr)
def visit_Call(self, node):
want_comma = []
......@@ -473,14 +476,19 @@ class SourceGenerator(ExplicitNodeVisitor):
def visit_ExtSlice(self, node):
self.comma_list(node.dims, len(node.dims) == 1)
def visit_Yield(self, node):
def visit_Yield(self, node, dofrom=False):
is_stmt = self.new_lines
if not is_stmt:
self.write('(')
self.write('yield')
self.conditional_write(' from' if dofrom else None)
self.conditional_write(' ', node.value)
if not is_stmt:
self.write(')')
# new for Python 3.3
def visit_YieldFrom(self, node):
self.write('yield from ')
self.visit(node.value)
self.visit_Yield(node, True)
# new for Python 3.5
def visit_Await(self, node):
......
......@@ -23,10 +23,12 @@ class CodegenTestCase(unittest.TestCase):
def assertAstSourceEqual(self, source):
self.assertEqual(astor.to_source(ast.parse(source)), source)
def assertAstSourceEqualIfAtLeastVersion(self, source, version_tuple):
def assertAstSourceEqualIfAtLeastVersion(self, source, version_tuple, version2=None):
if version2 is None:
version2 = version_tuple[0], version_tuple[1] - 1
if sys.version_info >= version_tuple:
self.assertAstSourceEqual(source)
else:
elif sys.version_info <= version2:
self.assertRaises(SyntaxError, ast.parse, source)
def test_imports(self):
......@@ -36,6 +38,8 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqual(source)
source = "from math import floor"
self.assertAstSourceEqual(source)
source = "from .. import foobar"
self.assertAstSourceEqual(source)
def test_dictionary_literals(self):
source = "{'a': 1, 'b': 2}"
......@@ -58,6 +62,20 @@ class CodegenTestCase(unittest.TestCase):
sys.stdout.write(exc)""")
self.assertAstSourceEqual(source)
source = textwrap.dedent("""\
try:
'spam'[10]
except IndexError as exc:
sys.stdout.write(exc)
else:
pass
finally:
pass""")
# This is interesting -- the latest 2.7 compiler seems to
# handle this OK, but creates an AST with nested try/finally
# and try/except, so the source code doesn't match.
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (1, 0))
def test_del_statement(self):
source = "del l[0]"
self.assertAstSourceEqual(source)
......@@ -82,6 +100,11 @@ class CodegenTestCase(unittest.TestCase):
if isinstance(n, ast.arguments)][0]
self.assertEqual(astor.to_source(arguments_node),
"a1, a2, b1=j, b2='123', b3={}, b4=[]")
source = textwrap.dedent("""\
def call(*popenargs, timeout=None, **kwargs):
pass""")
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
def test_matrix_multiplication(self):
for source in ("(a @ b)", "a @= b"):
......@@ -113,6 +136,19 @@ class CodegenTestCase(unittest.TestCase):
pass""")
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 0))
def test_yield(self):
source = "yield"
self.assertAstSourceEqual(source)
source = textwrap.dedent("""\
def dummy():
yield""")
self.assertAstSourceEqual(source)
source = "foo((yield bar))"
self.assertAstSourceEqual(source)
source = "return (yield from sam())"
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment