%PDF- %PDF-
Direktori : /lib/python3/dist-packages/pythran/transformations/ |
Current File : //lib/python3/dist-packages/pythran/transformations/normalize_static_if.py |
""" NormalizeStaticIf adds support for static guards. """ from pythran.analyses import (ImportedIds, HasReturn, IsAssigned, CFG, HasBreak, HasContinue, DefUseChains, Ancestors, StaticExpressions, HasStaticExpression) from pythran.passmanager import Transformation from pythran.syntax import PythranSyntaxError import gast as ast from copy import deepcopy LOOP_NONE, EARLY_RET, LOOP_BREAK, LOOP_CONT = range(4) def outline(name, formal_parameters, out_parameters, stmts, has_return, has_break, has_cont): args = ast.arguments( [ast.Name(fp, ast.Param(), None, None) for fp in formal_parameters], [], None, [], [], None, []) if isinstance(stmts, ast.expr): assert not out_parameters, "no out parameters with expr" fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None, None) else: fdef = ast.FunctionDef(name, args, stmts, [], None, None) # this is part of a huge trick that plays with delayed type inference # it basically computes the return type based on out parameters, and # the return statement is unconditionally added so if we have other # returns, there will be a computation of the output type based on the # __combined of the regular return types and this one The original # returns have been patched above to have a different type that # cunningly combines with this output tuple # # This is the only trick I found to let pythran compute both the output # variable type and the early return type. But hey, a dirty one :-/ stmts.append( ast.Return( ast.Tuple( [ast.Name(fp, ast.Load(), None, None) for fp in out_parameters], ast.Load() ) ) ) if has_return: pr = PatchReturn(stmts[-1], has_break or has_cont) pr.visit(fdef) if has_break or has_cont: if not has_return: stmts[-1].value = ast.Tuple([ast.Constant(LOOP_NONE, None), stmts[-1].value], ast.Load()) pbc = PatchBreakContinue(stmts[-1]) pbc.visit(fdef) return fdef class PatchReturn(ast.NodeTransformer): def __init__(self, guard, has_break_or_cont): self.guard = guard self.has_break_or_cont = has_break_or_cont def visit_Return(self, node): if node is self.guard: holder = "StaticIfNoReturn" else: holder = "StaticIfReturn" value = node.value return ast.Return( ast.Call( ast.Attribute( ast.Attribute( ast.Name("builtins", ast.Load(), None, None), "pythran", ast.Load()), holder, ast.Load()), [value] if value else [ast.Constant(None, None)], [])) class PatchBreakContinue(ast.NodeTransformer): def __init__(self, guard): self.guard = guard def visit_For(self, _): pass def visit_While(self, _): pass def patch_Control(self, node, flag): new_node = deepcopy(self.guard) ret_val = new_node.value if isinstance(ret_val, ast.Call): if flag == LOOP_BREAK: ret_val.func.attr = "StaticIfBreak" else: ret_val.func.attr = "StaticIfCont" else: new_node.value.elts[0].value = flag return new_node def visit_Break(self, node): return self.patch_Control(node, LOOP_BREAK) def visit_Continue(self, node): return self.patch_Control(node, LOOP_CONT) class NormalizeStaticIf(Transformation): def __init__(self): super(NormalizeStaticIf, self).__init__(StaticExpressions, Ancestors, DefUseChains) def visit_Module(self, node): self.new_functions = [] self.funcs = [] self.cfgs = [] self.generic_visit(node) node.body.extend(self.new_functions) return node def escaping_ids(self, scope_stmt, stmts): 'gather sets of identifiers defined in stmts and used out of it' assigned_nodes = self.gather(IsAssigned, self.make_fake(stmts)) escaping = set() for assigned_node in assigned_nodes: head = self.def_use_chains.chains[assigned_node] for user in head.users(): if scope_stmt not in self.ancestors[user.node]: escaping.add(head.name()) return escaping @staticmethod def make_fake(stmts): return ast.If(ast.Constant(0, None), stmts, []) @staticmethod def make_dispatcher(static_expr, func_true, func_false, imported_ids): dispatcher_args = [static_expr, ast.Name(func_true.name, ast.Load(), None, None), ast.Name(func_false.name, ast.Load(), None, None)] dispatcher = ast.Call( ast.Attribute( ast.Attribute( ast.Name("builtins", ast.Load(), None, None), "pythran", ast.Load()), "static_if", ast.Load()), dispatcher_args, []) actual_call = ast.Call( dispatcher, [ast.Name(ii, ast.Load(), None, None) for ii in imported_ids], []) return actual_call def true_name(self): return "$isstatic{}".format(len(self.new_functions) + 0) def false_name(self): return "$isstatic{}".format(len(self.new_functions) + 1) def visit_FunctionDef(self, node): self.cfgs.append(self.gather(CFG, node)) self.funcs.append(node) onode = self.generic_visit(node) self.funcs.pop() self.cfgs.pop() return onode def visit_IfExp(self, node): self.generic_visit(node) if node.test not in self.static_expressions: return node imported_ids = sorted(self.gather(ImportedIds, node)) func_true = outline(self.true_name(), imported_ids, [], node.body, False, False, False) func_false = outline(self.false_name(), imported_ids, [], node.orelse, False, False, False) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) return actual_call def make_control_flow_handlers(self, cont_n, status_n, expected_return, has_cont, has_break): ''' Create the statements in charge of gathering control flow information for the static_if result, and executes the expected control flow instruction ''' if expected_return: assign = cont_ass = [ast.Assign( [ast.Tuple(expected_return, ast.Store())], ast.Name(cont_n, ast.Load(), None, None), None)] else: assign = cont_ass = [] if has_cont: cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(LOOP_CONT, None)]) cont_ass = [ast.If(cmpr, deepcopy(assign) + [ast.Continue()], cont_ass)] if has_break: cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(LOOP_BREAK, None)]) cont_ass = [ast.If(cmpr, deepcopy(assign) + [ast.Break()], cont_ass)] return cont_ass def visit_If(self, node): if node.test not in self.static_expressions: return self.generic_visit(node) imported_ids = self.gather(ImportedIds, node) assigned_ids_left = self.escaping_ids(node, node.body) assigned_ids_right = self.escaping_ids(node, node.orelse) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) fbody = self.make_fake(node.body) true_has_return = self.gather(HasReturn, fbody) true_has_break = self.gather(HasBreak, fbody) true_has_cont = self.gather(HasContinue, fbody) felse = self.make_fake(node.orelse) false_has_return = self.gather(HasReturn, felse) false_has_break = self.gather(HasBreak, felse) false_has_cont = self.gather(HasContinue, felse) has_return = true_has_return or false_has_return has_break = true_has_break or false_has_break has_cont = true_has_cont or false_has_cont self.generic_visit(node) func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return, has_break, has_cont) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return, has_break, has_cont) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) # variable modified within the static_if expected_return = [ast.Name(ii, ast.Store(), None, None) for ii in assigned_ids] self.update = True # name for various variables resulting from the static_if n = len(self.new_functions) status_n = "$status{}".format(n) return_n = "$return{}".format(n) cont_n = "$cont{}".format(n) if has_return: cfg = self.cfgs[-1] always_return = all(isinstance(x, (ast.Return, ast.Yield)) for x in cfg[node]) always_return &= true_has_return and false_has_return fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(return_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] if always_return: return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.Return(ast.Name(return_n, ast.Load(), None, None))] else: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(EARLY_RET, None)]) return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.If(cmpr, [ast.Return(ast.Name(return_n, ast.Load(), None, None))], cont_ass)] elif has_break or has_cont: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None)] + cont_ass elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call, None) else: return ast.Expr(actual_call) class SplitStaticExpression(Transformation): def __init__(self): super(SplitStaticExpression, self).__init__(StaticExpressions) def visit_Cond(self, node): ''' generic expression splitting algorithm. Should work for ifexp and if using W(rap) and U(n)W(rap) to manage difference between expr and stmt The idea is to split a BinOp in three expressions: 1. a (possibly empty) non-static expr 2. an expr containing a static expr 3. a (possibly empty) non-static expr Once split, the if body is refactored to keep the semantic, and then recursively split again, until all static expr are alone in a test condition ''' NodeTy = type(node) if NodeTy is ast.IfExp: def W(x): return x def UW(x): return x else: def W(x): return [x] def UW(x): return x[0] has_static_expr = self.gather(HasStaticExpression, node.test) if not has_static_expr: return self.generic_visit(node) if node.test in self.static_expressions: return self.generic_visit(node) if not isinstance(node.test, ast.BinOp): return self.generic_visit(node) before, static = [], [] values = [node.test.right, node.test.left] def has_static_expression(n): return self.gather(HasStaticExpression, n) while values and not has_static_expression(values[-1]): before.append(values.pop()) while values and has_static_expression(values[-1]): static.append(values.pop()) after = list(reversed(values)) test_before = NodeTy(None, None, None) if before: assert len(before) == 1 test_before.test = before[0] test_static = NodeTy(None, None, None) if static: test_static.test = static[0] if len(static) > 1: if after: assert len(after) == 1 after = [ast.BinOp(static[1], node.test.op, after[0])] else: after = static[1:] test_after = NodeTy(None, None, None) if after: assert len(after) == 1 test_after.test = after[0] if isinstance(node.test.op, ast.BitAnd): if after: test_after.body = deepcopy(node.body) test_after.orelse = deepcopy(node.orelse) test_after = W(test_after) else: test_after = deepcopy(node.body) if static: test_static.body = test_after test_static.orelse = deepcopy(node.orelse) test_static = W(test_static) else: test_static = test_after if before: test_before.body = test_static test_before.orelse = node.orelse node = test_before else: node = UW(test_static) elif isinstance(node.test.op, ast.BitOr): if after: test_after.body = deepcopy(node.body) test_after.orelse = deepcopy(node.orelse) test_after = W(test_after) else: test_after = deepcopy(node.orelse) if static: test_static.body = deepcopy(node.body) test_static.orelse = test_after test_static = W(test_static) else: test_static = test_after if before: test_before.body = deepcopy(node.body) test_before.orelse = test_static node = test_before else: node = UW(test_static) else: raise PythranSyntaxError("operator not supported in a static if", node) self.update = True return self.generic_visit(node) visit_If = visit_IfExp = visit_Cond