Skip to content

Commit 521db35

Browse files
Fix CI unittest asserts (#4234)
1 parent e2c97a8 commit 521db35

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,12 +1466,12 @@ def test_train_vlm_prompt_completion_gemma(self):
14661466
trainer.train()
14671467

14681468
# Check that the training loss is not None
1469-
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1469+
assert trainer.state.log_history[-1]["train_loss"] is not None
14701470

14711471
# Check the params have changed
14721472
for n, param in previous_trainable_params.items():
14731473
new_param = trainer.model.get_parameter(n)
1474-
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
1474+
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
14751475

14761476
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
14771477
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.

0 commit comments

Comments
 (0)