diff --git a/lark/lexer.py b/lark/lexer.py index 1232c627..2fba894f 100644 --- a/lark/lexer.py +++ b/lark/lexer.py @@ -497,7 +497,24 @@ def _check_regex_collisions(terminal_to_regexp: Dict[TerminalDef, str], comparat return -class BasicLexer(Lexer): +class AbstractBasicLexer(Lexer): + terminals_by_name: Dict[str, TerminalDef] + + @abstractmethod + def __init__(self, conf: 'LexerConf', comparator=None) -> None: + ... + + @abstractmethod + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + ... + + def lex(self, state: LexerState, parser_state: Any) -> Iterator[Token]: + with suppress(EOFError): + while True: + yield self.next_token(state, parser_state) + + +class BasicLexer(AbstractBasicLexer): terminals: Collection[TerminalDef] ignore_types: FrozenSet[str] newline_types: FrozenSet[str] @@ -569,11 +586,6 @@ def scanner(self): def match(self, text, pos): return self.scanner.match(text, pos) - def lex(self, state: LexerState, parser_state: Any) -> Iterator[Token]: - with suppress(EOFError): - while True: - yield self.next_token(state, parser_state) - def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: line_ctr = lex_state.line_ctr while line_ctr.char_pos < len(lex_state.text): @@ -611,8 +623,10 @@ def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: class ContextualLexer(Lexer): - lexers: Dict[str, BasicLexer] - root_lexer: BasicLexer + lexers: Dict[str, AbstractBasicLexer] + root_lexer: AbstractBasicLexer + + BasicLexer: Type[AbstractBasicLexer] = BasicLexer def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always_accept: Collection[str]=()) -> None: terminals = list(conf.terminals) @@ -625,7 +639,7 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always comparator = interegular.Comparator.from_regexes({t: t.pattern.to_regexp() for t in terminals}) else: comparator = None - lexer_by_tokens: Dict[FrozenSet[str], BasicLexer] = {} + lexer_by_tokens: Dict[FrozenSet[str], AbstractBasicLexer] = {} self.lexers = {} for state, accepts in states.items(): key = frozenset(accepts) @@ -635,14 +649,14 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always accepts = set(accepts) | set(conf.ignore) | set(always_accept) lexer_conf = copy(trad_conf) lexer_conf.terminals = [terminals_by_name[n] for n in accepts if n in terminals_by_name] - lexer = BasicLexer(lexer_conf, comparator) + lexer = self.BasicLexer(lexer_conf, comparator) lexer_by_tokens[key] = lexer self.lexers[state] = lexer assert trad_conf.terminals is terminals trad_conf.skip_validation = True # We don't need to verify all terminals again - self.root_lexer = BasicLexer(trad_conf, comparator) + self.root_lexer = self.BasicLexer(trad_conf, comparator) def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: try: