diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index b5691359e..a57bfcaf4 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -1,7 +1,13 @@ -from typing import List, Optional, Tuple, Union +import itertools +import math +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union +import interegular import torch +from outlines.text.parsing import find_partial_matches + class Sequence: """Represents a sequence generation method.""" @@ -171,6 +177,98 @@ def update_token_ids( return new_token_ids + def find_boundary_tokens(self, prompt: str) -> Dict[int, List[int]]: + """Find a list of tokens that cross the prompt boundary.""" + + vocabulary = { + token_id: self.model.tokenizer.decode([token_id])[0] + for token_id in range(len(self.model.tokenizer.vocabulary)) + } + prompt_fsm = interegular.parse_pattern(prompt).to_fsm() + + prompt_token_ids, _ = self.model.tokenizer.encode(prompt) + prompt_tokens = self.model.tokenizer.decode(prompt_token_ids[0]) + + token_idx_in_prompt = [0] + list( + itertools.accumulate([len(t) for t in prompt_tokens]) + )[:-1] + + boundary_tokens = defaultdict(list) + for token_id, token in vocabulary.items(): + pmatches = find_partial_matches(prompt_fsm, token) + for pmatch in pmatches: + end_idx, states = pmatch + if end_idx is not None and states[-1] == len(prompt): + if states[0] in token_idx_in_prompt: + boundary_tokens[token_idx_in_prompt.index(states[0])].append( + token_id + ) + + return boundary_tokens + + def align_prompt_tokens( + self, prompt: Union[str, List[str]], rng: torch.Generator + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + """Align the prompt with the vocabulary.""" + + prompts = prompt + if isinstance(prompts, str): + prompts = [prompts] + + masks = [] + truncated_attention_masks = [] + truncated_token_idss = [] + attention_masks = [] + for prompt in prompts: + boundary_tokens = self.find_boundary_tokens(prompt) + + token_ids, attention_mask = self.model.tokenizer.encode(prompt) + token_ids = token_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + last_token = min(boundary_tokens.keys()) + truncated_token_ids = token_ids[:, :last_token] + truncated_attention_mask = attention_mask[:, :last_token] + + allowed_tokens = boundary_tokens[last_token] + mask = torch.full( + (len(self.model.tokenizer.vocabulary),), -math.inf, device=self.device + ) + mask[allowed_tokens] = 0 + + masks.append(mask) + truncated_attention_masks.append(truncated_attention_mask.squeeze()) + attention_masks.append(attention_mask.squeeze()) + truncated_token_idss.append(truncated_token_ids.squeeze()) + + # Pad left and stack + from torch.nn.utils.rnn import pad_sequence + + mask = torch.vstack(masks) + truncated_attention_mask = pad_sequence( + [t.flip(dims=[0]) for t in truncated_attention_masks], + batch_first=True, + padding_value=0, + ).flip(dims=[1]) + attention_mask = pad_sequence( + [a.flip(dims=[0]) for a in attention_masks], + batch_first=True, + padding_value=0, + ).flip(dims=[1]) + truncated_token_ids = pad_sequence( + [t.flip(dims=[0]) for t in truncated_token_idss], + batch_first=True, + padding_value=self.model.tokenizer.pad_token_id, + ).flip(dims=[1]) + + probs = self.model(truncated_token_ids, truncated_attention_mask) + probs = probs + mask + probs = torch.nn.functional.softmax(probs, dim=-1) + next_token_ids = torch.multinomial(probs, num_samples=1) + token_ids = torch.concatenate([truncated_token_ids, next_token_ids], axis=-1) + + return token_ids, attention_mask + @torch.inference_mode() def __call__( self, @@ -192,14 +290,11 @@ def __call__( The full sequence that contains the prompts and the generated string. """ - token_ids, attention_mask = self.model.tokenizer.encode(prompt) - - token_ids = token_ids.to(self.device) - attention_mask = attention_mask.to(self.device) - if rng is None: rng = torch.Generator(device=self.device) + token_ids, attention_mask = self.align_prompt_tokens(prompt, rng) + num_prompt_tokens = token_ids.shape[-1] if samples > 1: diff --git a/pyproject.toml b/pyproject.toml index d1bbffdb9..d57632d0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ module = [ "scipy.*", "tenacity.*", "tiktoken.*", - "torch", + "torch.*", "transformers.*", "lark.*", "regex.*",