Skip to content

Commit

Permalink
fix prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Jan 18, 2025
1 parent 284504c commit 6b9e8d8
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 7 deletions.
17 changes: 16 additions & 1 deletion mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def parse_arguments():
help="Ratio of visual tokens to keep during filtering topk tokens (between 0.1 and 1.0).",
choices=[x / 10 for x in range(1, 11)],
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--prefill-step-size",
type=int,
default=256,
help="Set the prefill step size",
)
return parser.parse_args()


Expand All @@ -113,6 +125,9 @@ def main():
prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image))

kwargs = {}

if args.max_kv_size is not None:
kwargs["max_kv_size"] = args.max_kv_size
if args.resize_shape is not None:
resize_shape = args.resize_shape
if len(resize_shape) not in [1, 2]:
Expand Down Expand Up @@ -153,7 +168,7 @@ def main():
prompt,
args.image,
max_tokens=args.max_tokens,
temp=args.temp,
temperature=args.temperature,
**kwargs,
):
response += chunk.text
Expand Down
93 changes: 90 additions & 3 deletions mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def update(self, keys, values):
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values

@property
def state(self):
return self.keys, self.values


class SimpleKVCache:
"""A simple key-value cache for transformer attention layers.
Expand Down Expand Up @@ -149,15 +153,15 @@ def update(self, keys, values):

class RotatingKVCache:

def __init__(self, head_dim, n_kv_heads, max_size, keep=4, step=256):
def __init__(self, head_dim, n_kv_heads, max_size, keep=None, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keep = keep if keep is not None else step // 2
self.keys = None
self.values = None
self.offset = 0
Expand Down Expand Up @@ -280,7 +284,7 @@ def __init__(self):
self.vision_tower = None
self.language_model = None

def prefill(self, input_embeds, cache=None, prefill_step_size=256, **kwargs):
def prefill(self, input_embeds, cache=None, prefill_step_size=256):
# Process input in batches for better parallelization
num_batches = (
input_embeds.shape[1] + prefill_step_size - 1
Expand Down Expand Up @@ -308,3 +312,86 @@ def prefill(self, input_embeds, cache=None, prefill_step_size=256, **kwargs):
return remaining_embeds

return input_embeds

def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None):
batch_size, seq_len = image_feature.shape[:2]

k_tokens = (
int(image_feature.shape[1] * dominant_tokens_ratio)
if dominant_tokens_ratio is not None
else None
) # keep 25% of the visual tokens
if k_tokens is None:
return image_feature
cls_idx = 0 # self.config.image_token_index

attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1)

topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:]
# use this to plot the dominant attention map
# https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/model/multimodal_encoder/clip_encoder.py#L62
# https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/serve/gradio_web_server.py#L424

# Create CLS token indices array
# Shape: (B, 1)
cls_indices = mx.full((batch_size, 1), cls_idx, dtype=mx.int32)

# Concat with CLS token index
# Add 1 to account for the offset after CLS token
dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1)

image_feature = mx.take(image_feature, dominant_idx, axis=1)[0]
return image_feature

def merge_similar_visual_tokens(
self, image_feature, visual_token_ratio, merge_ratio=0.4
):
# Skip CLS token (first token)
tokens = image_feature[:, 1:]
batch_size, num_tokens, hidden_dim = tokens.shape

# Calculate target number of tokens
target_tokens = max(1, int(num_tokens * visual_token_ratio))

while num_tokens > target_tokens:
# Calculate similarities between adjacent tokens
tokens_a = tokens[:, :-1] # all except last
tokens_b = tokens[:, 1:] # all except first

# Calculate cosine similarity
a_norm = mx.sqrt(mx.sum(tokens_a * tokens_a, axis=-1, keepdims=True))
b_norm = mx.sqrt(mx.sum(tokens_b * tokens_b, axis=-1, keepdims=True))
similarities = mx.sum(tokens_a * tokens_b, axis=-1)
similarities = similarities / (a_norm.squeeze(-1) * b_norm.squeeze(-1))

# Sort similarities and get indices of pairs to merge
# We'll merge about 50% of remaining excess tokens in each iteration
num_to_merge = max(1, int((num_tokens - target_tokens) * merge_ratio))
merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:]

# Create a list to track which indices to merge
to_merge = set(merge_indices[0].tolist())

# Merge selected pairs
new_tokens = []
i = 0
while i < num_tokens:
if i < num_tokens - 1 and i in to_merge:
# Merge this token with the next one
merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2
new_tokens.append(merged)
i += 2
elif i > 0 and (i - 1) in to_merge:
# Skip this token as it was merged in the previous step
i += 1
else:
# Keep this token as is
new_tokens.append(tokens[:, i : i + 1])
i += 1

# Update tokens
tokens = mx.concatenate(new_tokens, axis=1)
num_tokens = tokens.shape[1]

# Reattach CLS token
return mx.concatenate([image_feature[:, :1], tokens], axis=1)
12 changes: 9 additions & 3 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def generate_step(
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
max_size: Optional[int] = None,
max_kv_size: Optional[int] = None,
**kwargs,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
Expand All @@ -866,6 +866,7 @@ def generate_step(
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
logit_bias (dictionary, optional): Additive logit bias.
max_kv_size (int, optional): Set the maximum key-value cache size.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
Expand Down Expand Up @@ -910,11 +911,16 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
(SimpleKVCache(), SimpleKVCache()) for n in model.language_model.layers
]
else:
if max_size is None:
if max_kv_size is None:
cache = [KVCache(model.language_model.head_dim, n) for n in kv_heads]
else:
cache = [
RotatingKVCache(model.language_model.head_dim, n, max_size=max_size)
RotatingKVCache(
model.language_model.head_dim,
n,
max_size=max_kv_size,
keep=max_kv_size // 2 if pixel_values is None else 4,
)
for n in kv_heads
]

Expand Down

0 comments on commit 6b9e8d8

Please sign in to comment.