Skip to content

Commit

Permalink
revert deletion of validation checks on some train args
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Mar 8, 2024
1 parent 7dfe174 commit e718d04
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def train(

logger = logging.get_logger("sft_trainer")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, float)) or (
train_args.num_train_epochs <= 0
):
raise ValueError("num_train_epochs has to be an integer/float >= 1")
if (not isinstance(train_args.gradient_accumulation_steps, int)) or (
train_args.gradient_accumulation_steps <= 0
):
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

task_type = "CAUSAL_LM"
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down

0 comments on commit e718d04

Please sign in to comment.