Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj committed Dec 26, 2023
1 parent bfa83bf commit 59d4818
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,16 +817,19 @@ void NceGradInferMeta(const MetaTensor& input,
auto x_dims = input.dims();
if (input_grad != nullptr) {
input_grad->set_dims(x_dims);
input_grad->set_dtype(input.dtype());
}

auto w_dims = weight.dims();
if (weight_grad) {
weight_grad->set_dims(w_dims);
weight_grad->set_dtype(weight.dtype());
}

auto bias_dims = bias.dims();
if (bias_grad) {
bias_grad->set_dims(bias_dims);
bias_grad->set_dtype(bias.dtype());
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3260,6 +3260,7 @@ void NceInferMeta(const MetaTensor& input,
out_dims.push_back(x_dims[0]);
out_dims.push_back(1);
cost->set_dims(common::make_ddim(out_dims));
cost->set_dtype(DataType::FLOAT32);

if (!is_test) {
// set dims of output(SampleOut)
Expand Down

0 comments on commit 59d4818

Please sign in to comment.