diff --git a/habitat_baselines/common/baseline_registry.py b/habitat_baselines/common/baseline_registry.py index 6bff6d035a..2f5b2e6c8d 100644 --- a/habitat_baselines/common/baseline_registry.py +++ b/habitat_baselines/common/baseline_registry.py @@ -101,5 +101,49 @@ def get_policy(cls, name: str): r"""Get the RL policy with :p:`name`.""" return cls._get_impl("policy", name) + @classmethod + def register_obs_transformer( + cls, to_register=None, *, name: Optional[str] = None + ): + r"""Register a Observation Transformer with :p:`name`. + + :param name: Key with which the policy will be registered. + If :py:`None` will use the name of the class + + .. code:: py + + from habitat_baselines.common.obs_transformers import ObservationTransformer + from habitat_baselines.common.baseline_registry import ( + baseline_registry + ) + + @baseline_registry.register_policy + class MyObsTransformer(ObservationTransformer): + pass + + + # or + + @baseline_registry.register_policy(name="MyTransformer") + class MyObsTransformer(ObservationTransformer): + pass + + """ + from habitat_baselines.common.obs_transformers import ( + ObservationTransformer, + ) + + return cls._register_impl( + "obs_transformer", + to_register, + name, + assert_type=ObservationTransformer, + ) + + @classmethod + def get_obs_transformer(cls, name: str): + r"""Get the Observation Transformer with :p:`name`.""" + return cls._get_impl("obs_transformer", name) + baseline_registry = BaselineRegistry() diff --git a/habitat_baselines/common/obs_transformers.py b/habitat_baselines/common/obs_transformers.py new file mode 100644 index 0000000000..3936384928 --- /dev/null +++ b/habitat_baselines/common/obs_transformers.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the + +# LICENSE file in the root directory of this source tree. + +r"""This module defines various ObservationTransformers that can be used +to transform the output of the simulator before they are fed into the +policy of the neural network. This can include various useful preprocessing +including faking a semantic sensor using RGB input and MaskRCNN or faking +a depth sensor using RGB input. You can also stich together multiple sensors. +This code runs on the batched of inputs to these networks efficiently. +ObservationTransformer all run as nn.modules and can be used for encoders or +any other neural networks preprocessing steps. +Assumes the input is on CUDA. + +They also implement a function that transforms that observation space so help +fake or modify sensor input from the simulator. + +This module API is experimental and likely to change +""" +import abc +import copy +import math +import numbers +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from gym.spaces.dict_space import Dict as SpaceDict +from torch import nn + +from habitat.config import Config +from habitat.core.logging import logger +from habitat_baselines.common.baseline_registry import baseline_registry +from habitat_baselines.common.utils import ( + center_crop, + get_image_height_width, + image_resize_shortest_edge, + overwrite_gym_box_shape, +) + + +class ObservationTransformer(nn.Module, metaclass=abc.ABCMeta): + """This is the base ObservationTransformer class that all other observation + Transformers should extend. from_config must be implemented by the transformer. + transform_observation_space is only needed if the observation_space ie. + (resolution, range, or num of channels change).""" + + def transform_observation_space( + self, observation_space: SpaceDict, **kwargs + ): + return observation_space + + @classmethod + @abc.abstractmethod + def from_config(cls, config: Config): + pass + + def forward( + self, observations: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + return observations + + +@baseline_registry.register_obs_transformer() +class ResizeShortestEdge(ObservationTransformer): + r"""An nn module the resizes your the shortest edge of the input while maintaining aspect ratio. + This module assumes that all images in the batch are of the same size. + """ + + def __init__( + self, + size: int, + channels_last: bool = True, + trans_keys: Tuple[str] = ("rgb", "depth", "semantic"), + ): + """Args: + size: The size you want to resize the shortest edge to + channels_last: indicates if channels is the last dimension + """ + super(ResizeShortestEdge, self).__init__() + self._size: int = size + self.channels_last: bool = channels_last + self.trans_keys: Tuple[str] = trans_keys + + def transform_observation_space( + self, + observation_space: SpaceDict, + ): + size = self._size + observation_space = copy.deepcopy(observation_space) + if size: + for key in observation_space.spaces: + if key in self.trans_keys: + # In the observation space dict, the channels are always last + h, w = get_image_height_width( + observation_space.spaces[key], channels_last=True + ) + if size == min(h, w): + continue + scale = size / min(h, w) + new_h = int(h * scale) + new_w = int(w * scale) + new_size = (new_h, new_w) + logger.info( + "Resizing observation of %s: from %s to %s" + % (key, (h, w), new_size) + ) + observation_space.spaces[key] = overwrite_gym_box_shape( + observation_space.spaces[key], new_size + ) + return observation_space + + def _transform_obs(self, obs: torch.Tensor) -> torch.Tensor: + return image_resize_shortest_edge( + obs, self._size, channels_last=self.channels_last + ) + + @torch.no_grad() + def forward( + self, observations: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + if self._size is not None: + observations.update( + { + sensor: self._transform_obs(observations[sensor]) + for sensor in self.trans_keys + if sensor in observations + } + ) + return observations + + @classmethod + def from_config(cls, config: Config): + return cls(config.RL.POLICY.OBS_TRANSFORMS.RESIZE_SHORTEST_EDGE.SIZE) + + +@baseline_registry.register_obs_transformer() +class CenterCropper(ObservationTransformer): + """An observation transformer is a simple nn module that center crops your input.""" + + def __init__( + self, + size: Union[int, Tuple[int]], + channels_last: bool = True, + trans_keys: Tuple[str] = ("rgb", "depth", "semantic"), + ): + """Args: + size: A sequence (h, w) or int of the size you wish to resize/center_crop. + If int, assumes square crop + channels_list: indicates if channels is the last dimension + trans_keys: The list of sensors it will try to centercrop. + """ + super().__init__() + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + assert len(size) == 2, "forced input size must be len of 2 (h, w)" + self._size = size + self.channels_last = channels_last + self.trans_keys = trans_keys # TODO: Add to from_config constructor + + def transform_observation_space( + self, + observation_space: SpaceDict, + ): + size = self._size + observation_space = copy.deepcopy(observation_space) + if size: + for key in observation_space.spaces: + if ( + key in self.trans_keys + and observation_space.spaces[key].shape[-3:-1] != size + ): + h, w = get_image_height_width( + observation_space.spaces[key], channels_last=True + ) + logger.info( + "Center cropping observation size of %s from %s to %s" + % (key, (h, w), size) + ) + + observation_space.spaces[key] = overwrite_gym_box_shape( + observation_space.spaces[key], size + ) + return observation_space + + def _transform_obs(self, obs: torch.Tensor) -> torch.Tensor: + return center_crop( + obs, + self._size, + channels_last=self.channels_last, + ) + + @torch.no_grad() + def forward( + self, observations: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + if self._size is not None: + observations.update( + { + sensor: self._transform_obs(observations[sensor]) + for sensor in self.trans_keys + if sensor in observations + } + ) + return observations + + @classmethod + def from_config(cls, config: Config): + cc_config = config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER + return cls( + ( + cc_config.HEIGHT, + cc_config.WIDTH, + ) + ) + + +class Cube2Equirec(nn.Module): + """This is the backend Cube2Equirec nn.module that does the stiching. + Inspired from https://github.com/fuenwang/PanoramaUtility and + optimized for modern PyTorch.""" + + def __init__( + self, equ_h: int, equ_w: int, cube_length: int, fov: int = 90 + ): + """Args: + equ_h: (int) the height of the generated equirect + equ_w: (int) the width of the generated equirect + cube_length: (int) the length of each side of the cubemap + fov: (int) the FOV of each camera making the cubemap + """ + super(Cube2Equirec, self).__init__() + self.cube_h = cube_length + self.cube_w = cube_length + self.equ_h = equ_h + self.equ_w = equ_w + self.fov = fov + self.fov_rad = self.fov * np.pi / 180 + + # Compute the parameters for projection + assert self.cube_w == self.cube_h + self.radius = int(0.5 * cube_length) + + # Map equirectangular pixel to longitude and latitude + # NOTE: Make end a full length since arange have a right open bound [a, b) + theta_start = math.pi - (math.pi / equ_w) + theta_end = -math.pi + theta_step = 2 * math.pi / equ_w + theta_range = torch.arange(theta_start, theta_end, -theta_step) + + phi_start = 0.5 * math.pi - (0.5 * math.pi / equ_h) + phi_end = -0.5 * math.pi + phi_step = math.pi / equ_h + phi_range = torch.arange(phi_start, phi_end, -phi_step) + + # Stack to get the longitude latitude map + self.theta_map = theta_range.unsqueeze(0).repeat(equ_h, 1) + self.phi_map = phi_range.unsqueeze(-1).repeat(1, equ_w) + self.lonlat_map = torch.stack([self.theta_map, self.phi_map], dim=-1) + + # Get mapping relation (h, w, face) (orientation map) + # [back, down, front, left, right, up] => [0, 1, 2, 3, 4, 5] + + # Project each face to 3D cube and convert to pixel coordinates + self.grid, self.orientation_mask = self.get_grid2() + + def get_grid2(self): + # Get the point of equirectangular on 3D ball + x_3d = ( + self.radius * torch.cos(self.phi_map) * torch.sin(self.theta_map) + ).view(self.equ_h, self.equ_w, 1) + y_3d = (self.radius * torch.sin(self.phi_map)).view( + self.equ_h, self.equ_w, 1 + ) + z_3d = ( + self.radius * torch.cos(self.phi_map) * torch.cos(self.theta_map) + ).view(self.equ_h, self.equ_w, 1) + + self.grid_ball = torch.cat([x_3d, y_3d, z_3d], 2).view( + self.equ_h, self.equ_w, 3 + ) + + # Compute the down grid + radius_ratio_down = torch.abs(y_3d / self.radius) + grid_down_raw = self.grid_ball / radius_ratio_down.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_down_w = ( + -grid_down_raw[:, :, 0].clone() / self.radius + ).unsqueeze(-1) + grid_down_h = ( + -grid_down_raw[:, :, 2].clone() / self.radius + ).unsqueeze(-1) + grid_down = torch.cat([grid_down_w, grid_down_h], 2).unsqueeze(0) + mask_down = ( + ((grid_down_w <= 1) * (grid_down_w >= -1)) + * ((grid_down_h <= 1) * (grid_down_h >= -1)) + * (grid_down_raw[:, :, 1] == -self.radius).unsqueeze(2) + ).float() + + # Compute the up grid + radius_ratio_up = torch.abs(y_3d / self.radius) + grid_up_raw = self.grid_ball / radius_ratio_up.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_up_w = (-grid_up_raw[:, :, 0].clone() / self.radius).unsqueeze(-1) + grid_up_h = (grid_up_raw[:, :, 2].clone() / self.radius).unsqueeze(-1) + grid_up = torch.cat([grid_up_w, grid_up_h], 2).unsqueeze(0) + mask_up = ( + ((grid_up_w <= 1) * (grid_up_w >= -1)) + * ((grid_up_h <= 1) * (grid_up_h >= -1)) + * (grid_up_raw[:, :, 1] == self.radius).unsqueeze(2) + ).float() + + # Compute the front grid + radius_ratio_front = torch.abs(z_3d / self.radius) + grid_front_raw = self.grid_ball / radius_ratio_front.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_front_w = ( + -grid_front_raw[:, :, 0].clone() / self.radius + ).unsqueeze(-1) + grid_front_h = ( + -grid_front_raw[:, :, 1].clone() / self.radius + ).unsqueeze(-1) + grid_front = torch.cat([grid_front_w, grid_front_h], 2).unsqueeze(0) + mask_front = ( + ((grid_front_w <= 1) * (grid_front_w >= -1)) + * ((grid_front_h <= 1) * (grid_front_h >= -1)) + * (torch.round(grid_front_raw[:, :, 2]) == self.radius).unsqueeze( + 2 + ) + ).float() + + # Compute the back grid + radius_ratio_back = torch.abs(z_3d / self.radius) + grid_back_raw = self.grid_ball / radius_ratio_back.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_back_w = (grid_back_raw[:, :, 0].clone() / self.radius).unsqueeze( + -1 + ) + grid_back_h = ( + -grid_back_raw[:, :, 1].clone() / self.radius + ).unsqueeze(-1) + grid_back = torch.cat([grid_back_w, grid_back_h], 2).unsqueeze(0) + mask_back = ( + ((grid_back_w <= 1) * (grid_back_w >= -1)) + * ((grid_back_h <= 1) * (grid_back_h >= -1)) + * (torch.round(grid_back_raw[:, :, 2]) == -self.radius).unsqueeze( + 2 + ) + ).float() + + # Compute the right grid + radius_ratio_right = torch.abs(x_3d / self.radius) + grid_right_raw = self.grid_ball / radius_ratio_right.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_right_w = ( + -grid_right_raw[:, :, 2].clone() / self.radius + ).unsqueeze(-1) + grid_right_h = ( + -grid_right_raw[:, :, 1].clone() / self.radius + ).unsqueeze(-1) + grid_right = torch.cat([grid_right_w, grid_right_h], 2).unsqueeze(0) + mask_right = ( + ((grid_right_w <= 1) * (grid_right_w >= -1)) + * ((grid_right_h <= 1) * (grid_right_h >= -1)) + * (torch.round(grid_right_raw[:, :, 0]) == -self.radius).unsqueeze( + 2 + ) + ).float() + + # Compute the left grid + radius_ratio_left = torch.abs(x_3d / self.radius) + grid_left_raw = self.grid_ball / radius_ratio_left.view( + self.equ_h, self.equ_w, 1 + ).expand(-1, -1, 3) + grid_left_w = (grid_left_raw[:, :, 2].clone() / self.radius).unsqueeze( + -1 + ) + grid_left_h = ( + -grid_left_raw[:, :, 1].clone() / self.radius + ).unsqueeze(-1) + grid_left = torch.cat([grid_left_w, grid_left_h], 2).unsqueeze(0) + mask_left = ( + ((grid_left_w <= 1) * (grid_left_w >= -1)) + * ((grid_left_h <= 1) * (grid_left_h >= -1)) + * (torch.round(grid_left_raw[:, :, 0]) == self.radius).unsqueeze(2) + ).float() + + # Face map contains numbers correspond to that face + orientation_mask = ( + mask_back * 0 + + mask_down * 1 + + mask_front * 2 + + mask_left * 3 + + mask_right * 4 + + mask_up * 5 + ) + + return ( + torch.cat( + [ + grid_back, + grid_down, + grid_front, + grid_left, + grid_right, + grid_up, + ], + 0, + ), + orientation_mask, + ) + + # Convert cubic images to equirectangular + def _to_equirec(self, batch: torch.Tensor): + batch_size, ch, _H, _W = batch.shape + if batch_size != 6: + raise ValueError("Batch size mismatch!!") + + output = torch.zeros( + 1, ch, self.equ_h, self.equ_w, device=batch.device + ) + + for ori in range(6): + grid = self.grid[ori, :, :, :].unsqueeze( + 0 + ) # 1, self.equ_h, self.equ_w, 2 + mask = (self.orientation_mask == ori).unsqueeze( + 0 + ) # 1, self.equ_h, self.equ_w, 1 + + masked_grid = grid * mask.float().expand( + -1, -1, -1, 2 + ) # 1, self.equ_h, self.equ_w, 2 + + source_image = batch[ori].unsqueeze(0) # 1, ch, H, W + + sampled_image = torch.nn.functional.grid_sample( + source_image, + masked_grid, + align_corners=False, + padding_mode="border", + ) # 1, ch, self.equ_h, self.equ_w + + sampled_image_masked = sampled_image * ( + mask.float() + .view(1, 1, self.equ_h, self.equ_w) + .expand(1, ch, -1, -1) + ) + output = ( + output + sampled_image_masked + ) # 1, ch, self.equ_h, self.equ_w + + return output + + # Convert input cubic tensor to output equirectangular image + def to_equirec_tensor(self, batch: torch.Tensor): + # Move the params to the right device. NOOP after first call + self.grid = self.grid.to(batch.device) + self.orientation_mask = self.orientation_mask.to(batch.device) + # Check whether batch size is 6x + batch_size = batch.size()[0] + if batch_size % 6 != 0: + raise ValueError("Batch size should be 6x") + + processed = [] + for idx in range(int(batch_size / 6)): + target = batch[idx * 6 : (idx + 1) * 6, :, :, :] + target_processed = self._to_equirec(target) + processed.append(target_processed) + + output = torch.cat(processed, 0) + return output + + def forward(self, batch: torch.Tensor): + return self.to_equirec_tensor(batch) + + +@baseline_registry.register_obs_transformer() +class CubeMap2Equirec(ObservationTransformer): + r"""This is an experimental use of ObservationTransformer that converts a cubemap + output to an equirectangular one through projection. This needs to be fed + a list of 6 cameras at various orientations but will be able to stitch a + 360 sensor out of these inputs. The code below will generate a config that + has the 6 sensors in the proper orientations. This code also assumes a 90 + FOV. + + Sensor order for cubemap stiching is Back, Down, Front, Left, Right, Up. + The output will be writen the UUID of the first sensor. + """ + + def __init__( + self, + sensor_uuids: List[str], + eq_shape: Tuple[int], + cubemap_length: int, + channels_last: bool = False, + target_uuids: Optional[List[str]] = None, + ): + r""":param sensor: List of sensor_uuids: Back, Down, Front, Left, Right, Up. + :param eq_shape: The shape of the equirectangular output (height, width) + :param cubemap_length: int length of the each side of the cubemap + """ + super(CubeMap2Equirec, self).__init__() + num_sensors = len(sensor_uuids) + assert ( + num_sensors % 6 == 0 and num_sensors != 0 + ), f"{len(sensor_uuids)}: length of sensors is not a multiple of 6" + # TODO verify attributes of the sensors in the config if possible. Think about API design + assert ( + len(eq_shape) == 2 + ), f"eq_shape must be a tuple of (height, width), given: {eq_shape}" + assert ( + cubemap_length > 0 + ), f"cubemap_length must be greater than 0: provided {cubemap_length}" + self.sensor_uuids: List[str] = sensor_uuids + self.eq_shape: Tuple[int] = eq_shape + self.cubemap_length: int = cubemap_length + self.channels_last: bool = channels_last + self.c2eq: nn.Module = Cube2Equirec( + eq_shape[0], eq_shape[1], cubemap_length + ) + if target_uuids == None: + self.target_uuids: List[str] = self.sensor_uuids[::6] + else: + self.target_uuids: List[str] = target_uuids + # TODO support and test different FOVs than just 90 + + def transform_observation_space( + self, + observation_space: SpaceDict, + ): + r"""Transforms the target UUID's sensor obs_space so it matches the new shape (EQ_H, EQ_W)""" + # Transforms the observation space to of the target UUID + observation_space = copy.deepcopy(observation_space) + for i, key in enumerate(self.target_uuids): + assert ( + key in observation_space.spaces + ), f"{key} not found in observation space: {observation_space.spaces}" + c = self.cubemap_length + logger.info( + f"Overwrite sensor: {key} from size of ({c}, {c}) to equirect image of {self.eq_shape} from sensors: {self.sensor_uuids[i*6:(i+1)*6]}" + ) + if (c, c) != self.eq_shape: + observation_space.spaces[key] = overwrite_gym_box_shape( + observation_space.spaces[key], self.eq_shape + ) + return observation_space + + @classmethod + def from_config(cls, config): + cube2eq_config = config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ + if hasattr(cube2eq_config, "TARGET_UUIDS"): + # Optional Config Value to specify target UUID + target_uuids = cube2eq_config.TARGET_UUIDS + else: + target_uuids = None + return cls( + cube2eq_config.SENSOR_UUIDS, + eq_shape=( + cube2eq_config.HEIGHT, + cube2eq_config.WIDTH, + ), + cubemap_length=cube2eq_config.CUBE_LENGTH, + target_uuids=target_uuids, + ) + + @torch.no_grad() + def forward( + self, observations: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + for i in range(0, len(self.target_uuids), 6): + # The UUID we are overwriting + target_sensor_uuid = self.target_uuids[i // 6] + assert target_sensor_uuid in self.sensor_uuids[i : i + 6] + sensor_obs = [ + observations[sensor] for sensor in self.sensor_uuids[i : i + 6] + ] + sensor_dtype = observations[target_sensor_uuid].dtype + # Stacking along axis makes the flattening go in the right order. + imgs = torch.stack(sensor_obs, axis=1) + imgs = torch.flatten(imgs, end_dim=1) + if not self.channels_last: + imgs = imgs.permute((0, 3, 1, 2)) # NHWC => NCHW + imgs = imgs.float() + equirect = self.c2eq(imgs) # Here is where the stiching happens + equirect = equirect.to(dtype=sensor_dtype) + if not self.channels_last: + equirect = equirect.permute((0, 2, 3, 1)) # NCHW => NHWC + observations[target_sensor_uuid] = equirect + return observations + + +def get_active_obs_transforms(config: Config) -> List[ObservationTransformer]: + active_obs_transforms = [] + if hasattr(config.RL.POLICY, "OBS_TRANSFORMS"): + obs_transform_names = ( + config.RL.POLICY.OBS_TRANSFORMS.ENABLED_TRANSFORMS + ) + for obs_transform_name in obs_transform_names: + obs_trans_cls = baseline_registry.get_obs_transformer( + obs_transform_name + ) + obs_transform = obs_trans_cls.from_config(config) + active_obs_transforms.append(obs_transform) + return active_obs_transforms + + +def apply_obs_transforms_batch( + batch: Dict[str, torch.Tensor], + obs_transforms: Iterable[ObservationTransformer], +) -> Dict[str, torch.Tensor]: + for obs_transform in obs_transforms: + batch = obs_transform(batch) + return batch + + +def apply_obs_transforms_obs_space( + obs_space: SpaceDict, obs_transforms: Iterable[ObservationTransformer] +) -> SpaceDict: + for obs_transform in obs_transforms: + obs_space = obs_transform.transform_observation_space(obs_space) + return obs_space diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py index 6fb6df16f7..c782d022eb 100644 --- a/habitat_baselines/common/utils.py +++ b/habitat_baselines/common/utils.py @@ -4,19 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import copy import glob import numbers import os from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from gym.spaces import Box from torch import nn as nn -from habitat import logger from habitat.utils.visualizations.utils import images_to_video from habitat_baselines.common.tensorboard_utils import TensorboardWriter @@ -57,54 +55,6 @@ def forward(self, x): return CustomFixedCategorical(logits=x) -class ResizeCenterCropper(nn.Module): - def __init__(self, size, channels_last: bool = False): - r"""An nn module the resizes and center crops your input. - Args: - size: A sequence (w, h) or int of the size you wish to - resize/center_crop. If int, assumes square crop - channels_list: indicates if channels is the last dimension - """ - super().__init__() - if isinstance(size, numbers.Number): - size = (int(size), int(size)) - assert len(size) == 2, "forced input size must be len of 2 (w, h)" - self._size = size - self.channels_last = channels_last - - def transform_observation_space( - self, observation_space, trans_keys=("rgb", "depth", "semantic") - ): - size = self._size - observation_space = copy.deepcopy(observation_space) - if size: - for key in observation_space.spaces: - if ( - key in trans_keys - and observation_space.spaces[key].shape != size - ): - logger.info( - "Overwriting CNN input size of %s: %s" % (key, size) - ) - observation_space.spaces[key] = overwrite_gym_box_shape( - observation_space.spaces[key], size - ) - self.observation_space = observation_space - return observation_space - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self._size is None: - return input - - return center_crop( - image_resize_shortest_edge( - input, max(self._size), channels_last=self.channels_last - ), - self._size, - channels_last=self.channels_last, - ) - - def linear_decay(epoch: int, total_num_updates: int) -> float: r"""Returns a multiplicative factor for linear value decay @@ -127,8 +77,10 @@ def _to_tensor(v) -> torch.Tensor: return torch.tensor(v, dtype=torch.float) +@torch.no_grad() def batch_obs( - observations: List[Dict], device: Optional[torch.device] = None + observations: List[Dict], + device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: r"""Transpose a batch of observation dicts to a dict of batched observations. @@ -139,7 +91,7 @@ def batch_obs( Will not move the tensors if None Returns: - transposed dict of lists of observations. + transposed dict of torch.Tensor of observations. """ batch = defaultdict(list) @@ -148,11 +100,7 @@ def batch_obs( batch[sensor].append(_to_tensor(obs[sensor])) for sensor in batch: - batch[sensor] = ( - torch.stack(batch[sensor], dim=0) - .to(device=device) - .to(dtype=torch.float) - ) + batch[sensor] = torch.stack(batch[sensor], dim=0).to(device=device) return batch @@ -283,17 +231,14 @@ def image_resize_shortest_edge( raise NotImplementedError() if no_batch_dim: img = img.unsqueeze(0) # Adds a batch dimension + h, w = get_image_height_width(img, channels_last=channels_last) if channels_last: - h, w = img.shape[-3:-1] if len(img.shape) == 4: # NHWC -> NCHW img = img.permute(0, 3, 1, 2) else: # NDHWC -> NDCHW img = img.permute(0, 1, 4, 2, 3) - else: - # ..HW - h, w = img.shape[-2:] # Percentage resize scale = size / min(h, w) @@ -314,28 +259,24 @@ def image_resize_shortest_edge( return img -def center_crop(img, size, channels_last: bool = False): +def center_crop( + img, size: Union[int, Tuple[int]], channels_last: bool = False +): """Performs a center crop on an image. Args: - img: the array object that needs to be resized (either batched or - unbatched) - size: A sequence (w, h) or a python(int) that you want cropped + img: the array object that needs to be resized (either batched or unbatched) + size: A sequence (h, w) or a python(int) that you want cropped channels_last: If the channels are the last dimension. Returns: the resized array """ - if channels_last: - # NHWC - h, w = img.shape[-3:-1] - else: - # NCHW - h, w = img.shape[-2:] + h, w = get_image_height_width(img, channels_last=channels_last) if isinstance(size, numbers.Number): size = (int(size), int(size)) assert len(size) == 2, "size should be (h,w) you wish to resize to" - cropx, cropy = size + cropy, cropx = size startx = w // 2 - (cropx // 2) starty = h // 2 - (cropy // 2) @@ -345,6 +286,20 @@ def center_crop(img, size, channels_last: bool = False): return img[..., starty : starty + cropy, startx : startx + cropx] +def get_image_height_width( + img: Union[np.ndarray, torch.Tensor], channels_last: bool = False +): + if img.shape is None or len(img.shape) < 3 or len(img.shape) > 5: + raise NotImplementedError() + if channels_last: + # NHWC + h, w = img.shape[-3:-1] + else: + # NCHW + h, w = img.shape[-2:] + return h, w + + def overwrite_gym_box_shape(box: Box, shape) -> Box: if box.shape == shape: return box diff --git a/habitat_baselines/config/default.py b/habitat_baselines/config/default.py index 9517ef56b5..f0b91df2e5 100644 --- a/habitat_baselines/config/default.py +++ b/habitat_baselines/config/default.py @@ -59,6 +59,21 @@ _C.RL.POLICY = CN() _C.RL.POLICY.name = "PointNavBaselinePolicy" # ----------------------------------------------------------------------------- +# OBS_TRANSFORMS CONFIG +# ----------------------------------------------------------------------------- +_C.RL.POLICY.OBS_TRANSFORMS = CN() +_C.RL.POLICY.OBS_TRANSFORMS.ENABLED_TRANSFORMS = tuple() +_C.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER = CN() +_C.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER.HEIGHT = 256 +_C.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER.WIDTH = 256 +_C.RL.POLICY.OBS_TRANSFORMS.RESIZE_SHORTEST_EDGE = CN() +_C.RL.POLICY.OBS_TRANSFORMS.RESIZE_SHORTEST_EDGE.SIZE = 256 +_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ = CN() +_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.HEIGHT = 256 +_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.WIDTH = 512 +_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.CUBE_LENGTH = 256 +_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = list() +# ----------------------------------------------------------------------------- # PROXIMAL POLICY OPTIMIZATION (PPO) # ----------------------------------------------------------------------------- _C.RL.PPO = CN() diff --git a/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py b/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py index 57c194e14d..9236a8ae7d 100644 --- a/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py +++ b/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py @@ -22,6 +22,11 @@ from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.env_utils import construct_envs from habitat_baselines.common.environments import get_env_class +from habitat_baselines.common.obs_transformers import ( + apply_obs_transforms_batch, + apply_obs_transforms_obs_space, + get_active_obs_transforms, +) from habitat_baselines.common.rollout_storage import RolloutStorage from habitat_baselines.common.tensorboard_utils import TensorboardWriter from habitat_baselines.common.utils import batch_obs, linear_decay @@ -70,7 +75,15 @@ def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) - self.actor_critic = policy.from_config(self.config, self.envs) + self.obs_transforms = get_active_obs_transforms(self.config) + observation_space = self.envs.observation_spaces[0] + observation_space = apply_obs_transforms_obs_space( + observation_space, self.obs_transforms + ) + self.actor_critic = policy.from_config( + self.config, observation_space, self.envs.action_spaces[0] + ) + self.obs_space = observation_space self.actor_critic.to(self.device) if ( @@ -188,8 +201,9 @@ def train(self) -> None: observations = self.envs.reset() batch = batch_obs(observations, device=self.device) + batch = apply_obs_transforms_batch(batch, self.obs_transforms) - obs_space = self.envs.observation_spaces[0] + obs_space = self.obs_space if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict( diff --git a/habitat_baselines/rl/ddppo/policy/resnet_policy.py b/habitat_baselines/rl/ddppo/policy/resnet_policy.py index d522955121..7dac7015d2 100644 --- a/habitat_baselines/rl/ddppo/policy/resnet_policy.py +++ b/habitat_baselines/rl/ddppo/policy/resnet_policy.py @@ -5,12 +5,16 @@ # LICENSE file in the root directory of this source tree. +from typing import Dict, Tuple + import numpy as np import torch from gym import spaces +from gym.spaces.dict_space import Dict as SpaceDict from torch import nn as nn from torch.nn import functional as F +from habitat.config import Config from habitat.tasks.nav.nav import ( EpisodicCompassSensor, EpisodicGPSSensor, @@ -22,7 +26,7 @@ ) from habitat.tasks.nav.object_nav_task import ObjectGoalSensor from habitat_baselines.common.baseline_registry import baseline_registry -from habitat_baselines.common.utils import Flatten, ResizeCenterCropper +from habitat_baselines.common.utils import Flatten from habitat_baselines.rl.ddppo.policy import resnet from habitat_baselines.rl.ddppo.policy.running_mean_and_var import ( RunningMeanAndVar, @@ -35,16 +39,15 @@ class PointNavResNetPolicy(Policy): def __init__( self, - observation_space, + observation_space: SpaceDict, action_space, - hidden_size=512, - num_recurrent_layers=2, - rnn_type="LSTM", - resnet_baseplanes=32, - backbone="resnet50", - normalize_visual_inputs=False, - obs_transform=ResizeCenterCropper(size=(256, 256)), # noqa : B008 - force_blind_policy=False, + hidden_size: int = 512, + num_recurrent_layers: int = 2, + rnn_type: str = "LSTM", + resnet_baseplanes: int = 32, + backbone: str = "resnet50", + normalize_visual_inputs: bool = False, + force_blind_policy: bool = False, **kwargs ): super().__init__( @@ -57,22 +60,23 @@ def __init__( backbone=backbone, resnet_baseplanes=resnet_baseplanes, normalize_visual_inputs=normalize_visual_inputs, - obs_transform=obs_transform, force_blind_policy=force_blind_policy, ), action_space.n, ) @classmethod - def from_config(cls, config, envs): + def from_config( + cls, config: Config, observation_space: SpaceDict, action_space + ): return cls( - observation_space=envs.observation_spaces[0], - action_space=envs.action_spaces[0], + observation_space=observation_space, + action_space=action_space, hidden_size=config.RL.PPO.hidden_size, rnn_type=config.RL.DDPPO.rnn_type, num_recurrent_layers=config.RL.DDPPO.num_recurrent_layers, backbone=config.RL.DDPPO.backbone, - normalize_visual_inputs="rgb" in envs.observation_spaces[0].spaces, + normalize_visual_inputs="rgb" in observation_space.spaces, force_blind_policy=config.FORCE_BLIND_POLICY, ) @@ -80,22 +84,15 @@ def from_config(cls, config, envs): class ResNetEncoder(nn.Module): def __init__( self, - observation_space, - baseplanes=32, - ngroups=32, - spatial_size=128, + observation_space: SpaceDict, + baseplanes: int = 32, + ngroups: int = 32, + spatial_size: int = 128, make_backbone=None, - normalize_visual_inputs=False, - obs_transform=ResizeCenterCropper(size=(256, 256)), # noqa: B008 + normalize_visual_inputs: bool = False, ): super().__init__() - self.obs_transform = obs_transform - if self.obs_transform is not None: - observation_space = self.obs_transform.transform_observation_space( - observation_space - ) - if "rgb" in observation_space.spaces: self._n_input_rgb = observation_space.spaces["rgb"].shape[2] spatial_size = observation_space.spaces["rgb"].shape[0] // 2 @@ -157,7 +154,7 @@ def layer_init(self): if layer.bias is not None: nn.init.constant_(layer.bias, val=0) - def forward(self, observations): + def forward(self, observations: Dict[str, torch.Tensor]) -> torch.Tensor: if self.is_blind: return None @@ -177,9 +174,6 @@ def forward(self, observations): cnn_input.append(depth_observations) - if self.obs_transform: - cnn_input = [self.obs_transform(inp) for inp in cnn_input] - x = torch.cat(cnn_input, dim=1) x = F.avg_pool2d(x, 2) @@ -196,16 +190,15 @@ class PointNavResNetNet(Net): def __init__( self, - observation_space, + observation_space: SpaceDict, action_space, - hidden_size, - num_recurrent_layers, - rnn_type, + hidden_size: int, + num_recurrent_layers: int, + rnn_type: str, backbone, resnet_baseplanes, - normalize_visual_inputs, - obs_transform=ResizeCenterCropper(size=(256, 256)), # noqa: B008 - force_blind_policy=False, + normalize_visual_inputs: bool, + force_blind_policy: bool = False, ): super().__init__() @@ -288,7 +281,6 @@ def __init__( ngroups=resnet_baseplanes // 2, make_backbone=getattr(resnet, backbone), normalize_visual_inputs=normalize_visual_inputs, - obs_transform=obs_transform, ) self.goal_visual_fc = nn.Sequential( @@ -309,7 +301,6 @@ def __init__( ngroups=resnet_baseplanes // 2, make_backbone=getattr(resnet, backbone), normalize_visual_inputs=normalize_visual_inputs, - obs_transform=obs_transform, ) if not self.visual_encoder.is_blind: @@ -342,7 +333,13 @@ def is_blind(self): def num_recurrent_layers(self): return self.state_encoder.num_recurrent_layers - def forward(self, observations, rnn_hidden_states, prev_actions, masks): + def forward( + self, + observations: Dict[str, torch.Tensor], + rnn_hidden_states, + prev_actions, + masks, + ) -> Tuple[torch.Tensor]: x = [] if not self.is_blind: if "visual_features" in observations: diff --git a/habitat_baselines/rl/models/simple_cnn.py b/habitat_baselines/rl/models/simple_cnn.py index 90f87e9698..c8ca5fff10 100644 --- a/habitat_baselines/rl/models/simple_cnn.py +++ b/habitat_baselines/rl/models/simple_cnn.py @@ -1,8 +1,10 @@ +from typing import Dict + import numpy as np import torch from torch import nn as nn -from habitat_baselines.common.utils import Flatten, ResizeCenterCropper +from habitat_baselines.common.utils import Flatten class SimpleCNN(nn.Module): @@ -19,18 +21,9 @@ def __init__( self, observation_space, output_size, - obs_transform: nn.Module = ResizeCenterCropper( # noqa: B008 - size=(256, 256) - ), ): super().__init__() - self.obs_transform = obs_transform - if self.obs_transform is not None: - observation_space = obs_transform.transform_observation_space( - observation_space - ) - if "rgb" in observation_space.spaces: self._n_input_rgb = observation_space.spaces["rgb"].shape[2] else: @@ -141,7 +134,7 @@ def layer_init(self): def is_blind(self): return self._n_input_rgb + self._n_input_depth == 0 - def forward(self, observations): + def forward(self, observations: Dict[str, torch.Tensor]): cnn_input = [] if self._n_input_rgb > 0: rgb_observations = observations["rgb"] @@ -156,9 +149,6 @@ def forward(self, observations): depth_observations = depth_observations.permute(0, 3, 1, 2) cnn_input.append(depth_observations) - if self.obs_transform: - cnn_input = [self.obs_transform(inp) for inp in cnn_input] - cnn_input = torch.cat(cnn_input, dim=1) return self.cnn(cnn_input) diff --git a/habitat_baselines/rl/ppo/policy.py b/habitat_baselines/rl/ppo/policy.py index 542e8a0192..64f436825c 100644 --- a/habitat_baselines/rl/ppo/policy.py +++ b/habitat_baselines/rl/ppo/policy.py @@ -7,8 +7,10 @@ import torch from gym import spaces +from gym.spaces.dict_space import Dict as SpaceDict from torch import nn as nn +from habitat.config import Config from habitat.tasks.nav.nav import ( ImageGoalSensor, IntegratedPointGoalGPSAndCompassSensor, @@ -79,7 +81,7 @@ def evaluate_actions( @classmethod @abc.abstractmethod - def from_config(cls, config, envs): + def from_config(cls, config, observation_space, action_space): pass @@ -97,20 +99,28 @@ def forward(self, x): @baseline_registry.register_policy class PointNavBaselinePolicy(Policy): def __init__( - self, observation_space, action_space, hidden_size=512, **kwargs + self, + observation_space: SpaceDict, + action_space, + hidden_size: int = 512, + **kwargs ): super().__init__( PointNavBaselineNet( - observation_space=observation_space, hidden_size=hidden_size + observation_space=observation_space, + hidden_size=hidden_size, + **kwargs, ), action_space.n, ) @classmethod - def from_config(cls, config, envs): + def from_config( + cls, config: Config, observation_space: SpaceDict, action_space + ): return cls( - observation_space=envs.observation_spaces[0], - action_space=envs.action_spaces[0], + observation_space=observation_space, + action_space=action_space, hidden_size=config.RL.PPO.hidden_size, ) @@ -141,7 +151,11 @@ class PointNavBaselineNet(Net): goal vector with CNN's output and passes that through RNN. """ - def __init__(self, observation_space, hidden_size): + def __init__( + self, + observation_space: SpaceDict, + hidden_size: int, + ): super().__init__() if ( diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py index ecfa5a1ff0..e7f4514800 100644 --- a/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat_baselines/rl/ppo/ppo_trainer.py @@ -20,6 +20,11 @@ from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.env_utils import construct_envs from habitat_baselines.common.environments import get_env_class +from habitat_baselines.common.obs_transformers import ( + apply_obs_transforms_batch, + apply_obs_transforms_obs_space, + get_active_obs_transforms, +) from habitat_baselines.common.rollout_storage import RolloutStorage from habitat_baselines.common.tensorboard_utils import TensorboardWriter from habitat_baselines.common.utils import ( @@ -42,6 +47,7 @@ def __init__(self, config=None): self.actor_critic = None self.agent = None self.envs = None + self.obs_transforms = [] if config is not None: logger.info(f"config: {config}") @@ -60,7 +66,15 @@ def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) - self.actor_critic = policy.from_config(self.config, self.envs) + observation_space = self.envs.observation_spaces[0] + self.obs_transforms = get_active_obs_transforms(self.config) + observation_space = apply_obs_transforms_obs_space( + observation_space, self.obs_transforms + ) + self.obs_space = observation_space + self.actor_critic = policy.from_config( + self.config, observation_space, self.envs.action_spaces[0] + ) self.actor_critic.to(self.device) self.agent = PPO( @@ -187,6 +201,8 @@ def _collect_rollout_step( t_update_stats = time.time() batch = batch_obs(observations, device=self.device) + batch = apply_obs_transforms_batch(batch, self.obs_transforms) + rewards = torch.tensor( rewards, dtype=torch.float, device=current_episode_reward.device ) @@ -289,7 +305,7 @@ def train(self) -> None: rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, - self.envs.observation_spaces[0], + self.obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, ) @@ -297,6 +313,7 @@ def train(self) -> None: observations = self.envs.reset() batch = batch_obs(observations, device=self.device) + batch = apply_obs_transforms_batch(batch, self.obs_transforms) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) @@ -473,6 +490,7 @@ def _eval_checkpoint( observations = self.envs.reset() batch = batch_obs(observations, device=self.device) + batch = apply_obs_transforms_batch(batch, self.obs_transforms) current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device @@ -541,6 +559,7 @@ def _eval_checkpoint( list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) + batch = apply_obs_transforms_batch(batch, self.obs_transforms) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], @@ -594,7 +613,10 @@ def _eval_checkpoint( # episode continues elif len(self.config.VIDEO_OPTION) > 0: - frame = observations_to_image(observations[i], infos[i]) + # TODO move normalization / channel changing out of the policy and undo it here + frame = observations_to_image( + {k: v[i] for k, v in batch.items()}, infos[i] + ) rgb_frames[i].append(frame) ( diff --git a/habitat_baselines/run.py b/habitat_baselines/run.py index 433ad10efd..a5c71b4633 100644 --- a/habitat_baselines/run.py +++ b/habitat_baselines/run.py @@ -10,6 +10,7 @@ import numpy as np import torch +from habitat.config import Config from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.config.default import get_config @@ -39,19 +40,12 @@ def main(): run_exp(**vars(args)) -def run_exp(exp_config: str, run_type: str, opts=None) -> None: - r"""Runs experiment given mode and config - +def execute_exp(config: Config, run_type: str) -> None: + r"""This function runs the specified config with the specified runtype Args: - exp_config: path to config file. - run_type: "train" or "eval. - opts: list of strings of additional config options. - - Returns: - None. + config: Habitat.config + runtype: str {train or eval} """ - config = get_config(exp_config, opts) - random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) torch.manual_seed(config.TASK_CONFIG.SEED) @@ -66,5 +60,20 @@ def run_exp(exp_config: str, run_type: str, opts=None) -> None: trainer.eval() +def run_exp(exp_config: str, run_type: str, opts=None) -> None: + r"""Runs experiment given mode and config + + Args: + exp_config: path to config file. + run_type: "train" or "eval. + opts: list of strings of additional config options. + + Returns: + None. + """ + config = get_config(exp_config, opts) + execute_exp(config, run_type) + + if __name__ == "__main__": main() diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index 62b9f54015..e0072cc060 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import itertools import os import pytest @@ -23,7 +24,16 @@ @pytest.mark.skipif( not baseline_installed, reason="baseline sub-module not installed" ) -def test_ppo_agents(): +@pytest.mark.parametrize( + "input_type,resolution", + [ + (i_type, resolution) + for i_type, resolution in itertools.product( + ["blind", "rgb", "depth", "rgbd"], [256, 384] + ) + ], +) +def test_ppo_agents(input_type, resolution): agent_config = ppo_agents.get_default_config() agent_config.MODEL_PATH = "" @@ -34,29 +44,27 @@ def test_ppo_agents(): benchmark = habitat.Benchmark(config_paths=CFG_TEST) - for input_type in ["blind", "rgb", "depth", "rgbd"]: - for resolution in [256, 384]: - config_env.defrost() - config_env.SIMULATOR.AGENT_0.SENSORS = [] - if input_type in ["rgb", "rgbd"]: - config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"] - agent_config.RESOLUTION = resolution - config_env.SIMULATOR.RGB_SENSOR.WIDTH = resolution - config_env.SIMULATOR.RGB_SENSOR.HEIGHT = resolution - if input_type in ["depth", "rgbd"]: - config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"] - agent_config.RESOLUTION = resolution - config_env.SIMULATOR.DEPTH_SENSOR.WIDTH = resolution - config_env.SIMULATOR.DEPTH_SENSOR.HEIGHT = resolution - - config_env.freeze() - - del benchmark._env - benchmark._env = habitat.Env(config=config_env) - agent_config.INPUT_TYPE = input_type - - agent = ppo_agents.PPOAgent(agent_config) - habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) + config_env.defrost() + config_env.SIMULATOR.AGENT_0.SENSORS = [] + if input_type in ["rgb", "rgbd"]: + config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"] + agent_config.RESOLUTION = resolution + config_env.SIMULATOR.RGB_SENSOR.WIDTH = resolution + config_env.SIMULATOR.RGB_SENSOR.HEIGHT = resolution + if input_type in ["depth", "rgbd"]: + config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"] + agent_config.RESOLUTION = resolution + config_env.SIMULATOR.DEPTH_SENSOR.WIDTH = resolution + config_env.SIMULATOR.DEPTH_SENSOR.HEIGHT = resolution + + config_env.freeze() + + del benchmark._env + benchmark._env = habitat.Env(config=config_env) + agent_config.INPUT_TYPE = input_type + + agent = ppo_agents.PPOAgent(agent_config) + habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) @pytest.mark.skipif( diff --git a/test/test_baseline_trainers.py b/test/test_baseline_trainers.py index 7aef7f2bb1..555b4a3173 100644 --- a/test/test_baseline_trainers.py +++ b/test/test_baseline_trainers.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import itertools +import math import random +from copy import deepcopy from glob import glob import pytest @@ -16,25 +18,39 @@ from habitat_baselines.common.base_trainer import BaseRLTrainer from habitat_baselines.config.default import get_config - from habitat_baselines.run import run_exp + from habitat_baselines.run import execute_exp, run_exp baseline_installed = True except ImportError: baseline_installed = False +def _powerset(s): + return [ + combo + for r in range(len(s) + 1) + for combo in itertools.combinations(s, r) + ] + + @pytest.mark.skipif( not baseline_installed, reason="baseline sub-module not installed" ) @pytest.mark.parametrize( - "test_cfg_path,mode,gpu2gpu", + "test_cfg_path,mode,gpu2gpu,observation_transforms", itertools.product( glob("habitat_baselines/config/test/*"), ["train", "eval"], [True, False], + _powerset( + [ + "CenterCropper", + "ResizeShortestEdge", + ] + ), ), ) -def test_trainers(test_cfg_path, mode, gpu2gpu): +def test_trainers(test_cfg_path, mode, gpu2gpu, observation_transforms): if gpu2gpu: try: import habitat_sim @@ -47,7 +63,12 @@ def test_trainers(test_cfg_path, mode, gpu2gpu): run_exp( test_cfg_path, mode, - ["TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.GPU_GPU", str(gpu2gpu)], + [ + "TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.GPU_GPU", + str(gpu2gpu), + "RL.POLICY.OBS_TRANSFORMS.ENABLED_TRANSFORMS", + str(tuple(observation_transforms)), + ], ) # Deinit processes group @@ -55,6 +76,67 @@ def test_trainers(test_cfg_path, mode, gpu2gpu): torch.distributed.destroy_process_group() +@pytest.mark.skipif( + not baseline_installed, reason="baseline sub-module not installed" +) +@pytest.mark.parametrize( + "test_cfg_path,mode", + itertools.product( + glob("habitat_baselines/config/test/*pointnav_test.yaml"), + ["train", "eval"], + ), +) +def test_equirect_stiching(test_cfg_path, mode: str): + meta_config = get_config(config_paths=test_cfg_path) + meta_config.defrost() + config = meta_config.TASK_CONFIG + CAMERA_NUM = 6 + orient = [ + [0, math.pi, 0], # Back + [-math.pi / 2, 0, 0], # Down + [0, 0, 0], # Front + [0, math.pi / 2, 0], # Right + [0, 3 / 2 * math.pi, 0], # Left + [math.pi / 2, 0, 0], # Up + ] + sensor_uuids = [] + + if "RGB_SENSOR" in config.SIMULATOR.AGENT_0.SENSORS: + config.SIMULATOR.RGB_SENSOR.ORIENTATION = orient[0] + for camera_id in range(1, CAMERA_NUM): + camera_template = f"RGB_{camera_id}" + camera_config = deepcopy(config.SIMULATOR.RGB_SENSOR) + camera_config.ORIENTATION = orient[camera_id] + + camera_config.UUID = camera_template.lower() + sensor_uuids.append(camera_config.UUID) + setattr(config.SIMULATOR, camera_template, camera_config) + config.SIMULATOR.AGENT_0.SENSORS.append(camera_template) + + if "DEPTH_SENSOR" in config.SIMULATOR.AGENT_0.SENSORS: + config.SIMULATOR.DEPTH_SENSOR.ORIENTATION = orient[0] + for camera_id in range(1, CAMERA_NUM): + camera_template = f"DEPTH_{camera_id}" + camera_config = deepcopy(config.SIMULATOR.DEPTH_SENSOR) + camera_config.ORIENTATION = orient[camera_id] + camera_config.UUID = camera_template.lower() + sensor_uuids.append(camera_config.UUID) + + setattr(config.SIMULATOR, camera_template, camera_config) + config.SIMULATOR.AGENT_0.SENSORS.append(camera_template) + + meta_config.TASK_CONFIG = config + meta_config.SENSORS = config.SIMULATOR.AGENT_0.SENSORS + meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = tuple( + sensor_uuids + ) + meta_config.freeze() + execute_exp(meta_config, mode) + # Deinit processes group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + @pytest.mark.skipif( not baseline_installed, reason="baseline sub-module not installed" )