Skip to content

Commit

Permalink
replace .view with .reshape (#6777)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Oct 15, 2022
1 parent e2fa1f9 commit f467349
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 45 deletions.
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def forward(self, *inputs: Any) -> Any:

orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
Expand All @@ -497,9 +497,9 @@ def forward(self, *inputs: Any) -> Any:
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
) * m[:, 1].reshape([batch_dims[0], -1])

mix = m[:, 0].view(batch_dims) * batch
mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
Expand All @@ -517,8 +517,8 @@ def forward(self, *inputs: Any) -> Any:
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def _transform(
f"Got {inpt.device} vs {self.mean_vector.device}"
)

flat_tensor = inpt.view(-1, n) - self.mean_vector
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
return transformed_tensor.view(shape)
return transformed_tensor.reshape(shape)


class Normalize(Transform):
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
shape = image.shape

if image.ndim > 4:
image = image.view(-1, num_channels, height, width)
image = image.reshape(-1, num_channels, height, width)
needs_unsquash = True
else:
needs_unsquash = False

output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down Expand Up @@ -213,7 +213,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img))
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
Expand Down
66 changes: 33 additions & 33 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def horizontal_flip_bounding_box(

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
).reshape(shape)


def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -75,13 +75,13 @@ def vertical_flip_bounding_box(

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
).reshape(shape)


def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -123,7 +123,7 @@ def resize_image_tensor(
extra_dims = image.shape[:-3]

if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width)
image = image.reshape(-1, num_channels, old_height, old_width)

image = _FT.resize(
image,
Expand All @@ -132,7 +132,7 @@ def resize_image_tensor(
antialias=antialias,
)

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -168,7 +168,7 @@ def resize_bounding_box(
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
bounding_box.reshape(-1, 2, 2).mul(ratios).to(bounding_box.dtype).reshape(bounding_box.shape),
(new_height, new_width),
)

Expand Down Expand Up @@ -270,7 +270,7 @@ def affine_image_tensor(

num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
image = image.view(-1, num_channels, height, width)
image = image.reshape(-1, num_channels, height, width)

angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

Expand All @@ -283,7 +283,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width))
return output.reshape(extra_dims + (num_channels, height, width))


@torch.jit.unused
Expand Down Expand Up @@ -338,20 +338,20 @@ def _affine_bounding_box_xyxy(
dtype=dtype,
device=device,
)
.view(2, 3)
.reshape(2, 3)
.T
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
Expand Down Expand Up @@ -396,15 +396,15 @@ def affine_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)

# out_bboxes should be of shape [N boxes, 4]

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def affine_mask(
Expand Down Expand Up @@ -539,7 +539,7 @@ def rotate_image_tensor(

if image.numel() > 0:
image = _FT.rotate(
image.view(-1, num_channels, height, width),
image.reshape(-1, num_channels, height, width),
matrix,
interpolation=interpolation.value,
expand=expand,
Expand All @@ -549,7 +549,7 @@ def rotate_image_tensor(
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -585,7 +585,7 @@ def rotate_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

out_bboxes, spatial_size = _affine_bounding_box_xyxy(
bounding_box,
Expand All @@ -601,7 +601,7 @@ def rotate_bounding_box(
return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape),
).reshape(original_shape),
spatial_size,
)

Expand Down Expand Up @@ -691,15 +691,15 @@ def _pad_with_scalar_fill(

if image.numel() > 0:
image = _FT.pad(
img=image.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom
new_width = width + left + right

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
Expand All @@ -714,7 +714,7 @@ def _pad_with_vector_fill(

output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).view(-1, 1, 1)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

if top > 0:
output[..., :top, :] = fill
Expand Down Expand Up @@ -863,15 +863,15 @@ def perspective_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand All @@ -898,7 +898,7 @@ def perspective_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
Expand Down Expand Up @@ -947,7 +947,7 @@ def perspective_bounding_box(
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
Expand All @@ -959,7 +959,7 @@ def perspective_bounding_box(

# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
Expand All @@ -968,7 +968,7 @@ def perspective_bounding_box(

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def perspective_mask(
Expand Down Expand Up @@ -1027,15 +1027,15 @@ def elastic_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down Expand Up @@ -1063,7 +1063,7 @@ def elastic_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add spatial_size arg and check displacement shape
Expand All @@ -1075,21 +1075,21 @@ def elastic_bounding_box(
inv_grid = id_grid - displacement

# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5

transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def elastic_mask(
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def gaussian_blur_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.gaussian_blur(image, kernel_size, sigma)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down

0 comments on commit f467349

Please sign in to comment.