You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I see the issue when run unit tests on NVidia V100 (GPU). Here is the link for more details.
Briefly:
=========================== short test summary info ============================
FAILED opt/gemma/gemma/layers_test.py::EinsumTest::test_rmsnorm0 - AssertionE...
FAILED opt/gemma/gemma/modules_test.py::FeedForwardTest::test_ffw0 - Assertio...
FAILED opt/gemma/gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0
================== 3 failed, 13 passed, 2 warnings in 35.61s ===================```
Some details:
1. test_rmsnorm0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:348)). Looks like this is an EPS-error. I don't think it's a good idea to compare expected array of floats with resulted one. Is it possible to add some discrepancy between expected and calculated arrays? Like `rtol=1e-5, atol=1e-5`?
2. test_ffw0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:415)) is similar to previous one.
3. test_adds_positional_embeddings0 [link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:486). IMHO, jax cannot digest is correctly on GPUs
Thank you for your help! Hope it's fixable! =)
The text was updated successfully, but these errors were encountered:
DwarKapex
changed the title
Issue with unit tests on NVIdia A100 (GPU)
Issue with unit tests on NVIdia V100 (GPU)
May 15, 2024
Hi everyone.
I see the issue when run unit tests on NVidia V100 (GPU). Here is the link for more details.
Briefly:
The text was updated successfully, but these errors were encountered: