%PDF- %PDF-
| Direktori : /lib/python3/dist-packages/pythran/analyses/ |
| Current File : //lib/python3/dist-packages/pythran/analyses/ast_matcher.py |
""" Module to looks for a specified pattern in a given AST. """
from gast import AST, iter_fields, NodeVisitor, Dict, Set
from itertools import permutations
from math import isnan
MAX_UNORDERED_LENGTH = 10
class DamnTooLongPattern(Exception):
""" Exception for long dict/set comparison to reduce compile time. """
class Placeholder(AST):
""" Class to save information from ast while check for pattern. """
def __init__(self, identifier, type=None):
""" Placehorder are identified using an identifier. """
self.id = identifier
self.type = type
super(Placeholder, self).__init__()
class AST_any(AST):
""" Class to specify we don't care about a field value in ast. """
class AST_or(AST):
"""
Class to specify multiple possibles value for a given field in ast.
Attributes
----------
args: [ast field value]
List of possible value for a field of an ast.
"""
def __init__(self, *args):
""" Initialiser to keep track of arguments. """
self.args = args
super(AST_or, self).__init__()
class Check(NodeVisitor):
"""
Checker for ast <-> pattern.
NodeVisitor is needed for specific behavior checker.
Attributes
----------
node : AST
node we want to compare with pattern
placeholders : [AST]
list of placeholder value for later comparison or replacement.
"""
def __init__(self, node, placeholders):
""" Initialize attributes. """
self.node = node
self.placeholders = placeholders
def check_list(self, node_list, pattern_list):
""" Check if list of node are equal. """
if len(node_list) != len(pattern_list):
return False
return all(Check(node_elt, self.placeholders).visit(pattern_elt)
for node_elt, pattern_elt in zip(node_list, pattern_list))
def visit_Placeholder(self, pattern):
"""
Save matching node or compare it with the existing one.
FIXME : What if the new placeholder is a better choice?
"""
if (pattern.id in self.placeholders and
not Check(self.node, self.placeholders).visit(
self.placeholders[pattern.id])):
return False
elif pattern.type is not None and not isinstance(self.node,
pattern.type):
return False
else:
self.placeholders[pattern.id] = self.node
return True
@staticmethod
def visit_AST_any(_):
""" Every node match with it. """
return True
def visit_AST_or(self, pattern):
""" Match if any of the or content match with the other node. """
return any(self.field_match(self.node, value_or)
for value_or in pattern.args)
def visit_Set(self, pattern):
""" Set have unordered values. """
if not isinstance(self.node, Set):
return False
if len(pattern.elts) > MAX_UNORDERED_LENGTH:
raise DamnTooLongPattern("Pattern for Set is too long")
return any(self.check_list(self.node.elts, pattern_elts)
for pattern_elts in permutations(pattern.elts))
def visit_Dict(self, pattern):
""" Dict can match with unordered values. """
if not isinstance(self.node, Dict):
return False
if len(pattern.keys) > MAX_UNORDERED_LENGTH:
raise DamnTooLongPattern("Pattern for Dict is too long")
for permutation in permutations(range(len(self.node.keys))):
for i, value in enumerate(permutation):
if not self.field_match(self.node.keys[i],
pattern.keys[value]):
break
else:
pattern_values = [pattern.values[i] for i in permutation]
return self.check_list(self.node.values, pattern_values)
return False
def field_match(self, node_field, pattern_field):
"""
Check if two fields match.
Field match if:
- If it is a list, all values have to match.
- If if is a node, recursively check it.
- Otherwise, check values are equal.
"""
if isinstance(pattern_field, list):
return self.check_list(node_field, pattern_field)
if isinstance(pattern_field, AST):
return Check(node_field, self.placeholders).visit(pattern_field)
return Check.strict_eq(pattern_field, node_field)
@staticmethod
def strict_eq(f0, f1):
if f0 == f1:
return True
try:
return isnan(f0) and isnan(f1)
except TypeError:
return False
def generic_visit(self, pattern):
"""
Check if the pattern match with the checked node.
a node match if:
- type match
- all field match
"""
if not isinstance(pattern, type(self.node)):
return False
return all(self.field_match(value, getattr(pattern, field))
for field, value in iter_fields(self.node))
class ASTMatcher(NodeVisitor):
"""
Visitor to gather node matching with a given pattern.
Examples
--------
>>> import gast as ast
>>> code = "[(i, j) for i in range(a) for j in range(b)]"
>>> pattern = ast.Call(func=ast.Name('range', ctx=ast.Load(),
... annotation=None,
... type_comment=None),
... args=AST_any(), keywords=[])
>>> len(ASTMatcher(pattern).search(ast.parse(code)))
2
>>> code = "[(i, j) for i in range(a) for j in range(b)]"
>>> pattern = ast.Call(func=ast.Name(id=AST_or('range', 'range'),
... ctx=ast.Load(),
... annotation=None,
... type_comment=None),
... args=AST_any(), keywords=[])
>>> len(ASTMatcher(pattern).search(ast.parse(code)))
2
>>> code = "{1:2, 3:4}"
>>> pattern = ast.Dict(keys=[ast.Constant(3, None), ast.Constant(1, None)],
... values=[ast.Constant(4, None),
... ast.Constant(2, None)])
>>> len(ASTMatcher(pattern).search(ast.parse(code)))
1
>>> code = "{1, 2, 3}"
>>> pattern = ast.Set(elts=[ast.Constant(3, None),
... ast.Constant(2, None),
... ast.Constant(1, None)])
>>> len(ASTMatcher(pattern).search(ast.parse(code)))
1
"""
def __init__(self, pattern):
""" Basic initialiser saving pattern and initialising result set. """
self.pattern = pattern
self.result = set()
super(ASTMatcher, self).__init__()
def visit(self, node):
"""
Visitor looking for matching between current node and pattern.
If it match, save it but whatever happen, keep going.
"""
if Check(node, dict()).visit(self.pattern):
self.result.add(node)
self.generic_visit(node)
def search(self, node):
""" Facility to get values of the matcher for a given node. """
self.visit(node)
return self.result