@@ -59,7 +59,7 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
5959 std::vector<int64_t > x_strides (x_dims.size () - 3 , 1 );
6060 std::vector<int64_t > y_strides (x_dims.size () - 3 , 1 );
6161 std::vector<int64_t > out_strides (x_dims.size () - 3 , 1 );
62- std::vector<int64_t > out_ddims (x_dims.size () - 3 , 1 );
62+ std::vector<int64_t > out_dims (x_dims.size () - 3 , 1 );
6363
6464 x_strides.reserve (x_dims.size ());
6565 y_strides.reserve (x_dims.size ());
@@ -78,20 +78,20 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
7878 }
7979
8080 out_strides.insert (out_strides.end (), {M * N, N, 1 });
81- out_ddims .insert (out_ddims .end (),
82- {std::max (x_dims[MB_idx], y_dims[MB_idx]), M, N});
81+ out_dims .insert (out_dims .end (),
82+ {std::max (x_dims[MB_idx], y_dims[MB_idx]), M, N});
8383
8484 for (int i = x_dims.size () - 4 ; i >= 0 ; --i) {
85- out_ddims [i] = std::max (x_dims[i], y_dims[i]);
85+ out_dims [i] = std::max (x_dims[i], y_dims[i]);
8686 x_strides[i] = x_dims[i + 1 ] * x_strides[i + 1 ];
8787 y_strides[i] = y_dims[i + 1 ] * y_strides[i + 1 ];
8888
89- out_strides[i] = out_ddims [i + 1 ] * out_strides[i + 1 ];
89+ out_strides[i] = out_dims [i + 1 ] * out_strides[i + 1 ];
9090 }
9191
9292 auto x_md = memory::desc (x_dims, OneDNNGetDataType<XT>(), x_strides);
9393 auto y_md = memory::desc (y_dims, OneDNNGetDataType<YT>(), y_strides);
94- auto out_md = memory::desc (out_ddims , OneDNNGetDataType<OT>(), out_strides);
94+ auto out_md = memory::desc (out_dims , OneDNNGetDataType<OT>(), out_strides);
9595
9696 this ->AcquireForwardPrimitiveDescriptor (x_md, y_md, out_md);
9797 }
0 commit comments