Kaydet (Commit) acc8ec87 authored tarafından Patrick Maupin's avatar Patrick Maupin Kaydeden (comit) GitHub

Refactor tests (#74)

* Refactor tests
 - Make it easier to test against other packages

* Split out optional tests and use assertXXX naming conventions

* Fix Python 3 test intrapackage import

* Fix Python 3 test intrapackage import
üst a957d588
import ast
try:
import unittest2 as unittest
except ImportError:
import unittest
import test_code_gen
import astunparse
class MyTests(test_code_gen.CodegenTestCase):
to_source = staticmethod(astunparse.unparse)
# Just see if it'll do anything good at all
assertSrcRoundtrips = test_code_gen.CodegenTestCase.assertAstRoundtrips
# Don't look for exact comparison; see if ASTs match
def assertSrcEqual(self, src1, src2):
self.assertAstEqual(ast.parse(src1), ast.parse(src2))
if __name__ == '__main__':
unittest.main()
......@@ -10,7 +10,6 @@ Copyright (c) 2015 Patrick Maupin
import ast
import sys
import textwrap
import warnings
try:
import unittest2 as unittest
......@@ -24,65 +23,76 @@ def canonical(srctxt):
return textwrap.dedent(srctxt).strip()
class CodegenTestCase(unittest.TestCase):
class Comparisons(object):
def assertAstEqual(self, srctxt):
to_source = staticmethod(astor.to_source)
assertSrcEqual = unittest.TestCase.assertEqual
def assertAstEqual(self, ast1, ast2):
dmp1 = astor.dump_tree(ast1)
dmp2 = astor.dump_tree(ast2)
self.assertEqual(dmp1, dmp2)
def assertAstRoundtrips(self, srctxt):
"""This asserts that the reconstituted source
code can be compiled into the exact same AST
as the original source code.
"""
srctxt = canonical(srctxt)
srcast = ast.parse(srctxt)
dsttxt = astor.to_source(srcast)
dsttxt = self.to_source(srcast)
dstast = ast.parse(dsttxt)
srcdmp = astor.dump_tree(srcast)
dstdmp = astor.dump_tree(dstast)
self.assertEqual(dstdmp, srcdmp)
self.assertAstEqual(srcast, dstast)
def assertAstEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
def assertAstRoundtripsGtVer(self, source, min_should_work,
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
elif sys.version_info <= max_should_error:
self.assertRaises(SyntaxError, ast.parse, source)
def assertAstSourceEqual(self, srctxt):
def assertSrcRoundtrips(self, srctxt):
"""This asserts that the reconstituted source
code is identical to the original source code.
This is a much stronger statement than assertAstEqual,
This is a much stronger statement than assertAstRoundtrips,
which may not always be appropriate.
"""
srctxt = canonical(srctxt)
self.assertEqual(astor.to_source(ast.parse(srctxt)).rstrip(), srctxt)
self.assertSrcEqual(self.to_source(ast.parse(srctxt)).rstrip(),
srctxt)
def assertAstSourceEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
def assertSrcRoundtripsGtVer(self, source, min_should_work,
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
elif sys.version_info <= max_should_error:
self.assertRaises(SyntaxError, ast.parse, source)
class CodegenTestCase(unittest.TestCase, Comparisons):
def test_imports(self):
source = "import ast"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = "import operator as op"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = "from math import floor"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = "from .. import foobar"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = "from ..aaa import foo, bar as bar2"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
def test_dictionary_literals(self):
source = "{'a': 1, 'b': 2}"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
another_source = "{'nested': ['structures', {'are': 'important'}]}"
self.assertAstSourceEqual(another_source)
self.assertSrcRoundtrips(another_source)
def test_try_expect(self):
source = """
......@@ -90,14 +100,14 @@ class CodegenTestCase(unittest.TestCase):
'spam'[10]
except IndexError:
pass"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
try:
'spam'[10]
except IndexError as exc:
sys.stdout.write(exc)"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
try:
......@@ -108,7 +118,7 @@ class CodegenTestCase(unittest.TestCase):
pass
finally:
pass"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
try:
size = len(iterable)
......@@ -117,13 +127,13 @@ class CodegenTestCase(unittest.TestCase):
else:
if n >= size:
return sorted(iterable, key=key, reverse=True)[:n]"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_del_statement(self):
source = "del l[0]"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = "del obj.x"
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
def test_arguments(self):
source = """
......@@ -132,7 +142,7 @@ class CodegenTestCase(unittest.TestCase):
def test(a1, a2, b1=j, b2='123', b3={}, b4=[]):
pass"""
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
def test_pass_arguments_node(self):
source = canonical("""
......@@ -144,27 +154,27 @@ class CodegenTestCase(unittest.TestCase):
root_node = ast.parse(source)
arguments_node = [n for n in ast.walk(root_node)
if isinstance(n, ast.arguments)][0]
self.assertEqual(astor.to_source(arguments_node).rstrip(),
self.assertEqual(self.to_source(arguments_node).rstrip(),
"a1, a2, b1=j, b2='123', b3={}, b4=[]")
source = """
def call(*popenargs, timeout=None, **kwargs):
pass"""
# Probably also works on < 3.4, but doesn't work on 2.7...
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 4), (2, 7))
self.assertSrcRoundtripsGtVer(source, (3, 4), (2, 7))
def test_matrix_multiplication(self):
for source in ("(a @ b)", "a @= b"):
self.assertAstEqualIfAtLeastVersion(source, (3, 5))
self.assertAstRoundtripsGtVer(source, (3, 5))
def test_multiple_call_unpackings(self):
source = """
my_function(*[1], *[2], **{'three': 3}, **{'four': 'four'})"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
self.assertSrcRoundtripsGtVer(source, (3, 5))
def test_right_hand_side_dictionary_unpacking(self):
source = """
our_dict = {'a': 1, **{'b': 2, 'c': 3}}"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
self.assertSrcRoundtripsGtVer(source, (3, 5))
def test_async_def_with_for(self):
source = """
......@@ -174,118 +184,118 @@ class CodegenTestCase(unittest.TestCase):
async for datum in data:
if quux(datum):
return datum"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
self.assertSrcRoundtripsGtVer(source, (3, 5))
def test_double_await(self):
source = """
async def foo():
return await (await bar())"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 5))
self.assertSrcRoundtripsGtVer(source, (3, 5))
def test_class_definition_with_starbases_and_kwargs(self):
source = """
class TreeFactory(*[FactoryMixin, TreeBase], **{'metaclass': Foo}):
pass"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 0))
self.assertSrcRoundtripsGtVer(source, (3, 0))
def test_yield(self):
source = "yield"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
def dummy():
yield"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "foo((yield bar))"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "(yield bar)()"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "return (yield 1)"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "return (yield from sam())"
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
self.assertAstRoundtripsGtVer(source, (3, 3))
source = "((yield a) for b in c)"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "[(yield)]"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "if (yield): pass"
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = "if (yield from foo): pass"
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
self.assertAstRoundtripsGtVer(source, (3, 3))
source = "(yield from (a, b))"
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
self.assertAstRoundtripsGtVer(source, (3, 3))
source = "yield from sam()"
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 3))
self.assertSrcRoundtripsGtVer(source, (3, 3))
def test_with(self):
source = """
with foo:
pass
"""
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = """
with foo as bar:
pass
"""
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
source = """
with foo as bar, mary, william as bill:
pass
"""
self.assertAstEqualIfAtLeastVersion(source, (2, 7))
self.assertAstRoundtripsGtVer(source, (2, 7))
def test_inf(self):
source = """
(1e1000) + (-1e1000) + (1e1000j) + (-1e1000j)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_unary(self):
source = """
-(1) + ~(2) + +(3)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_pow(self):
source = """
(-2) ** (-3)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
(+2) ** (+3)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
2 ** 3 ** 4
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
-2 ** -3
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
-2 ** -3 ** -4
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
-((-1) ** other._sign)
(-1) ** self._sign
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_comprehension(self):
source = """
((x,y) for x,y in zip(a,b) if a)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
fields = [(a, _format(b)) for (a, b) in iter_fields(node) if a]
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
ra = np.fromiter(((i * 3, i * 2) for i in range(10)),
n, dtype='i8,f8')
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_async_comprehension(self):
source = """
......@@ -295,52 +305,52 @@ class CodegenTestCase(unittest.TestCase):
(await x async for x in y)
{i for i in b async for i in a if await i for b in i}
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
self.assertSrcRoundtripsGtVer(source, (3, 6))
def test_tuple_corner_cases(self):
source = """
a = ()
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
assert (a, b), (c, d)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
return UUID(fields=(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node), version=1)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
raise(os.error, ('multiple errors:', errors))
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
exec(expr, global_dict, local_dict)
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
source = """
with (a, b) as (c, d):
pass
"""
self.assertAstEqual(source)
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
self.assertAstRoundtrips(source)
source = """
with (a, b) as (c, d), (e,f) as (h,g):
pass
"""
self.assertAstEqualIfAtLeastVersion(source, (2, 7))
self.assertAstRoundtripsGtVer(source, (2, 7))
source = """
Pxx[..., (0,-1)] = xft[..., (0,-1)]**2
"""
self.assertAstEqualIfAtLeastVersion(source, (2, 7))
self.assertAstRoundtripsGtVer(source, (2, 7))
source = """
responses = {
v: (v.phrase, v.description)
for v in HTTPStatus.__members__.values()
}
"""
self.assertAstEqualIfAtLeastVersion(source, (2, 7))
self.assertAstRoundtripsGtVer(source, (2, 7))
def test_output_formatting(self):
source = """
......@@ -351,7 +361,7 @@ class CodegenTestCase(unittest.TestCase):
'ZERO_OR_MORE']
""" # NOQA
self.maxDiff = 2000
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
def test_elif(self):
source = """
......@@ -364,7 +374,7 @@ class CodegenTestCase(unittest.TestCase):
else:
g
"""
self.assertAstSourceEqual(source)
self.assertSrcRoundtrips(source)
def test_fstrings(self):
source = """
......@@ -376,16 +386,16 @@ class CodegenTestCase(unittest.TestCase):
x = f'""'
x = f'"\\''
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
self.assertSrcRoundtripsGtVer(source, (3, 6))
source = """
a_really_long_line_will_probably_break_things = (
f'a{b!s:c{d}e}fghijka{b!s:c{d}e}a{b!s:c{d}e}a{b!s:c{d}e}')
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
self.assertSrcRoundtripsGtVer(source, (3, 6))
source = """
return f"functools.{qualname}({', '.join(args)})"
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
self.assertSrcRoundtripsGtVer(source, (3, 6))
def test_annassign(self):
source = """
......@@ -403,37 +413,15 @@ class CodegenTestCase(unittest.TestCase):
(a.b): int = 0
a.b: int = 0
"""
self.assertAstEqualIfAtLeastVersion(source, (3, 6))
self.assertAstRoundtripsGtVer(source, (3, 6))
def test_compile_types(self):
code = '(a + b + c) * (d + e + f)\n'
for mode in 'exec eval single'.split():
srcast = compile(code, 'dummy', mode, ast.PyCF_ONLY_AST)
dsttxt = astor.to_source(srcast)
dsttxt = self.to_source(srcast)
if code.strip() != dsttxt.strip():
self.assertEqual('(%s)' % code.strip(), dsttxt.strip())
def test_deprecation(self):
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
ast1 = astor.code_to_ast.parse_file(__file__)
src1 = astor.to_source(ast1)
ast2 = astor.parsefile(__file__)
src2 = astor.codegen.to_source(ast2)
self.assertEqual(len(w), 2)
w = [warnings.formatwarning(x.message, x.category,
x.filename, x.lineno) for x in w]
w = [x.rsplit(':', 1)[-1].strip() for x in w]
self.assertEqual(w[0], 'astor.parsefile is deprecated. '
'Please use astor.code_to_ast.parse_file.\n'
' ast2 = astor.parsefile(__file__)')
self.assertEqual(w[1], 'astor.codegen is deprecated. '
'Please use astor.code_gen.\n'
' src2 = astor.codegen.to_source(ast2)')
self.assertEqual(src1, src2)
self.assertSrcEqual('(%s)' % code.strip(), dsttxt.strip())
def test_unicode_literals(self):
source = """
......@@ -441,7 +429,7 @@ class CodegenTestCase(unittest.TestCase):
x = b'abc'
y = u'abc'
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_slicing(self):
source = """
......@@ -470,27 +458,13 @@ class CodegenTestCase(unittest.TestCase):
x[1:2,3:4:5]
x[1:2,3:4:-5]
"""
self.assertAstEqual(source)
self.assertAstRoundtrips(source)
def test_non_string_leakage(self):
source = '''
tar_compression = {'gzip': 'gz', None: ''}
'''
self.assertAstEqual(source)
def test_fast_compare(self):
fast_compare = astor.node_util.fast_compare
def check(a, b):
ast_a = ast.parse(a)
ast_b = ast.parse(b)
dump_a = astor.dump_tree(ast_a)
dump_b = astor.dump_tree(ast_b)
self.assertEqual(dump_a == dump_b, fast_compare(ast_a, ast_b))
check('a = 3', 'a = 3')
check('a = 3', 'a = 5')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 5)')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 6)')
self.assertAstRoundtrips(source)
if __name__ == '__main__':
......
import ast
import sys
import warnings
try:
import unittest2 as unittest
except ImportError:
......@@ -17,5 +17,46 @@ class GetSymbolTestCase(unittest.TestCase):
self.assertEqual('@', astor.get_op_symbol(ast.MatMult()))
class DeprecationTestCase(unittest.TestCase):
def test_deprecation(self):
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
ast1 = astor.code_to_ast.parse_file(__file__)
src1 = astor.to_source(ast1)
ast2 = astor.parsefile(__file__)
src2 = astor.codegen.to_source(ast2)
self.assertEqual(len(w), 2)
w = [warnings.formatwarning(x.message, x.category,
x.filename, x.lineno) for x in w]
w = [x.rsplit(':', 1)[-1].strip() for x in w]
self.assertEqual(w[0], 'astor.parsefile is deprecated. '
'Please use astor.code_to_ast.parse_file.\n'
' ast2 = astor.parsefile(__file__)')
self.assertEqual(w[1], 'astor.codegen is deprecated. '
'Please use astor.code_gen.\n'
' src2 = astor.codegen.to_source(ast2)')
self.assertEqual(src1, src2)
class FastCompareTestCase(unittest.TestCase):
def test_fast_compare(self):
fast_compare = astor.node_util.fast_compare
def check(a, b):
ast_a = ast.parse(a)
ast_b = ast.parse(b)
dump_a = astor.dump_tree(ast_a)
dump_b = astor.dump_tree(ast_b)
self.assertEqual(dump_a == dump_b, fast_compare(ast_a, ast_b))
check('a = 3', 'a = 3')
check('a = 3', 'a = 5')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 5)')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 6)')
if __name__ == '__main__':
unittest.main()
"""
Part of the astor library for Python AST manipulation
License: 3-clause BSD
Copyright (c) 2014 Berker Peksag
Copyright (c) 2015, 2017 Patrick Maupin
Use this by putting a link to astunparse's common.py test file.
"""
try:
import unittest2 as unittest
except ImportError:
import unittest
try:
from test_code_gen import Comparisons
except ImportError:
from .test_code_gen import Comparisons
try:
from astunparse_common import AstunparseCommonTestCase
except ImportError:
AstunparseCommonTestCase = None
if AstunparseCommonTestCase is not None:
class UnparseTestCase(AstunparseCommonTestCase, unittest.TestCase,
Comparisons):
def check_roundtrip(self, code1, mode=None):
self.assertAstRoundtrips(code1)
def test_files(self):
""" Don't bother -- we do this manually and more thoroughly """
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment