Skip to content

Commit

Permalink
fix: remove similarity_from_cameras
Browse files Browse the repository at this point in the history
  • Loading branch information
AtticusZeller committed Jul 26, 2024
1 parent f468a50 commit 1b08cc0
Show file tree
Hide file tree
Showing 12 changed files with 707 additions and 264 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## GSLoc-Slam
[ ] clean config in base.py
1. tracking
1. [x] normalize pcd and pose via PCA ->noise
1. [x] normalize pcd and pose via PCA
2. [x] update depth_gt with a proper method
3. [x] attention layer for loss,and huber loss -> not work
4. [x] find an early stop condition !!! -> depth loss
4. [x] find an early stop condition !!! -> depth loss and later than 50 step
5. [ ] refactor for better order
6. [ ] total dataset eval
1. [x] simply
Expand Down
35 changes: 33 additions & 2 deletions src/eval/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def plot_rgbd(
rastered_color: Tensor | None = None,
color_loss: dict | None = None,
silhouette_loss: dict | None = None,
normal: Tensor | None = None, # New parameter for ground truth normal
rastered_normal: Tensor | None = None, # New parameter for rastered normal
normal_loss: dict | None = None, # New parameter for normal loss
fig_title="RGBD Visualization",
):
# Ensure depth tensors have a batch dimension
Expand All @@ -126,9 +129,9 @@ def plot_rgbd(

# Determine Plot Aspect Ratio
aspect_ratio = depth.shape[2] / depth.shape[1]
fig_height = 8
fig_height = 12
fig_width = aspect_ratio * 14 / 1.55
fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height))
fig, axs = plt.subplots(3, 3, figsize=(fig_width, fig_height))

if color is not None:
if color.dim() == 3 and color.shape[1] == 3: # (H, C, W)
Expand Down Expand Up @@ -186,6 +189,34 @@ def plot_rgbd(
axs[1, 2].imshow(diff_depth.squeeze(), cmap="jet", vmin=0, vmax=6)
axs[1, 2].set_title("Diff Depth L1")

# Add normal map visualization if provided
if normal is not None and rastered_normal is not None:
# Ensure normal tensors are in the correct shape (H, W, 3)
if normal.dim() == 4:
normal = normal.squeeze(0)
if rastered_normal.dim() == 4:
rastered_normal = rastered_normal.squeeze(0)

# Convert normal vectors to RGB
normal_rgb = (normal * 0.5 + 0.5).detach().cpu()
rastered_normal_rgb = (rastered_normal * 0.5 + 0.5).detach().cpu()

axs[2, 0].imshow(normal_rgb)
axs[2, 0].set_title("Ground Truth Normal")

axs[2, 1].imshow(rastered_normal_rgb)
if normal_loss is not None:
axs[2, 1].set_title(
f"Rasterized Normal\n{normal_loss['type']}: {normal_loss['value']:.4f}"
)
else:
axs[2, 1].set_title("Rasterized Normal")

# Calculate and display normal difference
normal_diff = torch.abs(normal - rastered_normal).detach().cpu()
axs[2, 2].imshow(normal_diff, cmap="jet", vmin=0, vmax=1)
axs[2, 2].set_title("Normal Difference")

for ax in axs.flatten():
if ax.get_visible():
ax.axis("off")
Expand Down
62 changes: 42 additions & 20 deletions src/my_gsplat/datasets/Image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import kornia
import numpy as np
import torch
import torch.nn.functional as F
from numpy.compat import long
from numpy.typing import NDArray
from torch import Tensor
import torch.nn.functional as F
from ..utils import DEVICE, to_tensor, visualize_point_cloud

from ..geometry import depth_to_points
from ..utils import (
DEVICE,
to_tensor,
)


class RGBDImage:
Expand All @@ -30,7 +34,7 @@ def __init__(
depth: np.ndarray,
K: np.ndarray,
depth_scale: float,
pose: NDArray[np.float64],
pose: NDArray[np.float32],
):
if rgb.shape[0] != depth.shape[0] or rgb.shape[1] != depth.shape[1]:
raise ValueError(
Expand All @@ -39,20 +43,36 @@ def __init__(
self._rgb = to_tensor(rgb, device=DEVICE)
self._depth = to_tensor(depth / depth_scale, device=DEVICE)
self._K = to_tensor(K, device=DEVICE)

self._pose = to_tensor(pose, device=DEVICE)
self._pcd = self._project_pcds(include_homogeneous=False)
self._pcd = depth_to_points(self._depth, self._K)

# NOTE: remove outliers
# self._pcd, inlier_mask = remove_outliers(self._pcd, verbose=True)
# self._colors = (self._rgb / 255.0).reshape(-1, 3)[inlier_mask] # N,3

self._colors = (self._rgb / 255.0).reshape(-1, 3) # N,3
# Adjust camera pose for normalized point cloud

@property
def size(self):
return self._pcd.size(0)

@property
def color(self) -> Tensor:
def colors(self) -> Tensor:
"""
normed colors
Returns
-------
colors: Tensor[torch.float32], shape=(n, 3)
"""
return self._colors

@property
def rgbs(self) -> Tensor:
"""
Returns
-------
color: Tensor[torch.float64], shape=(h, w, 3)
rgb: Tensor[torch.float32], shape=(h, w, 3)
"""
return self._rgb

Expand All @@ -61,7 +81,7 @@ def depth(self) -> Tensor:
"""
Returns
-------
depth: Tensor[torch.float64], shape=(h, w)
depth: Tensor[torch.float32], shape=(h, w)
Depth image in meters.
"""

Expand All @@ -78,7 +98,7 @@ def K(self) -> Tensor:
"""
Returns
-------
K: Tensor[torch.float64], shape=(3, 3)
K: Tensor[torch.float32], shape=(3, 3)
Camera intrinsic matrix.
"""
return self._K
Expand All @@ -95,7 +115,7 @@ def pose(self) -> Tensor:
Returns
-------
pose: Tensor[torch.float64] | None, shape=(4, 4)
pose: Tensor[torch.float32] | None, shape=(4, 4)
Camera pose matrix in world coordinates.
"""
return self._pose
Expand Down Expand Up @@ -141,7 +161,9 @@ def _project_pcds(
points: Tensor
The generated point cloud, shape=(h*w, 3) or (h*w, 4).
"""
points_3d = kornia.geometry.depth_to_3d_v2(self.depth, self.K).view(-1, 3)
points_3d = kornia.geometry.depth_to_3d_v2(
self.depth, self.K, normalize_points=True
).view(-1, 3)
if include_homogeneous:
points_3d = F.pad(points_3d, (0, 1), value=1)
return points_3d
Expand All @@ -150,7 +172,7 @@ def _color_pcds(
self,
colored: bool = False, # Optional color image
include_homogeneous: bool = True,
) -> NDArray[np.float64]:
) -> NDArray[np.float32]:
"""
Generate point clouds from depth image, optionally with color.
Expand All @@ -163,7 +185,7 @@ def _color_pcds(
Returns
-------
NDArray[np.float64]
NDArray[np.float32]
The generated point cloud, shape=(h*w, 4) or (h*w, 6) or (h*w, 3) or (h*w, 7) depending on options.
"""
h, w = self._depth.shape[:2]
Expand Down Expand Up @@ -193,7 +215,7 @@ def _color_pcds(

def _pointclouds(
self, stride: int = 1, include_homogeneous=True
) -> NDArray[np.float64]:
) -> NDArray[np.float32]:
"""
Generate point clouds from depth image.
Parameters
Expand All @@ -202,7 +224,7 @@ def _pointclouds(
include_homogeneous: bool, optional, whether to include homogeneous coordinate
Returns
-------
pcd: NDArray[np.float64], shape=(h*w, 4) or (h*w, 3)
pcd: NDArray[np.float32], shape=(h*w, 4) or (h*w, 3)
"""
i_indices, j_indices, depth_downsampled = self._grid_downsample(stride)
# Transform to camera coordinates
Expand All @@ -223,8 +245,8 @@ def _pointclouds(
def _camera_to_world(
self,
c2w: np.ndarray,
pcd_c: NDArray[np.float64] | None = None,
) -> NDArray[np.float64]:
pcd_c: NDArray[np.float32] | None = None,
) -> NDArray[np.float32]:
"""
Transform points from camera coordinates to world coordinates using the c2w matrix.
:param c2w: 4x4 transformation matrix from camera to world coordinates
Expand All @@ -240,7 +262,7 @@ def _camera_to_world(
def _grid_downsample(self, stride: int = 1) -> tuple[
NDArray[np.signedinteger | long],
NDArray[np.signedinteger | long],
NDArray[np.float64],
NDArray[np.float32],
]:
"""
Parameters
Expand All @@ -252,7 +274,7 @@ def _grid_downsample(self, stride: int = 1) -> tuple[
Pixel indices along the height axis.
j_indices: NDArray[np.signedinteger | long], shape=(h, w)
Pixel indices along the width axis.
depth_downsampled: NDArray[np.float64], shape=(h, w)
depth_downsampled: NDArray[np.float32], shape=(h, w)
Downsampled depth image.
"""
# Generate pixel indices
Expand Down
10 changes: 6 additions & 4 deletions src/my_gsplat/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class OptimizationConfig:
lpips: LearnedPerceptualImagePatchSimilarity = None

early_stop: bool = True
patience = 80
patience = 200
best_eR = float("inf")
best_eT = float("inf")
best_loss = float("inf")
Expand All @@ -98,7 +98,8 @@ def init_loss(self):
@dataclass
class DepthLossConfig:
depth_loss: bool = False
depth_lambda: float = 0.5
depth_lambda: float = 0.8
normal_lambda: float = 0.0


@dataclass
Expand Down Expand Up @@ -177,14 +178,15 @@ class AlignData(TensorWrapper):

colors: Tensor # N,3
pixels: Tensor # H,W,3
points: Tensor # N,3
# points: Tensor # N,3
tar_points: Tensor
src_points: Tensor
src_depth: Tensor
tar_c2w: Tensor # 4,4
src_c2w: Tensor # 4,4
tar_nums: int # for slice tar and src
scale_factor: Tensor # for scale depth after rot normalized
# sphere_factor: Tensor # for scale depth after rot normalized
pca_factor: Tensor


@dataclass
Expand Down
55 changes: 29 additions & 26 deletions src/my_gsplat/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from ..utils import as_intrinsics_matrix, load_camera_cfg, to_tensor
from .base import AlignData, TrainData
from .Image import RGBDImage
from .normalize import normalize_2C, normalize_T
from .normalize import (
normalize_2C,
normalize_T,
)


class DataLoaderBase:
Expand Down Expand Up @@ -149,47 +152,49 @@ def __init__(self, name: str = "room0", normalize: bool = False):
# normalize points and pose
self.normalize = normalize

def __len__(self) -> int:
return super().__len__() - 1

def __getitem__(self, index: int) -> AlignData:
assert index < len(self)
tar, src = super().__getitem__(index), super().__getitem__(index + 1)
# transform to world
tar.points = transform_points(tar.pose, tar.points)
src.points = transform_points(tar.pose, src.points)
scale_factor = torch.scalar_tensor(1.0, device=tar.points.device)

# NOTE: PCA
pca_factor = torch.scalar_tensor(1.0, device=tar.points.device)
if self.normalize:
tar, src, scale_factor = normalize_2C(tar, src)

# NOTE: PCA
tar, src, pca_factor = normalize_2C(tar, src)
ks = self.K.unsqueeze(0) # [1, 3, 3]
h, w = src.depth.shape
src_rgbs = (src.color / 255.0).reshape(-1, 3)

# # NOTE: normalize_points_spherical
# tar.points, _ = normalize_points_spherical(tar.points)
# src.points, sphere_factor = normalize_points_spherical(src.points)
# tar.pose = adjust_pose_spherical(tar.pose, _)
# src.pose = adjust_pose_spherical(src.pose, sphere_factor)

# NOTE: project depth
src.depth = (
compute_depth_gt(
src.points,
src_rgbs,
src.colors,
ks,
c2w=tar.pose.unsqueeze(0),
height=h,
width=w,
)
/ scale_factor
)
# scene_scale_normed = scene_scale([tar_normed, src_normed])
# combined
points = torch.cat([tar.points, src.points], dim=0) # N,3
rgbs = torch.stack([tar.color / 255.0, src.color / 255.0], dim=0).reshape(
-1, 3
) # N,3

# / pca_factor
) # / sphere_factor
return AlignData(
scale_factor=scale_factor,
colors=rgbs,
pixels=src.color / 255.0,
points=points,
pca_factor=pca_factor,
# sphere_factor=sphere_factor,
colors=src.colors,
pixels=src.rgbs / 255.0,
# points=points,
tar_points=tar.points,
src_points=src.points,
src_depth=src.depth, # NOTE: depth need to be normalized
src_depth=src.depth,
tar_c2w=tar.pose,
src_c2w=src.pose,
tar_nums=tar.points.shape[0],
Expand All @@ -212,11 +217,9 @@ def __getitem__(self, index: int) -> TrainData:
assert index < len(self)
tar = super().__getitem__(index)

rgbs = (tar.color / 255.0).reshape(-1, 3) # N,3

return TrainData(
colors=rgbs,
pixels=tar.color / 255.0,
colors=tar.colors,
pixels=tar.rgbs / 255.0,
points=tar.points,
depth=tar.depth,
c2w=tar.pose,
Expand Down
Loading

0 comments on commit 1b08cc0

Please sign in to comment.