-
Notifications
You must be signed in to change notification settings - Fork 145
/
Copy pathbuilder.py
3294 lines (2825 loc) · 189 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# Modifications Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved
"""
Run this script to create the desired ONNX model.
"""
from onnx import helper, numpy_helper, TensorProto, external_data_helper, save_model
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer, QuantFormat
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import numpy as np
import torch
import argparse
import gc
import json
import os
import textwrap
class Model:
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else self.context_length
self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel
self.intermediate_size = config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.multi_query_group_num if hasattr(config, "multi_query_group_num") else config.num_attention_heads
self.num_attn_heads = config.num_attention_heads
self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers if hasattr(config, "num_hidden_layers") else config.num_layers
self.vocab_size = config.vocab_size
self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act
self.model_name_or_path = config._name_or_path
self.model_type = config.architectures[0]
self.io_dtype = io_dtype # {'fp16', 'fp32'}
self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"}
self.quant_type = config.quantization_config["quant_method"] if hasattr(config, "quantization_config") else None
self.adapter_path = extra_options.get("adapter_path", None)
self.cache_dir = cache_dir
self.filename = extra_options.get("filename", "model.onnx")
self.hf_token = parse_hf_token(extra_options.get("hf_token", "true"))
self.extra_options = extra_options
self.inputs = []
self.outputs = []
self.initializers = []
self.value_infos = []
self.nodes = []
# EP-specific variables
self.ep = ep
self.ep_attrs = {
"cpu": {},
"cuda": {
"enable_cuda_graph": "1" if extra_options.get("enable_cuda_graph", False) else "0", # "1" if the model is able to enable cuda graph, "0" otherwise
},
"rocm": {
"tunable_op_enable": "1",
"tunable_op_tuning_enable": "1",
},
"dml": {},
"web": {},
}
# Map input names to their types and shapes
self.input_names = ["input_ids", "attention_mask", "position_ids"]
self.input_types = {
"input_ids": TensorProto.INT64, # For standard models
"attention_mask": TensorProto.INT64, # For standard models
"position_ids": TensorProto.INT64, # For standard models
"inputs_embeds": self.io_dtype, # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format)
"past_key_values.key": self.io_dtype, # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format)
"past_key_values.value": self.io_dtype, # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format)
}
self.input_shapes = {
"input_ids": ["batch_size", "sequence_length"], # For standard models
"attention_mask": ["batch_size", "total_sequence_length"], # For standard models
"position_ids": ["batch_size", "sequence_length"], # For standard models
"inputs_embeds": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format)
"past_key_values.key": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format)
"past_key_values.value": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format)
}
self.exclude_embeds = extra_options.get("exclude_embeds", False)
if self.exclude_embeds:
self.input_names = [name.replace("input_ids", "inputs_embeds") for name in self.input_names]
# Map output names to their types and shapes
self.output_names = ["logits"]
self.output_types = {
"hidden_states": self.io_dtype, # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format)
"logits": self.io_dtype, # For standard models
"present.key": self.io_dtype, # For standard models (note that `present.key` is written this way to match Hugging Face format)
"present.value": self.io_dtype, # For standard models (note that `present.value` is written this way to match Hugging Face format)
}
self.output_shapes = {
"hidden_states": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format)
"logits": ["batch_size", "sequence_length", self.vocab_size], # For standard models
"present.key": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.key` is written this way to match Hugging Face format)
"present.value": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.value` is written this way to match Hugging Face format)
}
self.exclude_lm_head = extra_options.get("exclude_lm_head", False)
self.include_hidden_states = extra_options.get("include_hidden_states", False)
if self.exclude_lm_head:
self.output_names = [name.replace("logits", "hidden_states") for name in self.output_names]
elif self.include_hidden_states:
self.output_names = ["hidden_states"] + self.output_names
# Store names of nodes already created
self.node_names = set()
# Map TensorProto dtypes to NumPy dtypes
self.to_numpy_dtype = {
TensorProto.INT8: np.uint8,
TensorProto.INT32: np.int32,
TensorProto.INT64: np.int64,
TensorProto.FLOAT16: np.float16,
TensorProto.FLOAT: np.float32,
}
# Map TensorProto dtypes to string dtypes
self.to_str_dtype = {
TensorProto.INT8: "TensorProto.INT8",
TensorProto.INT32: "TensorProto.INT32",
TensorProto.INT64: "TensorProto.INT64",
TensorProto.FLOAT16: "TensorProto.FLOAT16",
TensorProto.FLOAT: "TensorProto.FLOAT",
}
# Mask-specific variables
# TODO: Reconcile differences between `seqlens_k` and `key_total_seq_lens` in the GroupQueryAttention and SparseAttention implementations. Ideally the same subgraph can be shared for both.
self.mask_attrs = {
"mask_name": "", # Name of node that outputs 4D causal attention mask (used as add_qk in MultiHeadAttention)
"seqlens_k": "", # Sum of each row in attention mask - 1 (used as input to GroupQueryAttention)
"total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention and SparseAttention)
"block_row_indices": "", # Row indices of CSR format of block mask (used as input to SparseAttention)
"block_col_indices": "", # Col indices of CSR format of block mask (used as input to SparseAttention)
"key_total_seq_lens": "", # Sum of each row in attention mask (used as input to SparseAttention)
}
# Embedding-specific variables
self.embed_attrs = {
"scale": 1, # Scale value to multiply output of Embedding layer by
}
# LayerNorm-specific variables
epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06
self.layernorm_attrs = {
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
"first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms
"last_layernorm": False, # Last LayerNorm = SkipLayerNorm with only output 0 (no output 3)
"root_input": "", # Root input from parent node for LayerNorm and SkipLayerNorm
"skip_input": "", # Skip input from parent node for SkipLayerNorm
"output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
"output_3": "", # Output 3 for SkipLayerNorm
"add_offset": 0, # Offset value for LayerNorm weight
"epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm
}
# MatMul-specific variables
is_lora = hasattr(config, "peft_type") and config.peft_type == "LORA"
self.matmul_attrs = {
"use_lora": is_lora, # Use LoRA/QLoRA format
}
# RotaryEmbedding-specific variables
position_scale = config.rope_position_scale if hasattr(config, "rope_position_scale") else 1
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
rope_theta = config.rope_theta if hasattr(config, "rope_theta") else config.rope_embedding_base if hasattr(config, "rope_embedding_base") else 10000
self.rotemb_attrs = {
"create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings
"cache_length": self.context_length, # Cache length to use when creating cos/sin caches for rotary embeddings
"theta": rope_theta, # Base value if calculating cos/sin caches from scratch
"partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings
"interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0)
"num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rescale_factors": 1, # Rescale factors when calculating `inv_freq` in rotary embeddings
"t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings
"position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings
"mscale": 1, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings
"mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings
}
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
if "short_factor" in config.rope_scaling:
# For models with multiple rotary embedding caches (e.g. Phi-3 mini 128K)
self.rotemb_attrs["mscale_policy"] = config.rope_scaling["type"]
short_factor = torch.tensor(config.rope_scaling["short_factor"], dtype=torch.float32)
long_factor = torch.tensor(config.rope_scaling["long_factor"], dtype=torch.float32)
short_mscale = config.rope_scaling["short_mscale"] if "short_mscale" in config.rope_scaling else 0
long_mscale = config.rope_scaling["long_mscale"] if "long_mscale" in config.rope_scaling else 0
short_mscale = short_mscale if short_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length)
long_mscale = long_mscale if long_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length)
self.rotemb_attrs["multi_cache"] = {
"short_factor": short_factor, # Short factor when calculating `inv_freq` in rotary embeddings
"long_factor": long_factor, # Long factor when calculating `inv_freq` in rotary embeddings
"short_mscale": short_mscale, # Magnitude scaling for short factor when scaling `emb.cos()/emb.sin()` in rotary embeddings
"long_mscale": long_mscale, # Magnitude scaling for long factor when scaling `emb.cos()/emb.sin()` in rotary embeddings
}
elif "low_freq_factor" in config.rope_scaling:
# For models that rescale `inv_freq` using `low_freq_factor` and `high_freq_factor` (e.g. LLaMA-3.1)
factor = config.rope_scaling["factor"] if "factor" in config.rope_scaling else 0
low_freq_factor = config.rope_scaling["low_freq_factor"] if "low_freq_factor" in config.rope_scaling else 0
high_freq_factor = config.rope_scaling["high_freq_factor"] if "high_freq_factor" in config.rope_scaling else 0
self.rotemb_attrs["rescale_inv_freq"] = {
"factor": factor, # Scale factor when calculating `new_freq` in rotary embeddings
"low_freq_factor": low_freq_factor, # Low freq factor when calculating `low_freq_wavelen` in rotary embeddings
"high_freq_factor": high_freq_factor, # High freq factor when calculating `high_freq_wavelen` in rotary embeddings
}
# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") else 0.0 # default is 0.0 in GroupQueryAttention kernel
# Block-sparse attention-specific variables
sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0
kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0
local_blocks = config.blocksparse_num_local_blocks if hasattr(config, "blocksparse_num_local_blocks") else 0
vert_block_stride = config.blocksparse_vert_stride if hasattr(config, "blocksparse_vert_stride") else 0
homo_head = config.blocksparse_homo_head_pattern if hasattr(config, "blocksparse_homo_head_pattern") else False
self.attention_attrs = {
"q_path": "", # Q path to attention
"k_path": "", # K path to attention
"v_path": "", # V path to attention
"op_type": "MultiHeadAttention", # Attention op to use
"scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention
"softcap": softcap, # Softcap value to prevent values from exploding in attention
"use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
"block_sparse": { # Block-sparse attention-specific variables
"sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op
"kernel_block_size": kernel_block_size, # Kernel block size for sparse attention
"local_blocks": local_blocks, # Number of local blocks for sparse attention
"vert_stride": vert_block_stride, # Vertical stride to use for sparse attention
"homo_head": homo_head, # Use homo head pattern for sparse attention
}
}
valid_gqa_configurations = [
("cpu", TensorProto.FLOAT),
("cuda", TensorProto.FLOAT16),
("rocm", TensorProto.FLOAT16),
("dml", TensorProto.FLOAT16),
]
if (self.ep, self.io_dtype) in valid_gqa_configurations:
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model.")
# DML doesn't support packed Q/K/V for GQA yet
# Packed MatMul with LoRA/QLoRA is not currently supported
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and not self.matmul_attrs["use_lora"]
# GQA + Rot.Emb. does not require `position ids` as input
if self.ep != "dml":
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")
self.past_present_share_buffer = self.attention_attrs["op_type"] == "GroupQueryAttention"
# MLP-specific variables
self.mlp_attrs = {
"use_proj": True, # Use projection style for MLP (GateProj/UpProj/DownProj)
"use_fc": False, # Use fully-connected style for MLP (FC1/FC2)
"output_0": "", # Output 0 for MLP layer
}
# MoE-specific variables
num_experts = config.num_experts if hasattr(config, "num_experts") else 1
self.moe_attrs = {
"num_experts": num_experts, # Number of experts in MoE layer
"top_k": 1, # Number of experts to select in MoE layer
"activation_type": "relu", # Activation function for MoE layer
"normalize_routing_weights": False, # Normalize routing weights in MoE layer
"use_sparse_mixer": False, # Use SparseMixer in MoE layer. Used in Phi3 MoE
"use_int4": True, # Use INT4 quantization in MoE layer, otherwise use INT8
}
# LM head-specific variables
self.lm_head_attrs = {
"scale": 1, # Scale value to multiply output of LM head by
"mask": None, # LM head mask for tokens in the vocabulary
}
if hasattr(config, "dummy_token_indices"):
# Create LM head mask for tokens in the vocabulary
dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
dummy_tokens_mask[config.dummy_token_indices] = True
self.lm_head_attrs["mask"] = dummy_tokens_mask
# Quantization-specific variables (INT4, INT8, etc.)
self.quant_attrs = {
"int4": {
"accuracy_level": int(extra_options.get("int4_accuracy_level", 0)), # Default is 0 for non-QDQ formats, default is 4 for QDQ formats
"block_size": int(extra_options.get("int4_block_size", 32)),
"is_symmetric": extra_options.get("int4_is_symmetric", True),
"op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul", )),
},
"use_qdq": extra_options.get("use_qdq", False), # Use QDQ format
}
if self.quant_type is not None:
# Create quantized attributes from quantization config
self.quant_attrs["bits"] = config.quantization_config["bits"]
self.quant_attrs["group_size"] = config.quantization_config["group_size"]
self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False
def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
try:
config = GenerationConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs)
except:
config = AutoConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs)
inputs = dict(zip(self.input_names, self.input_names))
inputs.update({
"past_key_names": "past_key_values.%d.key",
"past_value_names": "past_key_values.%d.value",
})
outputs = dict(zip(self.output_names, self.output_names))
outputs.update({
"present_key_names": "present.%d.key",
"present_value_names": "present.%d.value",
})
if "hidden_states" in outputs:
# Remove 'hidden_states' from 'outputs' entry in config since ORT GenAI doesn't use it
del outputs["hidden_states"]
genai_config = {
"model": {
"bos_token_id": config.bos_token_id if hasattr(config, "bos_token_id") else 1, # config.bos_token_id not present in ChatGLM model configs.
"context_length": self.context_length,
"decoder": {
"session_options" : {
"log_id": "onnxruntime-genai",
"provider_options" : [],
},
"filename": self.filename,
"head_size": self.head_size,
"hidden_size": self.hidden_size,
"inputs": inputs,
"outputs": outputs,
"num_attention_heads": self.num_attn_heads,
"num_hidden_layers": self.num_layers,
"num_key_value_heads": self.num_kv_heads,
},
"eos_token_id": config.eos_token_id,
"pad_token_id": config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id is not None else config.eos_token_id[0] if isinstance(config.eos_token_id, list) else config.eos_token_id,
"type": self.model_type[ : self.model_type.find("For")].lower(),
"vocab_size": self.vocab_size,
},
"search": {
"diversity_penalty": config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0,
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
"early_stopping": True,
"length_penalty": config.length_penalty if hasattr(config, "length_penalty") else 1.0,
"max_length": self.context_length,
"min_length": 0,
"no_repeat_ngram_size": config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0,
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
"num_return_sequences": config.num_return_sequences if hasattr(config, "num_return_sequences") else 1,
"past_present_share_buffer": False if "config_only" in self.extra_options else self.past_present_share_buffer,
"repetition_penalty": config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0,
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
"top_k": 1,
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
},
}
if self.ep != "cpu":
ep_options = { self.ep : self.ep_attrs[self.ep] }
genai_config["model"]["decoder"]["session_options"]["provider_options"].append(ep_options)
if self.extra_options.get("include_prompt_templates", False):
prompt_templates = self._get_prompt_templates(model_name_or_path, extra_kwargs)
if prompt_templates is not None:
genai_config["model"]["prompt_templates"] = prompt_templates
print(f"Saving GenAI config in {out_dir}")
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
json.dump(genai_config, f, indent=4)
def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs)
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)
def _get_prompt_templates(self, hf_name, extra_kwargs):
try:
# disable end of sentence padding with eos_token=None
tokenizer = AutoTokenizer.from_pretrained(hf_name, token=self.hf_token, trust_remote_code=True, eos_token=None, **extra_kwargs)
system_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}], tokenize=False)
system_user_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}, {'role': 'user', 'content': '{Content}'}], tokenize=False)
system_user_assistant_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}, {'role': 'user', 'content': '{Content}'}, {'role': 'assistant', 'content': '{Content}'}], tokenize=False)
assert system_user_template.startswith(system_template), "Chat templates may contain padding tokens, leading to incorrect prompt templates"
assert system_user_assistant_template.startswith(system_user_template), "Chat templates may contain padding tokens, leading to incorrect prompt templates"
user_template = system_user_template[len(system_template):]
assistant_template = system_user_assistant_template[len(system_user_template):]
prompt_template = system_user_assistant_template[len(system_template):]
prompt_template = prompt_template[:prompt_template.rfind('{Content}')]
templates = {
"system": system_template,
"user": user_template,
"assistant": assistant_template,
"prompt": prompt_template
}
return templates
except Exception as e:
print(f"Failed to get prompt templates. Error: {e}")
return None
def save_model(self, out_dir):
print(f"Saving ONNX model in {out_dir}")
gc.collect()
# Create ONNX model
model = helper.make_model(
opset_imports=[self.clear_field(helper.make_operatorsetid('', 21 if self.quant_attrs["use_qdq"] else 14), 'domain'), helper.make_operatorsetid('com.microsoft', 1)],
ir_version=7,
producer_name="onnxruntime-genai",
producer_version="0.0.0",
graph=self.make_graph(
name="main_graph",
inputs=self.inputs,
outputs=self.outputs,
initializer=self.initializers,
value_info=self.value_infos,
nodes=self.nodes,
)
)
# Load external data into ONNX model
external_data_helper.load_external_data_for_model(model, self.cache_dir)
# Delete external data files on disk before re-saving
for path in os.listdir(self.cache_dir):
if path.endswith(".bin"):
os.remove(os.path.join(self.cache_dir, path))
# Delete temporary cache dir if empty
if len(os.listdir(self.cache_dir)) == 0:
os.rmdir(self.cache_dir)
# Quantize ONNX model to desired precision
# TODO: Replace by quantizing the MatMuls as they are created
already_quantized_in_qdq_format = self.quant_type is not None and self.quant_attrs["use_qdq"] # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format:
model = self.to_int4(model)
# Save ONNX model with only one external data file and delete any existing duplicate copies
out_path = os.path.join(out_dir, self.filename)
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
if os.path.exists(out_path):
print(f"Overwriting {out_path}")
os.remove(out_path)
if os.path.exists(data_path):
print(f"Overwriting {data_path}")
os.remove(data_path)
save_model(
model,
out_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=os.path.basename(data_path),
size_threshold=0,
convert_attribute=False,
)
def to_int4(self, model):
quant = MatMul4BitsQuantizer(
model=model,
block_size=self.quant_attrs["int4"]["block_size"],
is_symmetric=self.quant_attrs["int4"]["is_symmetric"],
accuracy_level=self.quant_attrs["int4"]["accuracy_level"],
nodes_to_exclude=[],
quant_format=QuantFormat.QDQ if self.quant_attrs["use_qdq"] else QuantFormat.QOperator,
op_types_to_quantize=self.quant_attrs["int4"]["op_types_to_quantize"],
)
quant.process()
return quant.model.model
def clear_field(self, proto, field):
proto.ClearField(field)
return proto
def order_repeated_field(self, repeated_proto, key_name, order):
order = list(order)
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
def make_external_tensor(self, np_data, name, unpack_int4=False, **kwargs):
tensor = numpy_helper.from_array(np_data)
tensor.name = name
filename = f"{name}.bin"
external_data_helper.set_external_data(tensor, location=filename)
with open(os.path.join(self.cache_dir, filename), "wb") as f:
f.write(tensor.raw_data)
tensor.ClearField("raw_data")
tensor.data_location = TensorProto.EXTERNAL
if unpack_int4 and self.onnx_dtype == 'int4':
tensor.data_type = TensorProto.UINT4
tensor.dims[-1] *= 2
self.initializers.append(tensor)
def make_node(self, op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
# Save any constants as nodes
for input_name in inputs:
if input_name.startswith("/model/constants") and input_name not in self.node_names:
self.make_constant(input_name)
# Make node only if it does not already exist
if name not in self.node_names:
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
if doc_string == '':
node.doc_string = ''
self.order_repeated_field(node.attribute, 'name', kwargs.keys())
self.nodes.append(node)
self.node_names.add(name)
# Note:
#
# The above approach allows functions that make similar subgraphs with the same naming schema
# to share existing nodes without needing to know whether the nodes already exist or not
# (e.g. attention mask subgraphs).
#
# This means that the nodes can be created in those functions regardless of their actual
# status in the graph. The above checks can then decide whether the proposed node actually
# needs to be added into the graph or not.
def make_value_info(self, name, dtype, shape):
value_info = helper.make_tensor_value_info(name, dtype, shape=shape)
self.value_infos.append(value_info)
def make_graph(self, *args, doc_string=None, **kwargs):
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
if doc_string == '':
graph.doc_string = ''
return graph
def make_inputs_and_outputs(self):
# Add model-specific inputs to list of model inputs
inputs = []
for name in self.input_names:
dtype = self.input_types[name]
shape = self.input_shapes[name]
inputs.append(helper.make_tensor_value_info(name, dtype, shape=shape))
# Add model-specific outputs to list of model outputs
outputs = []
for name in self.output_names:
dtype = self.output_types[name]
shape = self.output_shapes[name]
outputs.append(helper.make_tensor_value_info(name, dtype, shape=shape))
# Add KV cache to inputs and outputs
for i in range(self.num_layers):
# Add KV cache to inputs
key_name = f"past_key_values.{i}.key"
inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"]))
value_name = f"past_key_values.{i}.value"
inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"]))
# Add KV cache to outputs
key_name = f"present.{i}.key"
outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"]))
value_name = f"present.{i}.value"
outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"]))
self.inputs = inputs
self.outputs = outputs
def make_constant(self, name):
# Make constant ops for 0, 1, 2, 3, etc.
# Format of name is "/model/constants/{dtype}/{shape}/{num}"
path = name.split("/")
onnx_dtype, dims, num = eval(path[-3]), path[-2], eval(path[-1])
np_dtype = self.to_numpy_dtype[onnx_dtype]
value = numpy_helper.from_array(np.array(num if dims == "0D" else list(num) if type(num) == tuple else [num], dtype=np_dtype), name=name.replace("constants", "numpy_helper"))
node_name = name.replace("constants", "constant_nodes")
self.make_node("Constant", inputs=[], outputs=[name], name=node_name, value=value)
self.make_value_info(name, onnx_dtype, shape=[])
self.node_names.add(name)
def make_gather(self, name, inputs, axis):
output = f"{name}/output_0"
self.make_node("Gather", inputs=inputs, outputs=[output], name=name, axis=axis)
self.make_value_info(output, TensorProto.INT64, shape=[])
def make_reshape(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Reshape", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_shape(self, name, root_input, shape):
output = f"{name}/output_0"
self.make_node("Shape", inputs=[root_input], outputs=[output], name=name)
self.make_value_info(output, TensorProto.INT64, shape=shape)
def make_constant_of_shape(self, name, root_input, value, dtype, shape):
output = f"{name}/output_0"
self.make_node("ConstantOfShape", inputs=[root_input], outputs=[output], name=name, value=value)
self.make_value_info(output, dtype, shape=shape)
def make_unsqueeze(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Unsqueeze", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_squeeze(self, name, inputs):
output = f"{name}/output_0"
self.make_node("Squeeze", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.INT64, shape=[])
def make_concat(self, name, inputs, dtype, shape, axis=0):
output = f"{name}/output_0"
self.make_node("Concat", inputs=inputs, outputs=[output], name=name, axis=axis)
self.make_value_info(output, dtype, shape=shape)
def make_tile(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Tile", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_equal(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("Equal", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)
def make_greater(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("Greater", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)
def make_greater_or_equal(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("GreaterOrEqual", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)
def make_isinf(self, name, root_input, shape):
output = f"{name}/output_0"
self.make_node("IsInf", inputs=[root_input], outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)
def make_clip(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Clip", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_where(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Where", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_expand(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Expand", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_reduce_sum(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_reduce_max(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("ReduceMax", inputs=inputs, outputs=[output], name=name, keepdims=False)
self.make_value_info(output, dtype, shape=shape)
def make_cast(self, name, root_input, dtype, shape):
output = f"{name}/output_0"
self.make_node("Cast", inputs=[root_input], outputs=[output], name=name, to=dtype)
self.make_value_info(output, dtype, shape=shape)
def make_add(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Add", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_sub(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Sub", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_less(self, name, inputs):
output = f"{name}/output_0"
self.make_node("Less", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=None)
def make_range(self, name, inputs):
output = f"{name}/output_0"
self.make_node("Range", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.INT64, shape=["unk"])
def make_slice(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Slice", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_mul(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Mul", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_transpose(self, name, root_input, dtype, shape, perm):
output = f"{name}/output_0"
self.make_node("Transpose", inputs=[root_input], outputs=[output], name=name, perm=perm)
self.make_value_info(output, dtype, shape=shape)
def make_div(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("Div", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_tanh(self, name, root_input, dtype, shape):
output = f"{name}/output_0"
self.make_node("Tanh", inputs=[root_input], outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)
def make_matmul(self, matmul, basename, root_input, **kwargs):
if hasattr(matmul, "base_layer"):
# For LoRA `MatMul`
return self.make_matmul_lora(matmul, basename, root_input, **kwargs)
else:
# For regular `MatMul`
return self.make_matmul_op(matmul, basename, root_input, **kwargs)
def make_matmul_op(self, matmul, basename, root_input, **kwargs):
if self.onnx_dtype in {"fp16", "fp32"}:
return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs)
elif self.onnx_dtype == "int4":
if self.quant_attrs["use_qdq"]:
return self.make_matmul_int4_qdq(matmul, basename, root_input, **kwargs)
else:
return self.make_matmul_int4(matmul, basename, root_input, **kwargs)
else:
raise NotImplementedError(f"The {self.onnx_dtype} precision is not currently supported.")
def make_matmul_fp16_or_fp32(self, matmul, name, root_input, **kwargs):
weight = name[1:].replace("/", ".") + ".weight"
self.make_external_tensor(matmul.weight.detach().numpy().transpose().astype(self.to_numpy_dtype[self.io_dtype]), weight)
last_dim = matmul.weight.shape[0]
output = "logits" if kwargs.get("logits", False) else f"{name}/output_0"
self.make_node("MatMul", inputs=[root_input, weight], outputs=[output], name=name)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', last_dim])
return name
def make_matmul_int4(self, matmul, basename, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul numpy weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs)
name = f"{basename}NBits"
# Input weights are quantized, save quantized MatMul numpy weights for onnx model
weight = name[1:].replace("/", ".") + ".qweight"
self.make_external_tensor(matmul.qweight.detach().numpy(), weight)
scales = name[1:].replace("/", ".") + ".scales"
self.make_external_tensor(matmul.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), scales)
inputs = [root_input, weight, scales]
if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
zeros = name[1:].replace("/", ".") + ".qzeros"
self.make_external_tensor(matmul.qzeros.detach().numpy(), zeros)
inputs.append(zeros)
if hasattr(matmul, "g_idx") and matmul.g_idx is not None:
g_idx = name[1:].replace("/", ".") + ".g_idx"
self.make_external_tensor(matmul.g_idx.detach().numpy().astype(np.int32), g_idx)
inputs.append(g_idx)
output = "logits" if kwargs.get("logits", False) else f"{name}/output_0"
self.make_node(
"MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft",
accuracy_level=self.quant_attrs["int4"]["accuracy_level"],
bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features,
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])
return name
def make_dequantize_linear(self, dequantize_name, quantized_op):
# Input weights are quantized, save quantized MatMul numpy weights for onnx model
qweight = dequantize_name[1:].replace("/", ".") + ".qweight"
qweight_npy = quantized_op.qweight.detach().numpy()
qweight_npy = qweight_npy.reshape(*qweight_npy.shape[:-2], qweight_npy.shape[-2] * qweight_npy.shape[-1])
self.make_external_tensor(qweight_npy, qweight, True)
scales = dequantize_name[1:].replace("/", ".") + ".scales"
scales_npy = quantized_op.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype])
scales_npy = scales_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] * 2 // quantized_op.group_size)
self.make_external_tensor(scales_npy, scales)
dequantize_inputs = [qweight, scales]
if hasattr(quantized_op, "qzeros") and quantized_op.qzeros is not None:
zeros = dequantize_name[1:].replace("/", ".") + ".qzeros"
zeros_npy = quantized_op.qzeros.detach().numpy()
zeros_npy = zeros_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] // quantized_op.group_size)
self.make_external_tensor(zeros_npy, zeros, True)
dequantize_inputs.append(zeros)
dequantize_output = f"{dequantize_name}/output_0"
self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=quantized_op.group_size, axis=-1)
self.make_value_info(dequantize_output, self.io_dtype, shape=[*scales_npy.shape[:-1], scales_npy.shape[-1] * quantized_op.group_size])
return dequantize_output
def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul numpy weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_fp16_or_fp32(matmul, matmul_name, root_input, **kwargs)
dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul)
# Add a transpose instead of transposing the weights offline. The reason for this is that it is more natural and usually more performant to
# compute quantized matmul when the weights are transposed. In most implementations, the transpose should usually be converted to a "transposeB"
# attribute on the MatMul itself. A more natural way to represent this would have been to use Gemm since it already supports a transB attribute,
# but unfortunately Gemm doesn't support batches.
qweight_shape = matmul.qweight.detach().numpy().shape
transposed_shape = [qweight_shape[1] * qweight_shape[2] * 2, qweight_shape[0]]
transpose_name = f"{matmul_name}/Transpose"
self.make_transpose(transpose_name, dequantize_output, self.io_dtype, transposed_shape, [1, 0])
matmul_output = "logits" if kwargs.get("logits", False) else f"{matmul_name}/output_0"
self.make_node("MatMul", inputs=[root_input, f"{transpose_name}/output_0"], outputs=[matmul_output], name=matmul_name)
self.make_value_info(matmul_output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])
return matmul_name
def make_matmul_lora(self, matmul, basename, root_input, **kwargs):
# Make nodes for the MatMul-LoRA subgraph
#
# root_input
# |
# +------+------+
# | |
# MatMul_LoRA_A MatMul
# | |
# MatMul_LoRA_B |
# | |
# +------+------+
# |
# Add_LoRA_Add
basename_parts = basename.split("/")
# Make LoRA MatMul path
matmul_A_basename = "/".join(basename_parts[:-1] + ["lora_A"] + basename_parts[-1:])
matmul_A_name = self.make_matmul_op(matmul.lora_A.default, matmul_A_basename, root_input=root_input)
lora_A = f"{matmul_A_name}/output_0"
matmul.lora_B.default.weight *= matmul.scaling["default"]
matmul_B_basename = "/".join(basename_parts[:-1] + ["lora_B"] + basename_parts[-1:])
matmul_B_name = self.make_matmul_op(matmul.lora_B.default, matmul_B_basename, root_input=lora_A)
lora_B = f"{matmul_B_name}/output_0"
# Make regular MatMul path
last_dim = matmul.base_layer.weight.shape[0]
matmul_name = self.make_matmul_op(matmul.base_layer, basename, root_input, **kwargs)
# Make LoRA Add node
add_name = "/".join(basename_parts[:-1] + ["lora", "Add"])
add_inputs = [f"{matmul_name}/output_0", lora_B]
add_shape = ["batch_size", "sequence_length", last_dim]
self.make_add(add_name, add_inputs, dtype=self.io_dtype, shape=add_shape)
return add_name
def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs):
if self.onnx_dtype in {"fp16", "fp32"}:
return self.make_packed_matmul_fp16_or_fp32(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
elif self.onnx_dtype == "int4":
return self.make_packed_matmul_int4(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
else:
raise NotImplementedError(f"The {self.onnx_dtype} precision is not currently supported.")
def make_packed_matmul_fp16_or_fp32(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs):
# N_q = num_attention_heads * head_size, N_kv = num_key_value_heads * head_size, H = hidden_size
# Combine 3 MatMuls of shape N_q x H, N_kv x H, N_kv x H into 1 packed MatMul of shape (N_q+N_kv+N_kv)xH
# Note: Packed MatMul is of shape (N_q+N_kv+N_kv)xH instead of Hx(N_q+N_kv+N_kv) because `make_matmul` will
# apply a transpose before saving
N_q, H = q_matmul.weight.shape
N_kv, _ = k_matmul.weight.shape
# Create dummy PackedMatMul class
class PackedMatMul:
def __init__(self):
self.weight = torch.concatenate([q_matmul.weight.detach().cpu(), k_matmul.weight.detach().cpu(), v_matmul.weight.detach().cpu()], dim=0).reshape(N_q + N_kv + N_kv, H)
matmul = PackedMatMul()
new_name = self.make_matmul(matmul, name, root_input, **kwargs)
return new_name
def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs):
if not hasattr(q_matmul, "qweight"):
# TODO: quantize weights, then save new MatMul numpy weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_packed_matmul_fp16_or_fp32(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
name = f"{basename}NBits"
# Create dummy PackedMatMul class
class PackedMatMul:
def __init__(self):
self.qweight = torch.concatenate([q_matmul.qweight.detach().cpu(), k_matmul.qweight.detach().cpu(), v_matmul.qweight.detach().cpu()], dim=0)
self.scales = torch.concatenate([q_matmul.scales.detach().cpu(), k_matmul.scales.detach().cpu(), v_matmul.scales.detach().cpu()], dim=0)
self.qzeros = torch.concatenate([q_matmul.qzeros.detach().cpu(), k_matmul.qzeros.detach().cpu(), v_matmul.qzeros.detach().cpu()], dim=0)
self.g_idx = q_matmul.g_idx
self.in_features = q_matmul.in_features
self.out_features = q_matmul.out_features + k_matmul.out_features + v_matmul.out_features
self.bits = q_matmul.bits
self.group_size = q_matmul.group_size
matmul = PackedMatMul()
# Input weights are quantized, save quantized MatMul numpy weights for onnx model
weight = name[1:].replace("/", ".") + ".qweight"
self.make_external_tensor(matmul.qweight.detach().numpy(), weight)
scales = name[1:].replace("/", ".") + ".scales"
self.make_external_tensor(matmul.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), scales)
inputs = [root_input, weight, scales]
if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
zeros = name[1:].replace("/", ".") + ".qzeros"
self.make_external_tensor(matmul.qzeros.detach().numpy(), zeros)
inputs.append(zeros)
if hasattr(matmul, "g_idx") and matmul.g_idx is not None:
g_idx = name[1:].replace("/", ".") + ".g_idx"
self.make_external_tensor(matmul.g_idx.detach().numpy().astype(np.int32), g_idx)
inputs.append(g_idx)
output = "logits" if kwargs.get("logits", False) else f"{name}/output_0"
self.make_node(
"MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft",
accuracy_level=self.quant_attrs["int4"]["accuracy_level"],
bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features,
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])
return name
def make_add_bias(self, add, name, root_input, **kwargs):
bias = name[1:].replace("/", ".") + ".bias"
self.make_external_tensor(add.astype(self.to_numpy_dtype[self.io_dtype]), bias)
add_bias_inputs = [root_input, bias]
shape = ['batch_size', 'sequence_length', add.shape[0]]
if "logits" in kwargs:
output = "logits"
self.make_node("Add", inputs=add_bias_inputs, outputs=[output], name=name)
self.make_value_info(output, dtype=self.io_dtype, shape=shape)
else:
self.make_add(name, add_bias_inputs, dtype=self.io_dtype, shape=shape)
def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):
# Combine 3 Adds of shape N_q, N_kv, and N_kv into 1 packed Add of shape N_q + N_kv + N_kv
add = np.concatenate([q_add, k_add, v_add], axis=0).flatten()
self.make_add_bias(add, name, root_input, **kwargs)
def make_embedding(self, embedding):
weight = "model.embed_tokens.weight"
self.make_external_tensor(embedding.astype(self.to_numpy_dtype[self.io_dtype]), weight)
basename = "/model/embed_tokens"
gather_name = f"{basename}/Gather"
gather_output = f"{gather_name}/output_0"
self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name)
self.make_value_info(gather_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
if self.embed_attrs["scale"] != 1:
# Scale the embeddings
mul_name = f"{basename}/Mul"
mul_inputs = [gather_output, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.embed_attrs['scale']}"]
mul_output = f"{mul_name}/output_0"
self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name)
self.make_value_info(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
layernorm_attrs_value = mul_output
else:
layernorm_attrs_value = gather_output
self.layernorm_attrs["root_input"] = layernorm_attrs_value
self.layernorm_attrs["skip_input"] = layernorm_attrs_value