Skip to content

Commit 3a960bb

Browse files
committed
Bump transformers to 4.54.1
1 parent ab6261d commit 3a960bb

File tree

5 files changed

+38
-41
lines changed

5 files changed

+38
-41
lines changed

install_dev.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def install_dep_from_source():
3434
"-m",
3535
"pip",
3636
"install",
37-
"git+https://github.com/huggingface/transformers@896e9cea1ade521b2648f4798218550f6c72190c#egg=transformers", # 4.53.1
37+
"git+https://github.com/huggingface/transformers@9c641dc16154964e5ffc0c13e9ec6aaffa295ed6#egg=transformers", # 4.54.1
3838
]
3939
)
4040
subprocess.check_call(

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def __init__(
5454

5555
# Create a list of CustomKVCache instances, one per layer
5656
self.kv_cache = torch.nn.ModuleList()
57-
for _ in range(config.num_hidden_layers):
57+
for layer in self.layers:
5858
layer_cache = CustomKVCache(
59-
max_batch_size=self.max_batch_size,
60-
max_context_length=self.max_cache_len,
61-
n_heads=self.num_key_value_heads,
62-
head_dim=self.head_dim,
59+
max_batch_size=layer.max_batch_size,
60+
max_context_length=layer.max_cache_len,
61+
n_heads=layer.num_heads,
62+
head_dim=layer.head_dim,
6363
dtype=dtype,
6464
)
6565
self.kv_cache.append(layer_cache)
@@ -202,32 +202,29 @@ def __init__(
202202
layer_device_map=layer_device_map,
203203
)
204204

205-
# make sure layer_device_map is none
206205
assert layer_device_map is None
207206
assert device is None or device == "cpu", "Device must be None or 'cpu'"
208207

209208
self.cache_position = None
210-
# Create a list of cache instances, one per layer
211-
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
209+
# Create a list of cache instances, one per layer.
210+
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
212211
self.kv_cache = torch.nn.ModuleList()
213-
for layer_idx in range(config.num_hidden_layers):
214-
# newer version of transfomer has is_sliding defined
215-
# for HybridCache
216-
if self.is_sliding[layer_idx]:
212+
for layer in self.layers:
213+
if layer.is_sliding():
217214
# This is a sliding window layer
218215
layer_cache = CustomRingKVCache(
219-
max_batch_size=self.max_batch_size,
220-
max_context_length=self.sliding_window_len,
221-
n_heads=self.num_key_value_heads,
222-
head_dim=self.head_dim,
216+
max_batch_size=layer.max_batch_size,
217+
max_context_length=layer.max_cache_len,
218+
n_heads=layer.num_heads,
219+
head_dim=layer.head_dim,
223220
dtype=dtype,
224221
)
225222
else:
226223
layer_cache = CustomKVCache(
227-
max_batch_size=self.max_batch_size,
228-
max_context_length=self.max_cache_len,
229-
n_heads=self.num_key_value_heads,
230-
head_dim=self.head_dim,
224+
max_batch_size=layer.max_batch_size,
225+
max_context_length=layer.max_cache_len,
226+
n_heads=layer.num_heads,
227+
head_dim=layer.head_dim,
231228
dtype=dtype,
232229
)
233230
self.kv_cache.append(layer_cache)
@@ -284,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
284281

285282
# For CustomRingKVCache, we need to handle the sequence length differently
286283
layer_cache = self.kv_cache[layer_idx]
287-
if self.is_sliding[layer_idx]:
284+
if self.layers[layer_idx].is_sliding():
288285
# CustomRingKVCache cache_position_manager which
289286
# maintains cache position for each slot in the kv cache
290287
# we return the max position + 1 to indicate max position
@@ -308,7 +305,7 @@ def get_layer_cache(self, layer_idx: int):
308305

309306
def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
310307
"""
311-
Replace all KV caches in the module with ETCustomStaticCache.
308+
Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache.
312309
This modifies the model in place.
313310
314311
Args:
@@ -342,18 +339,18 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
342339
if getattr(module, "replace_cache", None) is not None:
343340
static_cache = ETCustomStaticCache(
344341
config=config,
345-
max_batch_size=generation_config.cache_config.batch_size,
346-
max_cache_len=generation_config.cache_config.max_cache_len,
347-
device=generation_config.cache_config.device,
342+
max_batch_size=generation_config.cache_config.get("batch_size"),
343+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
344+
device=generation_config.cache_config.get("device"),
348345
dtype=cache_dtype,
349346
)
350347
module.replace_cache(static_cache)
351348
else:
352349
module.static_cache = ETCustomStaticCache(
353350
config=config,
354-
max_batch_size=generation_config.cache_config.batch_size,
355-
max_cache_len=generation_config.cache_config.max_cache_len,
356-
device=generation_config.cache_config.device,
351+
max_batch_size=generation_config.cache_config.get("batch_size"),
352+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
353+
device=generation_config.cache_config.get("device"),
357354
dtype=cache_dtype,
358355
)
359356
# Dont know why we need to this even though
@@ -370,25 +367,25 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
370367
if getattr(module, "replace_cache", None) is not None:
371368
hybrid_cache = ETCustomHybridCache(
372369
config=config,
373-
max_batch_size=generation_config.cache_config.batch_size,
374-
max_cache_len=generation_config.cache_config.max_cache_len,
375-
device=generation_config.cache_config.device,
370+
max_batch_size=generation_config.cache_config.get("batch_size"),
371+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
372+
device=generation_config.cache_config.get("device"),
376373
dtype=cache_dtype,
377374
)
378375
module.replace_cache(hybrid_cache)
379376
else:
380377
module.cache = ETCustomHybridCache(
381378
config=config,
382-
max_batch_size=generation_config.cache_config.batch_size,
383-
max_cache_len=generation_config.cache_config.max_cache_len,
384-
device=generation_config.cache_config.device,
379+
max_batch_size=generation_config.cache_config.get("batch_size"),
380+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
381+
device=generation_config.cache_config.get("device"),
385382
dtype=cache_dtype,
386383
)
387384
# Register cache attributes for each layer
388385
for i in range(len(module.cache.kv_cache)):
389386
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
390387
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
391-
if module.cache.is_sliding[i]:
388+
if module.cache.layers[i].is_sliding():
392389
# Register cache_positions as buffer for sliding window layers
393390
# This prevents it from being traced as a constant
394391
module.register_buffer(

optimum/exporters/executorch/integrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
395395
wrapped_decoder = (
396396
Seq2SeqLMDecoderExportableModuleWithStaticCache(
397397
model=self.full_model,
398-
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
399-
batch_size=self.generation_config.cache_config.batch_size,
398+
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
399+
batch_size=self.generation_config.cache_config.get("batch_size"),
400400
)
401401
.to("cpu")
402402
.eval()

optimum/exporters/executorch/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def save_config_to_constant_methods(
5353
# Check for cache_config and its attributes
5454
cache_config = getattr(generation_config, "cache_config", None)
5555
if cache_config is not None:
56-
max_batch_size = getattr(cache_config, "batch_size", None)
57-
max_seq_len = getattr(cache_config, "max_cache_len", None)
56+
max_batch_size = cache_config.get("batch_size")
57+
max_seq_len = cache_config.get("max_cache_len")
5858

5959
if max_batch_size is not None:
6060
metadata["get_max_batch_size"] = max_batch_size

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
INSTALL_REQUIRE = [
1515
"optimum~=1.24",
1616
"executorch>=0.6.0",
17-
"transformers==4.51.3",
17+
"transformers==4.54.1",
1818
]
1919

2020
TESTS_REQUIRE = [

0 commit comments

Comments
 (0)