Skip to content

Commit

Permalink
Adds support for custom regex parsers (for multimodal structured gene…
Browse files Browse the repository at this point in the history
…ration) (#1039)

As [discussed in our Discord
server](https://discord.com/channels/1182316225284554793/1182317446225481788/1261998326077984802)

This PR adds support for custom regex parsers. This doesn't change the
behavior of Outlines by default. But this allows us to write custom
`Guide` classes that uses custom regex parsers for e.g. multimodal
generation.

Also improves documentation
  • Loading branch information
leloykun authored Jul 14, 2024
1 parent a48f86f commit 62b7601
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 24 deletions.
54 changes: 46 additions & 8 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)

import interegular
import torch
Expand Down Expand Up @@ -108,16 +118,44 @@ def copy(self):

@cache()
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
regex_string: str,
tokenizer: "Tokenizer",
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
The parameters of the function are used for caching purpose.
Parameters
----------
regex_string: (`str`):
The regular expression string to generate a states mapping for.
tokenizer: (`Tokenizer`):
The model's tokenizer.
regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*):
A function that parses a regex string into an `interegular` Pattern object.
frozen_tokens: (`List[str]`, *optional*):
A list of tokens that should be kept as-is when expanding the token-level FSM
into a byte-level FSM. Defaults to an empty list.
Returns
-------
states_to_token_maps: (`Dict[int, Dict[int, int]]`):
A mapping from states to a mapping from token ids originating from that state
to the next state to transition to given that token. The structure is as follows:
(origin_state -> (token_id -> next_state))
empty_token_ids: (`Set[int]`):
A set of token ids that correspond to empty strings.
final_states: (`set`):
A set of final states in the FSM.
"""
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_pattern = regex_parser(regex_string)
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True, frozen_tokens=frozen_tokens
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
regex_fsm, tokenizer, frozen_tokens=frozen_tokens
)

# We make sure that it is possible to generate strings in the language
Expand All @@ -138,7 +176,7 @@ class RegexGuide(Guide):

initial_state = 0

def __init__(self, regex_string: str, tokenizer):
def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
(
self.states_to_token_maps,
self.empty_token_ids,
Expand Down
112 changes: 98 additions & 14 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,27 @@ def byte_symbol(byte: int) -> str:
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
def make_byte_level_fsm(
fsm: FSM, keep_utf8: bool = False, frozen_tokens: List[str] = []
) -> FSM:
"""Convert an FSM to a byte-level FSM, expanding multi-byte characters as
sequences of single-byte transitions. If keep_utf8 is set, the original
utf-8 characters are kept in the alphabet.
NOTE: we're representing bytes as strings to keep it type-compatible.
sequences of single-byte transitions.
Parameters
----------
fsm: (`interegular.FSM`):
The token-level FSM to convert to a byte-level FSM.
keep_utf8: (`bool`, *optional*):
If set to True, the original utf-8 characters are kept as-is. Defaults to
False. NOTE: we're representing bytes as strings to keep it type-compatible.
frozen_tokens: (`List[str]`, *optional*):
A list of tokens that should be kept as-is in the byte-level FSM. That is,
these tokens will not be expanded into byte-level transitions. Defaults to
an empty list.
Returns
-------
`interegular.FSM`: A byte-level FSM.
"""

anything_else_key = fsm.alphabet[anything_else]
Expand All @@ -218,8 +234,8 @@ def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
all_bytes: Set[int] = set()
max_key = max(fsm.alphabet.values())
for symbol, transition_key in fsm.alphabet.items():
assert symbol == anything_else or len(symbol) == 1
if symbol == anything_else or ord(symbol) < 0x80:
assert symbol == anything_else or symbol in frozen_tokens or len(symbol) == 1
if symbol == anything_else or symbol in frozen_tokens or ord(symbol) < 0x80:
symbol_mapping[symbol] = transition_key
else:
if keep_utf8:
Expand Down Expand Up @@ -714,15 +730,40 @@ def get_vocabulary_transition_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
frozen_tokens: List[str] = numba.typed.List.empty_list(numba.types.unicode_type),
) -> List[Sequence[int]]:
"""
Calculate the sequence transition keys for each token str within a vocabulary
Parameters
----------
alphabet_symbol_mapping: (`Dict[str, int]`):
A mapping from an alphabet symbol in a FSM to its corresponding transition key.
alphabet_anything_value: (`int`):
The transition key for the anything_else symbol in the FSM.
vocabulary: (`List[Tuple[str, Sequence[int]]]`):
A list of tuples, each containing a token and a list of equivalent token ids.
frozen_tokens: (`List[str]`, *optional*):
A list of tokens that are kept as-is when transforming the FSM.
Defaults to an empty list.
Returns
-------
`List[Sequence[int]]`:
A list of token transition keys for each token in the vocabulary.
"""
vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
# Since these tokens are not expanded into byte-level transitions, we can
# simply get their transition keys directly.
if token_str in frozen_tokens:
token_transition_keys = np.array(
[alphabet_symbol_mapping[token_str]], dtype=np.int64
)
else:
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
vocab_transition_keys.append(token_transition_keys)

return vocab_transition_keys
Expand All @@ -731,8 +772,26 @@ def get_vocabulary_transition_keys(
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, Sequence[int]]],
frozen_tokens: List[str] = [],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing.
Parameters
----------
fsm_info: (`interegular.FSMInfo`):
The FSM information object containing the FSM's alphabet, transitions, initial
and final states, and other relevant information.
vocabulary: (`List[Tuple[str, Sequence[int]]]`):
A list of tuples, each containing a token and a list of equivalent token ids.
frozen_tokens: (`List[str]`, *optional*):
A list of tokens that are kept as-is when transforming the FSM.
Returns
-------
`Dict[int, Set[Tuple[int, int]]]`:
A mapping from FSM states to sets of tuples containing token ids and the end
states of the FSM after parsing the token.
"""

# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
Expand All @@ -750,6 +809,11 @@ def create_fsm_index_end_to_end(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
frozen_tokens=(
numba.typed.List(frozen_tokens)
if len(frozen_tokens) > 0
else numba.typed.List.empty_list(numba.types.unicode_type)
),
)

while next_states:
Expand Down Expand Up @@ -883,21 +947,41 @@ def reduced_vocabulary(


def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
fsm: BetterFSM, tokenizer: "Tokenizer", frozen_tokens: List[str] = []
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
Parameters
----------
fsm: (`BetterFSM`):
A cache-friendly FSM. Other interegular FSMs can also be used, but caching
may not work as expected.
tokenizer: (`Tokenizer`):
The model's tokenizer.
frozen_tokens: (`List[str]`, *optional*):
A list of tokens that should be kept as-is when expanding the token-level
FSM into a byte-level FSM. Defaults to an empty list.
Returns
-------
states_to_token_maps: (`Dict[int, Dict[int, int]]`):
A mapping from states to a mapping from token ids originating from that state
to the next state to transition to given that token. The structure is as follows:
(origin_state -> (token_id -> next_state))
empty_token_ids: (`Set[int]`):
A set of token ids that correspond to empty strings.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
states_to_token_subsets = create_fsm_index_end_to_end(
fsm.fsm_info, vocabulary, frozen_tokens
)

# Allow transitions to EOS from all terminals FSM states that are
# reachable
Expand Down
2 changes: 1 addition & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()
self.special_tokens: Set[str] = set()

self.vocabulary: Dict[str, int] = dict()

Expand Down
2 changes: 1 addition & 1 deletion outlines/models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Tokenizer(Hashable, Protocol):
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]
special_tokens: Set[str]

def encode(
self, prompt: Union[str, List[str]]
Expand Down
2 changes: 2 additions & 0 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def convert_token_to_string(self, token):
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
vocabulary,
numba.typed.List.empty_list(numba.types.unicode_type),
)

token_str_to_tranition_keys = {
Expand Down Expand Up @@ -637,6 +638,7 @@ def convert_token_to_string(self, token):
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
vocabulary,
numba.typed.List.empty_list(numba.types.unicode_type),
)

token_str_trans_key_seq = {
Expand Down

0 comments on commit 62b7601

Please sign in to comment.