Skip to content

Commit b8182a6

Browse files
quic-meetkumameetkuma
authored and
meetkuma
committed
Updated test case based on recent commits
1 parent efacb97 commit b8182a6

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/finetune/test_finetune.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,21 @@ def test_finetune(
6666

6767
results = finetune(**kwargs)
6868

69-
assert np.allclose(results["avg_train_prep"], 1.002326), "Train perplexity is not matching."
70-
assert np.allclose(results["avg_train_loss"], 0.00232327), "Train loss is not matching."
71-
assert np.allclose(results["avg_eval_prep"], 1.0193923), "Eval perplexity is not matching."
72-
assert np.allclose(results["avg_eval_loss"], 0.0192067), "Eval loss is not matching."
73-
assert results["avg_epoch_time"] < 30, "Training should complete within 30 seconds."
69+
assert np.allclose(results["avg_train_prep"], 1.002326, atol=1e-5), "Train perplexity is not matching."
70+
assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching."
71+
assert np.allclose(results["avg_eval_prep"], 1.0193923, atol=1e-5), "Eval perplexity is not matching."
72+
assert np.allclose(results["avg_eval_loss"], 0.0192067, atol=1e-5), "Eval loss is not matching."
73+
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
7474

7575
train_config_spy.assert_called_once()
7676
generate_dataset_config_spy.assert_called_once()
7777
generate_peft_config_spy.assert_called_once()
78-
update_config_spy.assert_called_once()
7978
get_custom_data_collator_spy.assert_called_once()
8079
get_longest_seq_length_spy.assert_called_once()
8180
print_model_size_spy.assert_called_once()
8281
train_spy.assert_called_once()
8382

83+
assert update_config_spy.call_count == 2
8484
assert get_dataloader_kwargs_spy.call_count == 2
8585
assert get_preprocessed_dataset_spy.call_count == 2
8686

@@ -102,7 +102,7 @@ def test_finetune(
102102
else:
103103
assert eval_dataloader is None
104104

105-
args, kwargs = update_config_spy.call_args
105+
args, kwargs = update_config_spy.call_args_list[0]
106106
train_config = args[0]
107107
assert max_train_step >= train_config.gradient_accumulation_steps, (
108108
"Total training step should be more than 4 which is gradient accumulation steps."

0 commit comments

Comments
 (0)