Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve outlines.processors, add integration tests to test_generate.py #998

Merged
merged 1 commit into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import mlx.core as mx
import numpy as np
import torch

from outlines.processors import OutlinesLogitsProcessor


def is_mlx_lm_allowed():
try:
import mlx.core as mx
except ImportError:
return False
return mx.metal.is_available()


class HalvingLogitsProcessor(OutlinesLogitsProcessor):
"""Simply halve the passed logits"""

def process_logits(self, input_ids, logits):
return logits / 2


class LogitsProcessorBenchmark:
params = ["torch", "numpy"]
if mx.metal.is_available():
params += ["mlx"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()

# logits: (4, 30,000 ) dtype=float
# input_ids shape: (4, 2048) dtype=int
if array_library == "torch":
self.logits = torch.rand((4, 30000), dtype=torch.float)
self.input_ids = torch.randint(
low=0, high=30000, size=(4, 2048), dtype=torch.int
)
elif array_library == "numpy":
self.logits = np.random.rand(4, 30000).astype(np.float32)
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
elif array_library == "mlx":
self.logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
)
self.input_ids = mx.random.randint(
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

def time_logits_processor(self, array_library):
self.logits_processor(self.input_ids, self.logits)
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import outlines.generate
import outlines.grammars
import outlines.models
import outlines.processors
import outlines.types
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
Expand Down
6 changes: 3 additions & 3 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import BaseLogitsProcessor
from outlines.processors import OutlinesLogitsProcessor


class MLXLM:
Expand Down Expand Up @@ -127,7 +127,7 @@ def generate_step(
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "BaseLogitsProcessor",
logits_processor: "OutlinesLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
Expand All @@ -142,7 +142,7 @@ def generate_step(
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
"""
import mlx.core as mx
import mlx_lm
Expand Down
2 changes: 1 addition & 1 deletion outlines/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .structured import (
BaseLogitsProcessor,
CFGLogitsProcessor,
FSMLogitsProcessor,
JSONLogitsProcessor,
OutlinesLogitsProcessor,
RegexLogitsProcessor,
)
145 changes: 101 additions & 44 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from abc import abstractmethod
from typing import List, Protocol, Union
from typing import TYPE_CHECKING, List, Protocol, Type, Union

import numpy as np
import torch
from numpy.typing import NDArray

if TYPE_CHECKING:
import mlx.core as mx

def is_mlx_array(logits):

Array = Union[NDArray, torch.Tensor, List, "mx.array"]


def is_mlx_array_type(array_type):
try:
import mlx.core as mx
except ImportError:
return False
return isinstance(logits, mx.array)
return issubclass(array_type, mx.array)


class BaseLogitsProcessor(Protocol):
class OutlinesLogitsProcessor(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call it LogitsProcessor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, do we really need that class if all logit processors are ever going to be FSMLogitsProcessors

Copy link
Contributor Author

@lapp0 lapp0 Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking here was if we want a new logits processor which isn't for structured generation, e.g. RepetitionPenaltyLogitsProcessor. We can easily implement this processor once and use it in every outlines.models without any changes specific to the model.

This provides us the flexiiblity to have the inference library handle the decoder pass and sampling, and outlines can handle all logits augmentation.

Another example: mlxlm doesn't support stop_strings and stop_strings is broken in the latest transformers version (4.41.2). We could implement a single StopStringsLogitsProcessor and make it available to all inference libraries. Not arguing that implementing this specific logits processor should be a priority, but I am arguing that having a separate base class which doesn't demand an FSM provides us flexibility and opportunities.

"""
Base class for logits processors which normalizes types of logits:
- ndarray (used by llama-cpp-python), converted to torch.Tensor
- mlx.core.array (used by mlx-lm), converted to torch.Tensor
- torch.Tensor (used by everything else)

Normalization of types and conversion to torch.Tensor
Expand All @@ -29,50 +36,100 @@ class BaseLogitsProcessor(Protocol):

@abstractmethod
def process_logits(
self, input_ids: List[int], logits: torch.Tensor
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
...
"""
input_ids and logits are always 2D tensors for handling a batch of sequences.

- input_ids -> List[List[tokens]]
- logits.shape[0] -> 2D_Tensor[logits]

Important to keep in mind when designing universal logits processors
- logits processors are only used once and never re-applied for a new sequence generator
- Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids
- Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM
"""
pass

@torch.no_grad()
def __call__(
self,
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor],
logits: Union[NDArray[np.float32], torch.Tensor],
) -> Union[NDArray[np.int64], torch.Tensor]:
input_ids: Array,
logits: Array,
) -> Array:
"""
Apply logits processor
Unify type
- convert input_ids: either ndarray, List[int], or Tensor -> List[int]
- convert logits: either ndarray, mlx array, Tensor -> Tensor
Call process_logits() to perform business logic

1) Unify type
- convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]]
- convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor
2) Unify shape, ensure logits and input_ids are 2D
3) Call self.process_logits() to perform business logic
4) Cast logits back to original array library type
"""
with torch.no_grad():
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()

if isinstance(logits, np.ndarray):
# Unify type, convert numpy array to Tensor
# from_numpy and .numpy() don't copy the data, it uses the same memory address
torch_logits = torch.from_numpy(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
return processed_torch_logits.detach().numpy()

elif isinstance(logits, torch.Tensor):
return self.process_logits(input_ids, logits)

elif is_mlx_array(logits):
# mlx -> torch -> mlx conversion docs:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
import mlx.core as mx

torch_logits = torch.from_dlpack(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)

# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
logits_float32_numpy = processed_torch_logits.float().numpy()
return mx.array(logits_float32_numpy)

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray"
", torch.Tensor, or mlx.core.array typed logits"
)

# ensure logits are torch Tensors
torch_logits = self._to_torch(logits)
Copy link
Member

@rlouf rlouf Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the cost of this conversion? We need to profile this.

I would personally handle this by dispatching the core logic used to process the logits, process_logits depending on the type of input_ids or logits

Copy link
Contributor Author

@lapp0 lapp0 Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that we can cast numpy and mlx.core arrays to torch.Tensor using shared memory, not a copy operation :)

Then we can implement a single OutlinesLogitsProcessor subclass with a single process_logits() method and it works out of the box for any library. While the tests aren't in main yet, I've tested FSMLogitsProcessor, and RegexLogitsProcessor with nearly all outlines.models options (haven't tested against exllamav2) and they work with no additional changes.

This makes the processors easy to maintain and implement.


assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1]

# ensure input_ids are List
if not isinstance(input_ids, list):
input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids, torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids], torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
return self._from_torch(processed_logits, type(logits))

@staticmethod
def _to_torch(tensor_like: Array) -> torch.Tensor:
"""Convert various types to torch.Tensor."""
if isinstance(tensor_like, torch.Tensor):
return tensor_like

elif isinstance(tensor_like, np.ndarray):
return torch.from_numpy(tensor_like)

elif isinstance(tensor_like, list):
return torch.tensor(tensor_like)

elif is_mlx_array_type(type(tensor_like)):
# mlx -> torch -> mlx conversion docs:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
return torch.from_dlpack(tensor_like)

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray, "
"torch.Tensor, list, or mlx.core.array typed logits"
)

@staticmethod
def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
"""Convert torch.Tensor to the specified target type."""
if target_type == torch.Tensor:
return tensor

elif target_type == np.ndarray:
return tensor.detach().numpy()

elif target_type == list:
return tensor.detach().tolist()

elif is_mlx_array_type(target_type):
import mlx.core as mx

# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
return mx.array(tensor.float().numpy())

else:
raise TypeError(
f"Failed to convert torch tensors to target_type `{target_type}`"
)
39 changes: 26 additions & 13 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,22 @@
limitations under the License.
"""
import math
from typing import TYPE_CHECKING, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union

import numpy as np
import torch
from numpy.typing import NDArray
from pydantic import BaseModel

from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str

from .base_logits_processor import BaseLogitsProcessor
from .base_logits_processor import OutlinesLogitsProcessor

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer


class FSMLogitsProcessor(BaseLogitsProcessor):
class FSMLogitsProcessor(OutlinesLogitsProcessor):
"""Bias generation using a finite state machine.

Attributes
Expand All @@ -63,13 +61,14 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_state = 0
self._fsm_states: Dict[int, int] = {}
self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf above remark on dispatching the function depending on the type of the tensors.

self, input_ids: List[int], logits: torch.Tensor
) -> NDArray[np.float32]:
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.

Parameters
Expand All @@ -84,17 +83,31 @@ def process_logits(
torch.Tensor
The biased logits.
"""
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
sequence_states = [0] * len(input_ids)

else:
last_token = input_ids[-1]
self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token)
for seq_ids in input_ids:
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]

allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens
allowed_tokens = torch.tensor(allowed_tokens, device=logits.device)
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])

self._fsm_states[curr_state_key] = curr_state
sequence_states.append(curr_state)

mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = logits[allowed_tokens]
for i, fsm_state in enumerate(sequence_states):
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
mask[i, allowed_tokens] = logits[i, allowed_tokens]

return mask

def copy(self) -> "FSMLogitsProcessor":
Expand Down
Loading
Loading