Kaydet (Commit) 5d004635 authored tarafından Patrick Maupin's avatar Patrick Maupin

Merge branch 'pretty_print' into pm_develop

......@@ -9,8 +9,6 @@ Copyright 2013 (c) Berker Peksag
"""
__version__ = '0.6'
from .code_gen import to_source # NOQA
from .node_util import iter_node, strip_tree, dump_tree
from .node_util import ExplicitNodeVisitor
......@@ -19,8 +17,10 @@ from .op_util import get_op_symbol, get_op_precedence # NOQA
from .op_util import symbol_data
from .tree_walk import TreeWalk # NOQA
__version__ = '0.6'
#DEPRECATED!!!
# DEPRECATED!!!
# These aliases support old programs. Please do not use in future.
......@@ -30,7 +30,7 @@ from .tree_walk import TreeWalk # NOQA
# things could be accessed from their submodule.
get_boolop = get_binop = get_cmpop = get_unaryop = get_op_symbol # NOQA
get_boolop = get_binop = get_cmpop = get_unaryop = get_op_symbol # NOQA
get_anyop = get_op_symbol
parsefile = code_to_ast.parse_file
codetoast = code_to_ast
......
This diff is collapsed.
......@@ -4,7 +4,7 @@ Part of the astor library for Python AST manipulation.
License: 3-clause BSD
Copyright (c) 2012-2015 Patrick Maupin
Copyright (c) 2015 Patrick Maupin
This module provides data and functions for mapping
AST nodes to symbols and precedences.
......@@ -14,48 +14,91 @@ AST nodes to symbols and precedences.
import ast
op_data = """
Or or 4
And and 6
Not not 8
Eq == 10
Gt > 10
GtE >= 10
In in 10
Is is 10
NotEq != 10
Lt < 10
LtE <= 10
NotIn not in 10
IsNot is not 10
BitOr | 12
BitXor ^ 14
BitAnd & 16
LShift << 18
RShift >> 18
Add + 20
Sub - 20
Mult * 22
Div / 22
Mod % 22
FloorDiv // 22
MatMult @ 22
UAdd + 24
USub - 24
Invert ~ 24
Pow ** 26
GeneratorExp 1
Assign 1
AugAssign 0
Expr 0
Yield 1
YieldFrom 0
If 1
For 0
While 0
Return 1
Slice 1
Subscript 0
Index 1
ExtSlice 1
comprehension_target 1
Tuple 0
Comma 1
Assert 0
Raise 0
call_one_arg 1
Lambda 1
IfExp 0
comprehension 1
Or or 1
And and 1
Not not 1
Eq == 1
Gt > 0
GtE >= 0
In in 0
Is is 0
NotEq != 0
Lt < 0
LtE <= 0
NotIn not in 0
IsNot is not 0
BitOr | 1
BitXor ^ 1
BitAnd & 1
LShift << 1
RShift >> 0
Add + 1
Sub - 0
Mult * 1
Div / 0
Mod % 0
FloorDiv // 0
MatMult @ 0
PowRHS 1
Invert ~ 1
UAdd + 0
USub - 0
Pow ** 1
Num 1
"""
op_data = [x.split() for x in op_data.splitlines()]
op_data = [(x[0], ' '.join(x[1:-1]), int(x[-1])) for x in op_data if x]
op_data = [[x[0], ' '.join(x[1:-1]), int(x[-1])] for x in op_data if x]
for index in range(1, len(op_data)):
op_data[index][2] *= 2
op_data[index][2] += op_data[index - 1][2]
precedence_data = dict((getattr(ast, x, None), z) for x, y, z in op_data)
symbol_data = dict((getattr(ast, x, None), y) for x, y, z in op_data)
def get_op_symbol(obj, fmt='%s', symbol_data=symbol_data, type=type):
"""Given an AST node object, returns a string containing the symbol.
"""
return fmt % symbol_data[type(obj)]
def get_op_precedence(obj, precedence_data=precedence_data, type=type):
"""Given an AST node object, returns the precedence.
"""
return precedence_data[type(obj)]
class Precedence(object):
vars().update((x, z) for x, y, z in op_data)
highest = max(z for x, y, z in op_data) + 2
# -*- coding: utf-8 -*-
"""
Part of the astor library for Python AST manipulation.
License: 3-clause BSD
Copyright (c) 2015 Patrick Maupin
Pretty-print source -- post-process for the decompiler
The goals of the initial cut of this engine are:
1) Do a passable, if not PEP8, job of line-wrapping.
2) Serve as an example of an interface to the decompiler
for anybody who wants to do a better job. :)
"""
def pretty_source(source):
""" Prettify the source.
"""
return ''.join(flatten(split_lines(source)))
def flatten(source, list=list, isinstance=isinstance):
""" Deal with nested lists
"""
def flatten_iter(source):
for item in source:
if isinstance(item, list):
for item in flatten_iter(item):
yield item
else:
yield item
return flatten_iter(source)
def split_lines(source, maxline=79):
"""Split inputs according to lines.
If a line is short enough, just yield it.
Otherwise, fix it.
"""
line = []
multiline = False
count = 0
for item in source:
if item.startswith('\n'):
if line:
if count <= maxline or multiline:
yield line
else:
for item2 in wrap_line(line, maxline):
yield item2
count = 0
multiline = False
line = []
yield item
else:
line.append(item)
multiline = '\n' in item
count += len(item)
def count(group):
return sum(len(x) for x in group)
def wrap_line(line, maxline=79, count=count):
""" We have a line that is too long,
so we're going to try to wrap it.
"""
# Extract the indentation
indentation = line[0]
lenfirst = len(indentation)
indent = lenfirst - len(indentation.strip())
assert indent in (0, lenfirst)
indentation = line.pop(0) if indent else ''
# Get splittable/non-splittable groups
dgroups = list(delimiter_groups(line))
unsplittable = dgroups[::2]
splittable = dgroups[1::2]
# If the largest non-splittable group won't fit
# on a line, try to add parentheses to the line.
if max(count(x) for x in unsplittable) > maxline - indent:
line = add_parens(line, maxline, indent)
dgroups = list(delimiter_groups(line))
unsplittable = dgroups[::2]
splittable = dgroups[1::2]
# Deal with the first (always unsplittable) group, and
# then set up to deal with the remainder in pairs.
first = unsplittable[0]
yield indentation
yield first
if not splittable:
return
pos = indent + count(first)
indentation += ' '
indent += 4
if indent >= maxline/2:
maxline = maxline/2 + indent
for sg, nsg in zip(splittable, unsplittable[1:]):
if sg:
# If we already have stuff on the line and even
# the very first item won't fit, start a new line
if pos > indent and pos + len(sg[0]) > maxline:
yield '\n'
yield indentation
pos = indent
# Dump lines out of the splittable group
# until the entire thing fits
csg = count(sg)
while pos + csg > maxline:
ready, sg = split_group(sg, pos, maxline)
if ready[-1].endswith(' '):
ready[-1] = ready[-1][:-1]
yield ready
yield '\n'
yield indentation
pos = indent
csg = count(sg)
# Dump the remainder of the splittable group
if sg:
yield sg
pos += csg
# Dump the unsplittable group, optionally
# preceded by a linefeed.
cnsg = count(nsg)
if pos > indent and pos + cnsg > maxline:
yield '\n'
yield indentation
pos = indent
yield nsg
pos += cnsg
def split_group(source, pos, maxline):
""" Split a group into two subgroups. The
first will be appended to the current
line, the second will start the new line.
Note that the first group must always
contain at least one item.
The original group may be destroyed.
"""
first = []
source.reverse()
while source:
tok = source.pop()
first.append(tok)
pos += len(tok)
if source:
tok = source[-1]
allowed = (maxline + 1) if tok.endswith(' ') else (maxline - 4)
if pos + len(tok) > allowed:
break
source.reverse()
return first, source
begin_delim = set('([{')
end_delim = set(')]}')
end_delim.add('):')
def delimiter_groups(line, begin_delim=begin_delim,
end_delim=end_delim):
"""Split a line into alternating groups.
The first group cannot have a line feed inserted,
the next one can, etc.
"""
text = []
line = iter(line)
while True:
# First build and yield an unsplittable group
for item in line:
text.append(item)
if item in begin_delim:
break
if not text:
break
yield text
# Now build and yield a splittable group
level = 0
text = []
for item in line:
if item in begin_delim:
level += 1
elif item in end_delim:
level -= 1
if level < 0:
yield text
text = [item]
break
text.append(item)
else:
assert not text, text
break
statements = set(['del ', 'return', 'yield ', 'if ', 'while '])
def add_parens(line, maxline, indent, statements=statements, count=count):
"""Attempt to add parentheses around the line
in order to make it splittable.
"""
if line[0] in statements:
index = 1
if not line[0].endswith(' '):
index = 2
assert line[1] == ' '
line.insert(index, '(')
if line[-1] == ':':
line.insert(-1, ')')
else:
line.append(')')
# That was the easy stuff. Now for assignments.
groups = list(get_assign_groups(line))
if len(groups) == 1:
# So sad, too bad
return line
counts = list(count(x) for x in groups)
didwrap = False
# If the LHS is large, wrap it first
if sum(counts[:-1]) >= maxline - indent - 4:
for group in groups[:-1]:
didwrap = False # Only want to know about last group
if len(group) > 1:
group.insert(0, '(')
group.insert(-1, ')')
didwrap = True
# Might not need to wrap the RHS if wrapped the LHS
if not didwrap or counts[-1] > maxline - indent - 10:
groups[-1].insert(0, '(')
groups[-1].append(')')
return [item for group in groups for item in group]
# Assignment operators
ops = list('|^&+-*/%@~') + '<< >> // **'.split() + ['']
ops = set(' %s= ' % x for x in ops)
def get_assign_groups(line, ops=ops):
""" Split a line into groups by assignment (including
augmented assignment)
"""
group = []
for item in line:
group.append(item)
if item in ops:
yield group
group = []
yield group
# -*- coding: utf-8 -*-
"""
Part of the astor library for Python AST manipulation.
License: 3-clause BSD
Copyright (c) 2015 Patrick Maupin
Pretty-print strings for the decompiler
We either return the repr() of the string,
or try to format it as a triple-quoted string.
This is a lot harder than you would think.
This has lots of Python 2 / Python 3 ugliness.
"""
import re
import logging
try:
special_unicode = unicode
except NameError:
class special_unicode(object):
pass
try:
basestring = basestring
except NameError:
basestring = str
def _get_line(current_output):
""" Back up in the output buffer to
find the start of the current line,
and return the entire line.
"""
myline = []
index = len(current_output)
while index:
index -= 1
try:
s = str(current_output[index])
except:
raise
myline.append(s)
if '\n' in s:
break
myline = ''.join(reversed(myline))
return myline.rsplit('\n', 1)[-1]
def _properly_indented(s, current_line):
line_indent = len(current_line) - len(current_line.lstrip())
mylist = s.split('\n')[1:]
mylist = [x.rstrip() for x in mylist]
mylist = [x for x in mylist if x]
if not s:
return False
counts = [(len(x) - len(x.lstrip())) for x in mylist]
return counts and min(counts) >= line_indent
mysplit = re.compile(r'(\\|\"\"\"|\"$)').split
replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'}
def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements):
""" Split the string up and force-feed some replacements
to make sure it will round-trip OK
"""
s = mysplit(s)
s[1::2] = (replacements[x] for x in s[1::2])
return ''.join(s)
def pretty_string(s, current_output, min_trip_str=20, max_line=100):
"""There are a lot of reasons why we might not want to or
be able to return a triple-quoted string. We can always
punt back to the default normal string.
"""
default = repr(s)
# Punt on abnormal strings
if (isinstance(s, special_unicode) or not isinstance(s, basestring)):
return default
len_s = len(default)
current_line = _get_line(current_output)
if current_line.strip():
if len_s < min_trip_str:
return default
total_len = len(current_line) + len_s
if total_len < max_line and not _properly_indented(s, current_line):
return default
fancy = '"""%s"""' % _prep_triple_quotes(s)
# Sometimes this doesn't work. One reason is that
# the AST has no understanding of whether \r\n was
# entered that way in the string or was a cr/lf in the
# file. So we punt just so we can round-trip properly.
try:
if eval(fancy) == s and '\r' not in fancy:
return fancy
except:
pass
"""
logging.warning("***String conversion did not work\n")
#print (eval(fancy), s)
print
print (fancy, repr(s))
print
"""
return default
......@@ -7,6 +7,9 @@ License: 3-clause BSD
Copyright 2012 (c) Patrick Maupin
Copyright 2013 (c) Berker Peksag
This file contains a TreeWalk class that views a node tree
as a unified whole and allows several modes of traversal.
"""
from .node_util import iter_node
......@@ -76,9 +79,9 @@ class TreeWalk(MetaFlatten):
methods can be written. They will be called in alphabetical order.
"""
nodestack = None
def __init__(self, node=None):
self.nodestack = []
self.setup()
if node is not None:
self.walk(node)
......@@ -106,11 +109,11 @@ class TreeWalk(MetaFlatten):
"""
pre_handlers = self.pre_handlers.get
post_handlers = self.post_handlers.get
oldstack = self.nodestack
self.nodestack = nodestack = []
nodestack = self.nodestack
emptystack = len(nodestack)
append, pop = nodestack.append, nodestack.pop
append([node, name, list(iter_node(node, name + '_item')), -1])
while nodestack:
while len(nodestack) > emptystack:
node, name, subnodes, index = nodestack[-1]
if index >= len(subnodes):
handler = (post_handlers(type(node).__name__) or
......@@ -138,7 +141,6 @@ class TreeWalk(MetaFlatten):
else:
node, name = subnodes[index]
append([node, name, list(iter_node(node, name + '_item')), -1])
self.nodestack = oldstack
@property
def parent(self):
......
......@@ -26,29 +26,29 @@ import ast
import astor
all_operators = (
#Selected special operands
# Selected special operands
'3 -3 () yield',
#operators with one parameter
# operators with one parameter
'yield lambda_: not + - ~ $, yield_from',
#operators with two parameters
# operators with two parameters
'or and == != > >= < <= in not_in is is_not '
'| ^ & << >> + - * / % // @ ** for$in$ $($) $[$] . '
'$,$ ',
#operators with 3 parameters
# operators with 3 parameters
'$if$else$ $for$in$'
)
select_operators = (
#Selected special operands -- remove
#some at redundant precedence levels
# Selected special operands -- remove
# some at redundant precedence levels
'-3',
#operators with one parameter
# operators with one parameter
'yield lambda_: not - ~ $,',
#operators with two parameters
# operators with two parameters
'or and == in is '
'| ^ & >> - % ** for$in$ $($) . ',
#operators with 3 parameters
# operators with 3 parameters
'$if$else$ $for$in$'
)
......@@ -111,9 +111,9 @@ def get_sub_combinations(maxop):
if numops:
combo[numops, 1].append((numops-1,))
for op1 in range(numops):
combo[numops, 2].append((op1, numops - op1 -1))
combo[numops, 2].append((op1, numops - op1 - 1))
for op2 in range(numops - op1):
combo[numops, 3].append((op1, op2, numops - op1 -op2-1))
combo[numops, 3].append((op1, op2, numops - op1 - op2 - 1))
return combo
......@@ -127,8 +127,8 @@ def get_paren_combos():
"""
results = [None] * 4
options = [('%s', '(%s)')]
for i in range(1,4):
results[i] = list(itertools.product(*(i*options)))
for i in range(1, 4):
results[i] = list(itertools.product(*(i * options)))
return results
......@@ -149,6 +149,7 @@ def operand_combo(expressions, operands, max_operand=13):
for op in op_combos[expr.count('%s')]:
yield expr % op
def build(numops=2, all_operators=all_operators, use_operands=False,
# Runtime optimization
tuple=tuple):
......@@ -167,22 +168,24 @@ def build(numops=2, all_operators=all_operators, use_operands=False,
for myop, nparams in operators:
myop = myop.replace('%%', '%%%%')
myparens = paren_combos[nparams]
#print combo[numops, nparams]
# print combo[numops, nparams]
for mycombo in combo[numops, nparams]:
#print mycombo
# print mycombo
call_again = (recurse_build(x) for x in mycombo)
for subexpr in product(*call_again):
for parens in myparens:
wrapped = tuple(x % y for (x, y) in izip(parens, subexpr))
wrapped = tuple(x % y for (x, y)
in izip(parens, subexpr))
yield myop % wrapped
result = recurse_build(numops)
return operand_combo(result, operands) if use_operands else result
def makelib():
parse = ast.parse
dump_tree = astor.dump_tree
default_value = lambda: (1000000, '')
def default_value(): return 1000000, ''
mydict = collections.defaultdict(default_value)
allparams = [tuple('abcdefghijklmnop'[:x]) for x in range(13)]
......@@ -191,10 +194,10 @@ def makelib():
build(3, select_operators))
yieldrepl = list(('yield %s %s' % (operator, operand),
'yield %s%s' % (operator, operand))
for operator in '+-' for operand in '(ab')
'yield %s%s' % (operator, operand))
for operator in '+-' for operand in '(ab')
yieldrepl.append(('yield[', 'yield ['))
#alltxt = itertools.chain(build(1), build(2))
# alltxt = itertools.chain(build(1), build(2))
badexpr = 0
goodexpr = 0
silly = '3( 3.( 3[ 3.['.split()
......
......@@ -83,7 +83,8 @@ def checklib():
print('******************')
print()
print()
f.write(('%s %s\n' % (repr(srctxt), repr(dsttxt))).encode('utf-8'))
f.write(('%s %s\n' % (repr(srctxt),
repr(dsttxt))).encode('utf-8'))
if __name__ == '__main__':
checklib()
......@@ -22,6 +22,7 @@ import astor
def canonical(srctxt):
return textwrap.dedent(srctxt).strip()
class CodegenTestCase(unittest.TestCase):
def assertAstEqual(self, srctxt):
......@@ -38,7 +39,7 @@ class CodegenTestCase(unittest.TestCase):
self.assertEqual(dstdmp, srcdmp)
def assertAstEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
......@@ -53,10 +54,10 @@ class CodegenTestCase(unittest.TestCase):
which may not always be appropriate.
"""
srctxt = canonical(srctxt)
self.assertEqual(astor.to_source(ast.parse(srctxt)), srctxt)
self.assertEqual(astor.to_source(ast.parse(srctxt)).rstrip(), srctxt)
def assertAstSourceEqualIfAtLeastVersion(self, source, min_should_work,
max_should_error=None):
max_should_error=None):
if max_should_error is None:
max_should_error = min_should_work[0], min_should_work[1] - 1
if sys.version_info >= min_should_work:
......@@ -140,7 +141,7 @@ class CodegenTestCase(unittest.TestCase):
root_node = ast.parse(source)
arguments_node = [n for n in ast.walk(root_node)
if isinstance(n, ast.arguments)][0]
self.assertEqual(astor.to_source(arguments_node),
self.assertEqual(astor.to_source(arguments_node).rstrip(),
"a1, a2, b1=j, b2='123', b3={}, b4=[]")
source = """
def call(*popenargs, timeout=None, **kwargs):
......@@ -197,6 +198,14 @@ class CodegenTestCase(unittest.TestCase):
self.assertAstEqual(source)
source = "[(yield)]"
self.assertAstEqual(source)
source = "if (yield): pass"
self.assertAstEqual(source)
source = "if (yield from foo): pass"
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
source = "(yield from (a, b))"
self.assertAstEqualIfAtLeastVersion(source, (3, 3))
source = "yield from sam()"
self.assertAstSourceEqualIfAtLeastVersion(source, (3, 3))
def test_with(self):
source = """
......@@ -254,7 +263,6 @@ class CodegenTestCase(unittest.TestCase):
"""
self.assertAstEqual(source)
def test_comprehension(self):
source = """
((x,y) for x,y in zip(a,b))
......@@ -265,7 +273,8 @@ class CodegenTestCase(unittest.TestCase):
"""
self.assertAstEqual(source)
source = """
ra = np.fromiter(((i * 3, i * 2) for i in range(10)), n, dtype='i8,f8')
ra = np.fromiter(((i * 3, i * 2) for i in range(10)),
n, dtype='i8,f8')
"""
self.assertAstEqual(source)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment