Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fisheye camera model #486

Merged
merged 5 commits into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions habitat_baselines/common/obs_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def angle2sphere(
)


# TODO Measure Inheritance of CubeMap2Equirec + CubeMap2FishEye into same abstract class
@baseline_registry.register_obs_transformer()
class CubeMap2Equirec(ObservationTransformer):
r"""This is an experimental use of ObservationTransformer that converts a cubemap
Expand Down Expand Up @@ -442,6 +443,304 @@ def forward(
return observations


class Cube2Fisheye(nn.Module):
r"""This is the implementation to generate fisheye images from cubemap images.
The camera model is based on the Double Sphere Camera Model (Usenko et. al.;3DV 2018).
Paper: https://arxiv.org/abs/1807.08957
"""

def __init__(
self,
fish_h: int,
fish_w: int,
fish_fov: float,
cx: float,
cy: float,
fx: float,
fy: float,
xi: float,
alpha: float,
):
"""Args:
fish_h: (int) the height of the generated fisheye
fish_w: (int) the width of the generated fisheye
fish_fov: (float) the fov of the generated fisheye in degrees
cx, cy: (float) the optical center of the generated fisheye
fx, fy, xi, alpha: (float) the fisheye camera model parameters
"""
super(Cube2Fisheye, self).__init__()
self.fish_h = fish_h
self.fish_w = fish_w
self.fish_fov = fish_fov
self.fish_param = [cx, cy, fx, fy, xi, alpha]
self.grids = self.generate_grid(
fish_h, fish_w, fish_fov, self.fish_param
)
self._grids_cache = None

def generate_grid(
self,
fish_h: int,
fish_w: int,
fish_fov: float,
fish_param: List[float],
) -> torch.Tensor:
# Project on sphere
xyz_on_sphere, fov_mask = self.get_points_on_sphere(
fish_h, fish_w, fish_fov, fish_param
)

# Rotate so that each face will be in front of camera
rotations = [
np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]), # Back
np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), # Down
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), # Front
np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]]), # Left
np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]), # Right
np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]), # Up
]
# Generate grid
grids = []
not_assigned_mask = torch.full(
(fish_h, fish_w), True, dtype=torch.bool
)
_h, _w, _ = xyz_on_sphere.shape
for rot in rotations:
R = torch.from_numpy(rot.T).float()
rotate_on_sphere = torch.matmul(
xyz_on_sphere.view((-1, 3)), R
).view(_h, _w, 3)

# Project points on z=1 plane
grid = rotate_on_sphere / torch.abs(rotate_on_sphere[..., 2:3])
mask = torch.abs(grid).max(-1)[0] <= 1 # -1 <= grid.xy <= 1
mask *= grid[..., 2] == 1
# Take care of FoV
mask *= fov_mask
# Make sure each point is only assigned to single face
mask *= not_assigned_mask
# Values bigger than one will be ignored by grid_sample
grid[~mask] = 2
# Update not_assigned_mask
not_assigned_mask *= ~mask
grid_xy = -grid[..., :2].unsqueeze(0)
grids.append(grid_xy)
grids = torch.cat(grids, dim=0)
return grids

def _to_fisheye(self, batch: torch.Tensor) -> torch.Tensor:
"""Takes a batch of cubemaps stacked in proper order and converts thems to fisheye, reduces batch size by 6"""
batch_size, ch, _H, _W = batch.shape
if batch_size == 0 or batch_size % 6 != 0:
raise ValueError("Batch size should be 6x")
output = torch.nn.functional.grid_sample(
batch,
self._grids_cache,
align_corners=True,
padding_mode="zeros",
)
output = output.view(
batch_size // 6, 6, ch, self.fish_h, self.fish_w
).sum(dim=1)
return output # batch_size // 6, ch, self.fish_h, self.fish_w

# Convert input cubic tensor to output fisheye image
def to_fisheye_tensor(self, batch: torch.Tensor) -> torch.Tensor:
batch_size = batch.size()[0]

# Check whether batch size is 6x
if batch_size == 0 or batch_size % 6 != 0:
raise ValueError("Batch size should be 6x")

# to(device) is a NOOP after the first call
self.grids = self.grids.to(batch.device)

# Cache the repeated grids for subsequent batches
if (
self._grids_cache is None
or self._grids_cache.size()[0] != batch_size
):
self._grids_cache = self.grids.repeat(batch_size // 6, 1, 1, 1)
assert self._grids_cache.size()[0] == batch_size
self._grids_cache = self._grids_cache.to(batch.device)
return self._to_fisheye(batch)

def forward(self, batch: torch.Tensor) -> torch.Tensor:
return self.to_fisheye_tensor(batch)

def get_points_on_sphere(
self,
fish_h: int,
fish_w: int,
fish_fov: float,
fish_param: List[float],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Unpack parameters
cx, cy, fx, fy, xi, alpha = fish_param
fov_rad = fish_fov / 180 * np.pi
fov_cos = np.cos(fov_rad / 2)

# Calculate unprojection
v, u = torch.meshgrid([torch.arange(fish_h), torch.arange(fish_w)])
mx = (u - cx) / fx
my = (v - cy) / fy
r2 = mx * mx + my * my
mz = (1 - alpha * alpha * r2) / (
alpha * torch.sqrt(1 - (2 * alpha - 1) * r2) + 1 - alpha
)
mz2 = mz * mz

k1 = mz * xi + torch.sqrt(mz2 + (1 - xi * xi) * r2)
k2 = mz2 + r2
k = k1 / k2

# Unprojected unit vectors
unprojected_unit = k.unsqueeze(-1) * torch.stack([mx, my, mz], dim=-1)
unprojected_unit[..., 2] -= xi
# Coordinate transformation between camera and habitat
unprojected_unit[..., 0] *= -1
unprojected_unit[..., 1] *= -1

# Calculate fov
z_axis = torch.tensor([0, 0, 1], dtype=torch.float32)
unprojected_fov_cos = torch.matmul(unprojected_unit, z_axis)
fov_mask = unprojected_fov_cos >= fov_cos
if alpha > 0.5:
fov_mask *= r2 <= (1 - (2 * alpha - 1))

return unprojected_unit, fov_mask


@baseline_registry.register_obs_transformer()
class CubeMap2Fisheye(ObservationTransformer):
r"""This is an experimental use of ObservationTransformer that converts a cubemap
output to a fisheye one through projection. This needs to be fed
a list of 6 cameras at various orientations but will be able to stitch a
fisheye image out of these inputs. The code below will generate a config that
has the 6 sensors in the proper orientations. This code also assumes a 90
FOV.

