Skip to content

Commit ed70d23

Browse files
[Comp] Support masked_fill_decomp (#73225)
* support masked_fill_decomp * update decomp interface conf * fix
1 parent 68b8791 commit ed70d23

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"unbind",
7979
"unsqueeze",
8080
"unstack",
81+
"masked_fill",
8182
]
8283
decomp_rule_interface_declare_gen_op_list = (
8384
GENERATE_IMPL_DECOMP + MANUAL_IMPL_DECOMP

paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,38 @@ Tensor full_like_decomp(const Tensor& x,
600600
}
601601
}
602602

603+
template <typename T>
604+
Tensor masked_fill_decomp(const Tensor& x,
605+
const Tensor& mask,
606+
const Tensor& v) {
607+
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(mask.shape())) {
608+
// NOTE: use add operator to get broadcast shape implicitly,
609+
// which is not efficient yet, should be improved in the future.
610+
Tensor dummy_x =
611+
backend::full_with_tensor<T>(shape64<T>(x), 0.0, x.dtype(), x.place());
612+
Tensor dummy_y = backend::full_with_tensor<T>(
613+
shape64<T>(mask), 0.0, x.dtype(), x.place());
614+
Tensor dummy = dummy_x + dummy_y;
615+
Tensor mask_expanded = backend::expand<T>(mask, shape64<T>(dummy));
616+
Tensor v_expanded = backend::expand<T>(v, shape64<T>(dummy));
617+
return where<T>(mask_expanded, v_expanded, x);
618+
619+
} else {
620+
auto out_dims = phi::funcs::BroadcastTwoDims(x.dims(), mask.dims());
621+
std::vector<int64_t> out_shape = common::vectorize(out_dims);
622+
Tensor x_expanded = x;
623+
if (x.dims() != out_dims) {
624+
x_expanded = expand<T>(x_expanded, out_shape);
625+
}
626+
Tensor mask_expanded = mask;
627+
if (mask.dims() != out_dims) {
628+
mask_expanded = expand<T>(mask, out_shape);
629+
}
630+
Tensor v_expanded = expand<T>(v, out_shape);
631+
return where<T>(mask_expanded, v_expanded, x_expanded);
632+
}
633+
}
634+
603635
template <typename T>
604636
std::tuple<Tensor, Tensor> dropout_decomp(
605637
const Tensor& x,

0 commit comments

Comments
 (0)