Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Dec 12, 2024
2 parents ffd3edd + 555d554 commit e898aea
Show file tree
Hide file tree
Showing 16 changed files with 309 additions and 200 deletions.
6 changes: 5 additions & 1 deletion nerfstudio/cameras/camera_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
"""
Ks = cameras.get_intrinsics_matrices()
poses = cameras.camera_to_worlds
poses, Ks = get_interpolated_poses_many(poses, Ks, steps_per_transition=steps, order_poses=order_poses)
times = cameras.times
poses, Ks, times = get_interpolated_poses_many(
poses, Ks, times, steps_per_transition=steps, order_poses=order_poses
)

cameras = Cameras(
fx=Ks[:, 0, 0],
Expand All @@ -48,6 +51,7 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
cy=Ks[0, 1, 2],
camera_type=cameras.camera_type[0],
camera_to_worlds=poses,
times=times,
)
return cameras

Expand Down
52 changes: 42 additions & 10 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,48 +206,74 @@ def get_interpolated_k(
return Ks


def get_ordered_poses_and_k(
def get_interpolated_time(
time_a: Float[Tensor, "1"], time_b: Float[Tensor, "1"], steps: int = 10
) -> List[Float[Tensor, "1"]]:
"""
Returns interpolated time between two camera poses with specified number of steps.
Args:
time_a: camera time 1
time_b: camera time 2
steps: number of steps the interpolated pose path should contain
"""
times: List[Float[Tensor, "1"]] = []
ts = np.linspace(0, 1, steps)
for t in ts:
new_t = time_a * (1.0 - t) + time_b * t
times.append(new_t)
return times


def get_ordered_poses_and_k_and_time(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
times: Optional[Float[Tensor, "num_poses 1"]] = None,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""
Returns ordered poses and intrinsics by euclidian distance between poses.
Args:
poses: list of camera poses
Ks: list of camera intrinsics
times: list of camera times
Returns:
tuple of ordered poses and intrinsics
tuple of ordered poses, intrinsics and times
"""

poses_num = len(poses)

ordered_poses = torch.unsqueeze(poses[0], 0)
ordered_ks = torch.unsqueeze(Ks[0], 0)
ordered_times = torch.unsqueeze(times[0], 0) if times is not None else None

# remove the first pose from poses
poses = poses[1:]
Ks = Ks[1:]
times = times[1:] if times is not None else None

for _ in range(poses_num - 1):
distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
idx = torch.argmin(distances)
ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
ordered_times = torch.cat((ordered_times, torch.unsqueeze(times[idx], 0)), dim=0) if times is not None else None # type: ignore
poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
times = torch.cat((times[0:idx], times[idx + 1 :]), dim=0) if times is not None else None

return ordered_poses, ordered_ks
return ordered_poses, ordered_ks, ordered_times


