Skip to content

Commit d89e18d

Browse files
committed
Fix sliding window, print loaded ops
1 parent 35fc918 commit d89e18d

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
211211
self.kv_cache = torch.nn.ModuleList()
212212
for layer in self.layers:
213-
if layer.is_sliding():
213+
if layer.is_sliding:
214214
# This is a sliding window layer
215215
layer_cache = CustomRingKVCache(
216216
max_batch_size=layer.max_batch_size,
@@ -281,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
281281

282282
# For CustomRingKVCache, we need to handle the sequence length differently
283283
layer_cache = self.kv_cache[layer_idx]
284-
if self.layers[layer_idx].is_sliding():
284+
if self.layers[layer_idx].is_sliding:
285285
# CustomRingKVCache cache_position_manager which
286286
# maintains cache position for each slot in the kv cache
287287
# we return the max position + 1 to indicate max position
@@ -385,7 +385,7 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
385385
for i in range(len(module.cache.kv_cache)):
386386
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
387387
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
388-
if module.cache.layers[i].is_sliding():
388+
if module.cache.layers[i].is_sliding:
389389
# Register cache_positions as buffer for sliding window layers
390390
# This prevents it from being traced as a constant
391391
module.register_buffer(

optimum/executorch/modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def _from_pretrained(
186186
subfolder=subfolder,
187187
local_files_only=local_files_only,
188188
)
189+
from executorch.extension.pybindings.portable_lib import _get_operator_names
190+
print("----------- LOADED OPS ----------")
191+
print('\n'.join(_get_operator_names()))
192+
print("---------------------------------")
189193
model = _load_for_executorch(model_cache_path)
190194
logging.info(
191195
f"Loaded model from {model_cache_path} ({os.path.getsize(model_cache_path) / (1024 * 1024):.2f} MB)"

0 commit comments

Comments
 (0)