Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Logits Processors, Ensure Tokenizers Have Identical Interfaces #676

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from numpy.typing import NDArray

from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM
from outlines.models.tokenizer import Tokenizer


class LlamaCpp:
Expand Down Expand Up @@ -89,20 +90,31 @@ def stream(
)


class LlamaCppTokenizer:
class LlamaCppTokenizer(Tokenizer):
def __init__(self, model, **kwargs):
self.tokenizer = model.model.tokenizer()

self.eos_token_id = model.model.token_eos()
self.eos_token = self.tokenizer.decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens = {}

self.vocabulary = {}
for t in range(model.model.n_vocab()):
token_piece = model.model.tokenizer().decode([t])
token_piece = self.tokenizer.decode([t])
self.vocabulary[token_piece] = t

def convert_token_to_string(self, token: str) -> str:
return token

def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
return self.tokenizer.encode(prompt)

def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
return self.tokenizer.decode(token_ids)


def llamacpp(
model_name: str,
Expand Down
29 changes: 27 additions & 2 deletions outlines/models/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import hashlib
import json
from abc import abstractmethod
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
from typing import Dict, List, Optional, Protocol, Set, Tuple, Union

import numpy as np
from numpy.typing import NDArray


class Tokenizer(Protocol, Hashable):
class Tokenizer(Protocol):
eos_token: str
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]

_hash: Optional[int] = None

@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
Expand All @@ -34,3 +38,24 @@ def convert_token_to_string(self, token: str) -> str:

"""
...

def __eq__(self, other):
return hash(self) == hash(other)

def __hash__(self):
if self._hash is None:
result = hashlib.md5()
result.update(
json.dumps(
[
self.eos_token,
self.eos_token_id,
self.pad_token_id,
self.vocabulary,
sorted(self.special_tokens),
],
sort_keys=True,
).encode("utf-8")
)
self._hash = hash(result.hexdigest())
return self._hash
10 changes: 0 additions & 10 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,6 @@ def convert_token_to_string(self, token: str) -> str:

return string

def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))


def transformers(
model_name: str,
Expand Down
Empty file added outlines/processors/__init__.py
Empty file.
50 changes: 50 additions & 0 deletions outlines/processors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import abstractmethod
from typing import List, Protocol, Union

import numpy as np
import torch
from numpy.typing import NDArray

from outlines.models.tokenizer import Tokenizer


class LogitsProcessor(Protocol):
tokenizer: Tokenizer

def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __copy__(self):
return self

@abstractmethod
def process_logits(
self, input_ids: List[int], logits: torch.Tensor
) -> torch.Tensor:
...

def __call__(
self,
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor],
logits: Union[NDArray[np.float32], torch.Tensor],
) -> Union[np.ndarray, torch.Tensor]:
"""
Apply logits processor

Unify type
- convert input_ids: either ndarray, List[int], or Tensor -> List[int]
- convert logits: either ndarray or Tensor -> Tensor

Call process_logits() to perform business logic
"""
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()

if isinstance(logits, np.ndarray):
# Unify type, convert numpy array to Tensor
# from_numpy and .numpy() don't copy the data, it uses the same memory address
torch_logits = torch.from_numpy(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
return processed_torch_logits.detach().numpy()
else:
return self.process_logits(input_ids, logits)
122 changes: 122 additions & 0 deletions outlines/processors/fsm_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import collections
import json
import math
from typing import DefaultDict, Dict, List, Optional, Protocol

import torch
from pydantic import BaseModel

from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.models.tokenizer import Tokenizer
from outlines.processors.base import LogitsProcessor


class FSMLogitsProcessor(LogitsProcessor, Protocol):
"""
Base class for processing logits with an automaton, either FSM or CFGFSM
FSMLogitsProcessors are stateful and for ONE TIME USE
"""

tokenizer: Tokenizer
fsm: FSM
fsm_state: DefaultDict[int, FSMState]

def __init__(self, fsm: FSM, tokenizer: Tokenizer):
self.fsm = fsm
self.fsm_state: DefaultDict = collections.defaultdict(int)
super().__init__(tokenizer)

def __copy__(self):
return self.__class__(self.tokenizer, self.fsm.copy())

def process_logits(
self, input_ids: List[int], logits: torch.Tensor
) -> torch.Tensor:
seq_id = hash(tuple(input_ids))

# set initial state as 0
# handles case where input_ids passed include prompt tokens
if not self.fsm_state:
self.fsm_state[seq_id] = FSMState(0)

else:
# apply state transitions
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token
)

allowed_token_ids = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

# bias logits with mask
mask = torch.full(logits.shape[-1:], -math.inf, device=logits.device)
mask[allowed_token_ids] = 0

return logits + mask


class RegexLogitsProcessor(FSMLogitsProcessor):
def __init__(self, regex_string: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the regex-guided generation.

Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
An outlines compatible Tokenizer
"""
fsm = RegexFSM(regex_string, tokenizer)
super().__init__(fsm, tokenizer)


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(
self,
schema: Dict,
tokenizer: Tokenizer,
whitespace_pattern: Optional[str] = None,
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
An outlines compatible Tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either "
+ "a Pydantic object, a dictionary or a string that contains the JSON "
+ "Schema specification"
)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(FSMLogitsProcessor):
def __init__(self, cfg_str: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the CFG-guided generation.

Parameters
----------
cfg_str
A string that represents a grammar
tokenizer
An outlines compatible Tokenizer
"""
fsm = CFGFSM(cfg_str, tokenizer)
super().__init__(fsm, tokenizer)
42 changes: 42 additions & 0 deletions outlines/processors/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from outlines.processors import fsm_logits_processor


class VLLMTokenizerAdapter:
@staticmethod
def _vllm_adapt_tokenizer(llm):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer = llm.tokenizer.tokenizer
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class RegexLogitsProcessor(fsm_logits_processor.RegexLogitsProcessor):
_adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer


class JSONLogitsProcessor(fsm_logits_processor.JSONLogitsProcessor):
_adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer


class CFGLogitsProcessor(fsm_logits_processor.JSONLogitsProcessor):
_adapt_tokenizer = VLLMTokenizerAdapter._vllm_adapt_tokenizer
Loading