def get_interpolated_poses_many(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
times: Optional[Float[Tensor, "num_poses 1"]] = None,
steps_per_transition: int = 10,
order_poses: bool = False,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""Return interpolated poses for many camera poses.
Args:
Expand All @@ -261,21 +287,27 @@ def get_interpolated_poses_many(
"""
traj = []
k_interp = []
time_interp = [] if times is not None else None

if order_poses:
poses, Ks = get_ordered_poses_and_k(poses, Ks)
poses, Ks, times = get_ordered_poses_and_k_and_time(poses, Ks, times)

for idx in range(poses.shape[0] - 1):
pose_a = poses[idx].cpu().numpy()
pose_b = poses[idx + 1].cpu().numpy()
poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
traj += poses_ab
traj += get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
if times is not None:
time_interp += get_interpolated_time(times[idx], times[idx + 1], steps=steps_per_transition) # type: ignore

traj = np.stack(traj, axis=0)
k_interp = torch.stack(k_interp, dim=0)

return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
time_interp = torch.stack(time_interp, dim=0) if time_interp is not None else None
return (
torch.tensor(traj, dtype=torch.float32),
torch.tensor(k_interp, dtype=torch.float32),
torch.tensor(time_interp, dtype=torch.float32) if time_interp is not None else None,
)


def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
Expand Down
6 changes: 3 additions & 3 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def train(self) -> None:

self._init_viewer_state()
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.max_num_iterations
num_iterations = self.config.max_num_iterations - self._start_step
step = 0
self.stop_training = False
for step in range(self._start_step, self._start_step + num_iterations):
Expand Down Expand Up @@ -478,8 +478,8 @@ def save_checkpoint(self, step: int) -> None:
)
# possibly delete old checkpoints
if self.config.save_only_latest_checkpoint:
# delete everything else in the checkpoint folder
for f in self.checkpoint_dir.glob("*"):
# delete every other checkpoint in the checkpoint folder
for f in self.checkpoint_dir.glob("*.ckpt"):
if f != ckpt_path:
f.unlink()

Expand Down
15 changes: 9 additions & 6 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
from nerfstudio.utils.math import expected_sin, generate_polyhedron_basis
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.utils.spherical_harmonics import MAX_SH_DEGREE, components_from_spherical_harmonics


class Encoding(FieldComponent):
Expand Down Expand Up @@ -756,14 +757,16 @@ class SHEncoding(Encoding):
"""Spherical harmonic encoding
Args:
levels: Number of spherical harmonic levels to encode.
levels: Number of spherical harmonic levels to encode. (level = sh degree + 1)
"""

def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
super().__init__(in_dim=3)

if levels <= 0 or levels > 4:
raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}")
if levels <= 0 or levels > MAX_SH_DEGREE + 1:
raise ValueError(
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}"
)

self.levels = levels

Expand All @@ -778,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
)

@classmethod
def get_tcnn_encoding_config(cls, levels) -> dict:
def get_tcnn_encoding_config(cls, levels: int) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "SphericalHarmonics",
Expand All @@ -792,7 +795,7 @@ def get_out_dim(self) -> int:
@torch.no_grad()
def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor)

def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
if self.tcnn_encoding is not None:
Expand Down
5 changes: 3 additions & 2 deletions nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.utils import colors
from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize
from nerfstudio.utils.math import safe_normalize
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics

BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]]
BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None
Expand Down Expand Up @@ -268,7 +269,7 @@ def forward(
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)

levels = int(math.sqrt(sh.shape[-1]))
components = components_from_spherical_harmonics(levels=levels, directions=directions)
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)

rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]
Expand Down
89 changes: 3 additions & 86 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import numpy as np
import torch
from gsplat.strategy import DefaultStrategy

Expand All @@ -42,70 +40,10 @@
from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.colors import get_color
from nerfstudio.utils.math import k_nearest_sklearn, random_quat_tensor
from nerfstudio.utils.misc import torch_compile
from nerfstudio.utils.rich_utils import CONSOLE


def num_sh_bases(degree: int) -> int:
"""
Returns the number of spherical harmonic bases for a given degree.
"""
assert degree <= 4, "We don't support degree greater than 4."
return (degree + 1) ** 2


def quat_to_rotmat(quat):
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(quat, dim=-1)
mat = torch.stack(
[
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
return mat.reshape(quat.shape[:-1] + (3, 3))


def random_quat_tensor(N):
"""
Defines a random quaternion tensor of shape (N, 4)
"""
u = torch.rand(N)
v = torch.rand(N)
w = torch.rand(N)
return torch.stack(
[
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
torch.sqrt(u) * torch.sin(2 * math.pi * w),
torch.sqrt(u) * torch.cos(2 * math.pi * w),
],
dim=-1,
)


def RGB2SH(rgb):
"""
Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
"""
C0 = 0.28209479177387814
return (rgb - 0.5) / C0


def SH2RGB(sh):
"""
Converts from the 0th spherical harmonic coefficient to RGB values [0,1]
"""
C0 = 0.28209479177387814
return sh * C0 + 0.5
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases


def resize_image(image: torch.Tensor, d: int):
Expand Down Expand Up @@ -243,8 +181,7 @@ def populate_modules(self):
means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
else:
means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
distances, _ = self.k_nearest_sklearn(means.data, 3)
distances = torch.from_numpy(distances)
distances, _ = k_nearest_sklearn(means.data, 3)
# find the average of the three nearest neighbors for each point and use that as the scale
avg_dist = distances.mean(dim=-1, keepdim=True)
scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
Expand Down Expand Up @@ -392,26 +329,6 @@ def load_state_dict(self, dict, **kwargs): # type: ignore
self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device))
super().load_state_dict(dict, **kwargs)

def k_nearest_sklearn(self, x: torch.Tensor, k: int):
"""
Find k-nearest neighbors using sklearn's NearestNeighbors.
x: The data tensor of shape [num_samples, num_features]
k: The number of neighbors to retrieve
"""
# Convert tensor to numpy array
x_np = x.cpu().numpy()

# Build the nearest neighbors model
from sklearn.neighbors import NearestNeighbors

nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)

# Find the k-nearest neighbors
distances, indices = nn_model.kneighbors(x_np)

# Exclude the point itself from the result and return
return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)

def set_crop(self, crop_box: Optional[OrientedBox]):
self.crop_box = crop_box

Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/process_data/equirect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def equirect2persp(img: torch.Tensor, fov: int, theta: int, phi: int, hd: int, w
return remap_cubic(img, lon, lat, border_mode="wrap")


def _crop_bottom(bound_arr: list, fov: int, crop_factor: float) -> List[float]:
def _crop_top(bound_arr: list, fov: int, crop_factor: float) -> List[float]:
"""Returns a list of vertical bounds with the bottom cropped.
Args:
Expand All @@ -184,7 +184,7 @@ def _crop_bottom(bound_arr: list, fov: int, crop_factor: float) -> List[float]:
return bound_arr


def _crop_top(bound_arr: list, fov: int, crop_factor: float) -> List[float]:
def _crop_bottom(bound_arr: list, fov: int, crop_factor: float) -> List[float]:
"""Returns a list of vertical bounds with the top cropped.
Args:
Expand Down
10 changes: 10 additions & 0 deletions nerfstudio/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils import comms, profiler
from nerfstudio.utils.available_devices import get_available_devices
from nerfstudio.utils.rich_utils import CONSOLE

DEFAULT_TIMEOUT = timedelta(minutes=30)
Expand Down Expand Up @@ -226,6 +227,15 @@ def launch(
def main(config: TrainerConfig) -> None:
"""Main function."""

# Check if the specified device type is available
available_device_types = get_available_devices()
if config.machine.device_type not in available_device_types:
raise RuntimeError(
f"Specified device type '{config.machine.device_type}' is not available. "
f"Available device types: {available_device_types}. "
"Please specify a valid device type using the CLI option: --machine.device_type [cuda|mps|cpu]"
)

if config.data:
CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data")
config.pipeline.datamanager.data = config.data
Expand Down
Loading

0 comments on commit e898aea

Please sign in to comment.