Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Fix failed tests #398

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from lark import Lark
from outlines import grammars

class BaseLogitsProcessor:

Expand All @@ -44,6 +45,23 @@ def __call__(self, input_ids: List[int],
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)
else:
# Note: this is a hack.
# Lark pickling does not work properly (silent failure),
# which breaks the RPC (which uses python pickleing).
# We need to find a better solution.
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark(
self._guide.cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
Expand Down
Loading