Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __init__(self,
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
use_sliding_window = (layer_idx % 2 == 0 and getattr(
config, "interleaved_sliding_window", None) is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
self.attn = Attention(self.num_heads,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(self,

# TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix)
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
self.is_sliding = (getattr(
config, "interleaved_sliding_window", None) is not None and bool(
(layer_idx + 1) % config.sliding_window_pattern))
# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
Expand Down
18 changes: 10 additions & 8 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = config.text_config.interleaved_sliding_window
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)

self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
Expand Down Expand Up @@ -680,13 +681,14 @@ def prepare_attn_masks(
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
if self.sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
Expand Down