Skip to content

[QEff Finetune]: Refactor the finetune main __call__ #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

vbaddi
Copy link
Contributor

@vbaddi vbaddi commented Feb 27, 2025

  • Refactor the finetune main api
  • Add support to override the PEFT config (yaml/json)
  • Add support to validate the correctness of PEFT Config
  • Some nit changes
r: 16
lora_alpha: 64
target_modules:
  - q_proj
  - v_proj
  - k_proj
bias: none
task_type: CAUSAL_LM
lora_dropout: 0.1

Command:

python -m QEfficient.cloud.finetune \
    --model_name "meta-llama/Llama-3.2-1B" \
    --lr 5e-4 \
    --peft_config_file "lora_config.yaml"

Using Default LoRA Config:

python -m QEfficient.cloud.finetune \
    --model_name "meta-llama/Llama-3.2-1B" \
    --lr 5e-4

@vbaddi vbaddi self-assigned this Feb 27, 2025
@vbaddi vbaddi force-pushed the add_peft_yaml_path branch from 2f19722 to 48061ee Compare February 27, 2025 13:57
@vbaddi vbaddi added the enhancement New feature or request label Mar 19, 2025
@quic-amitraj quic-amitraj marked this pull request as draft April 11, 2025 08:44
@ochougul ochougul closed this Apr 15, 2025
@ochougul
Copy link
Contributor

copy of #314

@pytest.mark.on_qaic
@pytest.mark.skip(reason="eager docker not available in sdk")
@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put pytest markers @pytest.mark.finetune and @pytest.mark.cli It will help in executing the test in stages.

finetune(**kwargs)
results = finetune(**kwargs)

assert np.allclose(results["avg_train_prep"], 1.002326, atol=1e-5), "Train perplexity is not matching."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avg_train_prep to be changed to avg_train_metric wrt changes in PR 292

@@ -40,7 +40,7 @@ def train(
optimizer,
lr_scheduler,
gradient_accumulation_steps,
train_config: TRAIN_CONFIG,
train_config: TrainConfig,
device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of passing all three train_config.gradient_accumulation_steps, train_config and train_config.device, only train_config is enough.

@quic-meetkuma quic-meetkuma force-pushed the add_peft_yaml_path branch 5 times, most recently from b8182a6 to d0fff22 Compare April 21, 2025 11:24
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
- Ensures types match expected values (int, float, list, etc.).
"""
if config_type.lower() != "lora":
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not doing lora finetuning in case of BERT, it will raise error.


Args:
config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON.
config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora").
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add field in config_type corresponding to BERT as we don't do lora fine tuning in it.

# local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
update_config(train_config, **kwargs)

lora_config = LoraConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is not required.

longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
lora_config = LoraConfig()

update_config(lora_config, **kwargs)
Copy link
Contributor

@quic-mamta quic-mamta Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do need to update lora_config here with kwargs?

train_config = args[0]
assert max_train_step >= train_config.gradient_accumulation_steps, (
"Total training step should be more than 4 which is gradient accumulation steps."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In place of '4', please pass train_config.gradient_accumulation_steps instead. In case, user passes some different value for train_config.gradient_accumulation_steps, 4 will be confusing.

train_config = args[0]
assert max_train_step >= train_config.gradient_accumulation_steps, (
Copy link
Contributor

@quic-swatia quic-swatia Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion will fail. #107 should only be validated if max_train_step >0 as the default value for max_train_step is 0. Please refer : https://github.com/quic/efficient-transformers/blob/main/QEfficient/finetune/utils/train_utils.py#L174

Copy link
Contributor

@quic-mamta quic-mamta Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line 24 , max_train_step value is set to 20, so this assertion is correct, but line 24 can be changed to max_train_step = 20 for interpretability, similarly for other params also.

@@ -44,132 +51,139 @@
warnings.filterwarnings("ignore")


def main(**kwargs):
def setup_distributed_training(config: TrainConfig) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use variable name train_config in place of config to maintain uniformity in the code. Different names can cause confusion.


if not hasattr(model, "base_model_prefix"):
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
def load_model_and_tokenizer(config: TrainConfig) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use variable name train_config in place of 'config' to maintain uniformity in the code. Different names can cause confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants