Skip to content

Commit

Permalink
OLMo 7B integration test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Apr 13, 2024
1 parent 450f038 commit 4df56a4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@ def test_model_7b_logits(self):
model = OLMoForCausalLM.from_pretrained("allenai/OLMo-7B-hf")
out = model(torch.tensor(input_ids)).logits
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.0271, 0.0249, -0.0578, -0.0870, 0.0167, 0.0710, 0.1002, 0.0677]])
EXPECTED_MEAN = torch.tensor([[0.0266, -0.0012, -0.0589, -0.0868, 0.0284, 0.0609, 0.0836, 0.0552]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([-1.7433, -1.6685, 7.4941, 6.1506, 0.1364, -0.1127, 1.3224, 4.5458, 4.2068, 5.8296, 7.4723, 2.7925, 3.1245, 10.8872, 10.0758, 10.6717, 7.0945, 1.2398, 3.6766, 4.2365, 2.5655, 2.2222, 1.7418, 0.5223, 0.7753, 1.0938, 0.6723, 6.2522, 6.2264, 1.8105]) # fmt: skip
EXPECTED_SLICE = torch.tensor([-1.5364, -1.5505, 6.6963, 4.6085, -0.5580, 0.3295, 2.0886, 4.2162, 3.3459, 4.6850, 6.9190, 2.5602, 2.6904, 9.5714, 9.6181, 9.7108, 6.0352, 1.0118, 3.6572, 3.8474, 1.9998, 1.7586, 1.6130, 0.8734, 0.5737, 1.2120, 0.2762, 5.5266, 5.8129, 1.8287]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2)

@unittest.skip("Logits are not yet correct, will update!")
Expand Down

0 comments on commit 4df56a4

Please sign in to comment.