@@ -66,21 +66,21 @@ def test_finetune(
66
66
67
67
results = finetune (** kwargs )
68
68
69
- assert np .allclose (results ["avg_train_prep" ], 1.002326 ), "Train perplexity is not matching."
70
- assert np .allclose (results ["avg_train_loss" ], 0.00232327 ), "Train loss is not matching."
71
- assert np .allclose (results ["avg_eval_prep" ], 1.0193923 ), "Eval perplexity is not matching."
72
- assert np .allclose (results ["avg_eval_loss" ], 0.0192067 ), "Eval loss is not matching."
73
- assert results ["avg_epoch_time" ] < 30 , "Training should complete within 30 seconds."
69
+ assert np .allclose (results ["avg_train_prep" ], 1.002326 , atol = 1e-5 ), "Train perplexity is not matching."
70
+ assert np .allclose (results ["avg_train_loss" ], 0.00232327 , atol = 1e-5 ), "Train loss is not matching."
71
+ assert np .allclose (results ["avg_eval_prep" ], 1.0193923 , atol = 1e-5 ), "Eval perplexity is not matching."
72
+ assert np .allclose (results ["avg_eval_loss" ], 0.0192067 , atol = 1e-5 ), "Eval loss is not matching."
73
+ assert results ["avg_epoch_time" ] < 60 , "Training should complete within 60 seconds."
74
74
75
75
train_config_spy .assert_called_once ()
76
76
generate_dataset_config_spy .assert_called_once ()
77
77
generate_peft_config_spy .assert_called_once ()
78
- update_config_spy .assert_called_once ()
79
78
get_custom_data_collator_spy .assert_called_once ()
80
79
get_longest_seq_length_spy .assert_called_once ()
81
80
print_model_size_spy .assert_called_once ()
82
81
train_spy .assert_called_once ()
83
82
83
+ assert update_config_spy .call_count == 2
84
84
assert get_dataloader_kwargs_spy .call_count == 2
85
85
assert get_preprocessed_dataset_spy .call_count == 2
86
86
@@ -102,7 +102,7 @@ def test_finetune(
102
102
else :
103
103
assert eval_dataloader is None
104
104
105
- args , kwargs = update_config_spy .call_args
105
+ args , kwargs = update_config_spy .call_args_list [ 0 ]
106
106
train_config = args [0 ]
107
107
assert max_train_step >= train_config .gradient_accumulation_steps , (
108
108
"Total training step should be more than 4 which is gradient accumulation steps."
0 commit comments