%PDF- %PDF-
| Direktori : /lib/python3/dist-packages/pythran/transformations/ |
| Current File : //lib/python3/dist-packages/pythran/transformations/normalize_static_if.py |
""" NormalizeStaticIf adds support for static guards. """
from pythran.analyses import (ImportedIds, HasReturn, IsAssigned, CFG,
HasBreak, HasContinue, DefUseChains, Ancestors,
StaticExpressions, HasStaticExpression)
from pythran.passmanager import Transformation
from pythran.syntax import PythranSyntaxError
import gast as ast
from copy import deepcopy
LOOP_NONE, EARLY_RET, LOOP_BREAK, LOOP_CONT = range(4)
def outline(name, formal_parameters, out_parameters, stmts,
has_return, has_break, has_cont):
args = ast.arguments(
[ast.Name(fp, ast.Param(), None, None) for fp in formal_parameters],
[], None, [], [], None, [])
if isinstance(stmts, ast.expr):
assert not out_parameters, "no out parameters with expr"
fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None, None)
else:
fdef = ast.FunctionDef(name, args, stmts, [], None, None)
# this is part of a huge trick that plays with delayed type inference
# it basically computes the return type based on out parameters, and
# the return statement is unconditionally added so if we have other
# returns, there will be a computation of the output type based on the
# __combined of the regular return types and this one The original
# returns have been patched above to have a different type that
# cunningly combines with this output tuple
#
# This is the only trick I found to let pythran compute both the output
# variable type and the early return type. But hey, a dirty one :-/
stmts.append(
ast.Return(
ast.Tuple(
[ast.Name(fp, ast.Load(), None, None)
for fp in out_parameters],
ast.Load()
)
)
)
if has_return:
pr = PatchReturn(stmts[-1], has_break or has_cont)
pr.visit(fdef)
if has_break or has_cont:
if not has_return:
stmts[-1].value = ast.Tuple([ast.Constant(LOOP_NONE, None),
stmts[-1].value],
ast.Load())
pbc = PatchBreakContinue(stmts[-1])
pbc.visit(fdef)
return fdef
class PatchReturn(ast.NodeTransformer):
def __init__(self, guard, has_break_or_cont):
self.guard = guard
self.has_break_or_cont = has_break_or_cont
def visit_Return(self, node):
if node is self.guard:
holder = "StaticIfNoReturn"
else:
holder = "StaticIfReturn"
value = node.value
return ast.Return(
ast.Call(
ast.Attribute(
ast.Attribute(
ast.Name("builtins", ast.Load(), None, None),
"pythran",
ast.Load()),
holder,
ast.Load()),
[value] if value else [ast.Constant(None, None)],
[]))
class PatchBreakContinue(ast.NodeTransformer):
def __init__(self, guard):
self.guard = guard
def visit_For(self, _):
pass
def visit_While(self, _):
pass
def patch_Control(self, node, flag):
new_node = deepcopy(self.guard)
ret_val = new_node.value
if isinstance(ret_val, ast.Call):
if flag == LOOP_BREAK:
ret_val.func.attr = "StaticIfBreak"
else:
ret_val.func.attr = "StaticIfCont"
else:
new_node.value.elts[0].value = flag
return new_node
def visit_Break(self, node):
return self.patch_Control(node, LOOP_BREAK)
def visit_Continue(self, node):
return self.patch_Control(node, LOOP_CONT)
class NormalizeStaticIf(Transformation):
def __init__(self):
super(NormalizeStaticIf, self).__init__(StaticExpressions, Ancestors,
DefUseChains)
def visit_Module(self, node):
self.new_functions = []
self.funcs = []
self.cfgs = []
self.generic_visit(node)
node.body.extend(self.new_functions)
return node
def escaping_ids(self, scope_stmt, stmts):
'gather sets of identifiers defined in stmts and used out of it'
assigned_nodes = self.gather(IsAssigned, self.make_fake(stmts))
escaping = set()
for assigned_node in assigned_nodes:
head = self.def_use_chains.chains[assigned_node]
for user in head.users():
if scope_stmt not in self.ancestors[user.node]:
escaping.add(head.name())
return escaping
@staticmethod
def make_fake(stmts):
return ast.If(ast.Constant(0, None), stmts, [])
@staticmethod
def make_dispatcher(static_expr, func_true, func_false,
imported_ids):
dispatcher_args = [static_expr,
ast.Name(func_true.name, ast.Load(), None, None),
ast.Name(func_false.name, ast.Load(), None, None)]
dispatcher = ast.Call(
ast.Attribute(
ast.Attribute(
ast.Name("builtins", ast.Load(), None, None),
"pythran",
ast.Load()),
"static_if",
ast.Load()),
dispatcher_args, [])
actual_call = ast.Call(
dispatcher,
[ast.Name(ii, ast.Load(), None, None) for ii in imported_ids],
[])
return actual_call
def true_name(self):
return "$isstatic{}".format(len(self.new_functions) + 0)
def false_name(self):
return "$isstatic{}".format(len(self.new_functions) + 1)
def visit_FunctionDef(self, node):
self.cfgs.append(self.gather(CFG, node))
self.funcs.append(node)
onode = self.generic_visit(node)
self.funcs.pop()
self.cfgs.pop()
return onode
def visit_IfExp(self, node):
self.generic_visit(node)
if node.test not in self.static_expressions:
return node
imported_ids = sorted(self.gather(ImportedIds, node))
func_true = outline(self.true_name(), imported_ids, [],
node.body, False, False, False)
func_false = outline(self.false_name(), imported_ids, [],
node.orelse, False, False, False)
self.new_functions.extend((func_true, func_false))
actual_call = self.make_dispatcher(node.test, func_true,
func_false, imported_ids)
return actual_call
def make_control_flow_handlers(self, cont_n, status_n, expected_return,
has_cont, has_break):
'''
Create the statements in charge of gathering control flow information
for the static_if result, and executes the expected control flow
instruction
'''
if expected_return:
assign = cont_ass = [ast.Assign(
[ast.Tuple(expected_return, ast.Store())],
ast.Name(cont_n, ast.Load(), None, None), None)]
else:
assign = cont_ass = []
if has_cont:
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
[ast.Eq()], [ast.Constant(LOOP_CONT, None)])
cont_ass = [ast.If(cmpr,
deepcopy(assign) + [ast.Continue()],
cont_ass)]
if has_break:
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
[ast.Eq()], [ast.Constant(LOOP_BREAK, None)])
cont_ass = [ast.If(cmpr,
deepcopy(assign) + [ast.Break()],
cont_ass)]
return cont_ass
def visit_If(self, node):
if node.test not in self.static_expressions:
return self.generic_visit(node)
imported_ids = self.gather(ImportedIds, node)
assigned_ids_left = self.escaping_ids(node, node.body)
assigned_ids_right = self.escaping_ids(node, node.orelse)
assigned_ids_both = assigned_ids_left.union(assigned_ids_right)
imported_ids.update(i for i in assigned_ids_left
if i not in assigned_ids_right)
imported_ids.update(i for i in assigned_ids_right
if i not in assigned_ids_left)
imported_ids = sorted(imported_ids)
assigned_ids = sorted(assigned_ids_both)
fbody = self.make_fake(node.body)
true_has_return = self.gather(HasReturn, fbody)
true_has_break = self.gather(HasBreak, fbody)
true_has_cont = self.gather(HasContinue, fbody)
felse = self.make_fake(node.orelse)
false_has_return = self.gather(HasReturn, felse)
false_has_break = self.gather(HasBreak, felse)
false_has_cont = self.gather(HasContinue, felse)
has_return = true_has_return or false_has_return
has_break = true_has_break or false_has_break
has_cont = true_has_cont or false_has_cont
self.generic_visit(node)
func_true = outline(self.true_name(), imported_ids, assigned_ids,
node.body, has_return, has_break, has_cont)
func_false = outline(self.false_name(), imported_ids, assigned_ids,
node.orelse, has_return, has_break, has_cont)
self.new_functions.extend((func_true, func_false))
actual_call = self.make_dispatcher(node.test,
func_true, func_false, imported_ids)
# variable modified within the static_if
expected_return = [ast.Name(ii, ast.Store(), None, None)
for ii in assigned_ids]
self.update = True
# name for various variables resulting from the static_if
n = len(self.new_functions)
status_n = "$status{}".format(n)
return_n = "$return{}".format(n)
cont_n = "$cont{}".format(n)
if has_return:
cfg = self.cfgs[-1]
always_return = all(isinstance(x, (ast.Return, ast.Yield))
for x in cfg[node])
always_return &= true_has_return and false_has_return
fast_return = [ast.Name(status_n, ast.Store(), None, None),
ast.Name(return_n, ast.Store(), None, None),
ast.Name(cont_n, ast.Store(), None, None)]
if always_return:
return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
actual_call, None),
ast.Return(ast.Name(return_n, ast.Load(), None, None))]
else:
cont_ass = self.make_control_flow_handlers(cont_n, status_n,
expected_return,
has_cont, has_break)
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
[ast.Eq()], [ast.Constant(EARLY_RET, None)])
return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
actual_call, None),
ast.If(cmpr,
[ast.Return(ast.Name(return_n, ast.Load(),
None, None))],
cont_ass)]
elif has_break or has_cont:
cont_ass = self.make_control_flow_handlers(cont_n, status_n,
expected_return,
has_cont, has_break)
fast_return = [ast.Name(status_n, ast.Store(), None, None),
ast.Name(cont_n, ast.Store(), None, None)]
return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
actual_call, None)] + cont_ass
elif expected_return:
return ast.Assign([ast.Tuple(expected_return, ast.Store())],
actual_call, None)
else:
return ast.Expr(actual_call)
class SplitStaticExpression(Transformation):
def __init__(self):
super(SplitStaticExpression, self).__init__(StaticExpressions)
def visit_Cond(self, node):
'''
generic expression splitting algorithm. Should work for ifexp and if
using W(rap) and U(n)W(rap) to manage difference between expr and stmt
The idea is to split a BinOp in three expressions:
1. a (possibly empty) non-static expr
2. an expr containing a static expr
3. a (possibly empty) non-static expr
Once split, the if body is refactored to keep the semantic,
and then recursively split again, until all static expr are alone in a
test condition
'''
NodeTy = type(node)
if NodeTy is ast.IfExp:
def W(x):
return x
def UW(x):
return x
else:
def W(x):
return [x]
def UW(x):
return x[0]
has_static_expr = self.gather(HasStaticExpression, node.test)
if not has_static_expr:
return self.generic_visit(node)
if node.test in self.static_expressions:
return self.generic_visit(node)
if not isinstance(node.test, ast.BinOp):
return self.generic_visit(node)
before, static = [], []
values = [node.test.right, node.test.left]
def has_static_expression(n):
return self.gather(HasStaticExpression, n)
while values and not has_static_expression(values[-1]):
before.append(values.pop())
while values and has_static_expression(values[-1]):
static.append(values.pop())
after = list(reversed(values))
test_before = NodeTy(None, None, None)
if before:
assert len(before) == 1
test_before.test = before[0]
test_static = NodeTy(None, None, None)
if static:
test_static.test = static[0]
if len(static) > 1:
if after:
assert len(after) == 1
after = [ast.BinOp(static[1], node.test.op, after[0])]
else:
after = static[1:]
test_after = NodeTy(None, None, None)
if after:
assert len(after) == 1
test_after.test = after[0]
if isinstance(node.test.op, ast.BitAnd):
if after:
test_after.body = deepcopy(node.body)
test_after.orelse = deepcopy(node.orelse)
test_after = W(test_after)
else:
test_after = deepcopy(node.body)
if static:
test_static.body = test_after
test_static.orelse = deepcopy(node.orelse)
test_static = W(test_static)
else:
test_static = test_after
if before:
test_before.body = test_static
test_before.orelse = node.orelse
node = test_before
else:
node = UW(test_static)
elif isinstance(node.test.op, ast.BitOr):
if after:
test_after.body = deepcopy(node.body)
test_after.orelse = deepcopy(node.orelse)
test_after = W(test_after)
else:
test_after = deepcopy(node.orelse)
if static:
test_static.body = deepcopy(node.body)
test_static.orelse = test_after
test_static = W(test_static)
else:
test_static = test_after
if before:
test_before.body = deepcopy(node.body)
test_before.orelse = test_static
node = test_before
else:
node = UW(test_static)
else:
raise PythranSyntaxError("operator not supported in a static if",
node)
self.update = True
return self.generic_visit(node)
visit_If = visit_IfExp = visit_Cond