Skip to content

Commit c0d2315

Browse files
vbaddimeetkuma
authored and
meetkuma
committed
refactor the finetune main __call__
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
1 parent 7b64b33 commit c0d2315

File tree

8 files changed

+385
-139
lines changed

8 files changed

+385
-139
lines changed

QEfficient/cloud/finetune.py

+175-104
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import random
99
import warnings
10+
from typing import Optional
1011

1112
import fire
1213
import numpy as np
@@ -17,13 +18,17 @@
1718
import torch.utils.data
1819
from peft import PeftModel, get_peft_model
1920
from torch.optim.lr_scheduler import StepLR
21+
from transformers import AutoModelForCausalLM, AutoTokenizer
2022

21-
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
23+
from QEfficient.finetune.configs.peft_config import LoraConfig
24+
from QEfficient.finetune.configs.training import TrainConfig
2225
from QEfficient.finetune.utils.config_utils import (
2326
generate_dataset_config,
2427
generate_peft_config,
2528
get_dataloader_kwargs,
29+
load_config_file,
2630
update_config,
31+
validate_config,
2732
)
2833
from QEfficient.finetune.utils.dataset_utils import (
2934
get_custom_data_collator,
@@ -32,114 +37,134 @@
3237
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3338
from QEfficient.utils._utils import login_and_download_hf_lm
3439

40+
# Try importing QAIC-specific module, proceed without it if unavailable
3541
try:
3642
import torch_qaic # noqa: F401
3743
except ImportError as e:
38-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
44+
print(f"Warning: {e}. Proceeding without QAIC modules.")
3945

46+
# Suppress all warnings for cleaner output
47+
warnings.filterwarnings("ignore")
4048

41-
from transformers import AutoModelForCausalLM, AutoTokenizer
4249

43-
# Suppress all warnings
44-
warnings.filterwarnings("ignore")
50+
def setup_distributed_training(config: TrainConfig) -> None:
51+
"""Initialize distributed training environment if enabled.
4552
53+
Args:
54+
config (TrainConfig): Training configuration object.
4655
47-
def main(**kwargs):
56+
Notes:
57+
- If distributed data parallel (DDP) is disabled, this function does nothing.
58+
- Ensures the device is not CPU and does not specify an index for DDP compatibility.
59+
- Initializes the process group using the specified distributed backend.
60+
61+
Raises:
62+
AssertionError: If device is CPU or includes an index with DDP enabled.
4863
"""
49-
Helper function to finetune the model on QAic.
64+
if not config.enable_ddp:
65+
return
66+
67+
torch_device = torch.device(config.device)
68+
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
69+
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
70+
71+
dist.init_process_group(backend=config.dist_backend)
72+
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
73+
getattr(torch, torch_device.type).set_device(dist.get_rank())
5074

51-
.. code-block:: bash
5275

53-
python -m QEfficient.cloud.finetune OPTIONS
76+
def setup_seeds(seed: int) -> None:
77+
"""Set random seeds across libraries for reproducibility.
5478
79+
Args:
80+
seed (int): Seed value to set for random number generators.
81+
82+
Notes:
83+
- Sets seeds for PyTorch, Python's random module, and NumPy.
5584
"""
56-
# update the configuration for the training process
57-
train_config = TRAIN_CONFIG()
58-
update_config(train_config, **kwargs)
59-
device = train_config.device
85+
torch.manual_seed(seed)
86+
random.seed(seed)
87+
np.random.seed(seed)
6088

61-
# dist init
62-
if train_config.enable_ddp:
63-
# TODO: may have to init qccl backend, next try run with torchrun command
64-
torch_device = torch.device(device)
65-
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
66-
assert torch_device.index is None, (
67-
f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
68-
)
69-
dist.init_process_group(backend=train_config.dist_backend)
70-
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
71-
getattr(torch, torch_device.type).set_device(dist.get_rank())
72-
73-
# Set the seeds for reproducibility
74-
torch.manual_seed(train_config.seed)
75-
random.seed(train_config.seed)
76-
np.random.seed(train_config.seed)
77-
78-
# Load the pre-trained model and setup its configuration
79-
# config = AutoConfig.from_pretrained(train_config.model_name)
80-
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
89+
90+
def load_model_and_tokenizer(config: TrainConfig) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
91+
"""Load the pre-trained model and tokenizer from Hugging Face.
92+
93+
Args:
94+
config (TrainConfig): Training configuration object containing model and tokenizer names.
95+
96+
Returns:
97+
tuple: A tuple containing the loaded model (AutoModelForCausalLM) and tokenizer (AutoTokenizer).
98+
99+
Notes:
100+
- Downloads the model if not already cached using login_and_download_hf_lm.
101+
- Configures the model with FP16 precision and disables caching for training.
102+
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
103+
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
104+
"""
105+
pretrained_model_path = login_and_download_hf_lm(config.model_name)
81106
model = AutoModelForCausalLM.from_pretrained(
82107
pretrained_model_path,
83108
use_cache=False,
84109
attn_implementation="sdpa",
85110
torch_dtype=torch.float16,
86111
)
87112

88-
# Load the tokenizer and add special tokens
89113
tokenizer = AutoTokenizer.from_pretrained(
90-
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
114+
config.model_name if config.tokenizer_name is None else config.tokenizer_name
91115
)
92116
if not tokenizer.pad_token_id:
93117
tokenizer.pad_token_id = tokenizer.eos_token_id
94118

95-
# If there is a mismatch between tokenizer vocab size and embedding matrix,
96-
# throw a warning and then expand the embedding matrix
97119
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
98-
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
120+
print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
99121
model.resize_token_embeddings(len(tokenizer))
100122

101-
print_model_size(model, train_config)
123+
return model, tokenizer
102124

103-
# print the datatype of the model parameters
104-
# print(get_parameter_dtypes(model))
105-
106-
if train_config.use_peft:
107-
# Load the pre-trained peft model checkpoint and setup its configuration
108-
if train_config.from_peft_checkpoint:
109-
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
110-
peft_config = model.peft_config
111-
# Generate the peft config and start fine-tuning from original model
112-
else:
113-
peft_config = generate_peft_config(train_config, kwargs)
114-
model = get_peft_model(model, peft_config)
115-
model.print_trainable_parameters()
116-
117-
# Get the dataset utils
118-
dataset_config = generate_dataset_config(train_config, kwargs)
119-
dataset_processer = tokenizer
120125

121-
# Load and preprocess the dataset for training and validation
122-
dataset_train = get_preprocessed_dataset(
123-
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
124-
)
126+
def apply_peft(model: AutoModelForCausalLM, train_config: TrainConfig, lora_config: LoraConfig) -> PeftModel:
127+
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled."""
128+
if not train_config.use_peft:
129+
return model
125130

126-
dataset_val = get_preprocessed_dataset(
127-
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
128-
)
131+
if train_config.from_peft_checkpoint:
132+
return PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
133+
134+
# Generate PEFT-compatible config from custom LoraConfig
135+
peft_config = generate_peft_config(train_config, lora_config)
136+
model = get_peft_model(model, peft_config)
137+
model.print_trainable_parameters()
138+
return model
139+
140+
141+
def setup_dataloaders(
142+
train_config: TrainConfig, dataset_config, tokenizer: AutoTokenizer, dataset_train, dataset_val
143+
) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
144+
"""Set up training and validation DataLoaders.
145+
146+
Args:
147+
train_config (TrainConfig): Training configuration object.
148+
dataset_config: Configuration for the dataset (generated from train_config).
149+
tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
150+
dataset_train: Preprocessed training dataset.
151+
dataset_val: Preprocessed validation dataset.
129152
130-
# TODO: vbaddi, check if its necessary to do this?
131-
# dataset_train = ConcatDataset(
132-
# dataset_train, chunk_size=train_config.context_length
133-
# )
134-
##
135-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
136-
print("length of dataset_train", len(dataset_train))
137-
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
153+
Returns:
154+
tuple: A tuple of (train_dataloader, eval_dataloader), where eval_dataloader is None if validation is disabled.
155+
156+
Raises:
157+
ValueError: If validation is enabled but the validation set is too small.
158+
159+
Notes:
160+
- Applies a custom data collator if provided by get_custom_data_collator.
161+
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
162+
"""
163+
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)
164+
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
138165
if custom_data_collator:
139-
print("custom_data_collator is used")
140166
train_dl_kwargs["collate_fn"] = custom_data_collator
141167

142-
# Create DataLoaders for the training and validation dataset
143168
train_dataloader = torch.utils.data.DataLoader(
144169
dataset_train,
145170
num_workers=train_config.num_workers_dataloader,
@@ -150,12 +175,7 @@ def main(**kwargs):
150175

151176
eval_dataloader = None
152177
if train_config.run_validation:
153-
# if train_config.batching_strategy == "packing":
154-
# dataset_val = ConcatDataset(
155-
# dataset_val, chunk_size=train_config.context_length
156-
# )
157-
158-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
178+
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
159179
if custom_data_collator:
160180
val_dl_kwargs["collate_fn"] = custom_data_collator
161181

@@ -165,37 +185,90 @@ def main(**kwargs):
165185
pin_memory=True,
166186
**val_dl_kwargs,
167187
)
188+
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
168189
if len(eval_dataloader) == 0:
169-
raise ValueError(
170-
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
171-
)
172-
else:
173-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
174-
175-
longest_seq_length, _ = get_longest_seq_length(
176-
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
177-
)
178-
else:
179-
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
190+
raise ValueError("Eval set too small to load even one batch.")
191+
192+
return train_dataloader, eval_dataloader
193+
180194

195+
def main(
196+
model_name: str = None,
197+
tokenizer_name: str = None,
198+
batch_size_training: int = None,
199+
lr: float = None,
200+
peft_config_file: str = None,
201+
**kwargs,
202+
) -> None:
203+
"""
204+
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
205+
206+
Args:
207+
model_name (str, optional): Override default model name.
208+
tokenizer_name (str, optional): Override default tokenizer name.
209+
batch_size_training (int, optional): Override default training batch size.
210+
lr (float, optional): Override default learning rate.
211+
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config.
212+
**kwargs: Additional arguments to override TrainConfig.
213+
214+
Example:
215+
.. code-block:: bash
216+
217+
# Using a YAML config file for PEFT
218+
python -m QEfficient.cloud.finetune \\
219+
--model_name "meta-llama/Llama-3.2-1B" \\
220+
--lr 5e-4 \\
221+
--peft_config_file "lora_config.yaml"
222+
223+
# Using default LoRA config
224+
python -m QEfficient.cloud.finetune \\
225+
--model_name "meta-llama/Llama-3.2-1B" \\
226+
--lr 5e-4
227+
"""
228+
train_config = TrainConfig()
229+
# local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
230+
update_config(train_config, **kwargs)
231+
232+
lora_config = LoraConfig()
233+
if peft_config_file:
234+
peft_config_data = load_config_file(peft_config_file)
235+
validate_config(peft_config_data, config_type="lora")
236+
lora_config = LoraConfig(**peft_config_data)
237+
238+
setup_distributed_training(train_config)
239+
setup_seeds(train_config.seed)
240+
model, tokenizer = load_model_and_tokenizer(train_config)
241+
print_model_size(model, train_config)
242+
model = apply_peft(model, train_config, lora_config)
243+
244+
# Pass an empty dict instead of kwargs to avoid irrelevant parameters
245+
dataset_config = generate_dataset_config(train_config, kwargs)
246+
dataset_train = get_preprocessed_dataset(
247+
tokenizer, dataset_config, split="train", context_length=train_config.context_length
248+
)
249+
dataset_val = get_preprocessed_dataset(
250+
tokenizer, dataset_config, split="test", context_length=train_config.context_length
251+
)
252+
train_dataloader, eval_dataloader = setup_dataloaders(
253+
train_config, dataset_config, tokenizer, dataset_train, dataset_val
254+
)
255+
dataset_for_seq_length = (
256+
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
257+
if train_config.run_validation
258+
else train_dataloader.dataset
259+
)
260+
longest_seq_length, _ = get_longest_seq_length(dataset_for_seq_length)
181261
print(
182-
f"The longest sequence length in the train data is {longest_seq_length}, "
183-
f"passed context length is {train_config.context_length} and overall model's context length is "
184-
f"{model.config.max_position_embeddings}"
262+
f"Longest sequence length: {longest_seq_length}, "
263+
f"Context length: {train_config.context_length}, "
264+
f"Model max context: {model.config.max_position_embeddings}"
185265
)
186266
model.to(train_config.device)
187-
optimizer = optim.AdamW(
188-
model.parameters(),
189-
lr=train_config.lr,
190-
weight_decay=train_config.weight_decay,
191-
)
267+
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
192268
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
193-
194-
# wrap model with DDP
195269
if train_config.enable_ddp:
196270
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
197-
198-
_ = train(
271+
train(
199272
model,
200273
train_dataloader,
201274
eval_dataloader,
@@ -208,8 +281,6 @@ def main(**kwargs):
208281
dist.get_rank() if train_config.enable_ddp else None,
209282
None,
210283
)
211-
212-
# finalize torch distributed
213284
if train_config.enable_ddp:
214285
dist.destroy_process_group()
215286

0 commit comments

Comments
 (0)