diff --git a/lark/lark.py b/lark/lark.py index d4050552..8da98409 100644 --- a/lark/lark.py +++ b/lark/lark.py @@ -3,6 +3,7 @@ import sys, os, pickle, hashlib import tempfile import types +import re from typing import ( TypeVar, Type, List, Dict, Iterator, Callable, Union, Optional, Sequence, Tuple, Iterable, IO, Any, TYPE_CHECKING, Collection @@ -15,6 +16,7 @@ from typing import Literal else: from typing_extensions import Literal + from .parser_frontends import ParsingFrontend from .exceptions import ConfigurationError, assert_config, UnexpectedInput from .utils import Serialize, SerializeMemoizer, FS, isascii, logger @@ -27,7 +29,7 @@ from .parser_frontends import _validate_frontend_args, _get_lexer_callbacks, _deserialize_parsing_frontend, _construct_parsing_frontend from .grammar import Rule -import re + try: import regex _has_regex = True @@ -176,7 +178,7 @@ class LarkOptions(Serialize): '_plugins': {}, } - def __init__(self, options_dict): + def __init__(self, options_dict: Dict[str, Any]) -> None: o = dict(options_dict) options = {} @@ -205,21 +207,21 @@ def __init__(self, options_dict): if o: raise ConfigurationError("Unknown options: %s" % o.keys()) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: try: return self.__dict__['options'][name] except KeyError as e: raise AttributeError(e) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: str) -> None: assert_config(name, self.options.keys(), "%r isn't a valid option. Expected one of: %s") self.options[name] = value - def serialize(self, memo): + def serialize(self, memo = None) -> Dict[str, Any]: return self.options @classmethod - def deserialize(cls, data, memo): + def deserialize(cls, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]]) -> "LarkOptions": return cls(data) @@ -252,7 +254,7 @@ class Lark(Serialize): grammar: 'Grammar' options: LarkOptions lexer: Lexer - terminals: List[TerminalDef] + terminals: Collection[TerminalDef] def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None: self.options = LarkOptions(options) @@ -446,7 +448,7 @@ def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None: __serialize_fields__ = 'parser', 'rules', 'options' - def _build_lexer(self, dont_ignore=False): + def _build_lexer(self, dont_ignore: bool=False) -> BasicLexer: lexer_conf = self.lexer_conf if dont_ignore: from copy import copy @@ -454,7 +456,7 @@ def _build_lexer(self, dont_ignore=False): lexer_conf.ignore = () return BasicLexer(lexer_conf) - def _prepare_callbacks(self): + def _prepare_callbacks(self) -> None: self._callbacks = {} # we don't need these callbacks if we aren't building a tree if self.options.ambiguity != 'forest': @@ -468,7 +470,7 @@ def _prepare_callbacks(self): self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer) self._callbacks.update(_get_lexer_callbacks(self.options.transformer, self.terminals)) - def _build_parser(self): + def _build_parser(self) -> "ParsingFrontend": self._prepare_callbacks() _validate_frontend_args(self.options.parser, self.options.lexer) parser_conf = ParserConf(self.rules, self._callbacks, self.options.start) @@ -480,7 +482,7 @@ def _build_parser(self): options=self.options ) - def save(self, f, exclude_options: Collection[str] = ()): + def save(self, f, exclude_options: Collection[str] = ()) -> None: """Saves the instance into the given file object Useful for caching and multiprocessing. @@ -491,7 +493,7 @@ def save(self, f, exclude_options: Collection[str] = ()): pickle.dump({'data': data, 'memo': m}, f, protocol=pickle.HIGHEST_PROTOCOL) @classmethod - def load(cls, f): + def load(cls: Type[_T], f) -> _T: """Loads an instance from the given file object Useful for caching and multiprocessing. @@ -499,7 +501,7 @@ def load(cls, f): inst = cls.__new__(cls) return inst._load(f) - def _deserialize_lexer_conf(self, data, memo, options): + def _deserialize_lexer_conf(self, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]], options: LarkOptions) -> LexerConf: lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo) lexer_conf.callbacks = options.lexer_callbacks or {} lexer_conf.re_module = regex if options.regex else re @@ -509,7 +511,7 @@ def _deserialize_lexer_conf(self, data, memo, options): lexer_conf.postlex = options.postlex return lexer_conf - def _load(self, f, **kwargs): + def _load(self: _T, f: Any, **kwargs) -> _T: if isinstance(f, dict): d = f else: @@ -593,6 +595,7 @@ def lex(self, text: str, dont_ignore: bool=False) -> Iterator[Token]: :raises UnexpectedCharacters: In case the lexer cannot find a suitable match. """ + lexer: Lexer if not hasattr(self, 'lexer') or dont_ignore: lexer = self._build_lexer(dont_ignore) else: diff --git a/lark/parsers/lalr_parser.py b/lark/parsers/lalr_parser.py index 292d5c24..48dac7b7 100644 --- a/lark/parsers/lalr_parser.py +++ b/lark/parsers/lalr_parser.py @@ -3,6 +3,7 @@ # Author: Erez Shinan (2017) # Email : erezshin@gmail.com from copy import deepcopy, copy +from typing import Dict, Any from ..lexer import Token from ..utils import Serialize @@ -29,7 +30,7 @@ def deserialize(cls, data, memo, callbacks, debug=False): inst.parser = _Parser(inst._parse_table, callbacks, debug) return inst - def serialize(self, memo): + def serialize(self, memo: Any = None) -> Dict[str, Any]: return self._parse_table.serialize(memo) def parse_interactive(self, lexer, start): diff --git a/lark/utils.py b/lark/utils.py index f9c0fd01..6781e6fb 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -2,10 +2,12 @@ import os from functools import reduce from collections import deque +from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence ###{standalone import sys, re import logging + logger: logging.Logger = logging.getLogger("lark") logger.addHandler(logging.StreamHandler()) # Set to highest level, since we have some warnings amongst the code @@ -15,9 +17,11 @@ NO_VALUE = object() +T = TypeVar("T") + -def classify(seq, key=None, value=None): - d = {} +def classify(seq: Sequence, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: + d: Dict[Any, Any] = {} for item in seq: k = key(item) if (key is not None) else item v = value(item) if (value is not None) else item @@ -28,7 +32,7 @@ def classify(seq, key=None, value=None): return d -def _deserialize(data, namespace, memo): +def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: if isinstance(data, dict): if '__type__' in data: # Object class_ = namespace[data['__type__']] @@ -41,6 +45,8 @@ def _deserialize(data, namespace, memo): return data +_T = TypeVar("_T", bound="Serialize") + class Serialize: """Safe-ish serialization interface that doesn't rely on Pickle @@ -50,11 +56,11 @@ class Serialize: Should include all field types that aren't builtin types. """ - def memo_serialize(self, types_to_memoize): + def memo_serialize(self, types_to_memoize: List) -> Any: memo = SerializeMemoizer(types_to_memoize) return self.serialize(memo), memo.serialize() - def serialize(self, memo=None): + def serialize(self, memo = None) -> Dict[str, Any]: if memo and memo.in_types(self): return {'@': memo.memoized.get(self)} @@ -62,11 +68,11 @@ def serialize(self, memo=None): res = {f: _serialize(getattr(self, f), memo) for f in fields} res['__type__'] = type(self).__name__ if hasattr(self, '_serialize'): - self._serialize(res, memo) + self._serialize(res, memo) # type: ignore[attr-defined] return res @classmethod - def deserialize(cls, data, memo): + def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: namespace = getattr(cls, '__serialize_namespace__', []) namespace = {c.__name__:c for c in namespace} @@ -83,7 +89,7 @@ def deserialize(cls, data, memo): raise KeyError("Cannot find key for class", cls, e) if hasattr(inst, '_deserialize'): - inst._deserialize() + inst._deserialize() # type: ignore[attr-defined] return inst @@ -93,18 +99,18 @@ class SerializeMemoizer(Serialize): __serialize_fields__ = 'memoized', - def __init__(self, types_to_memoize): + def __init__(self, types_to_memoize: List) -> None: self.types_to_memoize = tuple(types_to_memoize) self.memoized = Enumerator() - def in_types(self, value): + def in_types(self, value: Serialize) -> bool: return isinstance(value, self.types_to_memoize) - def serialize(self): + def serialize(self) -> Dict[int, Any]: # type: ignore[override] return _serialize(self.memoized.reversed(), None) @classmethod - def deserialize(cls, data, namespace, memo): + def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] return _deserialize(data, namespace, memo) @@ -123,7 +129,7 @@ def deserialize(cls, data, namespace, memo): categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') -def get_regexp_width(expr): +def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: if _has_regex: # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex @@ -134,7 +140,8 @@ def get_regexp_width(expr): raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) regexp_final = expr try: - return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] + # Fixed in next version (past 0.960) of typeshed + return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined] except sre_constants.error: if not _has_regex: raise ValueError(expr) @@ -154,19 +161,19 @@ def get_regexp_width(expr): _ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' _ID_CONTINUE = _ID_START + ('Nd', 'Nl',) -def _test_unicode_category(s, categories): +def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: if len(s) != 1: return all(_test_unicode_category(char, categories) for char in s) return s == '_' or unicodedata.category(s) in categories -def is_id_continue(s): +def is_id_continue(s: str) -> bool: """ Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. """ return _test_unicode_category(s, _ID_CONTINUE) -def is_id_start(s): +def is_id_start(s: str) -> bool: """ Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. @@ -174,19 +181,22 @@ def is_id_start(s): return _test_unicode_category(s, _ID_START) -def dedup_list(l): +def dedup_list(l: List[T]) -> List[T]: """Given a list (l) will removing duplicates from the list, preserving the original order of the list. Assumes that the list entries are hashable.""" dedup = set() - return [x for x in l if not (x in dedup or dedup.add(x))] + # This returns None, but that's expected + return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value] + # 2x faster (ordered in PyPy and CPython 3.6+, gaurenteed to be ordered in Python 3.7+) + # return list(dict.fromkeys(l)) class Enumerator(Serialize): - def __init__(self): - self.enums = {} + def __init__(self) -> None: + self.enums: Dict[Any, int] = {} - def get(self, item): + def get(self, item) -> int: if item not in self.enums: self.enums[item] = len(self.enums) return self.enums[item] @@ -194,7 +204,7 @@ def get(self, item): def __len__(self): return len(self.enums) - def reversed(self): + def reversed(self) -> Dict[int, Any]: r = {v: k for k, v in self.enums.items()} assert len(r) == len(self.enums) return r @@ -240,11 +250,11 @@ def open(name, mode="r", **kwargs): -def isascii(s): +def isascii(s: str) -> bool: """ str.isascii only exists in python3.7+ """ - try: + if sys.version_info >= (3, 7): return s.isascii() - except AttributeError: + else: try: s.encode('ascii') return True @@ -257,7 +267,7 @@ def __repr__(self): return '{%s}' % ', '.join(map(repr, self)) -def classify_bool(seq, pred): +def classify_bool(seq: Sequence, pred: Callable) -> Any: true_elems = [] false_elems = [] @@ -270,7 +280,7 @@ def classify_bool(seq, pred): return true_elems, false_elems -def bfs(initial, expand): +def bfs(initial: Sequence, expand: Callable) -> Iterator: open_q = deque(list(initial)) visited = set(open_q) while open_q: @@ -290,7 +300,7 @@ def bfs_all_unique(initial, expand): open_q += expand(node) -def _serialize(value, memo): +def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: if isinstance(value, Serialize): return value.serialize(memo) elif isinstance(value, list): @@ -305,7 +315,7 @@ def _serialize(value, memo): -def small_factors(n, max_factor): +def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: """ Splits n up into smaller factors and summands <= max_factor. Returns a list of [(a, b), ...]