Kaydet (Commit) aebce9b0 authored tarafından Patrick Maupin's avatar Patrick Maupin Kaydeden (comit) GitHub

Enhance code generator performance (#70)

* Enhance code generator performance
- Algorithmic gains in pretty printer and string prettifier
- Diminishing gains from a few closures in code_gen
- One code-gen fix (and a test case) for None leaking into code-gen result list.
- rtrip gains via new fast comparison function
- PEP8 fixes
üst 7189366e
...@@ -12,11 +12,11 @@ Copyright 2013 (c) Berker Peksag ...@@ -12,11 +12,11 @@ Copyright 2013 (c) Berker Peksag
import warnings import warnings
from .code_gen import to_source # NOQA from .code_gen import to_source # NOQA
from .node_util import iter_node, strip_tree, dump_tree from .node_util import iter_node, strip_tree, dump_tree # NOQA
from .node_util import ExplicitNodeVisitor from .node_util import ExplicitNodeVisitor # NOQA
from .file_util import CodeToAst, code_to_ast # NOQA from .file_util import CodeToAst, code_to_ast # NOQA
from .op_util import get_op_symbol, get_op_precedence # NOQA from .op_util import get_op_symbol, get_op_precedence # NOQA
from .op_util import symbol_data from .op_util import symbol_data # NOQA
from .tree_walk import TreeWalk # NOQA from .tree_walk import TreeWalk # NOQA
__version__ = '0.6' __version__ = '0.6'
...@@ -44,6 +44,7 @@ codegen = code_gen ...@@ -44,6 +44,7 @@ codegen = code_gen
exec(deprecated) exec(deprecated)
def deprecate(): def deprecate():
def wrap(deprecated_name, target_name): def wrap(deprecated_name, target_name):
if '.' in target_name: if '.' in target_name:
...@@ -51,7 +52,8 @@ def deprecate(): ...@@ -51,7 +52,8 @@ def deprecate():
target_func = getattr(globals()[target_mod], target_fname) target_func = getattr(globals()[target_mod], target_fname)
else: else:
target_func = globals()[target_name] target_func = globals()[target_name]
msg = "astor.%s is deprecated. Please use astor.%s." % (deprecated_name, target_name) msg = "astor.%s is deprecated. Please use astor.%s." % (
deprecated_name, target_name)
if callable(target_func): if callable(target_func):
def newfunc(*args, **kwarg): def newfunc(*args, **kwarg):
warnings.warn(msg, DeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
...@@ -65,13 +67,14 @@ def deprecate(): ...@@ -65,13 +67,14 @@ def deprecate():
globals()[deprecated_name] = newfunc globals()[deprecated_name] = newfunc
for line in deprecated.splitlines(): for line in deprecated.splitlines(): # NOQA
line = line.split('#')[0].replace('=', '').split() line = line.split('#')[0].replace('=', '').split()
if line: if line:
target_name = line.pop() target_name = line.pop()
for deprecated_name in line: for deprecated_name in line:
wrap(deprecated_name, target_name) wrap(deprecated_name, target_name)
deprecate() deprecate()
del deprecate, deprecated del deprecate, deprecated
This diff is collapsed.
...@@ -101,4 +101,5 @@ class CodeToAst(object): ...@@ -101,4 +101,5 @@ class CodeToAst(object):
cache[(fname, obj.lineno)] = obj cache[(fname, obj.lineno)] = obj
return cache[key] return cache[key]
code_to_ast = CodeToAst() code_to_ast = CodeToAst()
...@@ -13,6 +13,12 @@ For a whole-tree approach, see the treewalk submodule. ...@@ -13,6 +13,12 @@ For a whole-tree approach, see the treewalk submodule.
""" """
import ast import ast
import itertools
try:
zip_longest = itertools.zip_longest
except AttributeError:
zip_longest = itertools.izip_longest
class NonExistent(object): class NonExistent(object):
...@@ -163,3 +169,40 @@ def allow_ast_comparison(): ...@@ -163,3 +169,40 @@ def allow_ast_comparison():
item.__bases__ = tuple(list(item.__bases__) + [CompareHelper]) item.__bases__ = tuple(list(item.__bases__) + [CompareHelper])
except TypeError: except TypeError:
pass pass
def fast_compare(tree1, tree2):
""" This is optimized to compare two AST trees for equality.
It makes several assumptions that are currently true for
AST trees used by rtrip, and it doesn't examine the _attributes.
"""
geta = ast.AST.__getattribute__
work = [(tree1, tree2)]
pop = work.pop
extend = work.extend
# TypeError in cPython, AttributeError in PyPy
exception = TypeError, AttributeError
zipl = zip_longest
type_ = type
list_ = list
while work:
n1, n2 = pop()
try:
f1 = geta(n1, '_fields')
f2 = geta(n2, '_fields')
except exception:
if type_(n1) is list_:
extend(zipl(n1, n2))
continue
if n1 == n2:
continue
return False
else:
f1 = [x for x in f1 if x != 'ctx']
if f1 != [x for x in f2 if x != 'ctx']:
return False
extend((geta(n1, fname), geta(n2, fname)) for fname in f1)
return True
...@@ -78,19 +78,22 @@ import logging ...@@ -78,19 +78,22 @@ import logging
from astor.code_gen import to_source from astor.code_gen import to_source
from astor.file_util import code_to_ast from astor.file_util import code_to_ast
from astor.node_util import allow_ast_comparison, dump_tree, strip_tree from astor.node_util import (allow_ast_comparison, dump_tree,
strip_tree, fast_compare)
dsttree = 'tmp_rtrip' dsttree = 'tmp_rtrip'
def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False,
ignore_exceptions=False, fullcomp=False):
"""Walk the srctree, and convert/copy all python files """Walk the srctree, and convert/copy all python files
into the dsttree into the dsttree
""" """
allow_ast_comparison() if fullcomp:
allow_ast_comparison()
parse_file = code_to_ast.parse_file parse_file = code_to_ast.parse_file
find_py_files = code_to_ast.find_py_files find_py_files = code_to_ast.find_py_files
...@@ -134,7 +137,12 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): ...@@ -134,7 +137,12 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
badfiles.add(srcfname) badfiles.add(srcfname)
continue continue
dsttxt = to_source(srcast) try:
dsttxt = to_source(srcast)
except:
if not ignore_exceptions:
raise
dsttxt = ''
if not readonly: if not readonly:
dstfname = os.path.join(dstpath, fname) dstfname = os.path.join(dstpath, fname)
...@@ -150,12 +158,15 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): ...@@ -150,12 +158,15 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname) dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname)
except SyntaxError: except SyntaxError:
dstast = [] dstast = []
unknown_src_nodes.update(strip_tree(srcast)) if fullcomp:
unknown_dst_nodes.update(strip_tree(dstast)) unknown_src_nodes.update(strip_tree(srcast))
if dumpall or srcast != dstast: unknown_dst_nodes.update(strip_tree(dstast))
bad = srcast != dstast
else:
bad = not fast_compare(srcast, dstast)
if dumpall or bad:
srcdump = dump_tree(srcast) srcdump = dump_tree(srcast)
dstdump = dump_tree(dstast) dstdump = dump_tree(dstast)
bad = srcdump != dstdump
logging.warning(' calculating dump -- %s' % logging.warning(' calculating dump -- %s' %
('bad' if bad else 'OK')) ('bad' if bad else 'OK'))
if bad: if bad:
...@@ -230,6 +241,7 @@ def usage(msg): ...@@ -230,6 +241,7 @@ def usage(msg):
""") % msg) """) % msg)
if __name__ == '__main__': if __name__ == '__main__':
import textwrap import textwrap
......
...@@ -21,21 +21,7 @@ def pretty_source(source): ...@@ -21,21 +21,7 @@ def pretty_source(source):
""" Prettify the source. """ Prettify the source.
""" """
return ''.join(flatten(split_lines(source))) return ''.join(split_lines(source))
def flatten(source, list=list, isinstance=isinstance):
""" Deal with nested lists
"""
def flatten_iter(source):
for item in source:
if isinstance(item, list):
for item in flatten_iter(item):
yield item
else:
yield item
return flatten_iter(source)
def split_lines(source, maxline=79): def split_lines(source, maxline=79):
...@@ -43,38 +29,46 @@ def split_lines(source, maxline=79): ...@@ -43,38 +29,46 @@ def split_lines(source, maxline=79):
If a line is short enough, just yield it. If a line is short enough, just yield it.
Otherwise, fix it. Otherwise, fix it.
""" """
result = []
extend = result.extend
append = result.append
line = [] line = []
multiline = False multiline = False
count = 0 count = 0
find = str.find
for item in source: for item in source:
if item.startswith('\n'): index = find(item, '\n')
if index:
line.append(item)
multiline = index > 0
count += len(item)
else:
if line: if line:
if count <= maxline or multiline: if count <= maxline or multiline:
yield line extend(line)
else: else:
for item2 in wrap_line(line, maxline): wrap_line(line, maxline, result)
yield item2
count = 0 count = 0
multiline = False multiline = False
line = [] line = []
yield item append(item)
else: return result
line.append(item)
multiline = '\n' in item
count += len(item)
def count(group): def count(group, slen=str.__len__):
return sum(len(x) for x in group) return sum([slen(x) for x in group])
def wrap_line(line, maxline=79, count=count): def wrap_line(line, maxline=79, result=[], count=count):
""" We have a line that is too long, """ We have a line that is too long,
so we're going to try to wrap it. so we're going to try to wrap it.
""" """
# Extract the indentation # Extract the indentation
append = result.append
extend = result.extend
indentation = line[0] indentation = line[0]
lenfirst = len(indentation) lenfirst = len(indentation)
indent = lenfirst - len(indentation.strip()) indent = lenfirst - len(indentation.strip())
...@@ -100,10 +94,10 @@ def wrap_line(line, maxline=79, count=count): ...@@ -100,10 +94,10 @@ def wrap_line(line, maxline=79, count=count):
# then set up to deal with the remainder in pairs. # then set up to deal with the remainder in pairs.
first = unsplittable[0] first = unsplittable[0]
yield indentation append(indentation)
yield first extend(first)
if not splittable: if not splittable:
return return result
pos = indent + count(first) pos = indent + count(first)
indentation += ' ' indentation += ' '
indent += 4 indent += 4
...@@ -116,8 +110,8 @@ def wrap_line(line, maxline=79, count=count): ...@@ -116,8 +110,8 @@ def wrap_line(line, maxline=79, count=count):
# If we already have stuff on the line and even # If we already have stuff on the line and even
# the very first item won't fit, start a new line # the very first item won't fit, start a new line
if pos > indent and pos + len(sg[0]) > maxline: if pos > indent and pos + len(sg[0]) > maxline:
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
# Dump lines out of the splittable group # Dump lines out of the splittable group
...@@ -127,25 +121,25 @@ def wrap_line(line, maxline=79, count=count): ...@@ -127,25 +121,25 @@ def wrap_line(line, maxline=79, count=count):
ready, sg = split_group(sg, pos, maxline) ready, sg = split_group(sg, pos, maxline)
if ready[-1].endswith(' '): if ready[-1].endswith(' '):
ready[-1] = ready[-1][:-1] ready[-1] = ready[-1][:-1]
yield ready extend(ready)
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
csg = count(sg) csg = count(sg)
# Dump the remainder of the splittable group # Dump the remainder of the splittable group
if sg: if sg:
yield sg extend(sg)
pos += csg pos += csg
# Dump the unsplittable group, optionally # Dump the unsplittable group, optionally
# preceded by a linefeed. # preceded by a linefeed.
cnsg = count(nsg) cnsg = count(nsg)
if pos > indent and pos + cnsg > maxline: if pos > indent and pos + cnsg > maxline:
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
yield nsg extend(nsg)
pos += cnsg pos += cnsg
...@@ -215,6 +209,7 @@ def delimiter_groups(line, begin_delim=begin_delim, ...@@ -215,6 +209,7 @@ def delimiter_groups(line, begin_delim=begin_delim,
assert not text, text assert not text, text
break break
statements = set(['del ', 'return', 'yield ', 'if ', 'while ']) statements = set(['del ', 'return', 'yield ', 'if ', 'while '])
...@@ -259,6 +254,7 @@ def add_parens(line, maxline, indent, statements=statements, count=count): ...@@ -259,6 +254,7 @@ def add_parens(line, maxline, indent, statements=statements, count=count):
return [item for group in groups for item in group] return [item for group in groups for item in group]
# Assignment operators # Assignment operators
ops = list('|^&+-*/%@~') + '<< >> // **'.split() + [''] ops = list('|^&+-*/%@~') + '<< >> // **'.split() + ['']
ops = set(' %s= ' % x for x in ops) ops = set(' %s= ' % x for x in ops)
......
...@@ -18,7 +18,6 @@ This has lots of Python 2 / Python 3 ugliness. ...@@ -18,7 +18,6 @@ This has lots of Python 2 / Python 3 ugliness.
""" """
import re import re
import logging
try: try:
special_unicode = unicode special_unicode = unicode
...@@ -32,28 +31,7 @@ except NameError: ...@@ -32,28 +31,7 @@ except NameError:
basestring = str basestring = str
def _get_line(current_output): def _properly_indented(s, line_indent):
""" Back up in the output buffer to
find the start of the current line,
and return the entire line.
"""
myline = []
index = len(current_output)
while index:
index -= 1
try:
s = str(current_output[index])
except:
raise
myline.append(s)
if '\n' in s:
break
myline = ''.join(reversed(myline))
return myline.rsplit('\n', 1)[-1]
def _properly_indented(s, current_line):
line_indent = len(current_line) - len(current_line.lstrip())
mylist = s.split('\n')[1:] mylist = s.split('\n')[1:]
mylist = [x.rstrip() for x in mylist] mylist = [x.rstrip() for x in mylist]
mylist = [x for x in mylist if x] mylist = [x for x in mylist if x]
...@@ -62,6 +40,7 @@ def _properly_indented(s, current_line): ...@@ -62,6 +40,7 @@ def _properly_indented(s, current_line):
counts = [(len(x) - len(x.lstrip())) for x in mylist] counts = [(len(x) - len(x.lstrip())) for x in mylist]
return counts and min(counts) >= line_indent return counts and min(counts) >= line_indent
mysplit = re.compile(r'(\\|\"\"\"|\"$)').split mysplit = re.compile(r'(\\|\"\"\"|\"$)').split
replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'} replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'}
...@@ -76,8 +55,8 @@ def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements): ...@@ -76,8 +55,8 @@ def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements):
return ''.join(s) return ''.join(s)
def pretty_string(s, embedded, current_output, min_trip_str=20, def pretty_string(s, embedded, current_line, uni_lit=False,
max_line=100, uni_lit=False): min_trip_str=20, max_line=100):
"""There are a lot of reasons why we might not want to or """There are a lot of reasons why we might not want to or
be able to return a triple-quoted string. We can always be able to return a triple-quoted string. We can always
punt back to the default normal string. punt back to the default normal string.
...@@ -92,16 +71,24 @@ def pretty_string(s, embedded, current_output, min_trip_str=20, ...@@ -92,16 +71,24 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return 'b' + default return 'b' + default
len_s = len(default) len_s = len(default)
current_line = _get_line(current_output)
if current_line.strip(): if current_line.strip():
if embedded and '\n' not in s: len_current = len(current_line)
second_line_start = s.find('\n') + 1
if embedded > 1 and not second_line_start:
return default return default
if len_s < min_trip_str: if len_s < min_trip_str:
return default return default
total_len = len(current_line) + len_s line_indent = len_current - len(current_line.lstrip())
if total_len < max_line and not _properly_indented(s, current_line):
# Could be on a line by itself...
if embedded and not second_line_start:
return default
total_len = len_current + len_s
if total_len < max_line and not _properly_indented(s, line_indent):
return default return default
fancy = '"""%s"""' % _prep_triple_quotes(s) fancy = '"""%s"""' % _prep_triple_quotes(s)
...@@ -116,11 +103,4 @@ def pretty_string(s, embedded, current_output, min_trip_str=20, ...@@ -116,11 +103,4 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return fancy return fancy
except: except:
pass pass
"""
logging.warning("***String conversion did not work\n")
#print (eval(fancy), s)
print
print (fancy, repr(s))
print
"""
return default return default
...@@ -33,6 +33,7 @@ class MetaFlatten(type): ...@@ -33,6 +33,7 @@ class MetaFlatten(type):
# Delegate the real work to type # Delegate the real work to type
return type.__new__(clstype, name, newbases, newdict) return type.__new__(clstype, name, newbases, newdict)
MetaFlatten = MetaFlatten('MetaFlatten', (object, ), {}) MetaFlatten = MetaFlatten('MetaFlatten', (object, ), {})
......
...@@ -236,5 +236,6 @@ def makelib(): ...@@ -236,5 +236,6 @@ def makelib():
f.write(lineend) f.write(lineend)
f.write('"""\n'.encode('utf-8')) f.write('"""\n'.encode('utf-8'))
if __name__ == '__main__': if __name__ == '__main__':
makelib() makelib()
...@@ -23,10 +23,6 @@ should not be part of the automated regressions. ...@@ -23,10 +23,6 @@ should not be part of the automated regressions.
""" """
import sys import sys
import collections
import itertools
import textwrap
import hashlib
import ast import ast
import astor import astor
...@@ -86,5 +82,6 @@ def checklib(): ...@@ -86,5 +82,6 @@ def checklib():
f.write(('%s %s\n' % (repr(srctxt), f.write(('%s %s\n' % (repr(srctxt),
repr(dsttxt))).encode('utf-8')) repr(dsttxt))).encode('utf-8'))
if __name__ == '__main__': if __name__ == '__main__':
checklib() checklib()
...@@ -347,8 +347,8 @@ class CodegenTestCase(unittest.TestCase): ...@@ -347,8 +347,8 @@ class CodegenTestCase(unittest.TestCase):
'RawDescriptionHelpFormatter', 'RawTextHelpFormatter', 'Namespace', 'RawDescriptionHelpFormatter', 'RawTextHelpFormatter', 'Namespace',
'Action', 'ONE_OR_MORE', 'OPTIONAL', 'PARSER', 'REMAINDER', 'SUPPRESS', 'Action', 'ONE_OR_MORE', 'OPTIONAL', 'PARSER', 'REMAINDER', 'SUPPRESS',
'ZERO_OR_MORE'] 'ZERO_OR_MORE']
""" """ # NOQA
self.maxDiff=2000 self.maxDiff = 2000
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
def test_elif(self): def test_elif(self):
...@@ -366,13 +366,13 @@ class CodegenTestCase(unittest.TestCase): ...@@ -366,13 +366,13 @@ class CodegenTestCase(unittest.TestCase):
def test_fstrings(self): def test_fstrings(self):
source = """ source = """
f'{x}' x = f'{x}'
f'{x.y}' x = f'{x.y}'
f'{int(x)}' x = f'{int(x)}'
f'a{b:c}d' x = f'a{b:c}d'
f'a{b!s:c{d}e}f' x = f'a{b!s:c{d}e}f'
f'""' x = f'""'
f'"\\'' x = f'"\\''
""" """
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6)) self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
source = """ source = """
...@@ -385,7 +385,6 @@ class CodegenTestCase(unittest.TestCase): ...@@ -385,7 +385,6 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6)) self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
def test_annassign(self): def test_annassign(self):
source = """ source = """
a: int a: int
...@@ -404,7 +403,6 @@ class CodegenTestCase(unittest.TestCase): ...@@ -404,7 +403,6 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstEqualIfAtLeastVersion(source, (3, 6)) self.assertAstEqualIfAtLeastVersion(source, (3, 6))
def test_compile_types(self): def test_compile_types(self):
code = '(a + b + c) * (d + e + f)\n' code = '(a + b + c) * (d + e + f)\n'
for mode in 'exec eval single'.split(): for mode in 'exec eval single'.split():
...@@ -472,6 +470,26 @@ class CodegenTestCase(unittest.TestCase): ...@@ -472,6 +470,26 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstEqual(source) self.assertAstEqual(source)
def test_non_string_leakage(self):
source = '''
tar_compression = {'gzip': 'gz', None: ''}
'''
self.assertAstEqual(source)
def test_fast_compare(self):
fast_compare = astor.node_util.fast_compare
def check(a, b):
ast_a = ast.parse(a)
ast_b = ast.parse(b)
dump_a = astor.dump_tree(ast_a)
dump_b = astor.dump_tree(ast_b)
self.assertEqual(dump_a == dump_b, fast_compare(ast_a, ast_b))
check('a = 3', 'a = 3')
check('a = 3', 'a = 5')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 5)')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 6)')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,5 +16,6 @@ class GetSymbolTestCase(unittest.TestCase): ...@@ -16,5 +16,6 @@ class GetSymbolTestCase(unittest.TestCase):
def test_get_mat_mult(self): def test_get_mat_mult(self):
self.assertEqual('@', astor.get_op_symbol(ast.MatMult())) self.assertEqual('@', astor.get_op_symbol(ast.MatMult()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment