-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prototype] Switch to spatial_size
#6736
Conversation
0e2240c
to
973fe25
Compare
spatial_size
spatial_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding comments to places where they didn't happen automatically with the IDE:
def get_num_channels_video(video: torch.Tensor) -> int: | ||
return get_num_channels_image_tensor(video) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition of get_num_channels_video
kernel.
def get_spatial_size_video(video: torch.Tensor) -> List[int]: | ||
return get_spatial_size_image_tensor(video) | ||
|
||
|
||
def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: | ||
return get_spatial_size_image_tensor(mask) | ||
|
||
|
||
@torch.jit.unused | ||
def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]: | ||
return list(bounding_box.spatial_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition of the get_spatial_size_*
kernels. The one of BBox can't have a JIT-scriptable implementation as it relies on Tensor Subclassing to retrieve this info.
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): | ||
return list(inpt.spatial_size) | ||
else: | ||
return get_spatial_size_image_pil(inpt) | ||
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring to avoid the getattr
idiom. After that, mypy complains for the PIL kernel. It's unclear to be why it thinks we return Any
. The get_spatial_size_video
returns a List[int]
.
@@ -153,7 +153,7 @@ class RandomCutmix(_BaseMixupCutmix): | |||
def _get_params(self, sample: Any) -> Dict[str, Any]: | |||
lam = float(self._dist.sample(())) | |||
|
|||
_, H, W = query_chw(sample) | |||
H, W = query_hw(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the query_chw
in favour of query_hw
where possible. This happens in multiple places in the code-base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method was renamed to query_spatial_size
after #6736 (review)
@@ -100,7 +100,7 @@ def __init__( | |||
self.p = p | |||
|
|||
def _get_params(self, sample: Any) -> Dict[str, Any]: | |||
num_channels, _, _ = query_chw(sample) | |||
num_channels, *_ = query_chw(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't introduce yet another method for extracting channels only. This is indeed less elegant but doesn't introduce any limitations as the input is required to have channels. This happens in one more place in the codebase.
def query_hw(sample: Any) -> Tuple[int, int]: | ||
flat_sample, _ = tree_flatten(sample) | ||
hws = { | ||
tuple(get_spatial_size(item)) | ||
for item in flat_sample | ||
if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox)) | ||
or features.is_simple_tensor(item) | ||
} | ||
if not hws: | ||
raise TypeError("No image, video, mask or bounding box was found in the sample") | ||
elif len(hws) > 1: | ||
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(hws))}") | ||
h, w = hws.pop() | ||
return h, w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of code duplication with query_chw
. The two methods differ on the callable, the checked types, the error messages and the return type. I was tempted to write something that passes a callable and tries to reduce duplicate code but it become unnecessarily complex. Happy to implement other approaches if you have better ideas.
I got a few failures on Windows. They don't look like related at a first glance but then this PR touches too many things, so I'm not 100% sure. I'll check again tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK to me, thanks @datumbox !
Just a minor suggestion to rename query_hw to query_spatial_size... (not blocking)
Summary: * Change `image_size` to `spatial_size` * Fix linter * Fixing more tests. * Adding get_num_channels_video and get_spatial_size_* kernels for video, masks and bboxes. * Refactor get_spatial_size * Reduce the usage of `query_chw` where possible * Rename `query_chw` to `query_spatial_size` * Adding `get_num_frames` dispatcher and kernel. * Adding jit-scriptability tests Reviewed By: NicolasHug Differential Revision: D40427485 fbshipit-source-id: 2401fe20877177459fe23181655c9cf429cb0cc5
This PR:
image_size
tospatial_size
everywhere in the code-baseget_num_channels_video
andget_spatial_size_*
kernels for videos, masks and bboxesget_num_frames
dispatcher andget_num_frames_video
kernel for JIT.query_chw
to make things work with bboxes and masks