-
Notifications
You must be signed in to change notification settings - Fork 43
[QEff Finetune]: Refactor the finetune main __call__ #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
2f19722
to
48061ee
Compare
3ff66eb
to
c0d2315
Compare
copy of #314 |
7f2d367
to
b2ee39a
Compare
@pytest.mark.on_qaic | ||
@pytest.mark.skip(reason="eager docker not available in sdk") | ||
@pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put pytest markers @pytest.mark.finetune and @pytest.mark.cli It will help in executing the test in stages.
finetune(**kwargs) | ||
results = finetune(**kwargs) | ||
|
||
assert np.allclose(results["avg_train_prep"], 1.002326, atol=1e-5), "Train perplexity is not matching." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avg_train_prep to be changed to avg_train_metric wrt changes in PR 292
@@ -40,7 +40,7 @@ def train( | |||
optimizer, | |||
lr_scheduler, | |||
gradient_accumulation_steps, | |||
train_config: TRAIN_CONFIG, | |||
train_config: TrainConfig, | |||
device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need of passing all three train_config.gradient_accumulation_steps, train_config and train_config.device, only train_config is enough.
b8182a6
to
d0fff22
Compare
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com> Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
d0fff22
to
e27deeb
Compare
- Ensures types match expected values (int, float, list, etc.). | ||
""" | ||
if config_type.lower() != "lora": | ||
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are not doing lora finetuning in case of BERT, it will raise error.
|
||
Args: | ||
config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON. | ||
config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora"). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add field in config_type corresponding to BERT as we don't do lora fine tuning in it.
# local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"} | ||
update_config(train_config, **kwargs) | ||
|
||
lora_config = LoraConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is not required.
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset) | ||
lora_config = LoraConfig() | ||
|
||
update_config(lora_config, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do need to update lora_config here with kwargs?
train_config = args[0] | ||
assert max_train_step >= train_config.gradient_accumulation_steps, ( | ||
"Total training step should be more than 4 which is gradient accumulation steps." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In place of '4', please pass train_config.gradient_accumulation_steps instead. In case, user passes some different value for train_config.gradient_accumulation_steps, 4 will be confusing.
train_config = args[0] | ||
assert max_train_step >= train_config.gradient_accumulation_steps, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion will fail. #107 should only be validated if max_train_step >0 as the default value for max_train_step is 0. Please refer : https://github.com/quic/efficient-transformers/blob/main/QEfficient/finetune/utils/train_utils.py#L174
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line 24 , max_train_step value is set to 20, so this assertion is correct, but line 24 can be changed to max_train_step = 20 for interpretability, similarly for other params also.
@@ -44,132 +51,139 @@ | |||
warnings.filterwarnings("ignore") | |||
|
|||
|
|||
def main(**kwargs): | |||
def setup_distributed_training(config: TrainConfig) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use variable name train_config in place of config to maintain uniformity in the code. Different names can cause confusion.
|
||
if not hasattr(model, "base_model_prefix"): | ||
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.") | ||
def load_model_and_tokenizer(config: TrainConfig) -> tuple[AutoModelForCausalLM, AutoTokenizer]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use variable name train_config in place of 'config' to maintain uniformity in the code. Different names can cause confusion.
Command:
Using Default LoRA Config:
python -m QEfficient.cloud.finetune \ --model_name "meta-llama/Llama-3.2-1B" \ --lr 5e-4