@@ -133,7 +133,7 @@ class PlacementPassTest {
133133
134134 auto pass = PassRegistry::Instance ().Get (" onednn_placement_pass" );
135135
136- pass->Set (" mkldnn_enabled_op_types " ,
136+ pass->Set (" onednn_enabled_op_types " ,
137137 new std::unordered_set<std::string>(onednn_enabled_op_types));
138138
139139 graph.reset (pass->Apply (graph.release ()));
@@ -143,8 +143,10 @@ class PlacementPassTest {
143143 for (auto * node : graph->Nodes ()) {
144144 if (node->IsOp ()) {
145145 auto * op = node->Op ();
146- if (op->HasAttr (" use_mkldnn" ) &&
147- PADDLE_GET_CONST (bool , op->GetAttr (" use_mkldnn" ))) {
146+ if ((op->HasAttr (" use_mkldnn" ) &&
147+ PADDLE_GET_CONST (bool , op->GetAttr (" use_mkldnn" ))) ||
148+ (op->HasAttr (" use_onednn" ) &&
149+ PADDLE_GET_CONST (bool , op->GetAttr (" use_onednn" )))) {
148150 ++use_onednn_true_count;
149151 }
150152 }
@@ -156,27 +158,27 @@ class PlacementPassTest {
156158 void PlacementNameTest () {
157159 auto pass = PassRegistry::Instance ().Get (" onednn_placement_pass" );
158160 EXPECT_EQ (static_cast <PlacementPassBase*>(pass.get ())->GetPlacementName (),
159- " MKLDNN " );
161+ " ONEDNN " );
160162 }
161163};
162164
163- TEST (MKLDNNPlacementPass , enable_conv_relu) {
165+ TEST (ONEDNNPlacementPass , enable_conv_relu) {
164166 // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
165167 PlacementPassTest ().MainTest ({" conv2d" , " relu" }, 4 );
166168}
167169
168- TEST (MKLDNNPlacementPass , enable_relu_pool) {
170+ TEST (ONEDNNPlacementPass , enable_relu_pool) {
169171 // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
170172 PlacementPassTest ().MainTest ({" relu" , " pool2d" }, 4 );
171173}
172174
173- TEST (MKLDNNPlacementPass , enable_all) {
175+ TEST (ONEDNNPlacementPass , enable_all) {
174176 // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
175177 // 1 concat
176178 PlacementPassTest ().MainTest ({}, 6 );
177179}
178180
179- TEST (MKLDNNPlacementPass , placement_name) {
181+ TEST (ONEDNNPlacementPass , placement_name) {
180182 PlacementPassTest ().PlacementNameTest ();
181183}
182184
0 commit comments