%PDF- %PDF-
Direktori : /lib/python3/dist-packages/pythran/optimizations/ |
Current File : //lib/python3/dist-packages/pythran/optimizations/loop_full_unrolling.py |
""" LoopFullUnrolling fully unrolls loops with static bounds. """ from pythran import metadata from pythran.analyses import HasBreak, HasContinue, NodeCount from pythran.openmp import OMPDirective from pythran.conversion import to_ast from pythran.passmanager import Transformation from copy import deepcopy import gast as ast class LoopFullUnrolling(Transformation): ''' Fully unroll loops with static bounds >>> import gast as ast >>> from pythran import passmanager, backend >>> node = ast.parse('for j in [1,2,3]: i += j') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(LoopFullUnrolling, node) >>> print(pm.dump(backend.Python, node)) j = 1 i += j j = 2 i += j j = 3 i += j >>> node = ast.parse('for j in (a,b): i += j') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(LoopFullUnrolling, node) >>> print(pm.dump(backend.Python, node)) j = a i += j j = b i += j >>> node = ast.parse('for j in {1}: i += j') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(LoopFullUnrolling, node) >>> print(pm.dump(backend.Python, node)) j = 1 i += j >>> node = ast.parse('for j in builtins.enumerate("1"): j') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(LoopFullUnrolling, node) >>> print(pm.dump(backend.Python, node)) j = (0, '1') j ''' MAX_NODE_COUNT = 4096 def visit_For(self, node): # if the user added some OpenMP directive, trust him and no unroll if metadata.get(node, OMPDirective): return node # don't visit children because of collapse # first unroll children if needed or possible self.generic_visit(node) # a break or continue in the loop prevents unrolling too has_break = any(self.gather(HasBreak, n) for n in node.body) has_cont = any(self.gather(HasContinue, n) for n in node.body) if has_break or has_cont: return node # do not unroll too much to prevent code growth node_count = self.gather(NodeCount, node) def unroll(elt, body): return [ast.Assign([deepcopy(node.target)], elt, None)] + body def dc(body, i, n): if i == n - 1: return body else: return deepcopy(body) def getrange(n): return getattr(getattr(n, 'func', None), 'attr', None) if isinstance(node.iter, (ast.Tuple, ast.List)): elts_count = len(node.iter.elts) total_count = node_count * elts_count issmall = total_count < LoopFullUnrolling.MAX_NODE_COUNT if issmall: self.update = True return sum([unroll(elt, dc(node.body, i, elts_count)) for i, elt in enumerate(node.iter.elts)], []) code = compile(ast.gast_to_ast(ast.Expression(node.iter)), '<loop unrolling>', 'eval') try: values = list(eval(code, {'builtins': __import__('builtins')})) except Exception: return node values_count = len(values) total_count = node_count * values_count issmall = total_count < LoopFullUnrolling.MAX_NODE_COUNT if issmall: try: new_node = sum([unroll(to_ast(elt), dc(node.body, i, values_count)) for i, elt in enumerate(values)], []) self.update = True return new_node except Exception: return node return node