Sensor order for cubemap stiching is Back, Down, Front, Left, Right, Up.
The output will be writen the UUID of the first sensor.
"""

def __init__(
self,
sensor_uuids: List[str],
fish_shape: Tuple[int],
fish_fov: float,
fish_params: Tuple[float],
channels_last: bool = False,
target_uuids: Optional[List[str]] = None,
):
r""":param sensor: List of sensor_uuids: Back, Down, Front, Left, Right, Up.
:param fish_shape: The shape of the fisheye output (height, width)
:param fish_fov: The FoV of the fisheye output in degrees
:param fish_params: The camera parameters of fisheye output (f, xi, alpha)
:param channels_last: Are the channels last in the input
:param target_uuids: Optional List of which of the sensor_uuids to overwrite
"""
super(CubeMap2Fisheye, self).__init__()
num_sensors = len(sensor_uuids)
assert (
num_sensors % 6 == 0 and num_sensors != 0
), f"{len(sensor_uuids)}: length of sensors is not a multiple of 6"
# TODO verify attributes of the sensors in the config if possible. Think about API design
assert (
len(fish_shape) == 2
), f"fish_shape must be a tuple of (height, width), given: {fish_shape}"
assert len(fish_params) == 3
self.sensor_uuids: List[str] = sensor_uuids
self.fish_shape: Tuple[int] = fish_shape
self.channels_last: bool = channels_last
# fisheye camera parameters
fx = fish_params[0] * min(fish_shape)
fy = fx
cx = fish_shape[1] / 2
cy = fish_shape[0] / 2
xi = fish_params[1]
alpha = fish_params[2]
self.c2fish: nn.Module = Cube2Fisheye(
fish_shape[0], fish_shape[1], fish_fov, cx, cy, fx, fy, xi, alpha
)

if target_uuids == None:
self.target_uuids: List[str] = self.sensor_uuids[::6]
else:
self.target_uuids: List[str] = target_uuids
# TODO support and test different FOVs than just 90

def transform_observation_space(
self,
observation_space: SpaceDict,
):
r"""Transforms the target UUID's sensor obs_space so it matches the new shape (FISH_H, FISH_W)"""
# Transforms the observation space to of the target UUID
for i, key in enumerate(self.target_uuids):
assert (
key in observation_space.spaces
), f"{key} not found in observation space: {observation_space.spaces}"
h, w = get_image_height_width(
observation_space.spaces[key], channels_last=True
)
assert (
h == w
), f"cubemap height and width must be the same, but is {h} and {w}"
logger.info(
f"Overwrite sensor: {key} from size of ({h}, {w}) to fisheye image of {self.fish_shape} from sensors: {self.sensor_uuids[i*6:(i+1)*6]}"
)
if (h, w) != self.fish_shape:
observation_space.spaces[key] = overwrite_gym_box_shape(
observation_space.spaces[key], self.fish_shape
)
return observation_space

@classmethod
def from_config(cls, config):
cube2fish_config = config.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH
if hasattr(cube2fish_config, "TARGET_UUIDS"):
# Optional Config Value to specify target UUID
target_uuids = cube2fish_config.TARGET_UUIDS
else:
target_uuids = None
return cls(
cube2fish_config.SENSOR_UUIDS,
fish_shape=(
cube2fish_config.HEIGHT,
cube2fish_config.WIDTH,
),
fish_fov=cube2fish_config.FOV,
fish_params=cube2fish_config.PARAMS,
target_uuids=target_uuids,
)

@torch.no_grad()
def forward(
self, observations: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
for i, target_sensor_uuid in enumerate(self.target_uuids):
# The UUID we are overwriting
assert target_sensor_uuid in self.sensor_uuids[i * 6 : (i + 1) * 6]
sensor_obs = [
observations[sensor]
for sensor in self.sensor_uuids[i * 6 : (i + 1) * 6]
]
target_obs = observations[target_sensor_uuid]
sensor_dtype = target_obs.dtype
# Stacking along axis makes the flattening go in the right order.
imgs = torch.stack(sensor_obs, axis=1)
imgs = torch.flatten(imgs, end_dim=1)
if not self.channels_last:
imgs = imgs.permute((0, 3, 1, 2)) # NHWC => NCHW
imgs = imgs.float() # NCHW
fisheye = self.c2fish(imgs) # Here is where the stiching happens
fisheye = fisheye.to(dtype=sensor_dtype)
if not self.channels_last:
fisheye = fisheye.permute((0, 2, 3, 1)) # NCHW => NHWC
observations[target_sensor_uuid] = fisheye
return observations


def get_active_obs_transforms(config: Config) -> List[ObservationTransformer]:
active_obs_transforms = []
if hasattr(config.RL.POLICY, "OBS_TRANSFORMS"):
Expand Down
6 changes: 6 additions & 0 deletions habitat_baselines/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.HEIGHT = 256
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.WIDTH = 512
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = list()
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH = CN()
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.HEIGHT = 256
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.WIDTH = 256
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.FOV = 180
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.PARAMS = (0.2, 0.2, 0.2)
_C.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.SENSOR_UUIDS = list()
# -----------------------------------------------------------------------------
# PROXIMAL POLICY OPTIMIZATION (PPO)
# -----------------------------------------------------------------------------
Expand Down
19 changes: 14 additions & 5 deletions test/test_baseline_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ def test_trainers(test_cfg_path, mode, gpu2gpu, observation_transforms):
@pytest.mark.parametrize(
"test_cfg_path,mode",
[
["habitat_baselines/config/test/ppo_pointnav_test.yaml", "train"],
[
"habitat_baselines/config/test/ppo_pointnav_test.yaml",
"train",
],
],
)
def test_equirect_stiching(test_cfg_path, mode: str):
@pytest.mark.parametrize("camera", ["equirect", "fisheye"])
def test_cubemap_stiching(test_cfg_path: str, mode: str, camera: str):
meta_config = get_config(config_paths=test_cfg_path)
meta_config.defrost()
config = meta_config.TASK_CONFIG
Expand Down Expand Up @@ -127,9 +131,14 @@ def test_equirect_stiching(test_cfg_path, mode: str):

meta_config.TASK_CONFIG = config
meta_config.SENSORS = config.SIMULATOR.AGENT_0.SENSORS
meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = tuple(
sensor_uuids
)
if camera == "equirec":
meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = tuple(
sensor_uuids
)
elif camera == "fisheye":
meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.SENSOR_UUIDS = tuple(
sensor_uuids
)
meta_config.freeze()
execute_exp(meta_config, mode)
# Deinit processes group
Expand Down