File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
paddle/fluid/primitive/decomp_rule/decomp_rule Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -608,10 +608,10 @@ Tensor masked_fill_decomp(const Tensor& x,
608608 // NOTE: use add operator to get broadcast shape implicitly,
609609 // which is not efficient yet, should be improved in the future.
610610 Tensor dummy_x =
611- backend::full_with_tensor<T>(shape64<T>(x), 0.0 , x.dtype ());
612- Tensor dummy_y =
613- backend::full_with_tensor<T>( shape64<T>(mask), 0.0 , x.dtype ());
614- dummy = dummy_x + dummy_y;
611+ backend::full_with_tensor<T>(shape64<T>(x), 0.0 , x.dtype (), x. place () );
612+ Tensor dummy_y = backend::full_with_tensor<T>(
613+ shape64<T>(mask), 0.0 , x.dtype (), x. place ());
614+ Tensor dummy = dummy_x + dummy_y;
615615 Tensor mask_expanded = expand<T>(mask, shape64<T>(dummy));
616616 Tensor v_expanded = expand<T>(v, shape64<T>(dummy));
617617 return where<T>(mask_expanded, v_expanded, x);
You can’t perform that action at this time.
0 commit comments