7
7
8
8
import random
9
9
import warnings
10
+ from typing import Optional
10
11
11
12
import fire
12
13
import numpy as np
17
18
import torch .utils .data
18
19
from peft import PeftModel , get_peft_model
19
20
from torch .optim .lr_scheduler import StepLR
21
+ from transformers import AutoModelForCausalLM , AutoTokenizer
20
22
21
- from QEfficient .finetune .configs .training import train_config as TRAIN_CONFIG
23
+ from QEfficient .finetune .configs .peft_config import LoraConfig
24
+ from QEfficient .finetune .configs .training import TrainConfig
22
25
from QEfficient .finetune .utils .config_utils import (
23
26
generate_dataset_config ,
24
27
generate_peft_config ,
25
28
get_dataloader_kwargs ,
29
+ load_config_file ,
26
30
update_config ,
31
+ validate_config ,
27
32
)
28
33
from QEfficient .finetune .utils .dataset_utils import (
29
34
get_custom_data_collator ,
32
37
from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
33
38
from QEfficient .utils ._utils import login_and_download_hf_lm
34
39
40
+ # Try importing QAIC-specific module, proceed without it if unavailable
35
41
try :
36
42
import torch_qaic # noqa: F401
37
43
except ImportError as e :
38
- print (f"Warning: { e } . Moving ahead without these qaic modules." )
44
+ print (f"Warning: { e } . Proceeding without QAIC modules." )
39
45
46
+ # Suppress all warnings for cleaner output
47
+ warnings .filterwarnings ("ignore" )
40
48
41
- from transformers import AutoModelForCausalLM , AutoTokenizer
42
49
43
- # Suppress all warnings
44
- warnings . filterwarnings ( "ignore" )
50
+ def setup_distributed_training ( config : TrainConfig ) -> None :
51
+ """Initialize distributed training environment if enabled.
45
52
53
+ Args:
54
+ config (TrainConfig): Training configuration object.
46
55
47
- def main (** kwargs ):
56
+ Notes:
57
+ - If distributed data parallel (DDP) is disabled, this function does nothing.
58
+ - Ensures the device is not CPU and does not specify an index for DDP compatibility.
59
+ - Initializes the process group using the specified distributed backend.
60
+
61
+ Raises:
62
+ AssertionError: If device is CPU or includes an index with DDP enabled.
48
63
"""
49
- Helper function to finetune the model on QAic.
64
+ if not config .enable_ddp :
65
+ return
66
+
67
+ torch_device = torch .device (config .device )
68
+ assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
69
+ assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
70
+
71
+ dist .init_process_group (backend = config .dist_backend )
72
+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
73
+ getattr (torch , torch_device .type ).set_device (dist .get_rank ())
50
74
51
- .. code-block:: bash
52
75
53
- python -m QEfficient.cloud.finetune OPTIONS
76
+ def setup_seeds (seed : int ) -> None :
77
+ """Set random seeds across libraries for reproducibility.
54
78
79
+ Args:
80
+ seed (int): Seed value to set for random number generators.
81
+
82
+ Notes:
83
+ - Sets seeds for PyTorch, Python's random module, and NumPy.
55
84
"""
56
- # update the configuration for the training process
57
- train_config = TRAIN_CONFIG ()
58
- update_config (train_config , ** kwargs )
59
- device = train_config .device
85
+ torch .manual_seed (seed )
86
+ random .seed (seed )
87
+ np .random .seed (seed )
60
88
61
- # dist init
62
- if train_config .enable_ddp :
63
- # TODO: may have to init qccl backend, next try run with torchrun command
64
- torch_device = torch .device (device )
65
- assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
66
- assert torch_device .index is None , (
67
- f"DDP requires specification of device type only, however provided device index as well: { torch_device } "
68
- )
69
- dist .init_process_group (backend = train_config .dist_backend )
70
- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
71
- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
72
-
73
- # Set the seeds for reproducibility
74
- torch .manual_seed (train_config .seed )
75
- random .seed (train_config .seed )
76
- np .random .seed (train_config .seed )
77
-
78
- # Load the pre-trained model and setup its configuration
79
- # config = AutoConfig.from_pretrained(train_config.model_name)
80
- pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
89
+
90
+ def load_model_and_tokenizer (config : TrainConfig ) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
91
+ """Load the pre-trained model and tokenizer from Hugging Face.
92
+
93
+ Args:
94
+ config (TrainConfig): Training configuration object containing model and tokenizer names.
95
+
96
+ Returns:
97
+ tuple: A tuple containing the loaded model (AutoModelForCausalLM) and tokenizer (AutoTokenizer).
98
+
99
+ Notes:
100
+ - Downloads the model if not already cached using login_and_download_hf_lm.
101
+ - Configures the model with FP16 precision and disables caching for training.
102
+ - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
103
+ - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
104
+ """
105
+ pretrained_model_path = login_and_download_hf_lm (config .model_name )
81
106
model = AutoModelForCausalLM .from_pretrained (
82
107
pretrained_model_path ,
83
108
use_cache = False ,
84
109
attn_implementation = "sdpa" ,
85
110
torch_dtype = torch .float16 ,
86
111
)
87
112
88
- # Load the tokenizer and add special tokens
89
113
tokenizer = AutoTokenizer .from_pretrained (
90
- train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
114
+ config .model_name if config .tokenizer_name is None else config .tokenizer_name
91
115
)
92
116
if not tokenizer .pad_token_id :
93
117
tokenizer .pad_token_id = tokenizer .eos_token_id
94
118
95
- # If there is a mismatch between tokenizer vocab size and embedding matrix,
96
- # throw a warning and then expand the embedding matrix
97
119
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
98
- print ("WARNING: Resizing the embedding matrix to match the tokenizer vocab size." )
120
+ print ("WARNING: Resizing embedding matrix to match tokenizer vocab size." )
99
121
model .resize_token_embeddings (len (tokenizer ))
100
122
101
- print_model_size ( model , train_config )
123
+ return model , tokenizer
102
124
103
- # print the datatype of the model parameters
104
- # print(get_parameter_dtypes(model))
105
-
106
- if train_config .use_peft :
107
- # Load the pre-trained peft model checkpoint and setup its configuration
108
- if train_config .from_peft_checkpoint :
109
- model = PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
110
- peft_config = model .peft_config
111
- # Generate the peft config and start fine-tuning from original model
112
- else :
113
- peft_config = generate_peft_config (train_config , kwargs )
114
- model = get_peft_model (model , peft_config )
115
- model .print_trainable_parameters ()
116
-
117
- # Get the dataset utils
118
- dataset_config = generate_dataset_config (train_config , kwargs )
119
- dataset_processer = tokenizer
120
125
121
- # Load and preprocess the dataset for training and validation
122
- dataset_train = get_preprocessed_dataset (
123
- dataset_processer , dataset_config , split = "train" , context_length = train_config .context_length
124
- )
126
+ def apply_peft ( model : AutoModelForCausalLM , train_config : TrainConfig , lora_config : LoraConfig ) -> PeftModel :
127
+ """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled."""
128
+ if not train_config .use_peft :
129
+ return model
125
130
126
- dataset_val = get_preprocessed_dataset (
127
- dataset_processer , dataset_config , split = "test" , context_length = train_config .context_length
128
- )
131
+ if train_config .from_peft_checkpoint :
132
+ return PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
133
+
134
+ # Generate PEFT-compatible config from custom LoraConfig
135
+ peft_config = generate_peft_config (train_config , lora_config )
136
+ model = get_peft_model (model , peft_config )
137
+ model .print_trainable_parameters ()
138
+ return model
139
+
140
+
141
+ def setup_dataloaders (
142
+ train_config : TrainConfig , dataset_config , tokenizer : AutoTokenizer , dataset_train , dataset_val
143
+ ) -> tuple [torch .utils .data .DataLoader , Optional [torch .utils .data .DataLoader ]]:
144
+ """Set up training and validation DataLoaders.
145
+
146
+ Args:
147
+ train_config (TrainConfig): Training configuration object.
148
+ dataset_config: Configuration for the dataset (generated from train_config).
149
+ tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
150
+ dataset_train: Preprocessed training dataset.
151
+ dataset_val: Preprocessed validation dataset.
129
152
130
- # TODO: vbaddi, check if its necessary to do this?
131
- # dataset_train = ConcatDataset(
132
- # dataset_train, chunk_size=train_config.context_length
133
- # )
134
- ##
135
- train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
136
- print ("length of dataset_train" , len (dataset_train ))
137
- custom_data_collator = get_custom_data_collator (dataset_processer , dataset_config )
153
+ Returns:
154
+ tuple: A tuple of (train_dataloader, eval_dataloader), where eval_dataloader is None if validation is disabled.
155
+
156
+ Raises:
157
+ ValueError: If validation is enabled but the validation set is too small.
158
+
159
+ Notes:
160
+ - Applies a custom data collator if provided by get_custom_data_collator.
161
+ - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
162
+ """
163
+ custom_data_collator = get_custom_data_collator (tokenizer , dataset_config )
164
+ train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , tokenizer , "train" )
138
165
if custom_data_collator :
139
- print ("custom_data_collator is used" )
140
166
train_dl_kwargs ["collate_fn" ] = custom_data_collator
141
167
142
- # Create DataLoaders for the training and validation dataset
143
168
train_dataloader = torch .utils .data .DataLoader (
144
169
dataset_train ,
145
170
num_workers = train_config .num_workers_dataloader ,
@@ -150,12 +175,7 @@ def main(**kwargs):
150
175
151
176
eval_dataloader = None
152
177
if train_config .run_validation :
153
- # if train_config.batching_strategy == "packing":
154
- # dataset_val = ConcatDataset(
155
- # dataset_val, chunk_size=train_config.context_length
156
- # )
157
-
158
- val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , dataset_processer , "val" )
178
+ val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , tokenizer , "val" )
159
179
if custom_data_collator :
160
180
val_dl_kwargs ["collate_fn" ] = custom_data_collator
161
181
@@ -165,37 +185,90 @@ def main(**kwargs):
165
185
pin_memory = True ,
166
186
** val_dl_kwargs ,
167
187
)
188
+ print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
168
189
if len (eval_dataloader ) == 0 :
169
- raise ValueError (
170
- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
171
- )
172
- else :
173
- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
174
-
175
- longest_seq_length , _ = get_longest_seq_length (
176
- torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
177
- )
178
- else :
179
- longest_seq_length , _ = get_longest_seq_length (train_dataloader .dataset )
190
+ raise ValueError ("Eval set too small to load even one batch." )
191
+
192
+ return train_dataloader , eval_dataloader
193
+
180
194
195
+ def main (
196
+ model_name : str = None ,
197
+ tokenizer_name : str = None ,
198
+ batch_size_training : int = None ,
199
+ lr : float = None ,
200
+ peft_config_file : str = None ,
201
+ ** kwargs ,
202
+ ) -> None :
203
+ """
204
+ Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
205
+
206
+ Args:
207
+ model_name (str, optional): Override default model name.
208
+ tokenizer_name (str, optional): Override default tokenizer name.
209
+ batch_size_training (int, optional): Override default training batch size.
210
+ lr (float, optional): Override default learning rate.
211
+ peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config.
212
+ **kwargs: Additional arguments to override TrainConfig.
213
+
214
+ Example:
215
+ .. code-block:: bash
216
+
217
+ # Using a YAML config file for PEFT
218
+ python -m QEfficient.cloud.finetune \\
219
+ --model_name "meta-llama/Llama-3.2-1B" \\
220
+ --lr 5e-4 \\
221
+ --peft_config_file "lora_config.yaml"
222
+
223
+ # Using default LoRA config
224
+ python -m QEfficient.cloud.finetune \\
225
+ --model_name "meta-llama/Llama-3.2-1B" \\
226
+ --lr 5e-4
227
+ """
228
+ train_config = TrainConfig ()
229
+ # local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
230
+ update_config (train_config , ** kwargs )
231
+
232
+ lora_config = LoraConfig ()
233
+ if peft_config_file :
234
+ peft_config_data = load_config_file (peft_config_file )
235
+ validate_config (peft_config_data , config_type = "lora" )
236
+ lora_config = LoraConfig (** peft_config_data )
237
+
238
+ setup_distributed_training (train_config )
239
+ setup_seeds (train_config .seed )
240
+ model , tokenizer = load_model_and_tokenizer (train_config )
241
+ print_model_size (model , train_config )
242
+ model = apply_peft (model , train_config , lora_config )
243
+
244
+ # Pass an empty dict instead of kwargs to avoid irrelevant parameters
245
+ dataset_config = generate_dataset_config (train_config , kwargs )
246
+ dataset_train = get_preprocessed_dataset (
247
+ tokenizer , dataset_config , split = "train" , context_length = train_config .context_length
248
+ )
249
+ dataset_val = get_preprocessed_dataset (
250
+ tokenizer , dataset_config , split = "test" , context_length = train_config .context_length
251
+ )
252
+ train_dataloader , eval_dataloader = setup_dataloaders (
253
+ train_config , dataset_config , tokenizer , dataset_train , dataset_val
254
+ )
255
+ dataset_for_seq_length = (
256
+ torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
257
+ if train_config .run_validation
258
+ else train_dataloader .dataset
259
+ )
260
+ longest_seq_length , _ = get_longest_seq_length (dataset_for_seq_length )
181
261
print (
182
- f"The longest sequence length in the train data is { longest_seq_length } , "
183
- f"passed context length is { train_config .context_length } and overall model's context length is "
184
- f"{ model .config .max_position_embeddings } "
262
+ f"Longest sequence length: { longest_seq_length } , "
263
+ f"Context length: { train_config .context_length } , "
264
+ f"Model max context: { model .config .max_position_embeddings } "
185
265
)
186
266
model .to (train_config .device )
187
- optimizer = optim .AdamW (
188
- model .parameters (),
189
- lr = train_config .lr ,
190
- weight_decay = train_config .weight_decay ,
191
- )
267
+ optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
192
268
scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
193
-
194
- # wrap model with DDP
195
269
if train_config .enable_ddp :
196
270
model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
197
-
198
- _ = train (
271
+ train (
199
272
model ,
200
273
train_dataloader ,
201
274
eval_dataloader ,
@@ -208,8 +281,6 @@ def main(**kwargs):
208
281
dist .get_rank () if train_config .enable_ddp else None ,
209
282
None ,
210
283
)
211
-
212
- # finalize torch distributed
213
284
if train_config .enable_ddp :
214
285
dist .destroy_process_group ()
215
286
0 commit comments