Kaydet (Commit) aebce9b0 authored tarafından Patrick Maupin's avatar Patrick Maupin Kaydeden (comit) GitHub

Enhance code generator performance (#70)

* Enhance code generator performance
- Algorithmic gains in pretty printer and string prettifier
- Diminishing gains from a few closures in code_gen
- One code-gen fix (and a test case) for None leaking into code-gen result list.
- rtrip gains via new fast comparison function
- PEP8 fixes
üst 7189366e
......@@ -12,11 +12,11 @@ Copyright 2013 (c) Berker Peksag
import warnings
from .code_gen import to_source # NOQA
from .node_util import iter_node, strip_tree, dump_tree
from .node_util import ExplicitNodeVisitor
from .node_util import iter_node, strip_tree, dump_tree # NOQA
from .node_util import ExplicitNodeVisitor # NOQA
from .file_util import CodeToAst, code_to_ast # NOQA
from .op_util import get_op_symbol, get_op_precedence # NOQA
from .op_util import symbol_data
from .op_util import symbol_data # NOQA
from .tree_walk import TreeWalk # NOQA
__version__ = '0.6'
......@@ -44,6 +44,7 @@ codegen = code_gen
exec(deprecated)
def deprecate():
def wrap(deprecated_name, target_name):
if '.' in target_name:
......@@ -51,7 +52,8 @@ def deprecate():
target_func = getattr(globals()[target_mod], target_fname)
else:
target_func = globals()[target_name]
msg = "astor.%s is deprecated. Please use astor.%s." % (deprecated_name, target_name)
msg = "astor.%s is deprecated. Please use astor.%s." % (
deprecated_name, target_name)
if callable(target_func):
def newfunc(*args, **kwarg):
warnings.warn(msg, DeprecationWarning, stacklevel=2)
......@@ -65,13 +67,14 @@ def deprecate():
globals()[deprecated_name] = newfunc
for line in deprecated.splitlines():
for line in deprecated.splitlines(): # NOQA
line = line.split('#')[0].replace('=', '').split()
if line:
target_name = line.pop()
for deprecated_name in line:
wrap(deprecated_name, target_name)
deprecate()
del deprecate, deprecated
This diff is collapsed.
......@@ -101,4 +101,5 @@ class CodeToAst(object):
cache[(fname, obj.lineno)] = obj
return cache[key]
code_to_ast = CodeToAst()
......@@ -13,6 +13,12 @@ For a whole-tree approach, see the treewalk submodule.
"""
import ast
import itertools
try:
zip_longest = itertools.zip_longest
except AttributeError:
zip_longest = itertools.izip_longest
class NonExistent(object):
......@@ -163,3 +169,40 @@ def allow_ast_comparison():
item.__bases__ = tuple(list(item.__bases__) + [CompareHelper])
except TypeError:
pass
def fast_compare(tree1, tree2):
""" This is optimized to compare two AST trees for equality.
It makes several assumptions that are currently true for
AST trees used by rtrip, and it doesn't examine the _attributes.
"""
geta = ast.AST.__getattribute__
work = [(tree1, tree2)]
pop = work.pop
extend = work.extend
# TypeError in cPython, AttributeError in PyPy
exception = TypeError, AttributeError
zipl = zip_longest
type_ = type
list_ = list
while work:
n1, n2 = pop()
try:
f1 = geta(n1, '_fields')
f2 = geta(n2, '_fields')
except exception:
if type_(n1) is list_:
extend(zipl(n1, n2))
continue
if n1 == n2:
continue
return False
else:
f1 = [x for x in f1 if x != 'ctx']
if f1 != [x for x in f2 if x != 'ctx']:
return False
extend((geta(n1, fname), geta(n2, fname)) for fname in f1)
return True
......@@ -78,19 +78,22 @@ import logging
from astor.code_gen import to_source
from astor.file_util import code_to_ast
from astor.node_util import allow_ast_comparison, dump_tree, strip_tree
from astor.node_util import (allow_ast_comparison, dump_tree,
strip_tree, fast_compare)
dsttree = 'tmp_rtrip'
def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False,
ignore_exceptions=False, fullcomp=False):
"""Walk the srctree, and convert/copy all python files
into the dsttree
"""
allow_ast_comparison()
if fullcomp:
allow_ast_comparison()
parse_file = code_to_ast.parse_file
find_py_files = code_to_ast.find_py_files
......@@ -134,7 +137,12 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
badfiles.add(srcfname)
continue
dsttxt = to_source(srcast)
try:
dsttxt = to_source(srcast)
except:
if not ignore_exceptions:
raise
dsttxt = ''
if not readonly:
dstfname = os.path.join(dstpath, fname)
......@@ -150,12 +158,15 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname)
except SyntaxError:
dstast = []
unknown_src_nodes.update(strip_tree(srcast))
unknown_dst_nodes.update(strip_tree(dstast))
if dumpall or srcast != dstast:
if fullcomp:
unknown_src_nodes.update(strip_tree(srcast))
unknown_dst_nodes.update(strip_tree(dstast))
bad = srcast != dstast
else:
bad = not fast_compare(srcast, dstast)
if dumpall or bad:
srcdump = dump_tree(srcast)
dstdump = dump_tree(dstast)
bad = srcdump != dstdump
logging.warning(' calculating dump -- %s' %
('bad' if bad else 'OK'))
if bad:
......@@ -230,6 +241,7 @@ def usage(msg):
""") % msg)
if __name__ == '__main__':
import textwrap
......
......@@ -21,21 +21,7 @@ def pretty_source(source):
""" Prettify the source.
"""
return ''.join(flatten(split_lines(source)))
def flatten(source, list=list, isinstance=isinstance):
""" Deal with nested lists
"""
def flatten_iter(source):
for item in source:
if isinstance(item, list):
for item in flatten_iter(item):
yield item
else:
yield item
return flatten_iter(source)
return ''.join(split_lines(source))
def split_lines(source, maxline=79):
......@@ -43,38 +29,46 @@ def split_lines(source, maxline=79):
If a line is short enough, just yield it.
Otherwise, fix it.
"""
result = []
extend = result.extend
append = result.append
line = []
multiline = False
count = 0
find = str.find
for item in source:
if item.startswith('\n'):
index = find(item, '\n')
if index:
line.append(item)
multiline = index > 0
count += len(item)
else:
if line:
if count <= maxline or multiline:
yield line
extend(line)
else:
for item2 in wrap_line(line, maxline):
yield item2
wrap_line(line, maxline, result)
count = 0
multiline = False
line = []
yield item
else:
line.append(item)
multiline = '\n' in item
count += len(item)
append(item)
return result
def count(group):
return sum(len(x) for x in group)
def count(group, slen=str.__len__):
return sum([slen(x) for x in group])
def wrap_line(line, maxline=79, count=count):
def wrap_line(line, maxline=79, result=[], count=count):
""" We have a line that is too long,
so we're going to try to wrap it.
"""
# Extract the indentation
append = result.append
extend = result.extend
indentation = line[0]
lenfirst = len(indentation)
indent = lenfirst - len(indentation.strip())
......@@ -100,10 +94,10 @@ def wrap_line(line, maxline=79, count=count):
# then set up to deal with the remainder in pairs.
first = unsplittable[0]
yield indentation
yield first
append(indentation)
extend(first)
if not splittable:
return
return result
pos = indent + count(first)
indentation += ' '
indent += 4
......@@ -116,8 +110,8 @@ def wrap_line(line, maxline=79, count=count):
# If we already have stuff on the line and even
# the very first item won't fit, start a new line
if pos > indent and pos + len(sg[0]) > maxline:
yield '\n'
yield indentation
append('\n')
append(indentation)
pos = indent
# Dump lines out of the splittable group
......@@ -127,25 +121,25 @@ def wrap_line(line, maxline=79, count=count):
ready, sg = split_group(sg, pos, maxline)
if ready[-1].endswith(' '):
ready[-1] = ready[-1][:-1]
yield ready
yield '\n'
yield indentation
extend(ready)
append('\n')
append(indentation)
pos = indent
csg = count(sg)
# Dump the remainder of the splittable group
if sg:
yield sg
extend(sg)
pos += csg
# Dump the unsplittable group, optionally
# preceded by a linefeed.
cnsg = count(nsg)
if pos > indent and pos + cnsg > maxline:
yield '\n'
yield indentation
append('\n')
append(indentation)
pos = indent
yield nsg
extend(nsg)
pos += cnsg
......@@ -215,6 +209,7 @@ def delimiter_groups(line, begin_delim=begin_delim,
assert not text, text
break
statements = set(['del ', 'return', 'yield ', 'if ', 'while '])
......@@ -259,6 +254,7 @@ def add_parens(line, maxline, indent, statements=statements, count=count):
return [item for group in groups for item in group]
# Assignment operators
ops = list('|^&+-*/%@~') + '<< >> // **'.split() + ['']
ops = set(' %s= ' % x for x in ops)
......
......@@ -18,7 +18,6 @@ This has lots of Python 2 / Python 3 ugliness.
"""
import re
import logging
try:
special_unicode = unicode
......@@ -32,28 +31,7 @@ except NameError:
basestring = str
def _get_line(current_output):
""" Back up in the output buffer to
find the start of the current line,
and return the entire line.
"""
myline = []
index = len(current_output)
while index:
index -= 1
try:
s = str(current_output[index])
except:
raise
myline.append(s)
if '\n' in s:
break
myline = ''.join(reversed(myline))
return myline.rsplit('\n', 1)[-1]
def _properly_indented(s, current_line):
line_indent = len(current_line) - len(current_line.lstrip())
def _properly_indented(s, line_indent):
mylist = s.split('\n')[1:]
mylist = [x.rstrip() for x in mylist]
mylist = [x for x in mylist if x]
......@@ -62,6 +40,7 @@ def _properly_indented(s, current_line):
counts = [(len(x) - len(x.lstrip())) for x in mylist]
return counts and min(counts) >= line_indent
mysplit = re.compile(r'(\\|\"\"\"|\"$)').split
replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'}
......@@ -76,8 +55,8 @@ def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements):
return ''.join(s)
def pretty_string(s, embedded, current_output, min_trip_str=20,
max_line=100, uni_lit=False):
def pretty_string(s, embedded, current_line, uni_lit=False,
min_trip_str=20, max_line=100):
"""There are a lot of reasons why we might not want to or
be able to return a triple-quoted string. We can always
punt back to the default normal string.
......@@ -92,16 +71,24 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return 'b' + default
len_s = len(default)
current_line = _get_line(current_output)
if current_line.strip():
if embedded and '\n' not in s:
len_current = len(current_line)
second_line_start = s.find('\n') + 1
if embedded > 1 and not second_line_start:
return default
if len_s < min_trip_str:
return default
total_len = len(current_line) + len_s
if total_len < max_line and not _properly_indented(s, current_line):
line_indent = len_current - len(current_line.lstrip())
# Could be on a line by itself...
if embedded and not second_line_start:
return default
total_len = len_current + len_s
if total_len < max_line and not _properly_indented(s, line_indent):
return default
fancy = '"""%s"""' % _prep_triple_quotes(s)
......@@ -116,11 +103,4 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return fancy
except:
pass
"""
logging.warning("***String conversion did not work\n")
#print (eval(fancy), s)
print
print (fancy, repr(s))
print
"""
return default
......@@ -33,6 +33,7 @@ class MetaFlatten(type):
# Delegate the real work to type
return type.__new__(clstype, name, newbases, newdict)
MetaFlatten = MetaFlatten('MetaFlatten', (object, ), {})
......
......@@ -236,5 +236,6 @@ def makelib():
f.write(lineend)
f.write('"""\n'.encode('utf-8'))
if __name__ == '__main__':
makelib()
......@@ -23,10 +23,6 @@ should not be part of the automated regressions.
"""
import sys
import collections
import itertools
import textwrap
import hashlib
import ast
import astor
......@@ -86,5 +82,6 @@ def checklib():
f.write(('%s %s\n' % (repr(srctxt),
repr(dsttxt))).encode('utf-8'))
if __name__ == '__main__':
checklib()
......@@ -347,8 +347,8 @@ class CodegenTestCase(unittest.TestCase):
'RawDescriptionHelpFormatter', 'RawTextHelpFormatter', 'Namespace',
'Action', 'ONE_OR_MORE', 'OPTIONAL', 'PARSER', 'REMAINDER', 'SUPPRESS',
'ZERO_OR_MORE']
"""
self.maxDiff=2000
""" # NOQA
self.maxDiff = 2000
self.assertAstSourceEqual(source)
def test_elif(self):
......@@ -366,13 +366,13 @@ class CodegenTestCase(unittest.TestCase):
def test_fstrings(self):
source = """
f'{x}'
f'{x.y}'
f'{int(x)}'
f'a{b:c}d'
f'a{b!s:c{d}e}f'
f'""'
f'"\\''
x = f'{x}'
x = f'{x.y}'
x = f'{int(x)}'
x = f'a{b:c}d'
x = f'a{b!s:c{d}e}f'
x = f'""'
x = f'"\\''
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
source = """
......@@ -385,7 +385,6 @@ class CodegenTestCase(unittest.TestCase):
"""
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
def test_annassign(self):
source = """
a: int
......@@ -404,7 +403,6 @@ class CodegenTestCase(unittest.TestCase):
"""
self.assertAstEqualIfAtLeastVersion(source, (3, 6))
def test_compile_types(self):
code = '(a + b + c) * (d + e + f)\n'
for mode in 'exec eval single'.split():
......@@ -472,6 +470,26 @@ class CodegenTestCase(unittest.TestCase):
"""
self.assertAstEqual(source)
def test_non_string_leakage(self):
source = '''
tar_compression = {'gzip': 'gz', None: ''}
'''
self.assertAstEqual(source)
def test_fast_compare(self):
fast_compare = astor.node_util.fast_compare
def check(a, b):
ast_a = ast.parse(a)
ast_b = ast.parse(b)
dump_a = astor.dump_tree(ast_a)
dump_b = astor.dump_tree(ast_b)
self.assertEqual(dump_a == dump_b, fast_compare(ast_a, ast_b))
check('a = 3', 'a = 3')
check('a = 3', 'a = 5')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 5)')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 6)')
if __name__ == '__main__':
unittest.main()
......@@ -16,5 +16,6 @@ class GetSymbolTestCase(unittest.TestCase):
def test_get_mat_mult(self):
self.assertEqual('@', astor.get_op_symbol(ast.MatMult()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment