@@ -608,10 +608,10 @@ Tensor masked_fill_decomp(const Tensor& x,
608608 // NOTE: use add operator to get broadcast shape implicitly,
609609 // which is not efficient yet, should be improved in the future.
610610 Tensor dummy_x =
611- backend::full_with_tensor<T>(shape64<T>(x), 0.0 , x.dtype ());
612- Tensor dummy_y =
613- backend::full_with_tensor<T>( shape64<T>(mask), 0.0 , x.dtype ());
614- dummy = dummy_x + dummy_y;
611+ backend::full_with_tensor<T>(shape64<T>(x), 0.0 , x.dtype (), x. place () );
612+ Tensor dummy_y = backend::full_with_tensor<T>(
613+ shape64<T>(mask), 0.0 , x.dtype (), x. place ());
614+ Tensor dummy = dummy_x + dummy_y;
615615 Tensor mask_expanded = expand<T>(mask, shape64<T>(dummy));
616616 Tensor v_expanded = expand<T>(v, shape64<T>(dummy));
617617 return where<T>(mask_expanded, v_expanded, x);
@@ -622,14 +622,11 @@ Tensor masked_fill_decomp(const Tensor& x,
622622 if (x.dims () != out_dims) {
623623 x_expanded = expand<T>(x_expanded, out_dims);
624624 }
625-
626625 Tensor mask_expanded = mask;
627626 if (mask.dims () != out_dims) {
628627 mask_expanded = expand<T>(mask, out_dims);
629628 }
630-
631629 Tensor v_expanded = expand<T>(v, out_dims);
632-
633630 return where<T>(mask_expanded, v_expanded, x_expanded);
634631 }
635632}
0 commit comments