%PDF- %PDF-
| Direktori : /lib/python3/dist-packages/pythran/transformations/ |
| Current File : //lib/python3/dist-packages/pythran/transformations/handle_import.py |
"""HandleImport transformation takes care of importing user-defined modules."""
from pythran.passmanager import Transformation
from pythran.tables import MODULES, pythran_ward
from pythran.syntax import PythranSyntaxError
import gast as ast
import logging
import os
logger = logging.getLogger('pythran')
def add_filename_field(node, filename):
for descendant in ast.walk(node):
descendant.filename = filename
def mangle_imported_module(module_name):
return pythran_ward + "imported__" + module_name.replace('.', '$') + '$'
def mangle_imported_function(module_name, func_name):
return mangle_imported_module(module_name) + func_name
def demangle(name):
return name[len(pythran_ward + "imported__"):-1].replace('$', '.')
def is_builtin_function(func_name):
"""Test if a function is a builtin (like len(), map(), ...)."""
return func_name in MODULES["builtins"]
def is_builtin_module(module_name):
"""Test if a module is a builtin module (numpy, math, ...)."""
module_name = module_name.split(".")[0]
return module_name in MODULES
def is_mangled_module(name):
return name.endswith('$')
def getsource(name, module_dir, level):
# Try to load py file
module_base = name.replace('.', os.path.sep) + '.py'
if module_dir is None:
assert level <= 0, "Cannot use relative path without module_dir"
module_file = module_base
else:
module_file = os.path.sep.join(([module_dir] + ['..'] * (level - 1)
+ [module_base]))
try:
with open(module_file, 'r') as fp:
from pythran.frontend import raw_parse
node = raw_parse(fp.read())
add_filename_field(node, name + ".py")
return node
except IOError:
raise PythranSyntaxError("Module '{}' not found."
.format(name))
class HandleImport(Transformation):
"""This pass handle user-defined import, mangling name for function from
other modules and include them in the current module, patching all call
site accordingly.
"""
def __init__(self):
super(HandleImport, self).__init__()
self.identifiers = [{}]
self.imported = set()
self.prefixes = [""]
def lookup(self, name):
for renaming in reversed(self.identifiers):
if name in renaming:
return renaming[name]
return None
def is_imported(self, name):
return name in self.imported
def visit_Module(self, node):
self.imported_stmts = list()
self.generic_visit(node)
node.body = self.imported_stmts + node.body
return node
def rename(self, node, attr):
prev_name = getattr(node, attr)
new_name = self.prefixes[-1] + prev_name
setattr(node, attr, new_name)
self.identifiers[-1][prev_name] = new_name
def rename_top_level_functions(self, node):
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef):
self.rename(stmt, 'name')
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
self.rename(target, 'id')
def visit_FunctionDef(self, node):
self.identifiers.append({})
self.generic_visit(node)
self.identifiers.pop()
return node
def visit_ListComp(self, node):
# change transversal order so that store happens before load
for generator in node.generators:
self.visit(generator)
self.visit(node.elt)
return node
visit_SetComp = visit_ListComp
visit_GeneratorExp = visit_ListComp
def visit_DictComp(self, node):
for generator in node.generators:
self.visit(generator)
self.visit(node.key)
self.visit(node.value)
return node
def visit_comprehension(self, node):
self.visit(node.iter)
for if_ in node.ifs:
self.visit(if_)
self.visit(node.target)
return node
def visit_assign(self, node):
self.visit(node.value)
for target in node.targets:
self.visit(target)
return node
def visit_Assign(self, node):
if not isinstance(node.value, ast.Name):
return self.visit_assign(node)
renaming = self.lookup(node.value.id)
if not renaming:
return self.visit_assign(node)
if not is_mangled_module(renaming):
return self.visit_assign(node)
if any(not isinstance(target, ast.Name) for target in node.targets):
raise PythranSyntaxError("Invalid module assignment", node)
return node
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
renaming = self.lookup(node.id)
if renaming:
node.id = renaming
elif isinstance(node.ctx, (ast.Store, ast.Param)):
self.identifiers[-1][node.id] = node.id
elif isinstance(node.ctx, ast.Del):
pass
else:
raise NotImplementedError(node)
return node
def visit_Attribute(self, node):
if not isinstance(node.ctx, ast.Load):
return node
# is that a module attribute load?
root = node.value
while isinstance(root, ast.Attribute):
root = root.value
if not isinstance(root, ast.Name):
return node
renaming = self.lookup(root.id)
if not renaming:
return node
if not is_mangled_module(renaming):
return node
base_module = demangle(renaming)
if is_builtin_module(base_module):
return node
renaming = self.lookup(root.id)
root = node
suffix = ""
while isinstance(root, ast.Attribute):
root = root.value
suffix = '$' + node.attr + suffix
return ast.Name(renaming + suffix[1:], node.ctx, None, None)
def import_module(self, module_name, module_level):
self.imported.add(module_name)
module_node = getsource(module_name,
self.passmanager.module_dir,
module_level)
self.prefixes.append(mangle_imported_module(module_name))
self.identifiers.append({})
self.rename_top_level_functions(module_node)
self.generic_visit(module_node)
self.prefixes.pop()
self.identifiers.pop()
return module_node.body
def visit_ImportFrom(self, node):
if node.module == '__future__':
return None
if is_builtin_module(node.module):
for alias in node.names:
name = alias.asname or alias.name
self.identifiers[-1][name] = name
return node
else:
for alias in node.names:
name = alias.asname or alias.name
self.identifiers[-1][name] = mangle_imported_function(
node.module, alias.name)
if self.is_imported(node.module):
return None
new_stmts = self.import_module(node.module, node.level)
self.imported_stmts.extend(new_stmts)
return None
def visit_Import(self, node):
new_aliases = []
for alias in node.names:
name = alias.asname or alias.name
self.identifiers[-1][name] = mangle_imported_module(alias.name)
if alias.name in self.imported:
continue
if is_builtin_module(alias.name):
new_aliases.append(alias)
continue
new_stmts = self.import_module(alias.name, 0)
self.imported_stmts.extend(new_stmts)
if new_aliases:
node.names = new_aliases
return node
else:
return None