Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added center cropping and resize ops for PPO agents #365

Merged
merged 10 commits into from
Apr 15, 2020
1 change: 0 additions & 1 deletion habitat_baselines/agents/ppo_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def reset(self):

def act(self, observations):
batch = batch_obs([observations], device=self.device)

with torch.no_grad():
(
_,
Expand Down
145 changes: 143 additions & 2 deletions habitat_baselines/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import glob
import numbers
import os
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
from gym.spaces import Box

from habitat import logger
from habitat.utils.visualizations.utils import images_to_video
from habitat_baselines.common.tensorboard_utils import TensorboardWriter

Expand Down Expand Up @@ -53,12 +57,60 @@ def forward(self, x):
return CustomFixedCategorical(logits=x)


class ResizeCenterCropper(nn.Module):
def __init__(self, size, channels_last: bool = False):
r"""An nn module the resizes and center crops your input.
Args:
size: A sequence (w, h) or int of the size you wish to resize/center_crop.
If int, assumes square crop
channels_list: indicates if channels is the last dimension
"""
super().__init__()
if isinstance(size, numbers.Number):
size = (int(size), int(size))
assert len(size) == 2, "forced input size must be len of 2 (w, h)"
self._size = size
self.channels_last = channels_last

def transform_observation_space(
self, observation_space, trans_keys=["rgb", "depth", "semantic"]
):
size = self._size
observation_space = copy.deepcopy(observation_space)
if size:
for key in observation_space.spaces:
if (
key in trans_keys
and observation_space.spaces[key].shape != size
):
logger.info(
"Overwriting CNN input size of %s: %s" % (key, size)
)
observation_space.spaces[key] = overwrite_gym_box_shape(
observation_space.spaces[key], size
)
self.observation_space = observation_space
return observation_space

def forward(self, input: torch.Tensor):
if self._size is None:
return input

return center_crop(
image_resize_shortest_edge(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to disable image_resize_shortest_edge functionality in current setup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a way, that's why it's ResizeCenterCropper not just CenterCropper. Currently there is no way to disable it.

input, max(self._size), channels_last=self.channels_last
),
self._size,
channels_last=self.channels_last,
)


def linear_decay(epoch: int, total_num_updates: int) -> float:
r"""Returns a multiplicative factor for linear value decay

Args:
epoch: current epoch number
total_num_updates: total number of epochs
total_num_updates: total number of

Returns:
multiplicative factor that decreases param value linearly
Expand Down Expand Up @@ -174,3 +226,92 @@ def generate_video(
tb_writer.add_video_from_np_images(
f"episode{episode_id}", checkpoint_idx, images, fps=fps
)


def image_resize_shortest_edge(img, size: int, channels_last: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, specify type for img.

Suggested change
def image_resize_shortest_edge(img, size: int, channels_last: bool = False):
def image_resize_shortest_edge(img: torch.Tensor, size: int, channels_last: bool = False):

"""Resizes an img so that the shortest side is length of size while
preserving aspect ratio.

Args:
img: the array object that needs to be resized (HWC) or (NHWC)
size: the size that you want the shortest edge to be resize to
channels: a boolean that channel is the last dimension
Returns:
The resized array as a torch tensor.
"""
no_batch_dim = len(img.shape) == 3
if len(img.shape) < 3 or len(img.shape) > 5:
raise NotImplementedError()
img = _to_tensor(img)
if no_batch_dim:
img = img.unsqueeze(0) # Adds a batch dimension
if channels_last:
# NHWC
h, w = img.shape[-3:-1]
if len(img.shape) == 4:
img = img.permute(0, 3, 1, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this permutations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because PyTorch only accepts NCHW channel order for that function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007, for each img.permute can we add inline comments like # NHWC = >NCHW. Then it will be easier to support the code.

else:
img = img.permute(0, 1, 4, 2, 3)
else:
# NCHW
h, w = img.shape[-2:]

# Percentage resize
if w > h:
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
percent = size / h
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
else:
percent = size / w
h *= percent
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
w *= percent
h = int(h)
w = int(w)
img = torch.nn.functional.interpolate(
img.float(), size=(h, w), mode="area"
).to(dtype=img.dtype)
if channels_last:
if len(img.shape) == 4:
img = img.permute(0, 2, 3, 1)
else:
img = img.permute(0, 1, 3, 4, 2)
if no_batch_dim:
img = img.squeeze(dim=0) # Removes the batch dimension
return img


def center_crop(img, size, channels_last: bool = False):
"""Performs a center crop on an image.
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved

Args:
img: the array object that needs to be resized (either batched or unbatched)
size: A sequence (w, h) or a python(int) that you want cropped
channels_last: If the channels are the last dimension.
Returns:
the resized array
"""
if channels_last:
# NHWC
h, w = img.shape[-3:-1]
else:
# NCHW
h, w = img.shape[-2:]

if isinstance(size, numbers.Number):
size = (int(size), int(size))
assert len(size) == 2, "size should be (h,w) you wish to resize to"
cropx, cropy = size

startx = w // 2 - (cropx // 2)
starty = h // 2 - (cropy // 2)
if channels_last:
return img[..., starty : starty + cropy, startx : startx + cropx, :]
else:
return img[..., starty : starty + cropy, startx : startx + cropx]


def overwrite_gym_box_shape(box: Box, shape) -> Box:
if box.shape == shape:
return box
shape = list(shape) + list(box.shape[len(shape) :])
low = box.low if np.isscalar(box.low) else np.min(box.low)
high = box.high if np.isscalar(box.high) else np.max(box.high)
return Box(low=low, high=high, shape=shape, dtype=box.dtype)
5 changes: 4 additions & 1 deletion habitat_baselines/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def get_config(

for config_path in config_paths:
config.merge_from_file(config_path)

if opts:
for k, v in zip(opts[0::2], opts[1::2]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic behind this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's impossible to overwrite the BASE_TASK_CONFIG from the command line without since the BASE_TASK_CONFIG is used before its args are overwritten by the command line. Likewise, moving the code for that to this point would make it impossible to overwrite TASK_CONFIG variables from the command line. As such, BASE_TASK_CONFIG must be extracted and overwritten and then the remaining config parameters can be overwritten.

if k == "BASE_TASK_CONFIG_PATH":
config.BASE_TASK_CONFIG_PATH = v
config.TASK_CONFIG = get_task_config(config.BASE_TASK_CONFIG_PATH)
if opts:
config.CMD_TRAILING_OPTS = opts
Expand Down
1 change: 0 additions & 1 deletion habitat_baselines/config/pointnav/ddppo_pointnav.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ RL:
use_normalized_advantage: False

hidden_size: 512

DDPPO:
sync_frac: 0.6
# The PyTorch distributed backend to use
Expand Down
22 changes: 21 additions & 1 deletion habitat_baselines/rl/ddppo/policy/resnet_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym.spaces import Box

from habitat_baselines.common.utils import CategoricalNet, Flatten
from habitat_baselines.common.utils import (
CategoricalNet,
Flatten,
ResizeCenterCropper,
)
from habitat_baselines.rl.ddppo.policy import resnet
from habitat_baselines.rl.ddppo.policy.running_mean_and_var import (
RunningMeanAndVar,
Expand All @@ -30,6 +36,7 @@ def __init__(
resnet_baseplanes=32,
backbone="resnet50",
normalize_visual_inputs=False,
obs_transform=ResizeCenterCropper(size=(256, 256)),
):
super().__init__(
PointNavResNetNet(
Expand All @@ -42,6 +49,7 @@ def __init__(
backbone=backbone,
resnet_baseplanes=resnet_baseplanes,
normalize_visual_inputs=normalize_visual_inputs,
obs_transform=obs_transform,
),
action_space.n,
)
Expand All @@ -56,9 +64,16 @@ def __init__(
spatial_size=128,
make_backbone=None,
normalize_visual_inputs=False,
obs_transform=ResizeCenterCropper(size=(256, 256)),
):
super().__init__()

self.obs_transform = obs_transform
if self.obs_transform is not None:
observation_space = self.obs_transform.transform_observation_space(
observation_space
)

if "rgb" in observation_space.spaces:
self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
spatial_size = observation_space.spaces["rgb"].shape[0] // 2
Expand Down Expand Up @@ -140,6 +155,9 @@ def forward(self, observations):

cnn_input.append(depth_observations)

if self.obs_transform:
cnn_input = [self.obs_transform(inp) for inp in cnn_input]

x = torch.cat(cnn_input, dim=1)
x = F.avg_pool2d(x, 2)

Expand All @@ -165,6 +183,7 @@ def __init__(
backbone,
resnet_baseplanes,
normalize_visual_inputs,
obs_transform=ResizeCenterCropper(size=(256, 256)),
):
super().__init__()
self.goal_sensor_uuid = goal_sensor_uuid
Expand All @@ -187,6 +206,7 @@ def __init__(
ngroups=resnet_baseplanes // 2,
make_backbone=getattr(resnet, backbone),
normalize_visual_inputs=normalize_visual_inputs,
obs_transform=obs_transform,
)

if not self.visual_encoder.is_blind:
Expand Down
20 changes: 18 additions & 2 deletions habitat_baselines/rl/models/simple_cnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import torch
import torch.nn as nn
from gym.spaces import Box

from habitat_baselines.common.utils import Flatten
from habitat_baselines.common.utils import Flatten, ResizeCenterCropper


class SimpleCNN(nn.Module):
Expand All @@ -15,8 +16,20 @@ class SimpleCNN(nn.Module):
output_size: The size of the embedding vector
"""

def __init__(self, observation_space, output_size):
def __init__(
self,
observation_space,
output_size,
obs_transform: nn.Module = ResizeCenterCropper(size=(256, 256)),
):
super().__init__()

self.obs_transform = obs_transform
if self.obs_transform is not None:
observation_space = obs_transform.transform_observation_space(
observation_space
)

if "rgb" in observation_space.spaces:
self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
else:
Expand Down Expand Up @@ -142,6 +155,9 @@ def forward(self, observations):
depth_observations = depth_observations.permute(0, 3, 1, 2)
cnn_input.append(depth_observations)

if self.obs_transform:
cnn_input = [self.obs_transform(inp) for inp in cnn_input]

cnn_input = torch.cat(cnn_input, dim=1)

return self.cnn(cnn_input)
38 changes: 25 additions & 13 deletions test/test_baseline_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import habitat
from habitat.config import Config as CN

try:
from habitat_baselines.agents import ppo_agents
Expand All @@ -25,28 +26,39 @@
not baseline_installed, reason="baseline sub-module not installed"
)
def test_ppo_agents():

agent_config = ppo_agents.get_default_config()
agent_config.MODEL_PATH = ""
agent_config.defrost()
config_env = habitat.get_config(config_paths=CFG_TEST)
if not os.path.exists(config_env.SIMULATOR.SCENE):
pytest.skip("Please download Habitat test data to data folder.")

benchmark = habitat.Benchmark(config_paths=CFG_TEST)

for input_type in ["blind", "rgb", "depth", "rgbd"]:
config_env.defrost()
config_env.SIMULATOR.AGENT_0.SENSORS = []
if input_type in ["rgb", "rgbd"]:
config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"]
if input_type in ["depth", "rgbd"]:
config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"]
config_env.freeze()
del benchmark._env
benchmark._env = habitat.Env(config=config_env)
agent_config.INPUT_TYPE = input_type

agent = ppo_agents.PPOAgent(agent_config)
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))
for resolution in [256, 384]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use resolution for @pytest.mark.parametrize like here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require reconstructing the agent and benchmark for every iteration? I think there is a reason this is already done in a loop before I added the code (it would make the test a lot longer).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, do you want to test when h <> w?

config_env.defrost()
config_env.SIMULATOR.AGENT_0.SENSORS = []
if input_type in ["rgb", "rgbd"]:
config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"]
agent_config.RESOLUTION = resolution
config_env.SIMULATOR.RGB_SENSOR.WIDTH = resolution
config_env.SIMULATOR.RGB_SENSOR.HEIGHT = resolution
if input_type in ["depth", "rgbd"]:
config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"]
agent_config.RESOLUTION = resolution
config_env.SIMULATOR.DEPTH_SENSOR.WIDTH = resolution
config_env.SIMULATOR.DEPTH_SENSOR.HEIGHT = resolution

config_env.freeze()

del benchmark._env
benchmark._env = habitat.Env(config=config_env)
agent_config.INPUT_TYPE = input_type

agent = ppo_agents.PPOAgent(agent_config)
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))


@pytest.mark.skipif(
Expand Down