Skip to content

Commit

Permalink
fix masked_select infer shape (#33167)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean authored Jun 2, 2021
1 parent 47774d9 commit 9d4722c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/operators/masked_select_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class MaskedSelectOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect");
framework::DDim output_dims(ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", output_dims);

// output will only be a 1-D Tensor
ctx->SetOutputDim("Y", framework::make_ddim({-1}));
ctx->ShareLoD("X", /*->*/ "Y");
}

Expand Down

0 comments on commit 9d4722c

Please sign in to comment.