Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Florence-2 #105

Merged
merged 40 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6812665
loading florence2
Blaizzy Oct 23, 2024
695a96c
update enc-decoder flow
Blaizzy Oct 24, 2024
d9d27fa
fix embeddings and update cache
Blaizzy Oct 31, 2024
b064000
working language model florence-2 yay!!!
Blaizzy Nov 3, 2024
b370ffd
Merge branch 'main' into pc/florence-2
Blaizzy Nov 3, 2024
9502562
Florecen-2 LM complete
Blaizzy Nov 3, 2024
1a54af9
Merge branch 'pc/florence-2' of https://github.com/Blaizzy/mlx-vlm in…
Blaizzy Nov 3, 2024
48c9dc6
add simpleKVCache
Blaizzy Nov 3, 2024
4040aed
vision model 90% done
Blaizzy Nov 6, 2024
9c30c90
fix image encoding
Blaizzy Nov 6, 2024
d76b811
update GELU and format code
Blaizzy Nov 7, 2024
8d5724c
add new grouped conv and numerical close to HF
Blaizzy Nov 8, 2024
a8faa79
fix depthwise conv, and add image encoding
Blaizzy Nov 9, 2024
eef0594
expand generate to support florence2
Blaizzy Nov 13, 2024
6dc7156
update embeddings to fix quant errors
Blaizzy Nov 13, 2024
bae6a9d
fix inference and weight sanitization(vision)
Blaizzy Nov 13, 2024
9111eb1
Merge branch 'main' into pc/florence-2
Blaizzy Nov 13, 2024
4e2c593
Merge branch 'main' into pc/florence-2
Blaizzy Nov 21, 2024
a2f3872
Merge branch 'main' into pc/florence-2
Blaizzy Nov 21, 2024
8232e63
remove unused
Blaizzy Nov 21, 2024
b829b12
fix formatting
Blaizzy Nov 21, 2024
dcbd3b1
loading florence2
Blaizzy Oct 23, 2024
7270cd5
update enc-decoder flow
Blaizzy Oct 24, 2024
f544bdd
fix embeddings and update cache
Blaizzy Oct 31, 2024
194518e
working language model florence-2 yay!!!
Blaizzy Nov 3, 2024
31e6b19
Florecen-2 LM complete
Blaizzy Nov 3, 2024
c07eee3
add simpleKVCache
Blaizzy Nov 3, 2024
3e12cb5
vision model 90% done
Blaizzy Nov 6, 2024
d17f9d6
fix image encoding
Blaizzy Nov 6, 2024
35d8646
update GELU and format code
Blaizzy Nov 7, 2024
fbfcb6a
add new grouped conv and numerical close to HF
Blaizzy Nov 8, 2024
636d9ba
fix depthwise conv, and add image encoding
Blaizzy Nov 9, 2024
8ce7adc
expand generate to support florence2
Blaizzy Nov 13, 2024
2dcec73
update embeddings to fix quant errors
Blaizzy Nov 13, 2024
0201536
fix inference and weight sanitization(vision)
Blaizzy Nov 13, 2024
64d3076
remove unused
Blaizzy Nov 21, 2024
833d13a
fix formatting
Blaizzy Nov 21, 2024
529a01c
Merge branch 'pc/florence-2' of https://github.com/Blaizzy/mlx-vlm in…
Blaizzy Nov 21, 2024
5c912a5
fix repetition penalty
Blaizzy Nov 21, 2024
34bb063
add tests and formatting
Blaizzy Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def preprocess(self, images):

class KVCache:

def __init__(self, head_dim, n_kv_heads):
def __init__(self, head_dim, n_kv_heads, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
Expand All @@ -66,9 +66,13 @@ def __init__(self, head_dim, n_kv_heads):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.step = step

def update_and_fetch(self, keys, values):
self.update(keys, values)
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

def update(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
Expand All @@ -88,7 +92,51 @@ def update_and_fetch(self, keys, values):
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]


class SimpleKVCache:
"""A simple key-value cache for transformer attention layers.

Stores and concatenates key/value tensors along sequence dimension.
"""

def __init__(self):
self.keys = None
self.values = None
self.cache_length = 0

def update_and_fetch(self, keys, values):
"""Update cache with new key/value tensors and return full cache.

Args:
keys: New key tensor to add [batch, heads, seq_len, head_dim]
values: New value tensor to add [batch, heads, seq_len, head_dim]

Returns:
Tuple of (cached_keys, cached_values) containing full cache history
"""
if self.cache_length == 0:
# First update - just store tensors
self.keys = keys
self.values = values
else:
# Concatenate with existing cache along sequence dimension
self.keys = mx.concatenate([self.keys, keys], axis=2)
self.values = mx.concatenate([self.values, values], axis=2)

self.cache_length += keys.shape[2]
return self.keys, self.values

def update(self, keys, values):
"""Update cache with new key/value tensors without returning.

Args:
keys: New key tensor to store
values: New value tensor to store
"""
self.keys = keys
self.values = values
self.cache_length += keys.shape[2]


class RotatingKVCache:
Expand Down Expand Up @@ -212,3 +260,4 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
class LanguageModelOutput:
logits: mx.array
cross_attention_states: Optional[List[mx.array]] = None
encoder_outputs: Optional[List[mx.array]] = None
8 changes: 8 additions & 0 deletions mlx_vlm/models/florence2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .florence2 import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading
Loading