Kaydet (Commit) c504f6e8 authored tarafından Berker Peksag's avatar Berker Peksag

Merge pull request #35 from berkerpeksag/code_gen_cleanup

Code gen cleanup
This diff is collapsed.
......@@ -18,17 +18,49 @@ except ImportError:
import astor
def canonical(srctxt):
return textwrap.dedent(srctxt).strip()
class CodegenTestCase(unittest.TestCase):
def assertAstSourceEqual(self, source):
self.assertEqual(astor.to_source(ast.parse(source)), source)
def assertAstEqual(self, srctxt):
"""This asserts that the reconstituted source
code can be compiled into the exact same AST
as the original source code.
"""
srctxt = canonical(srctxt)
srcast = ast.parse(srctxt)
dsttxt = astor.to_source(srcast)
dstast = ast.parse(dsttxt)
srcdmp = astor.dump_tree(srcast)
dstdmp = astor.dump_tree(dstast)
self.assertEqual(dstdmp, srcdmp)
def assertAstEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
self.assertAstEqual(source)
elif sys.version_info <= max_should_error:
self.assertRaises(SyntaxError, ast.parse, source)
def assertAstSourceEqualIfAtLeastVersion(self, source, version_tuple, version2=None):
if version2 is None:
version2 = version_tuple[0], version_tuple[1] - 1
if sys.version_info >= version_tuple:
def assertAstSourceEqual(self, srctxt):
"""This asserts that the reconstituted source
code is identical to the original source code.
This is a much stronger statement than assertAstEqual,
which may not always be appropriate.
"""
srctxt = canonical(srctxt)
self.assertEqual(astor.to_source(ast.parse(srctxt)), srctxt)
def assertAstSourceEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
self.assertAstSourceEqual(source)
elif sys.version_info <= version2:
elif sys.version_info <= max_should_error:
self.assertRaises(SyntaxError, ast.parse, source)
def test_imports(self):
......@@ -40,6 +72,8 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqual(source)
source = "from .. import foobar"
self.assertAstSourceEqual(source)
source = "from ..aaa import foo, bar as bar2"
self.assertAstSourceEqual(source)
def test_dictionary_literals(self):
source = "{'a': 1, 'b': 2}"
......@@ -48,21 +82,21 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqual(another_source)
def test_try_expect(self):
source = textwrap.dedent("""\
source = """
try:
'spam'[10]
except IndexError:
pass""")
self.assertAstSourceEqual(source)
pass"""
self.assertAstEqual(source)
source = textwrap.dedent("""\
source = """
try:
'spam'[10]
except IndexError as exc:
sys.stdout.write(exc)""")
self.assertAstSourceEqual(source)
sys.stdout.write(exc)"""
self.assertAstEqual(source)
source = textwrap.dedent("""\
source = """
try:
'spam'[10]
except IndexError as exc:
......@@ -70,11 +104,8 @@ class CodegenTestCase(unittest.TestCase):
else:
pass
finally:
pass""")
# This is interesting -- the latest 2.7 compiler seems to
# handle this OK, but creates an AST with nested try/finally
# and try/except, so the source code doesn't match.
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (1, 0))
pass"""
self.assertAstEqual(source)
def test_del_statement(self):
source = "del l[0]"
......@@ -83,16 +114,17 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqual(source)
def test_arguments(self):
source = textwrap.dedent("""\
source = """
j = [1, 2, 3]
def test(a1, a2, b1=j, b2='123', b3={}, b4=[]):
pass""")
pass"""
self.assertAstSourceEqual(source)
def test_pass_arguments_node(self):
source = textwrap.dedent("""\
source = canonical("""
j = [1, 2, 3]
def test(a1, a2, b1=j, b2='123', b3={}, b4=[]):
pass""")
root_node = ast.parse(source)
......@@ -100,9 +132,9 @@ class CodegenTestCase(unittest.TestCase):
if isinstance(n, ast.arguments)][0]
self.assertEqual(astor.to_source(arguments_node),
"a1, a2, b1=j, b2='123', b3={}, b4=[]")
source = textwrap.dedent("""\
source = """
def call(*popenargs, timeout=None, **kwargs):
pass""")
pass"""
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
......@@ -111,45 +143,67 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
def test_multiple_call_unpackings(self):
source = textwrap.dedent("""\
my_function(*[1], *[2], **{'three': 3}, **{'four': 'four'})""")
source = """
my_function(*[1], *[2], **{'three': 3}, **{'four': 'four'})"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
def test_right_hand_side_dictionary_unpacking(self):
source = textwrap.dedent("""\
our_dict = {'a': 1, **{'b': 2, 'c': 3}}""")
source = """
our_dict = {'a': 1, **{'b': 2, 'c': 3}}"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
def test_async_def_with_for(self):
source = textwrap.dedent("""\
source = """
async def read_data(db):
async with connect(db) as db_cxn:
data = await db_cxn.fetch('SELECT foo FROM bar;')
async for datum in data:
if quux(datum):
return datum""")
return datum"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
def test_class_definition_with_starbases_and_kwargs(self):
source = textwrap.dedent("""\
source = """
class TreeFactory(*[FactoryMixin, TreeBase], **{'metaclass': Foo}):
pass""")
pass"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 0))
def test_yield(self):
source = "(yield)"
self.assertAstSourceEqual(source)
source = textwrap.dedent("""\
source = "yield"
self.assertAstEqual(source)
source = """
def dummy():
(yield)""")
self.assertAstSourceEqual(source)
yield"""
self.assertAstEqual(source)
source = "foo((yield bar))"
self.assertAstSourceEqual(source)
self.assertAstEqual(source)
source = "(yield bar)()"
self.assertAstSourceEqual(source)
self.assertAstEqual(source)
source = "return (yield from sam())"
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
def test_with(self):
source = """
with foo:
pass
"""
self.assertAstSourceEqual(source)
source = """
with foo as bar:
pass
"""
self.assertAstSourceEqual(source)
source = """
with foo as bar, mary, william as bill:
pass
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 3), (1, 0))
def test_inf(self):
source = """
(1e1000) + (-1e1000) + (1e1000j) + (-1e1000j)
"""
self.assertAstEqual(source)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment