%PDF- %PDF-
| Direktori : /lib/python3/dist-packages/pythran/optimizations/ |
| Current File : //lib/python3/dist-packages/pythran/optimizations/square.py |
""" Replaces **2 by a call to numpy.square. """
from pythran.passmanager import Transformation
from pythran.analyses.ast_matcher import ASTMatcher, AST_any
from pythran.conversion import mangle
from pythran.utils import isnum
import gast as ast
import copy
class Square(Transformation):
"""
Replaces **2 by a call to numpy.square.
>>> import gast as ast
>>> from pythran import passmanager, backend
>>> node = ast.parse('a**2')
>>> pm = passmanager.PassManager("test")
>>> _, node = pm.apply(Square, node)
>>> print(pm.dump(backend.Python, node))
import numpy as __pythran_import_numpy
__pythran_import_numpy.square(a)
>>> node = ast.parse('__pythran_import_numpy.power(a,2)')
>>> pm = passmanager.PassManager("test")
>>> _, node = pm.apply(Square, node)
>>> print(pm.dump(backend.Python, node))
import numpy as __pythran_import_numpy
__pythran_import_numpy.square(a)
"""
POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Constant(2, None))
POWER_PATTERN = ast.Call(
ast.Attribute(ast.Name(mangle('numpy'), ast.Load(), None, None),
'power',
ast.Load()),
[AST_any(), ast.Constant(2, None)],
[])
def __init__(self):
Transformation.__init__(self)
def replace(self, value):
self.update = self.need_import = True
module_name = ast.Name(mangle('numpy'), ast.Load(), None, None)
return ast.Call(ast.Attribute(module_name, 'square', ast.Load()),
[value], [])
def visit_Module(self, node):
self.need_import = False
self.generic_visit(node)
if self.need_import:
import_alias = ast.alias(name='numpy', asname=mangle('numpy'))
importIt = ast.Import(names=[import_alias])
node.body.insert(0, importIt)
return node
def expand_pow(self, node, n):
if n == 0:
return ast.Constant(1, None)
elif n == 1:
return node
else:
node_square = self.replace(node)
node_pow = self.expand_pow(node_square, n >> 1)
if n & 1:
return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node))
else:
return node_pow
def visit_BinOp(self, node):
self.generic_visit(node)
if ASTMatcher(Square.POW_PATTERN).search(node):
return self.replace(node.left)
elif isinstance(node.op, ast.Pow) and isnum(node.right):
n = node.right.value
if int(n) == n and n > 0:
return self.expand_pow(node.left, n)
else:
return node
else:
return node
def visit_Call(self, node):
self.generic_visit(node)
if ASTMatcher(Square.POWER_PATTERN).search(node):
return self.replace(node.args[0])
else:
return node