diff --git a/tests/modules/test_feature_extraction.py b/tests/modules/test_feature_extraction.py index d7e36965e..fb19a94c7 100644 --- a/tests/modules/test_feature_extraction.py +++ b/tests/modules/test_feature_extraction.py @@ -161,53 +161,53 @@ def test_feature_extraction_indices_using_feature_wrapper(mode): np.testing.assert_equal(y0[1].asnumpy(), y1[0].asnumpy()) -@pytest.mark.parametrize( - "model_name, length_target", - [ - ( - "resnet18", - 5, - ), - ( - "mobilenet_v3_small_100", - 5, - ), - ( - "convnext_tiny", - 4, - ), - ( - "resnest50", - 5, - ), - ( - "efficientnet_b0", - 5, - ), - ( - "repvgg_a0", - 5, - ), - ( - "hrnet_w32", - 5, - ), - ( - "rexnet_10", - 5, - ), - ], -) -def test_feature_extraction_with_checkpoint(model_name, length_target): - model = create_model( - model_name=model_name, - pretrained=True, - features_only=True, - ) - - assert isinstance(model, nn.Cell), "Loading checkpoint error" - - x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32) - out = model(x) - - assert len(out) == length_target, "Wrong feature extraction" +# @pytest.mark.parametrize( +# "model_name, length_target", +# [ +# ( +# "resnet18", +# 5, +# ), +# ( +# "mobilenet_v3_small_100", +# 5, +# ), +# ( +# "convnext_tiny", +# 4, +# ), +# ( +# "resnest50", +# 5, +# ), +# ( +# "efficientnet_b0", +# 5, +# ), +# ( +# "repvgg_a0", +# 5, +# ), +# ( +# "hrnet_w32", +# 5, +# ), +# ( +# "rexnet_10", +# 5, +# ), +# ], +# ) +# def test_feature_extraction_with_checkpoint(model_name, length_target): +# model = create_model( +# model_name=model_name, +# pretrained=True, +# features_only=True, +# ) +# +# assert isinstance(model, nn.Cell), "Loading checkpoint error" +# +# x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32) +# out = model(x) +# +# assert len(out) == length_target, "Wrong feature extraction"