Skip to content

Commit

Permalink
docs: add performance data to 910* benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
ChongWei905 committed Sep 12, 2024
1 parent becb9bf commit 4e97492
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions tests/modules/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,53 +161,53 @@ def test_feature_extraction_indices_using_feature_wrapper(mode):
np.testing.assert_equal(y0[1].asnumpy(), y1[0].asnumpy())


@pytest.mark.parametrize(
"model_name, length_target",
[
(
"resnet18",
5,
),
(
"mobilenet_v3_small_100",
5,
),
(
"convnext_tiny",
4,
),
(
"resnest50",
5,
),
(
"efficientnet_b0",
5,
),
(
"repvgg_a0",
5,
),
(
"hrnet_w32",
5,
),
(
"rexnet_10",
5,
),
],
)
def test_feature_extraction_with_checkpoint(model_name, length_target):
model = create_model(
model_name=model_name,
pretrained=True,
features_only=True,
)

assert isinstance(model, nn.Cell), "Loading checkpoint error"

x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32)
out = model(x)

assert len(out) == length_target, "Wrong feature extraction"
# @pytest.mark.parametrize(
# "model_name, length_target",
# [
# (
# "resnet18",
# 5,
# ),
# (
# "mobilenet_v3_small_100",
# 5,
# ),
# (
# "convnext_tiny",
# 4,
# ),
# (
# "resnest50",
# 5,
# ),
# (
# "efficientnet_b0",
# 5,
# ),
# (
# "repvgg_a0",
# 5,
# ),
# (
# "hrnet_w32",
# 5,
# ),
# (
# "rexnet_10",
# 5,
# ),
# ],
# )
# def test_feature_extraction_with_checkpoint(model_name, length_target):
# model = create_model(
# model_name=model_name,
# pretrained=True,
# features_only=True,
# )
#
# assert isinstance(model, nn.Cell), "Loading checkpoint error"
#
# x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32)
# out = model(x)
#
# assert len(out) == length_target, "Wrong feature extraction"

0 comments on commit 4e97492

Please sign in to comment.