Skip to content

Commit 0448bae

Browse files
fix
1 parent 82db05d commit 0448bae

File tree

1 file changed

+4
-7
lines changed
  • paddle/fluid/primitive/decomp_rule/decomp_rule

1 file changed

+4
-7
lines changed

paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -608,10 +608,10 @@ Tensor masked_fill_decomp(const Tensor& x,
608608
// NOTE: use add operator to get broadcast shape implicitly,
609609
// which is not efficient yet, should be improved in the future.
610610
Tensor dummy_x =
611-
backend::full_with_tensor<T>(shape64<T>(x), 0.0, x.dtype());
612-
Tensor dummy_y =
613-
backend::full_with_tensor<T>(shape64<T>(mask), 0.0, x.dtype());
614-
dummy = dummy_x + dummy_y;
611+
backend::full_with_tensor<T>(shape64<T>(x), 0.0, x.dtype(), x.place());
612+
Tensor dummy_y = backend::full_with_tensor<T>(
613+
shape64<T>(mask), 0.0, x.dtype(), x.place());
614+
Tensor dummy = dummy_x + dummy_y;
615615
Tensor mask_expanded = expand<T>(mask, shape64<T>(dummy));
616616
Tensor v_expanded = expand<T>(v, shape64<T>(dummy));
617617
return where<T>(mask_expanded, v_expanded, x);
@@ -622,14 +622,11 @@ Tensor masked_fill_decomp(const Tensor& x,
622622
if (x.dims() != out_dims) {
623623
x_expanded = expand<T>(x_expanded, out_dims);
624624
}
625-
626625
Tensor mask_expanded = mask;
627626
if (mask.dims() != out_dims) {
628627
mask_expanded = expand<T>(mask, out_dims);
629628
}
630-
631629
Tensor v_expanded = expand<T>(v, out_dims);
632-
633630
return where<T>(mask_expanded, v_expanded, x_expanded);
634631
}
635632
}

0 commit comments

Comments
 (0)