Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch compile for mixtral #30793

Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2aa4df3
first version
zhenglongjiepheonix May 3, 2024
1ddd617
fix sliding window
zhenglongjiepheonix May 4, 2024
2f5c7ca
fix style
zhenglongjiepheonix May 7, 2024
e117323
add sliding window cache
zhenglongjiepheonix May 9, 2024
dec4904
fix style
zhenglongjiepheonix May 9, 2024
900615b
address comments
zhenglongjiepheonix May 10, 2024
e04d68b
fix test
zhenglongjiepheonix May 10, 2024
bb0811b
fix style
zhenglongjiepheonix May 10, 2024
dd7ff33
move sliding window check inside cache init
zhenglongjiepheonix May 10, 2024
3e08b7e
add compile for mixtral
zhenglongjiepheonix May 13, 2024
dcac131
first version
zhenglongjiepheonix May 3, 2024
9afc73b
fix sliding window
zhenglongjiepheonix May 4, 2024
3fa9285
fix style
zhenglongjiepheonix May 7, 2024
5246ced
add sliding window cache
zhenglongjiepheonix May 9, 2024
6367154
fix style
zhenglongjiepheonix May 9, 2024
c74b329
address comments
zhenglongjiepheonix May 10, 2024
1cd711c
fix test
zhenglongjiepheonix May 10, 2024
06b64ca
fix style
zhenglongjiepheonix May 10, 2024
d46c601
move sliding window check inside cache init
zhenglongjiepheonix May 10, 2024
ec8f338
revert changes on irrelevant files & add comment on SlidingWindowCache
zhenglongjiepheonix May 13, 2024
a51b44f
address comments & fix style
zhenglongjiepheonix May 13, 2024
c1fca1a
update causal mask
zhenglongjiepheonix May 13, 2024
ad969d2
merge from main
zhenglongjiepheonix May 13, 2024
fb00bb9
merge changes on Mistral
zhenglongjiepheonix May 13, 2024
6d0bf35
fix style
zhenglongjiepheonix May 13, 2024
66de109
revert setup.py
zhenglongjiepheonix May 14, 2024
210179b
fix some bug
zhenglongjiepheonix May 14, 2024
c6bf7a1
attempt
zhenglongjiepheonix May 15, 2024
93680ea
attempt
zhenglongjiepheonix May 15, 2024
b2ab9b3
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix May 15, 2024
18fa186
attempt
zhenglongjiepheonix May 16, 2024
9b2c104
fix some bug
zhenglongjiepheonix May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.

> [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile.
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.

For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.

Expand Down
121 changes: 121 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,124 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.

The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:

indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])

We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)

Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)

super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)

self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)

torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Dict[str, Any] | None = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size

k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out

return k_out, v_out

def get_seq_length(self, layer_idx: int | None = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0

def get_max_length(self) -> int | None:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None

def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
42 changes: 25 additions & 17 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -96,9 +96,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}


@dataclass
Expand Down Expand Up @@ -1326,33 +1324,42 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.

Returns the resulting static cache object.
Returns the resulting cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
)
if needs_new_cache:
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len

if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
self._cache = cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache
self._cache.reset()
return self._cache

def _prepare_special_tokens(
self,
Expand Down Expand Up @@ -1615,8 +1622,9 @@ def generate(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -124,7 +124,7 @@ def _get_unpad_data(attention_mask):
)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Falcon
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
12 changes: 2 additions & 10 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
""" PyTorch Gemma model."""

import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -250,7 +249,6 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -262,8 +260,8 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -617,7 +615,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -633,10 +630,6 @@ def forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

residual = hidden_states

Expand All @@ -651,7 +644,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask):


class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down Expand Up @@ -614,7 +614,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
Loading
Loading