-
Notifications
You must be signed in to change notification settings - Fork 788
/
models.js
6843 lines (6050 loc) · 262 KB
/
models.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/**
* @file Definitions of all models available in Transformers.js.
*
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
* // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...]
* // dims: (3) [1, 6, 30522],
* // type: "float32",
* // size: 183132,
* // }
* ```
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
* // 'Ich liebe Transformatoren!'
* ```
*
* @module models
*/
import {
AutoConfig,
getKeyValueShapes,
} from './configs.js';
import {
deviceToExecutionProviders,
createInferenceSession,
isONNXTensor,
isONNXProxy,
} from './backends/onnx.js';
import {
DATA_TYPES,
DEFAULT_DEVICE_DTYPE_MAPPING,
DEFAULT_DTYPE_SUFFIX_MAPPING,
isWebGpuFp16Supported,
} from './utils/dtypes.js';
import {
Callable,
} from './utils/generic.js';
import {
isIntegralNumber,
mergeArrays,
pick,
} from './utils/core.js';
import {
getModelFile,
getModelJSON,
} from './utils/hub.js';
import {
GITHUB_ISSUE_URL,
} from './utils/constants.js';
import {
LogitsProcessorList,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
NoBadWordsLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
} from './generation/logits_process.js';
import {
GenerationConfig,
} from './generation/configuration_utils.js';
import {
cat,
full_like,
mean,
ones,
ones_like,
stack,
std_mean,
Tensor,
zeros_like,
} from './utils/tensor.js';
import { dynamic_time_warping, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';
import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js';
import { whisper_language_to_code } from './models/whisper/common_whisper.js';
//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
ImageTextToText: 6,
Musicgen: 7,
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Helper functions
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();
/**
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
const custom_config = options.config?.['transformers.js_config'] ?? {};
let device = options.device ?? custom_config.device;
if (device && typeof device !== 'string') {
if (device.hasOwnProperty(fileName)) {
device = device[fileName];
} else {
console.warn(`device not specified for "${fileName}". Using the default device.`);
device = null;
}
}
// If the device is not specified, we use the default (supported) execution providers.
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
);
const executionProviders = deviceToExecutionProviders(selectedDevice);
// If options.dtype is specified, we use it to choose the suffix for the model file.
// Otherwise, we use the default dtype for the device.
let dtype = options.dtype ?? custom_config.dtype;
if (typeof dtype !== 'string') {
if (dtype && dtype.hasOwnProperty(fileName)) {
dtype = dtype[fileName];
} else {
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
}
}
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
} else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
}
// Only valid for models with a decoder
const kv_cache_dtype = custom_config.kv_cache_dtype
? (typeof custom_config.kv_cache_dtype === 'string'
? custom_config.kv_cache_dtype
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
: undefined;
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
}
const session_config = {
dtype: selectedDtype,
kv_cache_dtype,
}
// Construct the model file name
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
const session_options = { ...options.session_options };
// Overwrite `executionProviders` if not specified
session_options.executionProviders ??= executionProviders;
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
const free_dimension_overrides = custom_config.free_dimension_overrides;
if (free_dimension_overrides) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
);
}
const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
// handle onnx external data files
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
if (use_external_data_format && (
use_external_data_format === true ||
(
typeof use_external_data_format === 'object' &&
use_external_data_format.hasOwnProperty(fileName) &&
use_external_data_format[fileName] === true
)
)) {
if (apis.IS_NODE_ENV) {
throw new Error('External data format is not yet supported in Node.js');
}
const path = `${fileName}${suffix}.onnx_data`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
resolve({ path, data })
}));
} else if (session_options.externalData !== undefined) {
externalDataPromises = session_options.externalData.map(async (ext) => {
// if the external data is a string, fetch the file and replace the string with its content
if (typeof ext.data === "string") {
const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
return { ...ext, data: ext_buffer };
}
return ext;
});
}
if (externalDataPromises.length > 0) {
session_options.externalData = await Promise.all(externalDataPromises);
}
if (selectedDevice === 'webgpu') {
const shapes = getKeyValueShapes(options.config, {
prefix: 'present',
});
if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
const preferredOutputLocation = {};
for (const key in shapes) {
preferredOutputLocation[key] = 'gpu-buffer';
}
session_options.preferredOutputLocation = preferredOutputLocation;
}
}
const buffer = await bufferPromise;
return { buffer, session_options, session_config };
}
/**
* Helper function to create multiple InferenceSession objects.
*
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await createInferenceSession(buffer, session_options, session_config);
return [name, session];
})
));
}
/**
* Helper function to load multiple optional configuration files
* @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
* @param {Record<string, string>} names The names of the config files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
* @private
*/
async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
return [name, config];
})
));
}
/**
* Validate model inputs
* @param {Object} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @returns {Record<string, Tensor>} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
function validateInputs(session, inputs) {
/**
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
* @type {Record<string, Tensor>}
*/
const checkedInputs = Object.create(null);
const missingInputs = [];
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
// Rare case where one of the model's input names corresponds to a built-in
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
// because it's not undefined but a function.
if (!(tensor instanceof Tensor)) {
missingInputs.push(inputName);
continue;
}
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
`An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`);
}
const numInputsProvided = Object.keys(inputs).length;
const numInputsNeeded = session.inputNames.length;
if (numInputsProvided > numInputsNeeded) {
// No missing inputs, but too many inputs were provided.
// Warn the user and ignore the extra inputs.
let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName));
console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`);
}
return checkedInputs;
}
/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
} catch (e) {
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', checkedInputs);
throw e;
}
}
/**
* Replaces ONNX Tensor objects with custom Tensor objects to support additional functions.
* @param {Object} obj The object to replace tensor objects in.
* @returns {Object} The object with tensor objects replaced by custom Tensor objects.
* @private
*/
function replaceTensors(obj) {
for (let prop in obj) {
if (isONNXTensor(obj[prop])) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
}
}
return obj;
}
/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {Array|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
*/
function toI64Tensor(items) {
if (items instanceof Tensor) {
return items;
}
// items is an array
if (items.length === 0) {
throw Error("items must be non-empty");
}
if (Array.isArray(items[0])) {
// batched
if (items.some(x => x.length !== items[0].length)) {
throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.")
}
return new Tensor('int64',
BigInt64Array.from(items.flat().map(x => BigInt(x))),
[items.length, items[0].length]
);
} else {
//flat
return new Tensor('int64',
BigInt64Array.from(items.map(x => BigInt(x))),
[1, items.length]
);
}
}
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
* @returns {Tensor} The boolean tensor.
* @private
*/
function boolTensor(value) {
return new Tensor('bool', [value], [1]);
}
// JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in
/**
* Perform forward pass on the seq2seq model (both encoder and decoder).
* @param {Object} self The seq2seq model object.
* @param {Object} model_inputs The input object for the model containing encoder and decoder inputs.
* @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model.
* @private
*/
async function seq2seqForward(self, model_inputs) {
let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs;
// Encode if needed
if (!encoder_outputs) {
const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames);
// Encoder outputs are not given, so we must compute them.
encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state;
}
other_decoder_inputs.input_ids = decoder_input_ids;
other_decoder_inputs.encoder_hidden_states = encoder_outputs;
if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) {
other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask
}
const decoderResults = await decoderForward(self, other_decoder_inputs, true);
return decoderResults;
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
const session = self.sessions['model'];
const encoderFeeds = pick(model_inputs, session.inputNames);
if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) {
if (!model_inputs.input_ids) {
throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.');
}
encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
}
if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
encoderFeeds.token_type_ids = new Tensor(
'int64',
new BigInt64Array(encoderFeeds.input_ids.data.length),
encoderFeeds.input_ids.dims
)
}
return await sessionRun(session, encoderFeeds);
}
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The logits and past key values.
* @private
*/
async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
const session = self.sessions[
is_encoder_decoder ? 'decoder_model_merged' : 'model'
]
const { past_key_values, ...new_model_inputs } = model_inputs;
if (session.inputNames.includes('use_cache_branch')) {
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
}
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values);
}
// Unpack the `past_key_values` object into model inputs
self.addPastKeyValues(new_model_inputs, past_key_values);
// Select only the inputs that are needed for the current session
const fixed = pick(new_model_inputs, session.inputNames);
return await sessionRun(session, fixed);
}
/**
* Forward pass of an image-text-to-text model.
* @param {Object} self The image-text-to-text model model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @param {Tensor} [model_inputs.input_ids=null]
* @param {Tensor} [model_inputs.attention_mask=null]
* @param {Tensor} [model_inputs.pixel_values=null]
* @param {Tensor} [model_inputs.position_ids=null]
* @param {Tensor} [model_inputs.inputs_embeds=null]
* @param {Tensor} [model_inputs.past_key_values=null]
* @param {Object} [model_inputs.generation_config=null]
* @param {Object} [model_inputs.logits_processor=null]
* @returns {Promise<Tensor>} The model's output tensor
* @private
*/
async function imageTextToTextForward(self, {
// Produced by the tokenizer/processor:
input_ids = null,
attention_mask = null,
pixel_values = null,
// Used during generation:
position_ids = null,
inputs_embeds = null,
past_key_values = null,
// Generic generation parameters
generation_config = null,
logits_processor = null,
// TODO: needed?
...kwargs
}) {
if (!inputs_embeds) {
// 1. Extract the input embeddings
inputs_embeds = await self.encode_text({ input_ids });
// 2. Possibly, merge text and images
if (pixel_values && input_ids.dims[1] !== 1) {
const image_features = await self.encode_image({ pixel_values });
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
image_features,
inputs_embeds,
input_ids,
attention_mask,
}));
} else if (past_key_values && pixel_values && input_ids.dims[1] === 1) {
// This is the case when we are generating with cache
const target_length = input_ids.dims[1]; // always 1
const past_length = Object.values(past_key_values)[0].dims.at(-2);
attention_mask = cat([
ones([input_ids.dims[0], past_length]),
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
], 1);
}
}
const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
attention_mask,
position_ids,
generation_config,
logits_processor,
}, true);
return outputs;
}
function createPositionIds(model_inputs, past_key_values = null) {
// If the model supports providing position_ids, we create position_ids on the fly for batch generation,
// by computing the cumulative sum of the attention mask along the sequence length dimension.
//
// Equivalent to:
// position_ids = attention_mask.long().cumsum(-1) - 1
// position_ids.masked_fill_(attention_mask == 0, 1)
// if past_key_values:
// position_ids = position_ids[:, -input_ids.shape[1] :]
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
const [bz, seq_len] = attention_mask.dims;
const data = new BigInt64Array(attention_mask.data.length);
for (let i = 0; i < bz; ++i) {
const start = i * seq_len;
let sum = BigInt(0);
for (let j = 0; j < seq_len; ++j) {
const index = start + j;
if (attention_mask.data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += attention_mask.data[index];
}
}
}
let position_ids = new Tensor('int64', data, attention_mask.dims);
if (past_key_values) {
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
position_ids = position_ids.slice(null, [offset, null]);
}
return position_ids;
}
function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
const { input_ids, attention_mask } = model_inputs;
// Keep only the unprocessed tokens:
// 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
// some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
// input)
if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
// NOTE: not needed since we only pass the generated tokens to the next forward pass
// const offset = -(attention_mask.dims[1] - past_length);
// model_inputs.input_ids = input_ids.slice(null, [offset, null]);
}
// 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
// We can discard input_ids based on the past_length.
else if (past_length < input_ids.dims[1]) {
// NOTE: Required for phi models.
// See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
}
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else {
if (
// NOTE: Only used by VLMs (!= so that null matches undefined)
self.config.image_token_index != null &&
// Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
input_ids.data.some(x => x == self.config.image_token_index)
) {
// TODO: Support multiple image tokens
const num_image_tokens = self.config.num_image_tokens;
if (!num_image_tokens) {
throw new Error('`num_image_tokens` is missing in the model configuration.');
}
const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);
// TODO: The attention mask should be formed from the attention mask passed in model_inputs
model_inputs.attention_mask = ones([1, past_length + num_new_tokens]);
}
}
}
return model_inputs;
}
function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
input_ids = input_ids.map(x => [x.at(-1)]);
}
return {
...model_inputs,
decoder_input_ids: toI64Tensor(input_ids),
};
}
function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
if (self.config.is_encoder_decoder) {
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
} else {
return decoder_prepare_inputs_for_generation(self, ...args);
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
/**
* A base class for pre-trained models that provides the model configuration and an ONNX session.
*/
export class PreTrainedModel extends Callable {
main_input_name = 'input_ids';
forward_params = ['input_ids', 'attention_mask'];
/**
* Creates a new instance of the `PreTrainedModel` class.
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
* @param {Record<string, any>} sessions The inference sessions for the model.
* @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
*/
constructor(config, sessions, configs) {
super();
this.config = config;
this.sessions = sessions;
this.configs = configs;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
this._forward = null;
this._prepare_inputs_for_generation = null;
switch (modelType) {
case MODEL_TYPES.DecoderOnly:
this.can_generate = true;
this._forward = decoderForward;
this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.Seq2Seq:
case MODEL_TYPES.Vision2Seq:
case MODEL_TYPES.Musicgen:
this.can_generate = true;
this._forward = seq2seqForward;
this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.EncoderDecoder:
this._forward = seq2seqForward;
break;
case MODEL_TYPES.ImageTextToText:
this.can_generate = true;
this._forward = imageTextToTextForward;
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
break;
default:
// should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
break;
}
if (this.can_generate) {
this.forward_params.push('past_key_values');
}
/** @type {import('./configs.js').TransformersJSConfig} */
this.custom_config = this.config['transformers.js_config'] ?? {};
}
/**
* Disposes of all the ONNX sessions that were created during inference.
* @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
const promises = [];
for (const session of Object.values(this.sessions)) {
if (session?.handler?.dispose) {
promises.push(session.handler.dispose())
}
}
return await Promise.all(promises);
}
/**
* Instantiate one of the model classes of the library from a pretrained model.
*
* The model class to instantiate is selected based on the `model_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
subfolder = 'onnx',
device = null,
dtype = null,
use_external_data_format = null,
session_options = {},
} = {}) {
let options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
subfolder,
device,
dtype,
use_external_data_format,
session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'vision_encoder',
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
}, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
if (config.is_encoder_decoder) {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Musicgen) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
]);
}
// @ts-ignore
return new this(config, ...info);
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Object containing input tensors
* @returns {Promise<Object>} Object containing output tensors
*/
async _call(model_inputs) {
return await this.forward(model_inputs);
}
/**
* Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
* will be chosen based on the model type.
* @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
* @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model.
* @throws {Error} This method must be implemented in subclasses.
*/
async forward(model_inputs) {
return await this._forward(this, model_inputs);
}
/**
* Get the model's generation config, if it exists.
* @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
*/
get generation_config() {
return this.configs?.generation_config ?? null;
}
/**
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
* instances used for multinomial sampling.
* @param {GenerationConfig} generation_config The generation config.
* @returns {LogitsProcessorList} generation_config
*/
_get_logits_warper(generation_config) {
// instantiate warpers list
const warpers = new LogitsProcessorList();
if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
warpers.push(new TemperatureLogitsWarper(generation_config.temperature));
}
if (generation_config.top_k !== null && generation_config.top_k !== 0) {
// TODO: add min_tokens_to_keep
warpers.push(new TopKLogitsWarper(generation_config.top_k));
}
if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
// TODO: add min_tokens_to_keep
warpers.push(new TopPLogitsWarper(generation_config.top_p));
}
return warpers;
}
/**
* @param {GenerationConfig} generation_config
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
* @private
*/
_get_logits_processor(
generation_config,
input_ids_seq_length,
// encoder_input_ids, TODO
// prefix_allowed_tokens_fn, TODO
logits_processor = null
) {
const processors = new LogitsProcessorList();
// if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) {
// processors.push(new HammingDiversityLogitsProcessor(