diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 511086dcf..ad5ba9b65 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,11 +1,12 @@ import math -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch from numpy.typing import NDArray from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM +from outlines.models.tokenizer import Tokenizer class LlamaCpp: @@ -89,20 +90,31 @@ def stream( ) -class LlamaCppTokenizer: +class LlamaCppTokenizer(Tokenizer): def __init__(self, model, **kwargs): + self.tokenizer = model.model.tokenizer() + self.eos_token_id = model.model.token_eos() + self.eos_token = self.tokenizer.decode([self.eos_token_id]) self.pad_token_id = self.eos_token_id self.special_tokens = {} self.vocabulary = {} for t in range(model.model.n_vocab()): - token_piece = model.model.tokenizer().decode([t]) + token_piece = self.tokenizer.decode([t]) self.vocabulary[token_piece] = t def convert_token_to_string(self, token: str) -> str: return token + def encode( + self, prompt: Union[str, List[str]] + ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: + return self.tokenizer.encode(prompt) + + def decode(self, token_ids: NDArray[np.int64]) -> List[str]: + return self.tokenizer.decode(token_ids) + def llamacpp( model_name: str, diff --git a/outlines/models/tokenizer.py b/outlines/models/tokenizer.py index 72bdae0fe..16259af38 100644 --- a/outlines/models/tokenizer.py +++ b/outlines/models/tokenizer.py @@ -1,17 +1,21 @@ +import hashlib +import json from abc import abstractmethod -from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union +from typing import Dict, List, Optional, Protocol, Set, Tuple, Union import numpy as np from numpy.typing import NDArray -class Tokenizer(Protocol, Hashable): +class Tokenizer(Protocol): eos_token: str eos_token_id: int pad_token_id: int vocabulary: Dict[str, int] special_tokens: Set[int] + _hash: Optional[int] = None + @abstractmethod def encode( self, prompt: Union[str, List[str]] @@ -34,3 +38,24 @@ def convert_token_to_string(self, token: str) -> str: """ ... + + def __eq__(self, other): + return hash(self) == hash(other) + + def __hash__(self): + if self._hash is None: + result = hashlib.md5() + result.update( + json.dumps( + [ + self.eos_token, + self.eos_token_id, + self.pad_token_id, + self.vocabulary, + sorted(self.special_tokens), + ], + sort_keys=True, + ).encode("utf-8") + ) + self._hash = hash(result.hexdigest()) + return self._hash diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 9e333bba0..332e7c4a9 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -169,16 +169,6 @@ def convert_token_to_string(self, token: str) -> str: return string - def __eq__(self, other): - if isinstance(other, type(self)): - return other.model_name == self.model_name and other.kwargs == self.kwargs - return NotImplemented - - def __hash__(self): - from datasets.fingerprint import Hasher - - return hash(Hasher.hash(self.tokenizer)) - def transformers( model_name: str, diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/outlines/processors/base.py b/outlines/processors/base.py new file mode 100644 index 000000000..2c3b42939 --- /dev/null +++ b/outlines/processors/base.py @@ -0,0 +1,50 @@ +from abc import abstractmethod +from typing import List, Protocol, Union + +import numpy as np +import torch +from numpy.typing import NDArray + +from outlines.models.tokenizer import Tokenizer + + +class LogitsProcessor(Protocol): + tokenizer: Tokenizer + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __copy__(self): + return self + + @abstractmethod + def process_logits( + self, input_ids: List[int], logits: torch.Tensor + ) -> torch.Tensor: + ... + + def __call__( + self, + input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], + logits: Union[NDArray[np.float32], torch.Tensor], + ) -> Union[np.ndarray, torch.Tensor]: + """ + Apply logits processor + + Unify type + - convert input_ids: either ndarray, List[int], or Tensor -> List[int] + - convert logits: either ndarray or Tensor -> Tensor + + Call process_logits() to perform business logic + """ + if not isinstance(input_ids, list): + input_ids = input_ids.tolist() + + if isinstance(logits, np.ndarray): + # Unify type, convert numpy array to Tensor + # from_numpy and .numpy() don't copy the data, it uses the same memory address + torch_logits = torch.from_numpy(logits) + processed_torch_logits = self.process_logits(input_ids, torch_logits) + return processed_torch_logits.detach().numpy() + else: + return self.process_logits(input_ids, logits) diff --git a/outlines/processors/fsm_logits_processor.py b/outlines/processors/fsm_logits_processor.py new file mode 100644 index 000000000..8f72ffc73 --- /dev/null +++ b/outlines/processors/fsm_logits_processor.py @@ -0,0 +1,122 @@ +import collections +import json +import math +from typing import DefaultDict, Dict, List, Optional, Protocol + +import torch +from pydantic import BaseModel + +from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.models.tokenizer import Tokenizer +from outlines.processors.base import LogitsProcessor + + +class FSMLogitsProcessor(LogitsProcessor, Protocol): + """ + Base class for processing logits with an automaton, either FSM or CFGFSM + FSMLogitsProcessors are stateful and for ONE TIME USE + """ + + tokenizer: Tokenizer + fsm: FSM + fsm_state: DefaultDict[int, FSMState] + + def __init__(self, fsm: FSM, tokenizer: Tokenizer): + self.fsm = fsm + self.fsm_state: DefaultDict = collections.defaultdict(int) + super().__init__(tokenizer) + + def __copy__(self): + return self.__class__(self.tokenizer, self.fsm.copy()) + + def process_logits( + self, input_ids: List[int], logits: torch.Tensor + ) -> torch.Tensor: + seq_id = hash(tuple(input_ids)) + + # set initial state as 0 + # handles case where input_ids passed include prompt tokens + if not self.fsm_state: + self.fsm_state[seq_id] = FSMState(0) + + else: + # apply state transitions + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[last_seq_id], last_token + ) + + allowed_token_ids = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + + # bias logits with mask + mask = torch.full(logits.shape[-1:], -math.inf, device=logits.device) + mask[allowed_token_ids] = 0 + + return logits + mask + + +class RegexLogitsProcessor(FSMLogitsProcessor): + def __init__(self, regex_string: str, tokenizer: Tokenizer): + """Compile the FSM that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + An outlines compatible Tokenizer + """ + fsm = RegexFSM(regex_string, tokenizer) + super().__init__(fsm, tokenizer) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + def __init__( + self, + schema: Dict, + tokenizer: Tokenizer, + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate + tokenizer + An outlines compatible Tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification" + ) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer) + + +class CFGLogitsProcessor(FSMLogitsProcessor): + def __init__(self, cfg_str: str, tokenizer: Tokenizer): + """Compile the FSM that drives the CFG-guided generation. + + Parameters + ---------- + cfg_str + A string that represents a grammar + tokenizer + An outlines compatible Tokenizer + """ + fsm = CFGFSM(cfg_str, tokenizer) + super().__init__(fsm, tokenizer) diff --git a/outlines/processors/vllm.py b/outlines/processors/vllm.py new file mode 100644 index 000000000..161c71f7d --- /dev/null +++ b/outlines/processors/vllm.py @@ -0,0 +1,42 @@ +from outlines.processors import fsm_logits_processor + + +class VLLMTokenizerAdapter: + @staticmethod + def _vllm_adapt_tokenizer(llm): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + """ + tokenizer = llm.tokenizer.tokenizer + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class RegexLogitsProcessor(fsm_logits_processor.RegexLogitsProcessor): + _adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer + + +class JSONLogitsProcessor(fsm_logits_processor.JSONLogitsProcessor): + _adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer + + +class CFGLogitsProcessor(fsm_logits_processor.JSONLogitsProcessor): + _adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer