Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
'acos_grad',
'put_along_axis_grad',
'masked_fill_grad',
'masked_select_grad',
]
131 changes: 97 additions & 34 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1353,49 +1353,112 @@ void masked_select_grad(const Tensor& x,
if (x_grad) {
auto promoted_x = ConvertToMT<T>(x);
auto promoted_out_grad = ConvertToMT<T>(out_grad);
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(mask.shape())) {
// clang-format off
/**
* expand_shape = broadcast(x, mask)
* out = masked_select(expand(x, expand_shape), expand(mask, expand_shape))
* Given dout, then:
* expand_dx_flat = scatter(out_grad.reshape([-1]), index, dout, axis=0, overwrite=False) # overwrite should be false for broadcast case
* where index = masked_select(
* arange(expand_shape.prod()),
* expand(mask, expand_shape).reshape([-1]),
* )
* dx = reduce_as(expand_dx_flat.reshape(expand_shape), x)
*/
// clang-format on
// get broadcast shape
auto dummy_x = backend::full_with_tensor<T>(
shape64<T>(x), 0.0, x.dtype(), x.place());
auto dummy_y = backend::full_with_tensor<T>(
shape64<T>(mask), 0.0, x.dtype(), x.place());
auto zeros = dummy_x + dummy_y;
auto expand_shape = shape64<T>(zeros);

auto x_num = 1;
for (size_t i = 0; i < promoted_x.shape().size(); i++) {
x_num *= promoted_x.shape()[i];
}
// generate out_indices for scatter
auto start = full_scalar<T>(0, DataType::INT64);
auto end = full_scalar<T>(1, DataType::INT64);
for (int i = 0; i < zeros.dims().size(); ++i) {
end = end * get_slice<T>(expand_shape, i);
}
auto step = full_scalar<T>(1, DataType::INT64);

auto grad_num = 1;
for (size_t i = 0; i < promoted_out_grad.shape().size(); i++) {
grad_num *= promoted_out_grad.shape()[i];
}
Tensor expand_shape_numel =
full<T>({1}, 1.0, DataType::INT64, expand_shape.place());
for (int i = 0; i < zeros.dims().size(); i++) {
expand_shape_numel = expand_shape_numel * get_slice<T>(expand_shape, i);
}

auto end = full<T>({1}, x_num, promoted_x.dtype(), x.place());
auto start = full<T>({1}, 0, promoted_x.dtype(), x.place());
auto step = full<T>({1}, 1, promoted_x.dtype(), x.place());
auto x_arange = backend::arange<T>(
start, end, step, promoted_x.dtype(), promoted_x.place());
auto out_indices = masked_select<T>(
backend::arange<T>(start, end, step, DataType::INT64, x.place()),
backend::reshape<T>(backend::expand<T>(mask, expand_shape),
expand_shape_numel));

auto x_arange_reshape = reshape<T>(x_arange, promoted_x.shape());
// scatter
auto out_grad_shape = shape64<T>(out_grad);
Tensor out_grad_numel =
full<T>({1}, 1.0, out_grad_shape.dtype(), out_grad.place());
for (int i = 0; i < out_grad.dims().size(); i++) {
out_grad_numel = out_grad_numel * get_slice<T>(out_grad_shape, i);
}
auto expand_dx_flat =
scatter<T>(backend::reshape<T>(zeros, expand_shape_numel),
out_indices,
out_grad,
false);
// reshape to broadcast shape
auto expand_dx = backend::reshape<T>(expand_dx_flat, expand_shape);

auto x_index = masked_select<T>(x_arange_reshape, mask);
// reduce to original x.shape
auto dx = reduce_as<T>(expand_dx, x);

auto index_num = x_index.shape()[0];
// cast to original dtype
auto res = cast<T>(dx, x.dtype());
set_output<T>(res, x_grad);
} else {
auto x_num = 1;
for (size_t i = 0; i < promoted_x.shape().size(); i++) {
x_num *= promoted_x.shape()[i];
}

auto grad_reshape =
cast<T>(reshape<T>(promoted_out_grad, {grad_num}), promoted_x.dtype());
auto grad_num = 1;
for (size_t i = 0; i < promoted_out_grad.shape().size(); i++) {
grad_num *= promoted_out_grad.shape()[i];
}

auto grad_trans = grad_reshape;
if (grad_num > index_num) {
grad_trans = slice<T>(grad_reshape, {0}, {0}, {index_num}, {1}, {});
} else if (grad_num < index_num) {
auto pad_zeros = full<T>(
{index_num - grad_num}, 0, promoted_x.dtype(), promoted_x.place());
grad_trans = concat<T>({grad_reshape, pad_zeros}, 0);
}
auto end = full<T>({1}, x_num, promoted_x.dtype(), x.place());
auto start = full<T>({1}, 0, promoted_x.dtype(), x.place());
auto step = full<T>({1}, 1, promoted_x.dtype(), x.place());
auto x_arange = backend::arange<T>(
start, end, step, promoted_x.dtype(), promoted_x.place());

auto input_tensor =
full<T>({x_num}, 0, promoted_x.dtype(), promoted_x.place());
auto index_tensor = cast<T>(x_index, DataType::INT64);
auto update_tensor = grad_trans;
auto x_output =
scatter<T>(input_tensor, index_tensor, update_tensor, false);
auto res = cast<T>(reshape<T>(x_output, promoted_x.shape()), x.dtype());
set_output<T>(res, x_grad);
auto x_arange_reshape = reshape<T>(x_arange, promoted_x.shape());

auto x_index = masked_select<T>(x_arange_reshape, mask);

auto index_num = x_index.shape()[0];

auto grad_reshape = cast<T>(reshape<T>(promoted_out_grad, {grad_num}),
promoted_x.dtype());

auto grad_trans = grad_reshape;
if (grad_num > index_num) {
grad_trans = slice<T>(grad_reshape, {0}, {0}, {index_num}, {1}, {});
} else if (grad_num < index_num) {
auto pad_zeros = full<T>(
{index_num - grad_num}, 0, promoted_x.dtype(), promoted_x.place());
grad_trans = concat<T>({grad_reshape, pad_zeros}, 0);
}

auto input_tensor =
full<T>({x_num}, 0, promoted_x.dtype(), promoted_x.place());
auto index_tensor = cast<T>(x_index, DataType::INT64);
auto update_tensor = grad_trans;
auto x_output =
scatter<T>(input_tensor, index_tensor, update_tensor, false);
auto res = cast<T>(reshape<T>(x_output, promoted_x.shape()), x.dtype());
set_output<T>(res, x_grad);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"pd_op.p_norm",
"pd_op.elu",
"pd_op.masked_fill",
"pd_op.masked_select",
]


Expand Down
Loading