Skip to content

Commit

Permalink
dynamic sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Jan 18, 2025
1 parent 6a98147 commit b3aae82
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mlx_vlm/models/llava/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class TextConfig:
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
sliding_window: int = None

@classmethod
def from_dict(cls, params):
Expand Down Expand Up @@ -51,9 +52,9 @@ def __init__(self, config: TextConfig):
super().__init__()

dim = config.hidden_size
self.config = config
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads

self.repeats = n_heads // n_kv_heads

head_dim = config.hidden_size // n_heads
Expand Down Expand Up @@ -102,6 +103,12 @@ def __call__(
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
if self.config.sliding_window:
print(self.config.sliding_window)
keys = keys[:, :, self.config.sliding_window :, :]
values = values[:, :, self.config.sliding_window :, :]
if mask is not None:
mask = mask[:, self.config.sliding_window :]
else:
queries = self.rope(queries)
keys = self.rope(keys)
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __call__(
inputs_embeds = self.prefill(
inputs_embeds, cache=cache, prefill_step_size=prefill_step_size
)
self.config.text_config.sliding_window = 4096

logits = self.language_model(
input_ids, cache=cache, inputs_embeds=inputs_embeds
Expand Down

0 comments on commit b3aae82

Please sign in to comment.