|
40 | 40 |
|
41 | 41 | import { |
42 | 42 | AutoConfig, |
43 | | - getKeyValueShapes, |
| 43 | + getCacheShapes, |
44 | 44 | } from './configs.js'; |
45 | 45 |
|
46 | 46 | import { |
@@ -318,7 +318,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { |
318 | 318 | } |
319 | 319 |
|
320 | 320 | if (selectedDevice === 'webgpu') { |
321 | | - const shapes = getKeyValueShapes(options.config, { |
| 321 | + const shapes = getCacheShapes(options.config, { |
322 | 322 | prefix: 'present', |
323 | 323 | }); |
324 | 324 | if (Object.keys(shapes).length > 0 && !isONNXProxy()) { |
@@ -1960,7 +1960,9 @@ export class PreTrainedModel extends Callable { |
1960 | 1960 |
|
1961 | 1961 | for (const name in decoderResults) { |
1962 | 1962 | if (name.startsWith('present')) { |
1963 | | - const newName = name.replace('present', 'past_key_values'); |
| 1963 | + const newName = name |
| 1964 | + .replace('present_conv', 'past_conv') // Hybrid cache architecture (e.g., LFM2) |
| 1965 | + .replace('present', 'past_key_values'); |
1964 | 1966 | const is_encoder_pkv = name.includes('encoder'); |
1965 | 1967 | if (is_encoder_pkv && pastKeyValues) { |
1966 | 1968 | // Optimization introduced by optimum to reuse past key values. |
@@ -2017,14 +2019,14 @@ export class PreTrainedModel extends Callable { |
2017 | 2019 | Object.assign(decoderFeeds, pastKeyValues) |
2018 | 2020 | } else { |
2019 | 2021 | const session = this.sessions['decoder_model_merged'] ?? this.sessions['model']; |
2020 | | - const dtype = session?.config?.kv_cache_dtype ?? 'float32'; |
2021 | | - const empty = (dtype === 'float16') ? new DataTypeMap.float16() : []; |
2022 | | - |
2023 | 2022 | const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask)?.dims?.[0] ?? 1; |
2024 | | - const shapes = getKeyValueShapes(this.config, { batch_size }); |
2025 | 2023 |
|
| 2024 | + const dtype = session?.config?.kv_cache_dtype ?? 'float32'; |
| 2025 | + const cls = (dtype === 'float16') ? DataTypeMap.float16 : DataTypeMap.float32; |
| 2026 | + const shapes = getCacheShapes(this.config, { batch_size }); |
2026 | 2027 | for (const name in shapes) { |
2027 | | - decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]); |
| 2028 | + const size = shapes[name].reduce((a, b) => a * b, 1); |
| 2029 | + decoderFeeds[name] = new Tensor(dtype, new cls(size), shapes[name]); |
2028 | 2030 | } |
2029 | 2031 | } |
2030 | 2032 | } |
@@ -4586,6 +4588,13 @@ export class LlamaModel extends LlamaPreTrainedModel { } |
4586 | 4588 | export class LlamaForCausalLM extends LlamaPreTrainedModel { } |
4587 | 4589 | ////////////////////////////////////////////////// |
4588 | 4590 |
|
| 4591 | +////////////////////////////////////////////////// |
| 4592 | +// LFM2 models |
| 4593 | +export class Lfm2PreTrainedModel extends PreTrainedModel { } |
| 4594 | +export class Lfm2Model extends Lfm2PreTrainedModel { } |
| 4595 | +export class Lfm2ForCausalLM extends Lfm2PreTrainedModel { } |
| 4596 | +////////////////////////////////////////////////// |
| 4597 | + |
4589 | 4598 | ////////////////////////////////////////////////// |
4590 | 4599 | // SmolLM3 models |
4591 | 4600 | export class SmolLM3PreTrainedModel extends PreTrainedModel { } |
@@ -7803,6 +7812,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ |
7803 | 7812 | ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]], |
7804 | 7813 | ['codegen', ['CodeGenModel', CodeGenModel]], |
7805 | 7814 | ['llama', ['LlamaModel', LlamaModel]], |
| 7815 | + ['lfm2', ['Lfm2Model', Lfm2Model]], |
7806 | 7816 | ['smollm3', ['SmolLM3Model', SmolLM3Model]], |
7807 | 7817 | ['exaone', ['ExaoneModel', ExaoneModel]], |
7808 | 7818 | ['olmo', ['OlmoModel', OlmoModel]], |
@@ -7908,6 +7918,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ |
7908 | 7918 | ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]], |
7909 | 7919 | ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]], |
7910 | 7920 | ['llama', ['LlamaForCausalLM', LlamaForCausalLM]], |
| 7921 | + ['lfm2', ['Lfm2ForCausalLM', Lfm2ForCausalLM]], |
7911 | 7922 | ['smollm3', ['SmolLM3ForCausalLM', SmolLM3ForCausalLM]], |
7912 | 7923 | ['exaone', ['ExaoneForCausalLM', ExaoneForCausalLM]], |
7913 | 7924 | ['olmo', ['OlmoForCausalLM', OlmoForCausalLM]], |
|
0 commit comments