Skip to content

Commit 4bf93cd

Browse files
committed
Fix
1 parent 4ad8416 commit 4bf93cd

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

test/ir/inference/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ elseif(WITH_ONEDNN)
311311
test_mkldnn_conv3d_op
312312
test_mkldnn_depthwise_conv_pass
313313
test_mkldnn_shape_op
314-
test_mkldnn_shuffle_channel_op)
314+
test_onednn_shuffle_channel_op)
315315
foreach(target ${PIR_COVERAGE_MKLDNN_TESTS})
316316
py_test_modules(${target}_pir MODULES ${target} ENVS FLAGS_enable_pir_api=1)
317317
set_tests_properties(${target} PROPERTIES LABELS "RUN_TYPE=INFER")
@@ -335,7 +335,7 @@ elseif(WITH_ONEDNN)
335335
endforeach()
336336

337337
set_tests_properties(test_mkldnn_shape_op_pir PROPERTIES TIMEOUT 300)
338-
set_tests_properties(test_mkldnn_shuffle_channel_op_pir PROPERTIES TIMEOUT
338+
set_tests_properties(test_onednn_shuffle_channel_op_pir PROPERTIES TIMEOUT
339339
300)
340340
set_tests_properties(test_onednn_conv_bias_fuse_pass_pir PROPERTIES TIMEOUT
341341
300)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from functools import partial
17+
18+
import hypothesis.strategies as st
19+
import numpy as np
20+
from auto_scan_test import OnednnAutoScanTest
21+
from hypothesis import given
22+
from program_config import OpConfig, ProgramConfig, TensorConfig
23+
24+
25+
class TestMKLDNNShuffleChannelOp(OnednnAutoScanTest):
26+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
27+
return True
28+
29+
def sample_program_configs(self, *args, **kwargs):
30+
def generate_input(*args, **kwargs):
31+
return np.random.random(kwargs['in_shape']).astype(np.float32)
32+
33+
shuffle_channel_op = OpConfig(
34+
type="shuffle_channel",
35+
inputs={"X": ["input_data"]},
36+
outputs={"Out": ["output_data"]},
37+
attrs={"group": kwargs['group']},
38+
)
39+
40+
program_config = ProgramConfig(
41+
ops=[shuffle_channel_op],
42+
weights={},
43+
inputs={
44+
"input_data": TensorConfig(
45+
data_gen=partial(generate_input, *args, **kwargs)
46+
),
47+
},
48+
outputs=["output_data"],
49+
)
50+
51+
yield program_config
52+
53+
def sample_predictor_configs(self, program_config):
54+
config = self.create_inference_config(use_onednn=True)
55+
yield config, (1e-5, 1e-5)
56+
57+
@given(
58+
group=st.sampled_from([1, 2, 8, 32, 128]),
59+
in_shape=st.sampled_from([[5, 512, 2, 3], [2, 256, 5, 4]]),
60+
)
61+
def test(self, *args, **kwargs):
62+
self.run_test(quant=False, *args, **kwargs)
63+
64+
65+
if __name__ == "__main__":
66+
unittest.main()

0 commit comments

Comments
 (0)