-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve outlines.processors
, add integration tests to test_generate.py
#998
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import mlx.core as mx | ||
import numpy as np | ||
import torch | ||
|
||
from outlines.processors import OutlinesLogitsProcessor | ||
|
||
|
||
def is_mlx_lm_allowed(): | ||
try: | ||
import mlx.core as mx | ||
except ImportError: | ||
return False | ||
return mx.metal.is_available() | ||
|
||
|
||
class HalvingLogitsProcessor(OutlinesLogitsProcessor): | ||
"""Simply halve the passed logits""" | ||
|
||
def process_logits(self, input_ids, logits): | ||
return logits / 2 | ||
|
||
|
||
class LogitsProcessorBenchmark: | ||
params = ["torch", "numpy"] | ||
if mx.metal.is_available(): | ||
params += ["mlx"] | ||
|
||
def setup(self, array_library): | ||
self.logits_processor = HalvingLogitsProcessor() | ||
|
||
# logits: (4, 30,000 ) dtype=float | ||
# input_ids shape: (4, 2048) dtype=int | ||
if array_library == "torch": | ||
self.logits = torch.rand((4, 30000), dtype=torch.float) | ||
self.input_ids = torch.randint( | ||
low=0, high=30000, size=(4, 2048), dtype=torch.int | ||
) | ||
elif array_library == "numpy": | ||
self.logits = np.random.rand(4, 30000).astype(np.float32) | ||
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048)) | ||
elif array_library == "mlx": | ||
self.logits = mx.random.uniform( | ||
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32 | ||
) | ||
self.input_ids = mx.random.randint( | ||
low=0, high=30000, shape=(4, 2048), dtype=mx.int32 | ||
) | ||
else: | ||
raise ValueError | ||
|
||
def time_logits_processor(self, array_library): | ||
self.logits_processor(self.input_ids, self.logits) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from .structured import ( | ||
BaseLogitsProcessor, | ||
CFGLogitsProcessor, | ||
FSMLogitsProcessor, | ||
JSONLogitsProcessor, | ||
OutlinesLogitsProcessor, | ||
RegexLogitsProcessor, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,30 @@ | ||
from abc import abstractmethod | ||
from typing import List, Protocol, Union | ||
from typing import TYPE_CHECKING, List, Protocol, Type, Union | ||
|
||
import numpy as np | ||
import torch | ||
from numpy.typing import NDArray | ||
|
||
if TYPE_CHECKING: | ||
import mlx.core as mx | ||
|
||
def is_mlx_array(logits): | ||
|
||
Array = Union[NDArray, torch.Tensor, List, "mx.array"] | ||
|
||
|
||
def is_mlx_array_type(array_type): | ||
try: | ||
import mlx.core as mx | ||
except ImportError: | ||
return False | ||
return isinstance(logits, mx.array) | ||
return issubclass(array_type, mx.array) | ||
|
||
|
||
class BaseLogitsProcessor(Protocol): | ||
class OutlinesLogitsProcessor(Protocol): | ||
""" | ||
Base class for logits processors which normalizes types of logits: | ||
- ndarray (used by llama-cpp-python), converted to torch.Tensor | ||
- mlx.core.array (used by mlx-lm), converted to torch.Tensor | ||
- torch.Tensor (used by everything else) | ||
|
||
Normalization of types and conversion to torch.Tensor | ||
|
@@ -29,50 +36,100 @@ class BaseLogitsProcessor(Protocol): | |
|
||
@abstractmethod | ||
def process_logits( | ||
self, input_ids: List[int], logits: torch.Tensor | ||
self, input_ids: List[List[int]], logits: torch.Tensor | ||
) -> torch.Tensor: | ||
... | ||
""" | ||
input_ids and logits are always 2D tensors for handling a batch of sequences. | ||
|
||
- input_ids -> List[List[tokens]] | ||
- logits.shape[0] -> 2D_Tensor[logits] | ||
|
||
Important to keep in mind when designing universal logits processors | ||
- logits processors are only used once and never re-applied for a new sequence generator | ||
- Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids | ||
- Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM | ||
""" | ||
pass | ||
|
||
@torch.no_grad() | ||
def __call__( | ||
self, | ||
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], | ||
logits: Union[NDArray[np.float32], torch.Tensor], | ||
) -> Union[NDArray[np.int64], torch.Tensor]: | ||
input_ids: Array, | ||
logits: Array, | ||
) -> Array: | ||
""" | ||
Apply logits processor | ||
Unify type | ||
- convert input_ids: either ndarray, List[int], or Tensor -> List[int] | ||
- convert logits: either ndarray, mlx array, Tensor -> Tensor | ||
Call process_logits() to perform business logic | ||
|
||
1) Unify type | ||
- convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]] | ||
- convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor | ||
2) Unify shape, ensure logits and input_ids are 2D | ||
3) Call self.process_logits() to perform business logic | ||
4) Cast logits back to original array library type | ||
""" | ||
with torch.no_grad(): | ||
if not isinstance(input_ids, list): | ||
input_ids = input_ids.tolist() | ||
|
||
if isinstance(logits, np.ndarray): | ||
# Unify type, convert numpy array to Tensor | ||
# from_numpy and .numpy() don't copy the data, it uses the same memory address | ||
torch_logits = torch.from_numpy(logits) | ||
processed_torch_logits = self.process_logits(input_ids, torch_logits) | ||
return processed_torch_logits.detach().numpy() | ||
|
||
elif isinstance(logits, torch.Tensor): | ||
return self.process_logits(input_ids, logits) | ||
|
||
elif is_mlx_array(logits): | ||
# mlx -> torch -> mlx conversion docs: | ||
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html | ||
import mlx.core as mx | ||
|
||
torch_logits = torch.from_dlpack(logits) | ||
processed_torch_logits = self.process_logits(input_ids, torch_logits) | ||
|
||
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch | ||
logits_float32_numpy = processed_torch_logits.float().numpy() | ||
return mx.array(logits_float32_numpy) | ||
|
||
else: | ||
raise TypeError( | ||
"LogitsProcessor must be called with either np.NDArray" | ||
", torch.Tensor, or mlx.core.array typed logits" | ||
) | ||
|
||
# ensure logits are torch Tensors | ||
torch_logits = self._to_torch(logits) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the cost of this conversion? We need to profile this. I would personally handle this by dispatching the core logic used to process the logits, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is that we can cast Then we can implement a single This makes the processors easy to maintain and implement. |
||
|
||
assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1] | ||
|
||
# ensure input_ids are List | ||
if not isinstance(input_ids, list): | ||
input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx | ||
|
||
# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape | ||
if len(torch_logits.shape) == 2: | ||
processed_logits = self.process_logits(input_ids, torch_logits) | ||
elif len(torch_logits.shape) == 1: | ||
processed_logits = self.process_logits( | ||
[input_ids], torch_logits.unsqueeze(0) | ||
).squeeze(0) | ||
|
||
# return logits as passed array type | ||
return self._from_torch(processed_logits, type(logits)) | ||
|
||
@staticmethod | ||
def _to_torch(tensor_like: Array) -> torch.Tensor: | ||
"""Convert various types to torch.Tensor.""" | ||
if isinstance(tensor_like, torch.Tensor): | ||
return tensor_like | ||
|
||
elif isinstance(tensor_like, np.ndarray): | ||
return torch.from_numpy(tensor_like) | ||
|
||
elif isinstance(tensor_like, list): | ||
return torch.tensor(tensor_like) | ||
|
||
elif is_mlx_array_type(type(tensor_like)): | ||
# mlx -> torch -> mlx conversion docs: | ||
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html | ||
return torch.from_dlpack(tensor_like) | ||
|
||
else: | ||
raise TypeError( | ||
"LogitsProcessor must be called with either np.NDArray, " | ||
"torch.Tensor, list, or mlx.core.array typed logits" | ||
) | ||
|
||
@staticmethod | ||
def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: | ||
"""Convert torch.Tensor to the specified target type.""" | ||
if target_type == torch.Tensor: | ||
return tensor | ||
|
||
elif target_type == np.ndarray: | ||
return tensor.detach().numpy() | ||
|
||
elif target_type == list: | ||
return tensor.detach().tolist() | ||
|
||
elif is_mlx_array_type(target_type): | ||
import mlx.core as mx | ||
|
||
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch | ||
return mx.array(tensor.float().numpy()) | ||
|
||
else: | ||
raise TypeError( | ||
f"Failed to convert torch tensors to target_type `{target_type}`" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,24 +24,22 @@ | |
limitations under the License. | ||
""" | ||
import math | ||
from typing import TYPE_CHECKING, List, Optional, Type, Union | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union | ||
|
||
import numpy as np | ||
import torch | ||
from numpy.typing import NDArray | ||
from pydantic import BaseModel | ||
|
||
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide | ||
from outlines.fsm.json_schema import build_regex_from_schema | ||
from outlines.integrations.utils import convert_json_schema_to_str | ||
|
||
from .base_logits_processor import BaseLogitsProcessor | ||
from .base_logits_processor import OutlinesLogitsProcessor | ||
|
||
if TYPE_CHECKING: | ||
from outlines.models.tokenizer import Tokenizer | ||
|
||
|
||
class FSMLogitsProcessor(BaseLogitsProcessor): | ||
class FSMLogitsProcessor(OutlinesLogitsProcessor): | ||
"""Bias generation using a finite state machine. | ||
|
||
Attributes | ||
|
@@ -63,13 +61,14 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): | |
The finite state machine which is used to bias the logits. | ||
""" | ||
self.tokenizer = tokenizer | ||
self._fsm_state = 0 | ||
self._fsm_states: Dict[int, int] = {} | ||
self.fsm: Guide = fsm | ||
self._is_first_token = True | ||
self._seq_start_idx: Optional[int] = None | ||
|
||
def process_logits( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cf above remark on dispatching the function depending on the type of the tensors. |
||
self, input_ids: List[int], logits: torch.Tensor | ||
) -> NDArray[np.float32]: | ||
self, input_ids: List[List[int]], logits: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Use the FSM to bias the logits before sampling the next token. | ||
|
||
Parameters | ||
|
@@ -84,17 +83,31 @@ def process_logits( | |
torch.Tensor | ||
The biased logits. | ||
""" | ||
sequence_states: List[int] = [] # vector of states corresponding to `input_ids` | ||
|
||
if self._is_first_token: | ||
self._is_first_token = False | ||
self._seq_start_idx = len(input_ids[0]) | ||
|
||
self._fsm_states = {hash(tuple([])): 0} | ||
sequence_states = [0] * len(input_ids) | ||
|
||
else: | ||
last_token = input_ids[-1] | ||
self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) | ||
for seq_ids in input_ids: | ||
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1])) | ||
prev_state = self._fsm_states[prev_state_key] | ||
|
||
allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens | ||
allowed_tokens = torch.tensor(allowed_tokens, device=logits.device) | ||
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :])) | ||
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1]) | ||
|
||
self._fsm_states[curr_state_key] = curr_state | ||
sequence_states.append(curr_state) | ||
|
||
mask = torch.full_like(logits, -math.inf) | ||
mask[allowed_tokens] = logits[allowed_tokens] | ||
for i, fsm_state in enumerate(sequence_states): | ||
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens | ||
mask[i, allowed_tokens] = logits[i, allowed_tokens] | ||
|
||
return mask | ||
|
||
def copy(self) -> "FSMLogitsProcessor": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not call it
LogitsProcessor
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, do we really need that class if all logit processors are ever going to be
FSMLogitsProcessors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking here was if we want a new logits processor which isn't for structured generation, e.g.
RepetitionPenaltyLogitsProcessor
. We can easily implement this processor once and use it in everyoutlines.models
without any changes specific to the model.This provides us the flexiiblity to have the inference library handle the decoder pass and sampling, and
outlines
can handle all logits augmentation.Another example:
mlxlm
doesn't supportstop_strings
andstop_strings
is broken in the latesttransformers
version (4.41.2
). We could implement a singleStopStringsLogitsProcessor
and make it available to all inference libraries. Not arguing that implementing this specific logits processor should be a priority, but I am arguing that having a separate base class which doesn't demand an FSM provides us flexibility and opportunities.