Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added interegular support #1258

Merged
merged 6 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 64 additions & 31 deletions lark/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,25 @@
)
from types import ModuleType
import warnings
try:
MegaIng marked this conversation as resolved.
Show resolved Hide resolved
import interegular
except ImportError:
pass
if TYPE_CHECKING:
from .common import LexerConf

from .utils import classify, get_regexp_width, Serialize
from .utils import classify, get_regexp_width, Serialize, logger
from .exceptions import UnexpectedCharacters, LexError, UnexpectedToken
from .grammar import TOKEN_DEFAULT_PRIORITY


###{standalone
from copy import copy

try: # For the standalone parser, we need to make sure that has_interegular is False to avoid NameErrors later on
has_interegular = bool(interegular)
except NameError:
has_interegular = False

class Pattern(Serialize, ABC):

Expand All @@ -27,7 +36,7 @@ class Pattern(Serialize, ABC):
raw: Optional[str]
type: ClassVar[str]

def __init__(self, value: str, flags: Collection[str]=(), raw: Optional[str]=None) -> None:
def __init__(self, value: str, flags: Collection[str] = (), raw: Optional[str] = None) -> None:
self.value = value
self.flags = frozenset(flags)
self.raw = raw
Expand Down Expand Up @@ -110,7 +119,7 @@ class TerminalDef(Serialize):
pattern: Pattern
priority: int

def __init__(self, name: str, pattern: Pattern, priority: int=TOKEN_DEFAULT_PRIORITY) -> None:
def __init__(self, name: str, pattern: Pattern, priority: int = TOKEN_DEFAULT_PRIORITY) -> None:
assert isinstance(pattern, Pattern), pattern
self.name = name
self.pattern = pattern
Expand All @@ -120,7 +129,7 @@ def __repr__(self):
return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern)

def user_repr(self) -> str:
if self.name.startswith('__'): # We represent a generated terminal
if self.name.startswith('__'): # We represent a generated terminal
return self.pattern.raw or self.name
else:
return self.name
Expand Down Expand Up @@ -162,29 +171,29 @@ class Token(str):

@overload
def __new__(
cls,
type: str,
value: Any,
start_pos: Optional[int]=None,
line: Optional[int]=None,
column: Optional[int]=None,
end_line: Optional[int]=None,
end_column: Optional[int]=None,
end_pos: Optional[int]=None
cls,
type: str,
value: Any,
start_pos: Optional[int] = None,
line: Optional[int] = None,
column: Optional[int] = None,
end_line: Optional[int] = None,
end_column: Optional[int] = None,
end_pos: Optional[int] = None
) -> 'Token':
...

@overload
def __new__(
cls,
type_: str,
value: Any,
start_pos: Optional[int]=None,
line: Optional[int]=None,
column: Optional[int]=None,
end_line: Optional[int]=None,
end_column: Optional[int]=None,
end_pos: Optional[int]=None
cls,
type_: str,
value: Any,
start_pos: Optional[int] = None,
line: Optional[int] = None,
column: Optional[int] = None,
end_line: Optional[int] = None,
end_column: Optional[int] = None,
end_pos: Optional[int] = None
) -> 'Token': ...

def __new__(cls, *args, **kwargs):
Expand Down Expand Up @@ -213,11 +222,11 @@ def _future_new(cls, type, value, start_pos=None, line=None, column=None, end_li
return inst

@overload
def update(self, type: Optional[str]=None, value: Optional[Any]=None) -> 'Token':
def update(self, type: Optional[str] = None, value: Optional[Any] = None) -> 'Token':
...

@overload
def update(self, type_: Optional[str]=None, value: Optional[Any]=None) -> 'Token':
def update(self, type_: Optional[str] = None, value: Optional[Any] = None) -> 'Token':
...

def update(self, *args, **kwargs):
Expand All @@ -230,7 +239,7 @@ def update(self, *args, **kwargs):

return self._future_update(*args, **kwargs)

def _future_update(self, type: Optional[str]=None, value: Optional[Any]=None) -> 'Token':
def _future_update(self, type: Optional[str] = None, value: Optional[Any] = None) -> 'Token':
return Token.new_borrow_pos(
type if type is not None else self.type,
value if value is not None else self.value,
Expand Down Expand Up @@ -364,7 +373,7 @@ def _build_mres(self, terminals, max_size):
try:
mre = self.re_.compile(pattern, self.g_regex_flags)
except AssertionError: # Yes, this is what Python provides us.. :/
return self._build_mres(terminals, max_size//2)
return self._build_mres(terminals, max_size // 2)

mres.append(mre)
terminals = terminals[max_size:]
Expand Down Expand Up @@ -457,26 +466,45 @@ class BasicLexer(Lexer):
callback: Dict[str, _Callback]
re: ModuleType

def __init__(self, conf: 'LexerConf') -> None:
def __init__(self, conf: 'LexerConf', comparator=None) -> None:
terminals = list(conf.terminals)
assert all(isinstance(t, TerminalDef) for t in terminals), terminals

self.re = conf.re_module

if not conf.skip_validation:
# Sanitization
terminal_to_regexp = {}
for t in terminals:
regexp = t.pattern.to_regexp()
try:
self.re.compile(t.pattern.to_regexp(), conf.g_regex_flags)
self.re.compile(regexp, conf.g_regex_flags)
except self.re.error:
raise LexError("Cannot compile token %s: %s" % (t.name, t.pattern))

if t.pattern.min_width == 0:
raise LexError("Lexer does not allow zero-width terminals. (%s: %s)" % (t.name, t.pattern))
if t.pattern.type == "re":
terminal_to_regexp[t] = regexp

if not (set(conf.ignore) <= {t.name for t in terminals}):
raise LexError("Ignore terminals are not defined: %s" % (set(conf.ignore) - {t.name for t in terminals}))

if has_interegular:
if not comparator:
comparator = interegular.Comparator.from_regexes(terminal_to_regexp)
for group in classify(terminal_to_regexp, lambda t: t.priority).values():
for a, b in comparator.check(group, skip_marked=True):
assert a.priority == b.priority
# Mark this pair to not repeat warnings when multiple different BasicLexers see the same collision
comparator.mark(a, b)
erezsh marked this conversation as resolved.
Show resolved Hide resolved

# leave it as a warning for the moment
# raise LexError("Collision between Terminals %s and %s" % (a.name, b.name))
example = comparator.get_example_overlap(a, b).format_multiline()
logger.warning(f"Collision between Terminals {a.name} and {b.name}. "
f"The lexer will choose between them arbitrarily\n" + example)

# Init
self.newline_types = frozenset(t.name for t in terminals if _regexp_has_newline(t.pattern.to_regexp()))
self.ignore_types = frozenset(conf.ignore)
Expand Down Expand Up @@ -517,7 +545,7 @@ def lex(self, state: LexerState, parser_state: Any) -> Iterator[Token]:
while True:
yield self.next_token(state, parser_state)

def next_token(self, lex_state: LexerState, parser_state: Any=None) -> Token:
def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
line_ctr = lex_state.line_ctr
while line_ctr.char_pos < len(lex_state.text):
res = self.match(lex_state.text, line_ctr.char_pos)
Expand Down Expand Up @@ -565,6 +593,10 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always
trad_conf = copy(conf)
trad_conf.terminals = terminals

if has_interegular and not conf.skip_validation:
comparator = interegular.Comparator.from_regexes({t: t.pattern.to_regexp() for t in terminals})
else:
comparator = None
lexer_by_tokens: Dict[FrozenSet[str], BasicLexer] = {}
self.lexers = {}
for state, accepts in states.items():
Expand All @@ -575,13 +607,14 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always
accepts = set(accepts) | set(conf.ignore) | set(always_accept)
lexer_conf = copy(trad_conf)
lexer_conf.terminals = [terminals_by_name[n] for n in accepts if n in terminals_by_name]
lexer = BasicLexer(lexer_conf)
lexer = BasicLexer(lexer_conf, comparator)
lexer_by_tokens[key] = lexer

self.lexers[state] = lexer

assert trad_conf.terminals is terminals
self.root_lexer = BasicLexer(trad_conf)
trad_conf.skip_validation = True # We don't need to verify all terminals again
self.root_lexer = BasicLexer(trad_conf, comparator)

def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]:
try:
Expand Down
2 changes: 1 addition & 1 deletion lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
'_RBRA': r'\]',
'_LBRACE': r'\{',
'_RBRACE': r'\}',
'OP': '[+*]|[?](?![a-z])',
'OP': '[+*]|[?](?![a-z_])',
'_COLON': ':',
'_COMMA': ',',
'_OR': r'\|',
Expand Down
7 changes: 7 additions & 0 deletions lark/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import warnings

from lark import Lark, logger
try:
from interegular import logger as interegular_logger
has_interegular = True
except ImportError:
has_interegular = False

lalr_argparser = ArgumentParser(add_help=False, epilog='Look at the Lark documentation for more info on the options')

Expand Down Expand Up @@ -40,6 +45,8 @@

def build_lalr(namespace):
logger.setLevel((ERROR, WARN, INFO, DEBUG)[min(namespace.verbose, 3)])
if has_interegular:
interegular_logger.setLevel(logger.getEffectiveLevel())
if len(namespace.start) == 0:
namespace.start.append('start')
kwargs = {n: getattr(namespace, n) for n in options}
Expand Down
4 changes: 2 additions & 2 deletions lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import reduce
from itertools import product
from collections import deque
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable

###{standalone
import sys, re
Expand All @@ -21,7 +21,7 @@
T = TypeVar("T")


def classify(seq: Sequence, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
d: Dict[Any, Any] = {}
for item in seq:
k = key(item) if (key is not None) else item
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"regex": ["regex"],
"nearley": ["js2py"],
"atomic_cache": ["atomicwrites"],
"interegular": ["interegular>=0.2.4"],
},

package_data = {'': ['*.md', '*.lark'], 'lark': ['py.typed']},
Expand Down
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
interegular>=0.2.4
Js2Py==0.68
regex
41 changes: 39 additions & 2 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
from contextlib import contextmanager
from lark import Lark, logger
from unittest import TestCase, main
from unittest import TestCase, main, skipIf

try:
from StringIO import StringIO
except ImportError:
from io import StringIO

try:
import interegular
except ImportError:
interegular = None

@contextmanager
def capture_log():
stream = StringIO()
Expand Down Expand Up @@ -46,7 +51,7 @@ def test_non_debug(self):
Lark(collision_grammar, parser='lalr', debug=False)
log = log.getvalue()
# no log message
self.assertEqual(len(log), 0)
self.assertEqual(log, "")

def test_loglevel_higher(self):
logger.setLevel(logging.ERROR)
Expand All @@ -61,5 +66,37 @@ def test_loglevel_higher(self):
# no log message
self.assertEqual(len(log), 0)

@skipIf(interegular is None, "interegular is not installed, can't test regex collisions")
def test_regex_collision(self):
logger.setLevel(logging.WARNING)
collision_grammar = '''
start: A | B
A: /a+/
B: /(a|b)+/
'''
with capture_log() as log:
Lark(collision_grammar, parser='lalr')

log = log.getvalue()
# since there are conflicts between A and B
# symbols A and B should appear in the log message
self.assertIn("A", log)
self.assertIn("B", log)

@skipIf(interegular is None, "interegular is not installed, can't test regex collisions")
def test_regex_no_collision(self):
logger.setLevel(logging.WARNING)
collision_grammar = '''
start: A " " B
A: /a+/
B: /(a|b)+/
'''
with capture_log() as log:
Lark(collision_grammar, parser='lalr')

log = log.getvalue()
self.assertEqual(log, "")


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ passenv =
# to always force recreation and avoid unexpected side effects
recreate = True

# Require since the commands use `git`
allowlist_externals = git

commands =
git submodule sync -q
git submodule update --init
Expand All @@ -23,12 +26,14 @@ skip_install = true
recreate = false
deps =
mypy==0.950
interegular>=0.2.4
types-atomicwrites
types-regex
rich
commands =
mypy


[testenv:lint]
description = run linters on code base
skip_install = true
Expand Down