Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Switch to spatial_size #6736

Merged
merged 10 commits into from
Oct 11, 2022
Merged

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 10, 2022

This PR:

  • Renames image_size to spatial_size everywhere in the code-base
  • Adds get_num_channels_video and get_spatial_size_* kernels for videos, masks and bboxes
  • Adds get_num_frames dispatcher and get_num_frames_video kernel for JIT.
  • Eliminates the over-use of query_chw to make things work with bboxes and masks

@datumbox datumbox force-pushed the prototype/image_size branch from 0e2240c to 973fe25 Compare October 10, 2022 15:53
@datumbox datumbox requested review from pmeier and vfdev-5 October 10, 2022 17:40
@datumbox datumbox changed the title [WIP] [prototype] Switch to spatial_size [prototype] Switch to spatial_size Oct 10, 2022
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding comments to places where they didn't happen automatically with the IDE:

Comment on lines +31 to +32
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition of get_num_channels_video kernel.

Comment on lines +62 to +72
def get_spatial_size_video(video: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(video)


def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(mask)


@torch.jit.unused
def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]:
return list(bounding_box.spatial_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition of the get_spatial_size_* kernels. The one of BBox can't have a JIT-scriptable implementation as it relies on Tensor Subclassing to retrieve this info.

Comment on lines 78 to 81
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
return list(inpt.spatial_size)
else:
return get_spatial_size_image_pil(inpt)
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring to avoid the getattr idiom. After that, mypy complains for the PIL kernel. It's unclear to be why it thinks we return Any. The get_spatial_size_video returns a List[int].

@@ -153,7 +153,7 @@ class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))

_, H, W = query_chw(sample)
H, W = query_hw(sample)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the query_chw in favour of query_hw where possible. This happens in multiple places in the code-base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method was renamed to query_spatial_size after #6736 (review)

@@ -100,7 +100,7 @@ def __init__(
self.p = p

def _get_params(self, sample: Any) -> Dict[str, Any]:
num_channels, _, _ = query_chw(sample)
num_channels, *_ = query_chw(sample)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't introduce yet another method for extracting channels only. This is indeed less elegant but doesn't introduce any limitations as the input is required to have channels. This happens in one more place in the codebase.

Comment on lines 101 to 114
def query_hw(sample: Any) -> Tuple[int, int]:
flat_sample, _ = tree_flatten(sample)
hws = {
tuple(get_spatial_size(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(item)
}
if not hws:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(hws) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(hws))}")
h, w = hws.pop()
return h, w
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of code duplication with query_chw. The two methods differ on the callable, the checked types, the error messages and the return type. I was tempted to write something that passes a callable and tries to reduce duplicate code but it become unnecessarily complex. Happy to implement other approaches if you have better ideas.

@datumbox
Copy link
Contributor Author

I got a few failures on Windows. They don't look like related at a first glance but then this PR touches too many things, so I'm not 100% sure. I'll check again tomorrow.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to me, thanks @datumbox !
Just a minor suggestion to rename query_hw to query_spatial_size... (not blocking)

@datumbox datumbox merged commit 4d4711d into pytorch:main Oct 11, 2022
@datumbox datumbox deleted the prototype/image_size branch October 11, 2022 09:10
facebook-github-bot pushed a commit that referenced this pull request Oct 17, 2022
Summary:
* Change `image_size` to `spatial_size`

* Fix linter

* Fixing more tests.

* Adding get_num_channels_video and get_spatial_size_* kernels for video, masks and bboxes.

* Refactor get_spatial_size

* Reduce the usage of `query_chw` where possible

* Rename `query_chw` to `query_spatial_size`

* Adding `get_num_frames` dispatcher and kernel.

* Adding jit-scriptability tests

Reviewed By: NicolasHug

Differential Revision: D40427485

fbshipit-source-id: 2401fe20877177459fe23181655c9cf429cb0cc5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants