From caea0ac348946acd402c609d7233c435c839eb28 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Sat, 4 Jan 2025 18:11:38 -0500 Subject: [PATCH] re-enable vit_b_16 test --- tests/test_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 98c7e56..12e4529 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -83,7 +83,6 @@ def test_torchvision_models_resnet18(): # to reasonable expectations. -@pytest.mark.skip(reason="https://github.com/jax-ml/jax/issues/25066") @pytest.mark.skipif(not is_network_reachable(), reason="Network is not reachable") def test_torchvision_models_vit_b_16(): import torchvision @@ -92,8 +91,7 @@ def test_torchvision_models_vit_b_16(): model.eval() parameters = {k: t2j(v) for k, v in model.named_parameters()} - # buffers = {k: t2j(v) for k, v in model.named_buffers()} - # assert len(buffers.keys()) == 0 + assert len(dict(model.named_buffers()).keys()) == 0 input_batch = random.normal(random.PRNGKey(123), (1, 3, 224, 224)) res_torch = model(j2t(input_batch))