%PDF- %PDF-
Mini Shell

Mini Shell

Direktori : /lib/python3/dist-packages/mypy/
Upload File :
Create Path :
Current File : //lib/python3/dist-packages/mypy/checkpattern.py

"""Pattern checker. This file is conceptually part of TypeChecker."""

from collections import defaultdict
from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union
from typing_extensions import Final

import mypy.checker
from mypy.checkmember import analyze_member_access
from mypy.expandtype import expand_type_by_instance
from mypy.join import join_types
from mypy.literals import literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy import message_registry
from mypy.messages import MessageBuilder
from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr
from mypy.patterns import (
    Pattern, AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern,
    ClassPattern, SingletonPattern
)
from mypy.plugin import Plugin
from mypy.subtypes import is_subtype
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \
    coerce_to_literal
from mypy.types import (
    ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type,
    TypedDictType, TupleType, NoneType, UnionType
)
from mypy.typevars import fill_typevars
from mypy.visitor import PatternVisitor

self_match_type_names: Final = [
    "builtins.bool",
    "builtins.bytearray",
    "builtins.bytes",
    "builtins.dict",
    "builtins.float",
    "builtins.frozenset",
    "builtins.int",
    "builtins.list",
    "builtins.set",
    "builtins.str",
    "builtins.tuple",
]

non_sequence_match_type_names: Final = [
    "builtins.str",
    "builtins.bytes",
    "builtins.bytearray"
]


# For every Pattern a PatternType can be calculated. This requires recursively calculating
# the PatternTypes of the sub-patterns first.
# Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
PatternType = NamedTuple(
    'PatternType',
    [
        ('type', Type),  # The type the match subject can be narrowed to
        ('rest_type', Type),  # The remaining type if the pattern didn't match
        ('captures', Dict[Expression, Type]),  # The variables captured by the pattern
    ])


