diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 7a9b786c4e9b..4a6c38211fee 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1071,6 +1071,16 @@ def masked_fill_(self, mask, value): self.ivy_array = self.masked_fill(mask, value).ivy_array return self + def masked_scatter(self, mask, source): + ret = self.clone() + ret.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + return ret + + + def masked_scatter_(self, mask, source): + self.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + return self + @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add( @@ -2300,10 +2310,16 @@ def corrcoef(self): def index_put(self, indices, values, accumulate=False): ret = self.clone() + def _set_add(index): + ret[index] += values + + def _set(index): + ret[index] = values + if accumulate: - ret[indices[0]] += values + ivy.map(fn=_set_add, unique={"index": indices}) else: - ret[indices[0]] = values + ivy.map(fn=_set, unique={"index": indices}) return ret def index_put_(self, indices, values, accumulate=False):