Kaydet (Commit) aebce9b0 authored tarafından Patrick Maupin's avatar Patrick Maupin Kaydeden (comit) GitHub

Enhance code generator performance (#70)

* Enhance code generator performance
- Algorithmic gains in pretty printer and string prettifier
- Diminishing gains from a few closures in code_gen
- One code-gen fix (and a test case) for None leaking into code-gen result list.
- rtrip gains via new fast comparison function
- PEP8 fixes
üst 7189366e
...@@ -12,11 +12,11 @@ Copyright 2013 (c) Berker Peksag ...@@ -12,11 +12,11 @@ Copyright 2013 (c) Berker Peksag
import warnings import warnings
from .code_gen import to_source # NOQA from .code_gen import to_source # NOQA
from .node_util import iter_node, strip_tree, dump_tree from .node_util import iter_node, strip_tree, dump_tree # NOQA
from .node_util import ExplicitNodeVisitor from .node_util import ExplicitNodeVisitor # NOQA
from .file_util import CodeToAst, code_to_ast # NOQA from .file_util import CodeToAst, code_to_ast # NOQA
from .op_util import get_op_symbol, get_op_precedence # NOQA from .op_util import get_op_symbol, get_op_precedence # NOQA
from .op_util import symbol_data from .op_util import symbol_data # NOQA
from .tree_walk import TreeWalk # NOQA from .tree_walk import TreeWalk # NOQA
__version__ = '0.6' __version__ = '0.6'
...@@ -44,6 +44,7 @@ codegen = code_gen ...@@ -44,6 +44,7 @@ codegen = code_gen
exec(deprecated) exec(deprecated)
def deprecate(): def deprecate():
def wrap(deprecated_name, target_name): def wrap(deprecated_name, target_name):
if '.' in target_name: if '.' in target_name:
...@@ -51,7 +52,8 @@ def deprecate(): ...@@ -51,7 +52,8 @@ def deprecate():
target_func = getattr(globals()[target_mod], target_fname) target_func = getattr(globals()[target_mod], target_fname)
else: else:
target_func = globals()[target_name] target_func = globals()[target_name]
msg = "astor.%s is deprecated. Please use astor.%s." % (deprecated_name, target_name) msg = "astor.%s is deprecated. Please use astor.%s." % (
deprecated_name, target_name)
if callable(target_func): if callable(target_func):
def newfunc(*args, **kwarg): def newfunc(*args, **kwarg):
warnings.warn(msg, DeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
...@@ -65,13 +67,14 @@ def deprecate(): ...@@ -65,13 +67,14 @@ def deprecate():
globals()[deprecated_name] = newfunc globals()[deprecated_name] = newfunc
for line in deprecated.splitlines(): for line in deprecated.splitlines(): # NOQA
line = line.split('#')[0].replace('=', '').split() line = line.split('#')[0].replace('=', '').split()
if line: if line:
target_name = line.pop() target_name = line.pop()
for deprecated_name in line: for deprecated_name in line:
wrap(deprecated_name, target_name) wrap(deprecated_name, target_name)
deprecate() deprecate()
del deprecate, deprecated del deprecate, deprecated
...@@ -5,8 +5,8 @@ Part of the astor library for Python AST manipulation. ...@@ -5,8 +5,8 @@ Part of the astor library for Python AST manipulation.
License: 3-clause BSD License: 3-clause BSD
Copyright (c) 2008 Armin Ronacher Copyright (c) 2008 Armin Ronacher
Copyright (c) 2012-2015 Patrick Maupin Copyright (c) 2012-2017 Patrick Maupin
Copyright (c) 2013-2015 Berker Peksag Copyright (c) 2013-2017 Berker Peksag
This module converts an AST into Python source code. This module converts an AST into Python source code.
...@@ -50,21 +50,35 @@ def to_source(node, indent_with=' ' * 4, add_line_information=False, ...@@ -50,21 +50,35 @@ def to_source(node, indent_with=' ' * 4, add_line_information=False,
pretty_string) pretty_string)
generator.visit(node) generator.visit(node)
generator.result.append('\n') generator.result.append('\n')
return pretty_source(str(s) for s in generator.result) if set(generator.result[0]) == set('\n'):
generator.result[0] = ''
return pretty_source(generator.result)
def set_precedence(value, *nodes): def precedence_setter(AST=ast.AST, get_op_precedence=get_op_precedence,
"""Set the precedence (of the parent) into the children. isinstance=isinstance, list=list):
""" This only uses a closure for performance reasons,
to reduce the number of attribute lookups. (set_precedence
is called a lot of times.)
""" """
if isinstance(value, ast.AST):
value = get_op_precedence(value) def set_precedence(value, *nodes):
for node in nodes: """Set the precedence (of the parent) into the children.
if isinstance(node, ast.AST): """
node._pp = value if isinstance(value, AST):
elif isinstance(node, list): value = get_op_precedence(value)
set_precedence(value, *node) for node in nodes:
else: if isinstance(node, AST):
assert node is None, node node._pp = value
elif isinstance(node, list):
set_precedence(value, *node)
else:
assert node is None, node
return set_precedence
set_precedence = precedence_setter()
class Delimit(object): class Delimit(object):
...@@ -76,7 +90,6 @@ class Delimit(object): ...@@ -76,7 +90,6 @@ class Delimit(object):
""" """
discard = False discard = False
coalesce = False
def __init__(self, tree, *args): def __init__(self, tree, *args):
""" use write instead of using result directly """ use write instead of using result directly
...@@ -113,8 +126,7 @@ class Delimit(object): ...@@ -113,8 +126,7 @@ class Delimit(object):
result[start] = '' result[start] = ''
else: else:
result.append(self.closing) result.append(self.closing)
if self.coalesce:
result[start:] = [''.join(result[start:])]
class SourceGenerator(ExplicitNodeVisitor): class SourceGenerator(ExplicitNodeVisitor):
"""This visitor is able to transform a well formed syntax tree into Python """This visitor is able to transform a well formed syntax tree into Python
...@@ -128,13 +140,45 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -128,13 +140,45 @@ class SourceGenerator(ExplicitNodeVisitor):
using_unicode_literals = False using_unicode_literals = False
def __init__(self, indent_with, add_line_information=False, def __init__(self, indent_with, add_line_information=False,
pretty_string=pretty_string): pretty_string=pretty_string,
# constants
len=len, isinstance=isinstance, callable=callable):
self.result = [] self.result = []
self.indent_with = indent_with self.indent_with = indent_with
self.add_line_information = add_line_information self.add_line_information = add_line_information
self.indentation = 0 self.indentation = 0 # Current indentation level
self.new_lines = 0 self.new_lines = 0 # Number of lines to insert before next code
self.colinfo = 0, 0 # index in result of string containing linefeed, and
# position of last linefeed in that string
self.pretty_string = pretty_string self.pretty_string = pretty_string
AST = ast.AST
visit = self.visit
newline = self.newline
result = self.result
append = result.append
def write(*params):
""" self.write is a closure for performance (to reduce the number
of attribute lookups).
"""
for item in params:
if isinstance(item, AST):
visit(item)
elif callable(item):
item()
elif item == '\n':
newline()
else:
if self.new_lines:
append('\n' * self.new_lines)
self.colinfo = len(result), 0
append(self.indent_with * self.indentation)
self.new_lines = 0
if item:
append(item)
self.write = write
def __getattr__(self, name, defaults=dict(keywords=(), def __getattr__(self, name, defaults=dict(keywords=(),
_pp=Precedence.highest).get): _pp=Precedence.highest).get):
...@@ -156,22 +200,6 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -156,22 +200,6 @@ class SourceGenerator(ExplicitNodeVisitor):
def delimit(self, *args): def delimit(self, *args):
return Delimit(self, *args) return Delimit(self, *args)
def write(self, *params):
for item in params:
if isinstance(item, ast.AST):
self.visit(item)
elif hasattr(item, '__call__'):
item()
elif item == '\n':
self.newline()
elif item != '':
if self.new_lines:
if self.result:
self.result.append('\n' * self.new_lines)
self.result.append(self.indent_with * self.indentation)
self.new_lines = 0
self.result.append(item)
def conditional_write(self, *stuff): def conditional_write(self, *stuff):
if stuff[-1] is not None: if stuff[-1] is not None:
self.write(*stuff) self.write(*stuff)
...@@ -268,7 +296,7 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -268,7 +296,7 @@ class SourceGenerator(ExplicitNodeVisitor):
self.comma_list(node.names) self.comma_list(node.names)
# Goofy stuff for Python 2.7 _pyio module # Goofy stuff for Python 2.7 _pyio module
if node.module == '__future__' and 'unicode_literals' in ( if node.module == '__future__' and 'unicode_literals' in (
x.name for x in node.names): x.name for x in node.names):
self.using_unicode_literals = True self.using_unicode_literals = True
def visit_Import(self, node): def visit_Import(self, node):
...@@ -368,7 +396,7 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -368,7 +396,7 @@ class SourceGenerator(ExplicitNodeVisitor):
self.conditional_write(' as ', node.optional_vars) self.conditional_write(' as ', node.optional_vars)
def visit_NameConstant(self, node): def visit_NameConstant(self, node):
self.write(node.value) self.write(str(node.value))
def visit_Pass(self, node): def visit_Pass(self, node):
self.statement(node, 'pass') self.statement(node, 'pass')
...@@ -459,12 +487,13 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -459,12 +487,13 @@ class SourceGenerator(ExplicitNodeVisitor):
def visit_Attribute(self, node): def visit_Attribute(self, node):
self.write(node.value, '.', node.attr) self.write(node.value, '.', node.attr)
def visit_Call(self, node): def visit_Call(self, node, len=len):
write = self.write
want_comma = [] want_comma = []
def write_comma(): def write_comma():
if want_comma: if want_comma:
self.write(', ') write(', ')
else: else:
want_comma.append(True) want_comma.append(True)
...@@ -478,72 +507,94 @@ class SourceGenerator(ExplicitNodeVisitor): ...@@ -478,72 +507,94 @@ class SourceGenerator(ExplicitNodeVisitor):
p = Precedence.Comma if numargs > 1 else Precedence.call_one_arg p = Precedence.Comma if numargs > 1 else Precedence.call_one_arg
set_precedence(p, *args) set_precedence(p, *args)
self.visit(node.func) self.visit(node.func)
self.write('(') write('(')
for arg in args: for arg in args:
self.write(write_comma, arg) write(write_comma, arg)
set_precedence(Precedence.Comma, *(x.value for x in keywords)) set_precedence(Precedence.Comma, *(x.value for x in keywords))
for keyword in keywords: for keyword in keywords:
# a keyword.arg of None indicates dictionary unpacking # a keyword.arg of None indicates dictionary unpacking
# (Python >= 3.5) # (Python >= 3.5)
arg = keyword.arg or '' arg = keyword.arg or ''
self.write(write_comma, arg, '=' if arg else '**', keyword.value) write(write_comma, arg, '=' if arg else '**', keyword.value)
# 3.5 no longer has these # 3.5 no longer has these
self.conditional_write(write_comma, '*', starargs) self.conditional_write(write_comma, '*', starargs)
self.conditional_write(write_comma, '**', kwargs) self.conditional_write(write_comma, '**', kwargs)
self.write(')') write(')')
def visit_Name(self, node): def visit_Name(self, node):
self.write(node.id) self.write(node.id)
def visit_Str(self, node): def visit_JoinedStr(self, node):
result = self.result self.visit_Str(node, True)
embedded = self.get__pp(node) > Precedence.Expr
# Cheesy way to force a flush def visit_Str(self, node, is_joined=False):
self.write('foo')
result.pop()
result.append(self.pretty_string(node.s, embedded, result,
uni_lit=self.using_unicode_literals))
def visit_JoinedStr(self, node, # embedded is used to control when we might want
# constants # to use a triple-quoted string. We determine
new=sys.version_info >= (3, 0)): # if we are in an assignment and/or in an expression
precedence = self.get__pp(node)
embedded = ((precedence > Precedence.Expr) +
(precedence >= Precedence.Assign))
def recurse(node): # Flush any pending newlines, because we're about
for value in node.values: # to severely abuse the result list.
if isinstance(value, ast.Str): self.write('')
if new: result = self.result
encoded = value.s.encode('unicode-escape').decode()
# Calculate the string representing the line
# we are working on, up to but not including
# the string we are adding.
res_index, str_index = self.colinfo
current_line = self.result[res_index:]
if str_index:
current_line[0] = current_line[0][str_index:]
current_line = ''.join(current_line)
if is_joined:
# Handle new f-strings. This is a bit complicated, because
# the tree can contain subnodes that recurse back to JoinedStr
# subnodes...
def recurse(node):
for value in node.values:
if isinstance(value, ast.Str):
self.write(value.s)
elif isinstance(value, ast.FormattedValue):
with self.delimit('{}'):
self.visit(value.value)
if value.conversion != -1:
self.write('!%s' % chr(value.conversion))
if value.format_spec is not None:
self.write(':')
recurse(value.format_spec)
else: else:
encoded = value.s.encode('string-escape') kind = type(value).__name__
self.write(encoded) assert False, 'Invalid node %s inside JoinedStr' % kind
elif isinstance(value, ast.FormattedValue):
with self.delimit('{}'):
self.visit(value.value)
if value.conversion != -1:
self.write('!%s' % chr(value.conversion))
if value.format_spec is not None:
self.write(':')
recurse(value.format_spec)
else:
kind = type(value).__name__
assert False, 'Invalid node %s inside JoinedStr' % kind
with self.delimit(("f'", "'")) as delimiters: index = len(result)
recurse(node) recurse(node)
delimiters.coalesce = True mystr = ''.join(result[index:])
del result[index:]
s = self.result.pop() self.colinfo = res_index, str_index # Put it back like we found it
squotes = s.count("'") - 2 uni_lit = False # No formatted byte strings
if squotes:
dquotes = s.count('"') else:
s = s[2:-1] mystr = node.s
if dquotes < squotes: uni_lit = self.using_unicode_literals
s = 'f"%s"' % s.replace('"', r'\"')
else: mystr = self.pretty_string(mystr, embedded, current_line, uni_lit)
s = "f'%s'" % s.replace("'", r"\'")
self.write(s) if is_joined:
mystr = 'f' + mystr
self.write(mystr)
lf = mystr.rfind('\n') + 1
if lf:
self.colinfo = len(result) - 1, lf
def visit_Bytes(self, node): def visit_Bytes(self, node):
self.write(repr(node.s)) self.write(repr(node.s))
......
...@@ -101,4 +101,5 @@ class CodeToAst(object): ...@@ -101,4 +101,5 @@ class CodeToAst(object):
cache[(fname, obj.lineno)] = obj cache[(fname, obj.lineno)] = obj
return cache[key] return cache[key]
code_to_ast = CodeToAst() code_to_ast = CodeToAst()
...@@ -13,6 +13,12 @@ For a whole-tree approach, see the treewalk submodule. ...@@ -13,6 +13,12 @@ For a whole-tree approach, see the treewalk submodule.
""" """
import ast import ast
import itertools
try:
zip_longest = itertools.zip_longest
except AttributeError:
zip_longest = itertools.izip_longest
class NonExistent(object): class NonExistent(object):
...@@ -163,3 +169,40 @@ def allow_ast_comparison(): ...@@ -163,3 +169,40 @@ def allow_ast_comparison():
item.__bases__ = tuple(list(item.__bases__) + [CompareHelper]) item.__bases__ = tuple(list(item.__bases__) + [CompareHelper])
except TypeError: except TypeError:
pass pass
def fast_compare(tree1, tree2):
""" This is optimized to compare two AST trees for equality.
It makes several assumptions that are currently true for
AST trees used by rtrip, and it doesn't examine the _attributes.
"""
geta = ast.AST.__getattribute__
work = [(tree1, tree2)]
pop = work.pop
extend = work.extend
# TypeError in cPython, AttributeError in PyPy
exception = TypeError, AttributeError
zipl = zip_longest
type_ = type
list_ = list
while work:
n1, n2 = pop()
try:
f1 = geta(n1, '_fields')
f2 = geta(n2, '_fields')
except exception:
if type_(n1) is list_:
extend(zipl(n1, n2))
continue
if n1 == n2:
continue
return False
else:
f1 = [x for x in f1 if x != 'ctx']
if f1 != [x for x in f2 if x != 'ctx']:
return False
extend((geta(n1, fname), geta(n2, fname)) for fname in f1)
return True
...@@ -78,19 +78,22 @@ import logging ...@@ -78,19 +78,22 @@ import logging
from astor.code_gen import to_source from astor.code_gen import to_source
from astor.file_util import code_to_ast from astor.file_util import code_to_ast
from astor.node_util import allow_ast_comparison, dump_tree, strip_tree from astor.node_util import (allow_ast_comparison, dump_tree,
strip_tree, fast_compare)
dsttree = 'tmp_rtrip' dsttree = 'tmp_rtrip'
def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False,
ignore_exceptions=False, fullcomp=False):
"""Walk the srctree, and convert/copy all python files """Walk the srctree, and convert/copy all python files
into the dsttree into the dsttree
""" """
allow_ast_comparison() if fullcomp:
allow_ast_comparison()
parse_file = code_to_ast.parse_file parse_file = code_to_ast.parse_file
find_py_files = code_to_ast.find_py_files find_py_files = code_to_ast.find_py_files
...@@ -134,7 +137,12 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): ...@@ -134,7 +137,12 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
badfiles.add(srcfname) badfiles.add(srcfname)
continue continue
dsttxt = to_source(srcast) try:
dsttxt = to_source(srcast)
except:
if not ignore_exceptions:
raise
dsttxt = ''
if not readonly: if not readonly:
dstfname = os.path.join(dstpath, fname) dstfname = os.path.join(dstpath, fname)
...@@ -150,12 +158,15 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False): ...@@ -150,12 +158,15 @@ def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname) dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname)
except SyntaxError: except SyntaxError:
dstast = [] dstast = []
unknown_src_nodes.update(strip_tree(srcast)) if fullcomp:
unknown_dst_nodes.update(strip_tree(dstast)) unknown_src_nodes.update(strip_tree(srcast))
if dumpall or srcast != dstast: unknown_dst_nodes.update(strip_tree(dstast))
bad = srcast != dstast
else:
bad = not fast_compare(srcast, dstast)
if dumpall or bad:
srcdump = dump_tree(srcast) srcdump = dump_tree(srcast)
dstdump = dump_tree(dstast) dstdump = dump_tree(dstast)
bad = srcdump != dstdump
logging.warning(' calculating dump -- %s' % logging.warning(' calculating dump -- %s' %
('bad' if bad else 'OK')) ('bad' if bad else 'OK'))
if bad: if bad:
...@@ -230,6 +241,7 @@ def usage(msg): ...@@ -230,6 +241,7 @@ def usage(msg):
""") % msg) """) % msg)
if __name__ == '__main__': if __name__ == '__main__':
import textwrap import textwrap
......
...@@ -21,21 +21,7 @@ def pretty_source(source): ...@@ -21,21 +21,7 @@ def pretty_source(source):
""" Prettify the source. """ Prettify the source.
""" """
return ''.join(flatten(split_lines(source))) return ''.join(split_lines(source))
def flatten(source, list=list, isinstance=isinstance):
""" Deal with nested lists
"""
def flatten_iter(source):
for item in source:
if isinstance(item, list):
for item in flatten_iter(item):
yield item
else:
yield item
return flatten_iter(source)
def split_lines(source, maxline=79): def split_lines(source, maxline=79):
...@@ -43,38 +29,46 @@ def split_lines(source, maxline=79): ...@@ -43,38 +29,46 @@ def split_lines(source, maxline=79):
If a line is short enough, just yield it. If a line is short enough, just yield it.
Otherwise, fix it. Otherwise, fix it.
""" """
result = []
extend = result.extend
append = result.append
line = [] line = []
multiline = False multiline = False
count = 0 count = 0
find = str.find
for item in source: for item in source:
if item.startswith('\n'): index = find(item, '\n')
if index:
line.append(item)
multiline = index > 0
count += len(item)
else:
if line: if line:
if count <= maxline or multiline: if count <= maxline or multiline:
yield line extend(line)
else: else:
for item2 in wrap_line(line, maxline): wrap_line(line, maxline, result)
yield item2
count = 0 count = 0
multiline = False multiline = False
line = [] line = []
yield item append(item)
else: return result
line.append(item)
multiline = '\n' in item
count += len(item)
def count(group): def count(group, slen=str.__len__):
return sum(len(x) for x in group) return sum([slen(x) for x in group])
def wrap_line(line, maxline=79, count=count): def wrap_line(line, maxline=79, result=[], count=count):
""" We have a line that is too long, """ We have a line that is too long,
so we're going to try to wrap it. so we're going to try to wrap it.
""" """
# Extract the indentation # Extract the indentation
append = result.append
extend = result.extend
indentation = line[0] indentation = line[0]
lenfirst = len(indentation) lenfirst = len(indentation)
indent = lenfirst - len(indentation.strip()) indent = lenfirst - len(indentation.strip())
...@@ -100,10 +94,10 @@ def wrap_line(line, maxline=79, count=count): ...@@ -100,10 +94,10 @@ def wrap_line(line, maxline=79, count=count):
# then set up to deal with the remainder in pairs. # then set up to deal with the remainder in pairs.
first = unsplittable[0] first = unsplittable[0]
yield indentation append(indentation)
yield first extend(first)
if not splittable: if not splittable:
return return result
pos = indent + count(first) pos = indent + count(first)
indentation += ' ' indentation += ' '
indent += 4 indent += 4
...@@ -116,8 +110,8 @@ def wrap_line(line, maxline=79, count=count): ...@@ -116,8 +110,8 @@ def wrap_line(line, maxline=79, count=count):
# If we already have stuff on the line and even # If we already have stuff on the line and even
# the very first item won't fit, start a new line # the very first item won't fit, start a new line
if pos > indent and pos + len(sg[0]) > maxline: if pos > indent and pos + len(sg[0]) > maxline:
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
# Dump lines out of the splittable group # Dump lines out of the splittable group
...@@ -127,25 +121,25 @@ def wrap_line(line, maxline=79, count=count): ...@@ -127,25 +121,25 @@ def wrap_line(line, maxline=79, count=count):
ready, sg = split_group(sg, pos, maxline) ready, sg = split_group(sg, pos, maxline)
if ready[-1].endswith(' '): if ready[-1].endswith(' '):
ready[-1] = ready[-1][:-1] ready[-1] = ready[-1][:-1]
yield ready extend(ready)
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
csg = count(sg) csg = count(sg)
# Dump the remainder of the splittable group # Dump the remainder of the splittable group
if sg: if sg:
yield sg extend(sg)
pos += csg pos += csg
# Dump the unsplittable group, optionally # Dump the unsplittable group, optionally
# preceded by a linefeed. # preceded by a linefeed.
cnsg = count(nsg) cnsg = count(nsg)
if pos > indent and pos + cnsg > maxline: if pos > indent and pos + cnsg > maxline:
yield '\n' append('\n')
yield indentation append(indentation)
pos = indent pos = indent
yield nsg extend(nsg)
pos += cnsg pos += cnsg
...@@ -215,6 +209,7 @@ def delimiter_groups(line, begin_delim=begin_delim, ...@@ -215,6 +209,7 @@ def delimiter_groups(line, begin_delim=begin_delim,
assert not text, text assert not text, text
break break
statements = set(['del ', 'return', 'yield ', 'if ', 'while ']) statements = set(['del ', 'return', 'yield ', 'if ', 'while '])
...@@ -259,6 +254,7 @@ def add_parens(line, maxline, indent, statements=statements, count=count): ...@@ -259,6 +254,7 @@ def add_parens(line, maxline, indent, statements=statements, count=count):
return [item for group in groups for item in group] return [item for group in groups for item in group]
# Assignment operators # Assignment operators
ops = list('|^&+-*/%@~') + '<< >> // **'.split() + [''] ops = list('|^&+-*/%@~') + '<< >> // **'.split() + ['']
ops = set(' %s= ' % x for x in ops) ops = set(' %s= ' % x for x in ops)
......
...@@ -18,7 +18,6 @@ This has lots of Python 2 / Python 3 ugliness. ...@@ -18,7 +18,6 @@ This has lots of Python 2 / Python 3 ugliness.
""" """
import re import re
import logging
try: try:
special_unicode = unicode special_unicode = unicode
...@@ -32,28 +31,7 @@ except NameError: ...@@ -32,28 +31,7 @@ except NameError:
basestring = str basestring = str
def _get_line(current_output): def _properly_indented(s, line_indent):
""" Back up in the output buffer to
find the start of the current line,
and return the entire line.
"""
myline = []
index = len(current_output)
while index:
index -= 1
try:
s = str(current_output[index])
except:
raise
myline.append(s)
if '\n' in s:
break
myline = ''.join(reversed(myline))
return myline.rsplit('\n', 1)[-1]
def _properly_indented(s, current_line):
line_indent = len(current_line) - len(current_line.lstrip())
mylist = s.split('\n')[1:] mylist = s.split('\n')[1:]
mylist = [x.rstrip() for x in mylist] mylist = [x.rstrip() for x in mylist]
mylist = [x for x in mylist if x] mylist = [x for x in mylist if x]
...@@ -62,6 +40,7 @@ def _properly_indented(s, current_line): ...@@ -62,6 +40,7 @@ def _properly_indented(s, current_line):
counts = [(len(x) - len(x.lstrip())) for x in mylist] counts = [(len(x) - len(x.lstrip())) for x in mylist]
return counts and min(counts) >= line_indent return counts and min(counts) >= line_indent
mysplit = re.compile(r'(\\|\"\"\"|\"$)').split mysplit = re.compile(r'(\\|\"\"\"|\"$)').split
replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'} replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'}
...@@ -76,8 +55,8 @@ def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements): ...@@ -76,8 +55,8 @@ def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements):
return ''.join(s) return ''.join(s)
def pretty_string(s, embedded, current_output, min_trip_str=20, def pretty_string(s, embedded, current_line, uni_lit=False,
max_line=100, uni_lit=False): min_trip_str=20, max_line=100):
"""There are a lot of reasons why we might not want to or """There are a lot of reasons why we might not want to or
be able to return a triple-quoted string. We can always be able to return a triple-quoted string. We can always
punt back to the default normal string. punt back to the default normal string.
...@@ -92,16 +71,24 @@ def pretty_string(s, embedded, current_output, min_trip_str=20, ...@@ -92,16 +71,24 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return 'b' + default return 'b' + default
len_s = len(default) len_s = len(default)
current_line = _get_line(current_output)
if current_line.strip(): if current_line.strip():
if embedded and '\n' not in s: len_current = len(current_line)
second_line_start = s.find('\n') + 1
if embedded > 1 and not second_line_start:
return default return default
if len_s < min_trip_str: if len_s < min_trip_str:
return default return default
total_len = len(current_line) + len_s line_indent = len_current - len(current_line.lstrip())
if total_len < max_line and not _properly_indented(s, current_line):
# Could be on a line by itself...
if embedded and not second_line_start:
return default
total_len = len_current + len_s
if total_len < max_line and not _properly_indented(s, line_indent):
return default return default
fancy = '"""%s"""' % _prep_triple_quotes(s) fancy = '"""%s"""' % _prep_triple_quotes(s)
...@@ -116,11 +103,4 @@ def pretty_string(s, embedded, current_output, min_trip_str=20, ...@@ -116,11 +103,4 @@ def pretty_string(s, embedded, current_output, min_trip_str=20,
return fancy return fancy
except: except:
pass pass
"""
logging.warning("***String conversion did not work\n")
#print (eval(fancy), s)
print
print (fancy, repr(s))
print
"""
return default return default
...@@ -33,6 +33,7 @@ class MetaFlatten(type): ...@@ -33,6 +33,7 @@ class MetaFlatten(type):
# Delegate the real work to type # Delegate the real work to type
return type.__new__(clstype, name, newbases, newdict) return type.__new__(clstype, name, newbases, newdict)
MetaFlatten = MetaFlatten('MetaFlatten', (object, ), {}) MetaFlatten = MetaFlatten('MetaFlatten', (object, ), {})
......
...@@ -236,5 +236,6 @@ def makelib(): ...@@ -236,5 +236,6 @@ def makelib():
f.write(lineend) f.write(lineend)
f.write('"""\n'.encode('utf-8')) f.write('"""\n'.encode('utf-8'))
if __name__ == '__main__': if __name__ == '__main__':
makelib() makelib()
...@@ -23,10 +23,6 @@ should not be part of the automated regressions. ...@@ -23,10 +23,6 @@ should not be part of the automated regressions.
""" """
import sys import sys
import collections
import itertools
import textwrap
import hashlib
import ast import ast
import astor import astor
...@@ -86,5 +82,6 @@ def checklib(): ...@@ -86,5 +82,6 @@ def checklib():
f.write(('%s %s\n' % (repr(srctxt), f.write(('%s %s\n' % (repr(srctxt),
repr(dsttxt))).encode('utf-8')) repr(dsttxt))).encode('utf-8'))
if __name__ == '__main__': if __name__ == '__main__':
checklib() checklib()
...@@ -347,8 +347,8 @@ class CodegenTestCase(unittest.TestCase): ...@@ -347,8 +347,8 @@ class CodegenTestCase(unittest.TestCase):
'RawDescriptionHelpFormatter', 'RawTextHelpFormatter', 'Namespace', 'RawDescriptionHelpFormatter', 'RawTextHelpFormatter', 'Namespace',
'Action', 'ONE_OR_MORE', 'OPTIONAL', 'PARSER', 'REMAINDER', 'SUPPRESS', 'Action', 'ONE_OR_MORE', 'OPTIONAL', 'PARSER', 'REMAINDER', 'SUPPRESS',
'ZERO_OR_MORE'] 'ZERO_OR_MORE']
""" """ # NOQA
self.maxDiff=2000 self.maxDiff = 2000
self.assertAstSourceEqual(source) self.assertAstSourceEqual(source)
def test_elif(self): def test_elif(self):
...@@ -366,13 +366,13 @@ class CodegenTestCase(unittest.TestCase): ...@@ -366,13 +366,13 @@ class CodegenTestCase(unittest.TestCase):
def test_fstrings(self): def test_fstrings(self):
source = """ source = """
f'{x}' x = f'{x}'
f'{x.y}' x = f'{x.y}'
f'{int(x)}' x = f'{int(x)}'
f'a{b:c}d' x = f'a{b:c}d'
f'a{b!s:c{d}e}f' x = f'a{b!s:c{d}e}f'
f'""' x = f'""'
f'"\\'' x = f'"\\''
""" """
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6)) self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
source = """ source = """
...@@ -385,7 +385,6 @@ class CodegenTestCase(unittest.TestCase): ...@@ -385,7 +385,6 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6)) self.assertAstSourceEqualIfAtLeastVersion(source, (3, 6))
def test_annassign(self): def test_annassign(self):
source = """ source = """
a: int a: int
...@@ -404,7 +403,6 @@ class CodegenTestCase(unittest.TestCase): ...@@ -404,7 +403,6 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstEqualIfAtLeastVersion(source, (3, 6)) self.assertAstEqualIfAtLeastVersion(source, (3, 6))
def test_compile_types(self): def test_compile_types(self):
code = '(a + b + c) * (d + e + f)\n' code = '(a + b + c) * (d + e + f)\n'
for mode in 'exec eval single'.split(): for mode in 'exec eval single'.split():
...@@ -472,6 +470,26 @@ class CodegenTestCase(unittest.TestCase): ...@@ -472,6 +470,26 @@ class CodegenTestCase(unittest.TestCase):
""" """
self.assertAstEqual(source) self.assertAstEqual(source)
def test_non_string_leakage(self):
source = '''
tar_compression = {'gzip': 'gz', None: ''}
'''
self.assertAstEqual(source)
def test_fast_compare(self):
fast_compare = astor.node_util.fast_compare
def check(a, b):
ast_a = ast.parse(a)
ast_b = ast.parse(b)
dump_a = astor.dump_tree(ast_a)
dump_b = astor.dump_tree(ast_b)
self.assertEqual(dump_a == dump_b, fast_compare(ast_a, ast_b))
check('a = 3', 'a = 3')
check('a = 3', 'a = 5')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 5)')
check('a = 3 - (3, 4, 5)', 'a = 3 - (3, 4, 6)')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,5 +16,6 @@ class GetSymbolTestCase(unittest.TestCase): ...@@ -16,5 +16,6 @@ class GetSymbolTestCase(unittest.TestCase):
def test_get_mat_mult(self): def test_get_mat_mult(self):
self.assertEqual('@', astor.get_op_symbol(ast.MatMult())) self.assertEqual('@', astor.get_op_symbol(ast.MatMult()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment