@@ -600,6 +600,38 @@ Tensor full_like_decomp(const Tensor& x,
600600 }
601601}
602602
603+ template <typename T>
604+ Tensor masked_fill_decomp (const Tensor& x,
605+ const Tensor& mask,
606+ const Tensor& v) {
607+ if (has_dynamic_shape (x.shape ()) || has_dynamic_shape (mask.shape ())) {
608+ // NOTE: use add operator to get broadcast shape implicitly,
609+ // which is not efficient yet, should be improved in the future.
610+ Tensor dummy_x =
611+ backend::full_with_tensor<T>(shape64<T>(x), 0.0 , x.dtype (), x.place ());
612+ Tensor dummy_y = backend::full_with_tensor<T>(
613+ shape64<T>(mask), 0.0 , x.dtype (), x.place ());
614+ Tensor dummy = dummy_x + dummy_y;
615+ Tensor mask_expanded = backend::expand<T>(mask, shape64<T>(dummy));
616+ Tensor v_expanded = backend::expand<T>(v, shape64<T>(dummy));
617+ return where<T>(mask_expanded, v_expanded, x);
618+
619+ } else {
620+ auto out_dims = phi::funcs::BroadcastTwoDims (x.dims (), mask.dims ());
621+ std::vector<int64_t > out_shape = common::vectorize (out_dims);
622+ Tensor x_expanded = x;
623+ if (x.dims () != out_dims) {
624+ x_expanded = expand<T>(x_expanded, out_shape);
625+ }
626+ Tensor mask_expanded = mask;
627+ if (mask.dims () != out_dims) {
628+ mask_expanded = expand<T>(mask, out_shape);
629+ }
630+ Tensor v_expanded = expand<T>(v, out_shape);
631+ return where<T>(mask_expanded, v_expanded, x_expanded);
632+ }
633+ }
634+
603635template <typename T>
604636std::tuple<Tensor, Tensor> dropout_decomp (
605637 const Tensor& x,
0 commit comments