Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jczaja committed Jul 29, 2022
1 parent ee3ed41 commit b91e447
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ static paddle::framework::DDim ColumnMatrixDimsFromVector(
return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1});
}


phi::DDim GetDimForInput(const ExecutionContext &ctx,
std::string input_name) {
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
Expand All @@ -114,16 +112,15 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx,
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
Expand All @@ -134,8 +131,6 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx,
return input_dims;
}



template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
Expand Down Expand Up @@ -606,13 +601,15 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i, input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d", input_name,
i, input_dims.size()));
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
Expand Down

0 comments on commit b91e447

Please sign in to comment.