class PatternChecker(PatternVisitor[PatternType]):
    """Pattern checker.

    This class checks if a pattern can match a type, what the type can be narrowed to, and what
    type capture patterns should be inferred as.
    """

    # Some services are provided by a TypeChecker instance.
    chk: 'mypy.checker.TypeChecker'
    # This is shared with TypeChecker, but stored also here for convenience.
    msg: MessageBuilder
    # Currently unused
    plugin: Plugin
    # The expression being matched against the pattern
    subject: Expression

    subject_type: Type
    # Type of the subject to check the (sub)pattern against
    type_context: List[Type]
    # Types that match against self instead of their __match_args__ if used as a class pattern
    # Filled in from self_match_type_names
    self_match_types: List[Type]
    # Types that are sequences, but don't match sequence patterns. Filled in from
    # non_sequence_match_type_names
    non_sequence_match_types: List[Type]

    def __init__(self,
                 chk: 'mypy.checker.TypeChecker',
                 msg: MessageBuilder, plugin: Plugin
                 ) -> None:
        self.chk = chk
        self.msg = msg
        self.plugin = plugin

        self.type_context = []
        self.self_match_types = self.generate_types_from_names(self_match_type_names)
        self.non_sequence_match_types = self.generate_types_from_names(
            non_sequence_match_type_names
        )

    def accept(self, o: Pattern, type_context: Type) -> PatternType:
        self.type_context.append(type_context)
        result = o.accept(self)
        self.type_context.pop()

        return result

    def visit_as_pattern(self, o: AsPattern) -> PatternType:
        current_type = self.type_context[-1]
        if o.pattern is not None:
            pattern_type = self.accept(o.pattern, current_type)
            typ, rest_type, type_map = pattern_type
        else:
            typ, rest_type, type_map = current_type, UninhabitedType(), {}

        if not is_uninhabited(typ) and o.name is not None:
            typ, _ = self.chk.conditional_types_with_intersection(current_type,
                                                                  [get_type_range(typ)],
                                                                  o,
                                                                  default=current_type)
            if not is_uninhabited(typ):
                type_map[o.name] = typ

        return PatternType(typ, rest_type, type_map)

    def visit_or_pattern(self, o: OrPattern) -> PatternType:
        current_type = self.type_context[-1]

        #
        # Check all the subpatterns
        #
        pattern_types = []
        for pattern in o.patterns:
            pattern_type = self.accept(pattern, current_type)
            pattern_types.append(pattern_type)
            current_type = pattern_type.rest_type

        #
        # Collect the final type
        #
        types = []
        for pattern_type in pattern_types:
            if not is_uninhabited(pattern_type.type):
                types.append(pattern_type.type)

        #
        # Check the capture types
        #
        capture_types: Dict[Var, List[Tuple[Expression, Type]]] = defaultdict(list)
        # Collect captures from the first subpattern
        for expr, typ in pattern_types[0].captures.items():
            node = get_var(expr)
            capture_types[node].append((expr, typ))

        # Check if other subpatterns capture the same names
        for i, pattern_type in enumerate(pattern_types[1:]):
            vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
            if capture_types.keys() != vars:
                self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
            for expr, typ in pattern_type.captures.items():
                node = get_var(expr)
                capture_types[node].append((expr, typ))

        captures: Dict[Expression, Type] = {}
        for var, capture_list in capture_types.items():
            typ = UninhabitedType()
            for _, other in capture_list:
                typ = join_types(typ, other)

            captures[capture_list[0][0]] = typ

        union_type = make_simplified_union(types)
        return PatternType(union_type, current_type, captures)

    def visit_value_pattern(self, o: ValuePattern) -> PatternType:
        current_type = self.type_context[-1]
        typ = self.chk.expr_checker.accept(o.expr)
        typ = coerce_to_literal(typ)
        narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
            current_type,
            [get_type_range(typ)],
            o,
            default=current_type
        )
        return PatternType(narrowed_type, rest_type, {})

    def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
        current_type = self.type_context[-1]
        value: Union[bool, None] = o.value
        if isinstance(value, bool):
            typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
        elif value is None:
            typ = NoneType()
        else:
            assert False

        narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
            current_type,
            [get_type_range(typ)],
            o,
            default=current_type
        )
        return PatternType(narrowed_type, rest_type, {})

    def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
        #
        # check for existence of a starred pattern
        #
        current_type = get_proper_type(self.type_context[-1])
        if not self.can_match_sequence(current_type):
            return self.early_non_match()
        star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
        star_position: Optional[int] = None
        if len(star_positions) == 1:
            star_position = star_positions[0]
        elif len(star_positions) >= 2:
            assert False, "Parser should prevent multiple starred patterns"
        required_patterns = len(o.patterns)
        if star_position is not None:
            required_patterns -= 1

        #
        # get inner types of original type
        #
        if isinstance(current_type, TupleType):
            inner_types = current_type.items
            size_diff = len(inner_types) - required_patterns
            if size_diff < 0:
                return self.early_non_match()
            elif size_diff > 0 and star_position is None:
                return self.early_non_match()
        else:
            inner_type = self.get_sequence_type(current_type)
            if inner_type is None:
                inner_type = self.chk.named_type("builtins.object")
            inner_types = [inner_type] * len(o.patterns)

        #
        # match inner patterns
        #
        contracted_new_inner_types: List[Type] = []
        contracted_rest_inner_types: List[Type] = []
        captures: Dict[Expression, Type] = {}

        contracted_inner_types = self.contract_starred_pattern_types(inner_types,
                                                                     star_position,
                                                                     required_patterns)
        can_match = True
        for p, t in zip(o.patterns, contracted_inner_types):
            pattern_type = self.accept(p, t)
            typ, rest, type_map = pattern_type
            if is_uninhabited(typ):
                can_match = False
            else:
                contracted_new_inner_types.append(typ)
                contracted_rest_inner_types.append(rest)
            self.update_type_map(captures, type_map)
        new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types,
                                                            star_position,
                                                            len(inner_types))
        rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types,
                                                             star_position,
                                                             len(inner_types))

        #
        # Calculate new type
        #
        new_type: Type
        rest_type: Type = current_type
        if not can_match:
            new_type = UninhabitedType()
        elif isinstance(current_type, TupleType):
            narrowed_inner_types = []
            inner_rest_types = []
            for inner_type, new_inner_type in zip(inner_types, new_inner_types):
                narrowed_inner_type, inner_rest_type = \
                    self.chk.conditional_types_with_intersection(
                        new_inner_type,
                        [get_type_range(inner_type)],
                        o,
                        default=new_inner_type
                    )
                narrowed_inner_types.append(narrowed_inner_type)
                inner_rest_types.append(inner_rest_type)
            if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
                new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
            else:
                new_type = UninhabitedType()

            if all(is_uninhabited(typ) for typ in inner_rest_types):
                # All subpatterns always match, so we can apply negative narrowing
                rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
        else:
            new_inner_type = UninhabitedType()
            for typ in new_inner_types:
                new_inner_type = join_types(new_inner_type, typ)
            new_type = self.construct_sequence_child(current_type, new_inner_type)
            if is_subtype(new_type, current_type):
                new_type, _ = self.chk.conditional_types_with_intersection(
                    current_type,
                    [get_type_range(new_type)],
                    o,
                    default=current_type
                )
            else:
                new_type = current_type
        return PatternType(new_type, rest_type, captures)

    def get_sequence_type(self, t: Type) -> Optional[Type]:
        t = get_proper_type(t)
        if isinstance(t, AnyType):
            return AnyType(TypeOfAny.from_another_any, t)
        if isinstance(t, UnionType):
            items = [self.get_sequence_type(item) for item in t.items]
            not_none_items = [item for item in items if item is not None]
            if len(not_none_items) > 0:
                return make_simplified_union(not_none_items)
            else:
                return None

        if self.chk.type_is_iterable(t) and isinstance(t, Instance):
            return self.chk.iterable_item_type(t)
        else:
            return None

    def contract_starred_pattern_types(self,
                                       types: List[Type],
                                       star_pos: Optional[int],
                                       num_patterns: int
                                       ) -> List[Type]:
        """
        Contracts a list of types in a sequence pattern depending on the position of a starred
        capture pattern.

        For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
        bytes] the contracted types are [bool, Union[int, str], bytes].

        If star_pos in None the types are returned unchanged.
        """
        if star_pos is None:
            return types
        new_types = types[:star_pos]
        star_length = len(types) - num_patterns
        new_types.append(make_simplified_union(types[star_pos:star_pos+star_length]))
        new_types += types[star_pos+star_length:]

        return new_types

    def expand_starred_pattern_types(self,
                                     types: List[Type],
                                     star_pos: Optional[int],
                                     num_types: int
                                     ) -> List[Type]:
        """Undoes the contraction done by contract_starred_pattern_types.

        For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
        to lenght 4 the result is [bool, int, int, str].
        """
        if star_pos is None:
            return types
        new_types = types[:star_pos]
        star_length = num_types - len(types) + 1
        new_types += [types[star_pos]] * star_length
        new_types += types[star_pos+1:]

        return new_types

    def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
        captures: Dict[Expression, Type] = {}
        if o.capture is not None:
            list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]])
            captures[o.capture] = list_type
        return PatternType(self.type_context[-1], UninhabitedType(), captures)

    def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
        current_type = get_proper_type(self.type_context[-1])
        can_match = True
        captures: Dict[Expression, Type] = {}
        for key, value in zip(o.keys, o.values):
            inner_type = self.get_mapping_item_type(o, current_type, key)
            if inner_type is None:
                can_match = False
                inner_type = self.chk.named_type("builtins.object")
            pattern_type = self.accept(value, inner_type)
            if is_uninhabited(pattern_type.type):
                can_match = False
            else:
                self.update_type_map(captures, pattern_type.captures)

        if o.rest is not None:
            mapping = self.chk.named_type("typing.Mapping")
            if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
                mapping_inst = map_instance_to_supertype(current_type, mapping.type)
                dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
                dict_type = fill_typevars(dict_typeinfo)
                rest_type = expand_type_by_instance(dict_type, mapping_inst)
            else:
                object_type = self.chk.named_type("builtins.object")
                rest_type = self.chk.named_generic_type("builtins.dict",
                                                        [object_type, object_type])

            captures[o.rest] = rest_type

        if can_match:
            # We can't narrow the type here, as Mapping key is invariant.
            new_type = self.type_context[-1]
        else:
            new_type = UninhabitedType()
        return PatternType(new_type, current_type, captures)

    def get_mapping_item_type(self,
                              pattern: MappingPattern,
                              mapping_type: Type,
                              key: Expression
                              ) -> Optional[Type]:
        mapping_type = get_proper_type(mapping_type)
        if isinstance(mapping_type, TypedDictType):
            with self.msg.filter_errors() as local_errors:
                result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr(
                    mapping_type, key)
                has_local_errors = local_errors.has_new_errors()
            # If we can't determine the type statically fall back to treating it as a normal
            # mapping
            if has_local_errors:
                with self.msg.filter_errors() as local_errors:
                    result = self.get_simple_mapping_item_type(pattern,
                                                               mapping_type,
                                                               key,
                                                               self.msg)

                    if local_errors.has_new_errors():
                        result = None
        else:
            with self.msg.filter_errors():
                result = self.get_simple_mapping_item_type(pattern,
                                                           mapping_type,
                                                           key,
                                                           self.msg)
        return result

    def get_simple_mapping_item_type(self,
                                     pattern: MappingPattern,
                                     mapping_type: Type,
                                     key: Expression,
                                     local_errors: MessageBuilder
                                     ) -> Type:
        result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__',
                                                                    mapping_type,
                                                                    [key],
                                                                    [ARG_POS],
                                                                    pattern,
                                                                    local_errors=local_errors)
        return result

    def visit_class_pattern(self, o: ClassPattern) -> PatternType:
        current_type = get_proper_type(self.type_context[-1])

        #
        # Check class type
        #
        type_info = o.class_ref.node
        if type_info is None:
            return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {})
        if isinstance(type_info, TypeAlias) and not type_info.no_args:
            self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
            return self.early_non_match()
        if isinstance(type_info, TypeInfo):
            any_type = AnyType(TypeOfAny.implementation_artifact)
            typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
        elif isinstance(type_info, TypeAlias):
            typ = type_info.target
        else:
            if isinstance(type_info, Var):
                name = str(type_info.type)
            else:
                name = type_info.name
            self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o.class_ref)
            return self.early_non_match()

        new_type, rest_type = self.chk.conditional_types_with_intersection(
            current_type, [get_type_range(typ)], o, default=current_type
        )
        if is_uninhabited(new_type):
            return self.early_non_match()
        # TODO: Do I need this?
        narrowed_type = narrow_declared_type(current_type, new_type)

        #
        # Convert positional to keyword patterns
        #
        keyword_pairs: List[Tuple[Optional[str], Pattern]] = []
        match_arg_set: Set[str] = set()

        captures: Dict[Expression, Type] = {}

        if len(o.positionals) != 0:
            if self.should_self_match(typ):
                if len(o.positionals) > 1:
                    self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                pattern_type = self.accept(o.positionals[0], narrowed_type)
                if not is_uninhabited(pattern_type.type):
                    return PatternType(pattern_type.type,
                                       join_types(rest_type, pattern_type.rest_type),
                                       pattern_type.captures)
                captures = pattern_type.captures
            else:
                with self.msg.filter_errors() as local_errors:
                    match_args_type = analyze_member_access("__match_args__", typ, o,
                                                            False, False, False,
                                                            self.msg,
                                                            original_type=typ,
                                                            chk=self.chk)
                    has_local_errors = local_errors.has_new_errors()
                if has_local_errors:
                    self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o)
                    return self.early_non_match()

                proper_match_args_type = get_proper_type(match_args_type)
                if isinstance(proper_match_args_type, TupleType):
                    match_arg_names = get_match_arg_names(proper_match_args_type)

                    if len(o.positionals) > len(match_arg_names):
                        self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                        return self.early_non_match()
                else:
                    match_arg_names = [None] * len(o.positionals)

                for arg_name, pos in zip(match_arg_names, o.positionals):
                    keyword_pairs.append((arg_name, pos))
                    if arg_name is not None:
                        match_arg_set.add(arg_name)

        #
        # Check for duplicate patterns
        #
        keyword_arg_set = set()
        has_duplicates = False
        for key, value in zip(o.keyword_keys, o.keyword_values):
            keyword_pairs.append((key, value))
            if key in match_arg_set:
                self.msg.fail(
                    message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key),
                    value
                )
                has_duplicates = True
            elif key in keyword_arg_set:
                self.msg.fail(message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key),
                              value)
                has_duplicates = True
            keyword_arg_set.add(key)

        if has_duplicates:
            return self.early_non_match()

        #
        # Check keyword patterns
        #
        can_match = True
        for keyword, pattern in keyword_pairs:
            key_type: Optional[Type] = None
            with self.msg.filter_errors() as local_errors:
                if keyword is not None:
                    key_type = analyze_member_access(keyword,
                                                     narrowed_type,
                                                     pattern,
                                                     False,
                                                     False,
                                                     False,
                                                     self.msg,
                                                     original_type=new_type,
                                                     chk=self.chk)
                else:
                    key_type = AnyType(TypeOfAny.from_error)
                has_local_errors = local_errors.has_new_errors()
            if has_local_errors or key_type is None:
                key_type = AnyType(TypeOfAny.from_error)
                self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword),
                              pattern)

            inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
            if is_uninhabited(inner_type):
                can_match = False
            else:
                self.update_type_map(captures, inner_captures)
                if not is_uninhabited(inner_rest_type):
                    rest_type = current_type

        if not can_match:
            new_type = UninhabitedType()
        return PatternType(new_type, rest_type, captures)

    def should_self_match(self, typ: Type) -> bool:
        typ = get_proper_type(typ)
        if isinstance(typ, Instance) and typ.type.is_named_tuple:
            return False
        for other in self.self_match_types:
            if is_subtype(typ, other):
                return True
        return False

    def can_match_sequence(self, typ: ProperType) -> bool:
        if isinstance(typ, UnionType):
            return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
        for other in self.non_sequence_match_types:
            # We have to ignore promotions, as memoryview should match, but bytes,
            # which it can be promoted to, shouldn't
            if is_subtype(typ, other, ignore_promotions=True):
                return False
        sequence = self.chk.named_type("typing.Sequence")
        # If the static type is more general than sequence the actual type could still match
        return is_subtype(typ, sequence) or is_subtype(sequence, typ)

    def generate_types_from_names(self, type_names: List[str]) -> List[Type]:
        types: List[Type] = []
        for name in type_names:
            try:
                types.append(self.chk.named_type(name))
            except KeyError as e:
                # Some built in types are not defined in all test cases
                if not name.startswith('builtins.'):
                    raise e
                pass

        return types

    def update_type_map(self,
                        original_type_map: Dict[Expression, Type],
                        extra_type_map: Dict[Expression, Type]
                        ) -> None:
        # Calculating this would not be needed if TypeMap directly used literal hashes instead of
        # expressions, as suggested in the TODO above it's definition
        already_captured = set(literal_hash(expr) for expr in original_type_map)
        for expr, typ in extra_type_map.items():
            if literal_hash(expr) in already_captured:
                node = get_var(expr)
                self.msg.fail(message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name),
                              expr)
            else:
                original_type_map[expr] = typ

    def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
        """
        If outer_type is a child class of typing.Sequence returns a new instance of
        outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
        typing.Sequence just returns a Sequence of inner_type

        For example:
        construct_sequence_child(List[int], str) = List[str]
        """
        proper_type = get_proper_type(outer_type)
        if isinstance(proper_type, UnionType):
            types = [
                self.construct_sequence_child(item, inner_type) for item in proper_type.items
                if self.can_match_sequence(get_proper_type(item))
            ]
            return make_simplified_union(types)
        sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
        if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
            proper_type = get_proper_type(outer_type)
            assert isinstance(proper_type, Instance)
            empty_type = fill_typevars(proper_type.type)
            partial_type = expand_type_by_instance(empty_type, sequence)
            return expand_type_by_instance(partial_type, proper_type)
        else:
            return sequence

    def early_non_match(self) -> PatternType:
        return PatternType(UninhabitedType(), self.type_context[-1], {})


def get_match_arg_names(typ: TupleType) -> List[Optional[str]]:
    args: List[Optional[str]] = []
    for item in typ.items:
        values = try_getting_str_literals_from_type(item)
        if values is None or len(values) != 1:
            args.append(None)
        else:
            args.append(values[0])
    return args


def get_var(expr: Expression) -> Var:
    """
    Warning: this in only true for expressions captured by a match statement.
    Don't call it from anywhere else
    """
    assert isinstance(expr, NameExpr)
    node = expr.node
    assert isinstance(node, Var)
    return node


def get_type_range(typ: Type) -> 'mypy.checker.TypeRange':
    typ = get_proper_type(typ)
    if (isinstance(typ, Instance)
            and typ.last_known_value
            and isinstance(typ.last_known_value.value, bool)):
        typ = typ.last_known_value
    return mypy.checker.TypeRange(typ, is_upper_bound=False)


def is_uninhabited(typ: Type) -> bool:
    return isinstance(get_proper_type(typ), UninhabitedType)

Zerion Mini Shell 1.0