Skip to content

Commit

Permalink
try torch.where over boolean masking
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Dec 19, 2023
1 parent 11c019b commit e6a54bf
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
float_img = torch.where(mask < 0.5, fill_img.expand_as(float_img), float_img)
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
Expand Down

0 comments on commit e6a54bf

Please sign in to comment.