-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added center cropping and resize ops for PPO agents #365
Conversation
Codecov Report
@@ Coverage Diff @@
## master #365 +/- ##
==========================================
+ Coverage 77.24% 77.27% +0.03%
==========================================
Files 108 108
Lines 7330 7425 +95
==========================================
+ Hits 5662 5738 +76
- Misses 1668 1687 +19
Continue to review full report at Codecov.
|
@@ -144,7 +144,10 @@ def get_config( | |||
|
|||
for config_path in config_paths: | |||
config.merge_from_file(config_path) | |||
|
|||
if opts: | |||
for k, v in zip(opts[0::2], opts[1::2]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic behind this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's impossible to overwrite the BASE_TASK_CONFIG from the command line without since the BASE_TASK_CONFIG is used before its args are overwritten by the command line. Likewise, moving the code for that to this point would make it impossible to overwrite TASK_CONFIG variables from the command line. As such, BASE_TASK_CONFIG must be extracted and overwritten and then the remaining config parameters can be overwritten.
habitat_baselines/common/utils.py
Outdated
@@ -53,6 +56,51 @@ def forward(self, x): | |||
return CustomFixedCategorical(logits=x) | |||
|
|||
|
|||
class ResizeCenterCropper(nn.Module): | |||
def __init__( | |||
self, force_input_size: Optional[tuple], channels_first: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
force_input_size
name is a little bit confusing. Is that input size for transformation or size of input to the next level. Maybe, use naming similar to [torchvision.transforms.CenterCrop]:(https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.CenterCrop)
size (sequence or python:int) – Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made.
Regarding channels_first
maybe call it like NCHW
and add docstring. Because, channels are not really first and to be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also resizes the input the though, it doesn't just center crop. I also kept it like so because at least for center cropping I wrote the code such that ....CHW or ...HWC (regardless of the number of channels) were supported. The channels_first is a convention from Tensorflow/Keras.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed to channels_list since that is true regardless of the channels number.
# NHWC | ||
h, w = img.shape[-3:-1] | ||
if len(img.shape) == 4: | ||
img = img.permute(0, 3, 1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to do this permutations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because PyTorch only accepts NCHW channel order for that function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007, for each img.permute
can we add inline comments like # NHWC = >NCHW
. Then it will be easier to support the code.
habitat_baselines/common/utils.py
Outdated
return observations | ||
|
||
|
||
def overwrite_gym_box(box: Box, shape: tuple) -> Box: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe overwrite_gym_box_height_width
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can also overwrite channels if specified though or any other aspects of shape.
return input | ||
|
||
return center_crop( | ||
image_resize_shortest_edge( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to disable image_resize_shortest_edge
functionality in current setup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't a way, that's why it's ResizeCenterCropper not just CenterCropper. Currently there is no way to disable it.
|
||
agent = ppo_agents.PPOAgent(agent_config) | ||
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) | ||
for resolution in [256, 384]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use resolution
for @pytest.mark.parametrize
like here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would require reconstructing the agent and benchmark for every iteration? I think there is a reason this is already done in a loop before I added the code (it would make the test a lot longer).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, do you want to test when h <> w
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requested some changes, please add doc strings as well to ResizeCenterCropper
.
habitat_baselines/common/utils.py
Outdated
@@ -174,3 +226,92 @@ def generate_video( | |||
tb_writer.add_video_from_np_images( | |||
f"episode{episode_id}", checkpoint_idx, images, fps=fps | |||
) | |||
|
|||
|
|||
def image_resize_shortest_edge(img, size: int, channels_last: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, specify type for img.
def image_resize_shortest_edge(img, size: int, channels_last: bool = False): | |
def image_resize_shortest_edge(img: torch.Tensor, size: int, channels_last: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007 thank you for following on all the comments. Looks much better now. Left some suggestions.
Quick question: do you have a sense of benchmark how performance will be affected if the resize, crop is enabled?
Thank you!
@mathfac it's not noticeably affected and it shouldn't be as all those operations are performed on the GPU. (Cropping is a straight up array view and resizing uses an optimized GPU kernel to compute it. |
@Skylion007, looks like habitat baseline config part is missing to configure ResizeCrop. |
We don't even have a good way to deserialize arbitrary transforms for this yet and walking the config down to that scope would require significant code changes at the moment so I plan on leaving it hard coded for now. |
Tested performance and functionality for ObjectNav |
* --cov-append for second pytest Add coverage flags in just the CI, not globally * Add tests without CUDA
…#365) Allows people to train PPO agents with non-square screen resolution. Add config options to automatically resize and crop the image in a performant manner.
Motivation and Context
Allows people to train PPO agents with non-square screen resolution. Add config options to automatically resize and crop the image in a performant manner.
Feedback requested on how to best implement this API and configs.
How Has This Been Tested
It has been tested by testing the utility functions on various local agents.
Types of changes
Checklist