Skip to content

Commit

Permalink
Fixes to alternating SWA layers in Gemma2 (#31775)
Browse files Browse the repository at this point in the history
* HybridCache: Flip order of alternating global-attn/sliding-attn layers

* HybridCache: Read sliding_window argument from cache_kwargs

* Gemma2Model: Flip order of alternating global-attn/sliding-attn layers

* Code formatting
  • Loading branch information
turboderp authored and ArthurZucker committed Jul 11, 2024
1 parent 0be998b commit c43fd9d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, devi
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
[not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
Expand Down Expand Up @@ -1056,9 +1056,9 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
self.sliding_window = config.sliding_window if layer_idx % 2 else None
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None

def forward(
self,
Expand Down Expand Up @@ -616,7 +616,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.is_sliding = bool(layer_idx % 2)
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
Expand Down

0 comments on commit c43fd9d

Please sign in to comment.