-
Notifications
You must be signed in to change notification settings - Fork 19
/
scene_representation.py
601 lines (513 loc) · 31.7 KB
/
scene_representation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
import os
import torch
import torchvision
import numpy as np
import cv2
import math
from opt import get_opts, BLENDER_PATH, ROOT_DIR
import time
import glob
from tqdm import tqdm
import json
from blender import blend_all
# from lighting.ldr2hdr import convert_ldr2hdr
from lighting.difflight import get_envmap_from_single_view
from inpaint.inpaint_anything import inpaint_img
from sugar.sugar_scene.cameras import load_gs_cameras, GSCamera
from sugar.gaussian_splatting.scene.gaussian_model import GaussianModel
# from sugar.gaussian_splatting.scene.cameras import Camera
from sugar.gaussian_splatting.utils.graphics_utils import focal2fov, fov2focal
# from sugar.gaussian_splatting.arguments import PipelineParams
from sugar.sugar_scene.gs_model import PipelineParams, OptimizationParams
from sugar.gaussian_splatting.gaussian_renderer import render
from sugar.gaussian_splatting.render import generate_video_from_frames, depth2img
from sugar.gaussian_splatting.render_panorama import render_panorama
from rich.console import Console
# from blender.static_rendering import run_blender_render as render_all_from_blender
from gaussians_utils import load_gaussians, merge_two_gaussians, transform_gaussians, get_center_of_mesh, get_center_of_mesh_2
from random import randint
from sugar.gaussian_splatting.utils.loss_utils import l1_loss, ssim
from PIL import Image
from sugar.sugar_utils.general_utils import PILtoTorch
import copy
from inpaint.retrain_utils import compute_lpips_loss, init_lpips_model, is_large_mask
from sugar.gaussian_splatting.utils.loss_utils import ssim
import open3d as o3d
import trimesh
CONSOLE = Console(width=120)
class SceneRepresentation():
def __init__(self, hparams):
self.hparams = hparams
self.load_scene()
self.load_cameras()
self.dataset_dir = hparams.source_path
self.results_dir = hparams.model_path
os.makedirs(os.path.join(self.results_dir), exist_ok=True)
custom_traj_name = hparams.custom_traj_name if hparams.custom_traj_name is not None else 'training_cameras'
self.traj_results_dir = os.path.join(self.results_dir, 'custom_camera_path', custom_traj_name)
os.makedirs(os.path.join(self.traj_results_dir), exist_ok=True)
self.tracking_results_dir = os.path.join(self.results_dir, 'track_with_deva', custom_traj_name)
os.makedirs(self.tracking_results_dir, exist_ok=True)
self.blender_output_dir = os.path.join(self.traj_results_dir, 'blender_output', hparams.blender_output_dir_name)
os.makedirs(self.blender_output_dir, exist_ok=True)
self.cache_dir = os.path.join(ROOT_DIR, '_cache')
os.makedirs(self.cache_dir, exist_ok=True)
self.cfg_path = os.path.join(self.blender_output_dir, hparams.blender_config_name)
self.custom_traj_name = custom_traj_name
self.scene_scale = float(hparams.scene_scale) if not hparams.waymo_scene else 1.0
self.anchor_frame_idx = hparams.anchor_frame_idx if hparams.anchor_frame_idx is not None else 0
self.inserted_objects = []
self.fire_objects = []
self.smoke_objects = []
self.events = []
self.blender_cfg = {}
self.rb_transform_info = None
self.blender_cache_dir = os.path.join(
self.cache_dir,
'blender_rendering',
self.dataset_dir.rstrip('/').split('/')[-1], # scene name
self.custom_traj_name
)
bg_color = [1,1,1] if self.hparams.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.pipe = PipelineParams()
self.DINO_THRESHOLD = hparams.deva_dino_threshold
self.scene_mesh_path_for_blender = hparams.scene_mesh_path
self.total_frames = self.cameras['c2w'].shape[0] if hparams.render_type == 'MULTI_VIEW' else self.hparams.num_frames
self.fps = 15
self.camera_position = self.cameras['c2w'][self.anchor_frame_idx][:3, 3].copy()
self.camera_rotation = self.cameras['c2w'][self.anchor_frame_idx][:3, :3].copy()
self.waymo_scene = hparams.waymo_scene
def insert_object(self, object_info):
assert isinstance(object_info, dict)
self.inserted_objects.append(object_info)
def load_cameras(self):
'''
Refernce: loadCustomCameras() in line 104 of sugar/gaussian_splatting/scene/__init__.py
'''
# Option 1: Load cameras from custom camera trajectory
if self.hparams.custom_traj_name is not None:
# get the info of custom camera trajectory
custom_traj_folder = os.path.join(self.hparams.source_path, "custom_camera_path")
with open(os.path.join(custom_traj_folder, self.hparams.custom_traj_name + '.json'), 'r') as f:
custom_traj = json.load(f)
# get camera poses and intrinsics
fx, fy, cx, cy = custom_traj["fl_x"], custom_traj["fl_y"], custom_traj["cx"], custom_traj["cy"]
w, h = custom_traj["w"], custom_traj["h"]
c2w_dict = {}
for frame in custom_traj["frames"]:
c2w_dict[frame["filename"]] = np.array(frame["transform_matrix"])
c2w_dict = dict(sorted(c2w_dict.items()))
if self.hparams.downscale_factor > 1.0:
h = round(h / self.hparams.downscale_factor)
w = round(w / self.hparams.downscale_factor)
fx = fx / self.hparams.downscale_factor
fy = fy / self.hparams.downscale_factor
cx = cx / self.hparams.downscale_factor
cy = cy / self.hparams.downscale_factor
# camera list
custom_cameras = []
for cam_idx, (filename, c2w) in enumerate(tqdm(c2w_dict.items(), desc="Loading custom cameras")):
w2c = np.linalg.inv(c2w)
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
FovY = focal2fov(fy, h)
FovX = focal2fov(fx, w)
view = GSCamera(
colmap_id=cam_idx, R=R, T=T,
FoVx=FovX, FoVy=FovY, image=None, gt_alpha_mask=None,
image_name='{0:05d}'.format(cam_idx), uid=cam_idx,
image_height=h, image_width=w)
custom_cameras.append(view)
# store information for blender rendering
self.cameras = {
'cameras': custom_cameras,
'img_wh': (w, h),
'K': np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]),
'c2w': np.array([c2w_dict[frame] for frame in c2w_dict]),
'c2w_dict': c2w_dict,
}
# Option 2: Load cameras from dataset
else:
camera_list = load_gs_cameras(self.hparams.source_path, self.hparams.model_path, self.hparams.downscale_factor)
tmp_camera = camera_list[0]
h, w = tmp_camera.image_height, tmp_camera.image_width
cx, cy = w / 2, h / 2
fx, fy = fov2focal(tmp_camera.FoVx, w), fov2focal(tmp_camera.FoVy, h)
c2w_list = []
c2w_dict = {}
for cam in camera_list:
c2w = np.zeros((4,4))
c2w[:3, :3] = cam.R.transpose()
c2w[:3, 3] = cam.T
c2w[3, 3] = 1.0
c2w_list.append(c2w)
c2w_dict[cam.image_name + '.png'] = c2w
self.cameras = {
'cameras': camera_list,
'img_wh': (w, h),
'K': np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]),
'c2w': np.array(c2w_list),
'c2w_dict': c2w_dict,
}
def load_scene(self):
'''
Reference: convert_refined_sugar_into_gaussians() in line 2617 of sugar/sugar_scene/sugar_model.py
'''
if self.hparams.gaussians_ckpt_path.endswith('.pt'):
# Load gaussians parameters from sugar checkpoint
CONSOLE.print(f"\nLoading the coarse SuGaR model from path {self.hparams.gaussians_ckpt_path}...")
gaussians = GaussianModel(self.hparams.max_sh_degree)
checkpoint = torch.load(self.hparams.gaussians_ckpt_path, map_location=gaussians.get_xyz.device)
with torch.no_grad():
xyz = checkpoint['state_dict']['_points'].cpu().numpy()
opacities = checkpoint['state_dict']['all_densities'].cpu().numpy()
features_dc = checkpoint['state_dict']['_sh_coordinates_dc'].cpu().numpy()
features_extra = checkpoint['state_dict']['_sh_coordinates_rest'].cpu().numpy()
scales = checkpoint['state_dict']['_scales'].cpu().numpy()
rots = checkpoint['state_dict']['_quaternions'].cpu().numpy()
_set_require_grad = False
gaussians._xyz = torch.nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians._features_dc = torch.nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians._features_rest = torch.nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians._opacity = torch.nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians._scaling = torch.nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians._rotation = torch.nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(_set_require_grad))
gaussians.active_sh_degree = self.hparams.max_sh_degree
if self.hparams.gaussians_ckpt_path.endswith('.ply'):
# Load gaussians parameters from vanilla 3DGS checkpoint
CONSOLE.print(f"\nLoading the vanilla 3DGS model from path {self.hparams.gaussians_ckpt_path}...")
gaussians = GaussianModel(self.hparams.max_sh_degree - 1) # SuGaR: 4, vanilla 3DGS: 3
gaussians.load_ply(self.hparams.gaussians_ckpt_path)
self.gaussians = gaussians
def render_scene(self, skip_render_3DGS=False):
self.render_from_blender()
if (
not skip_render_3DGS or
self.rb_transform_info is not None or
os.path.exists(os.path.join(self.blender_output_dir, 'melting_meshes'))
):
self.render_from_3DGS(post_rendering=True)
blend_all.blend_frames(self.blender_output_dir, self.cfg_path)
def save_cfg(self, cfg, cfg_path):
with open(cfg_path, 'w') as f:
json.dump(cfg, f, indent=4)
def set_basic_blender_cfg(self):
new_cfg = {}
new_cfg['edit_text'] = self.hparams.edit_text
new_cfg['blender_cache_dir'] = self.blender_cache_dir
new_cfg['im_width'], new_cfg['im_height'] = self.cameras['img_wh']
new_cfg['K'] = self.cameras['K'].tolist()
new_cfg['c2w'] = self.cameras['c2w'].tolist()
new_cfg['scene_mesh_path'] = self.scene_mesh_path_for_blender
new_cfg['is_uv_mesh'] = self.hparams.is_uv_mesh
new_cfg['output_dir_name'] = self.hparams.blender_output_dir_name
new_cfg['render_type'] = self.hparams.render_type
new_cfg['num_frames'] = self.hparams.num_frames
new_cfg['anchor_frame_idx'] = self.anchor_frame_idx
new_cfg['emitter_mesh_path'] = self.hparams.emitter_mesh_path
new_cfg['is_indoor_scene'] = self.hparams.is_indoor_scene
new_cfg['waymo_scene'] = self.waymo_scene
self.blender_cfg.update(new_cfg)
def render_from_blender(self):
self.set_basic_blender_cfg()
hdr_env_map_path, sun_dir = self.render_global_env_map()
self.blender_cfg['global_env_map_path'] = hdr_env_map_path
self.blender_cfg['sun_dir'] = sun_dir.tolist() if sun_dir is not None else None
self.blender_cfg['insert_object_info'] = []
for obj in self.inserted_objects:
obj['pos'] = obj['pos'].tolist()
obj['rot'] = obj['rot'].tolist()
if obj['material'] is not None and obj['material']['rgb'] is not None:
obj['material']['rgb'] = obj['material']['rgb'].tolist()
if obj['animation'] is not None and obj['animation']['type'] == 'trajectory':
obj['animation']['points'] = [point.tolist() for point in obj['animation']['points']]
self.blender_cfg['insert_object_info'].append(obj)
self.blender_cfg['fire_objects'] = self.fire_objects
self.blender_cfg['smoke_objects'] = self.smoke_objects
self.blender_cfg['events'] = self.events
self.save_cfg(self.blender_cfg, self.cfg_path)
torch.cuda.empty_cache() # release gpu memory for blender
os.system('{} --background --python ./blender/all_rendering.py -- --input_config_path={}'.format( \
BLENDER_PATH, self.cfg_path
))
# check if rigid body transform is added to the blender config
with open(self.cfg_path, 'r') as f:
self.blender_cfg = json.load(f)
if 'rb_transform' in self.blender_cfg:
self.rb_transform_info = self.blender_cfg['rb_transform']
# def render_local_env_map(self, origin):
# origin = torch.FloatTensor(origin)
# env_map_dir = os.path.join(self.results_dir, 'panorama', str(math.floor(time.time()))) # use current timestamp as the name of the env map
# ldr_env_map_path = render_panorama(self.gaussians, self.pipe, self.background, origin, env_map_dir)
# ldr_env_map_path = inpaint_img(ldr_env_map_path)
# hdr_env_map_path = convert_ldr2hdr(ldr_env_map_path)
# return hdr_env_map_path
def render_global_env_map(self):
assert self.anchor_frame_idx is not None # anchor frame index must be specified
image_path = os.path.join(self.traj_results_dir, 'images', '{0:05d}.png'.format(self.anchor_frame_idx))
output_dir = os.path.join(self.results_dir, 'hdr', self.hparams.custom_traj_name)
c2w = self.cameras['c2w'][self.anchor_frame_idx]
hdr_env_map_path = os.path.join(output_dir, '{0:05d}_rotate.exr'.format(self.anchor_frame_idx))
if not os.path.exists(hdr_env_map_path):
hdr_env_map_path = get_envmap_from_single_view(image_path, output_dir, c2w)
else:
print('HDR environment map already exists, skip rendering...')
# TODO: get the sunlight direction for waymo scenes
sun_dir = None
if self.waymo_scene:
ev_image_path = os.path.join(output_dir, 'envmap', '{0:05d}_ev-50.png'.format(self.anchor_frame_idx))
sun_dir = self.get_sunlight_direction(ev_image_path, c2w)
print('Sunlight direction: ', sun_dir)
return hdr_env_map_path, sun_dir
def get_sunlight_direction(self, img_path, c2w):
image = Image.open(img_path).convert('L')
# image = image.filter(ImageFilter.GaussianBlur(3))
image = np.array(image)
max_index = np.unravel_index(np.argmax(image), image.shape) # Find the index of the maximum intensity value
y, x = max_index # max_index will contain the (y, x) coordinates of the pixel with the highest intensity
h, w = image.shape
theta = (x / w) * 2 * np.pi # convert to spherical coordinates
phi = (y / h) * np.pi
x = np.sin(phi) * np.cos(theta)
y = np.sin(phi) * np.sin(theta)
z = np.cos(phi)
dir_vector = np.array([x, y, z])
dir_vector = dir_vector / np.linalg.norm(dir_vector)
dir_vector = c2w[:3, :3] @ dir_vector # rotate the direction vector to the world coordinate
dir_vector = dir_vector / np.linalg.norm(dir_vector)
dir_vector = -dir_vector
return dir_vector
def render_from_3DGS(self, render_video=False, post_rendering=False):
self.load_scene() # reload the scene to get the latest gaussians
camera_views = self.cameras['cameras'] # a list of Camera objects
if post_rendering and self.hparams.render_type == 'SINGLE_VIEW':
camera_views = [copy.deepcopy(self.cameras['cameras'][self.anchor_frame_idx]) for _ in range(self.total_frames)]
for cam_idx, view in enumerate(camera_views):
camera_views[cam_idx].image_name = '{0:05d}'.format(cam_idx)
render_path = os.path.join(self.traj_results_dir, "images")
os.makedirs(render_path, exist_ok=True)
depth_path = os.path.join(self.traj_results_dir, "depth")
os.makedirs(depth_path, exist_ok=True)
normal_path = os.path.join(self.traj_results_dir, "normal")
os.makedirs(normal_path, exist_ok=True)
with torch.no_grad():
for idx, view in tqdm(enumerate(camera_views), desc="Rendering progress"):
if self.rb_transform_info is not None:
all_gaussians = copy.deepcopy(self.gaussians)
for obj_id, obj_rb_info in self.rb_transform_info.items():
if "{0:03d}".format(idx + 1) not in obj_rb_info:
continue
rb_transform = obj_rb_info["{0:03d}".format(idx + 1)] # frame index starts from 001
obj_info = [obj for obj in self.blender_cfg['insert_object_info'] if obj['object_id'] == obj_id][0]
obj_gaussians_path = os.path.join('/'.join(obj_info['object_path'].split('/')[:-2]), 'object_gaussians.ply')
center = torch.Tensor(rb_transform['pos']).cuda()
rotation = torch.Tensor(rb_transform['rot']).cuda()
scaling = rb_transform['scale']
initial_center = torch.Tensor(get_center_of_mesh_2(obj_info['object_path'])).cuda()
object_gaussians = load_gaussians(obj_gaussians_path, self.hparams.max_sh_degree - 1)
transformed_gaussians = transform_gaussians(object_gaussians, center, rotation, scaling, initial_center)
all_gaussians = merge_two_gaussians(all_gaussians, transformed_gaussians)
elif os.path.exists(os.path.join(self.blender_cache_dir, self.hparams.blender_output_dir_name, 'melting_meshes')):
all_gaussians = copy.deepcopy(self.gaussians)
mesh_output_dir = os.path.join(self.blender_cache_dir, self.hparams.blender_output_dir_name, 'melting_meshes')
for obj_id in sorted(os.listdir(mesh_output_dir)):
melting_mesh_dir = os.path.join(mesh_output_dir, obj_id)
obj_info = [
obj for obj in self.blender_cfg['insert_object_info']
if obj['object_id'] == obj_id
][0]
orig_mesh_path = obj_info['object_path']
orig_gaussians_path = os.path.join('/'.join(orig_mesh_path.split('/')[:-2]), 'object_gaussians.ply')
orig_mesh = trimesh.load_mesh(orig_mesh_path)
orig_gaussians = load_gaussians(orig_gaussians_path, self.hparams.max_sh_degree - 1)
# associate closest triangle in the original mesh to each Gaussian center
orig_mesh_o3d = o3d.t.geometry.RaycastingScene()
orig_mesh_o3d.add_triangles(o3d.t.geometry.TriangleMesh.from_legacy(orig_mesh.as_open3d))
gaussians_xyz = orig_gaussians._xyz.detach().cpu().numpy()
ret_dict = orig_mesh_o3d.compute_closest_points(
o3d.core.Tensor.from_numpy(gaussians_xyz.astype(np.float32))
)
triangle_ids_from_gaussians = ret_dict['primitive_ids'].cpu().numpy()
# iterate over the melting meshes
melting_mesh_paths = [
os.path.join(melting_mesh_dir, '{0:03d}_obj.stl'.format(idx + 1)),
os.path.join(melting_mesh_dir, '{0:03d}_obj_dup.stl'.format(idx + 1))
]
for melting_mesh_path in melting_mesh_paths:
if not os.path.exists(melting_mesh_path):
continue
melting_mesh = trimesh.load_mesh(melting_mesh_path) # meet ValueError: PLY is unexpected length!
# melting_mesh = o3d.io.read_triangle_mesh(melting_mesh_path)
# associate closest triangle in the original mesh to each vertex in the melting mesh
ret_dict = orig_mesh_o3d.compute_closest_points(
o3d.core.Tensor.from_numpy(np.array(melting_mesh.triangles_center).astype(np.float32))
)
# ret_dict = orig_mesh_o3d.compute_closest_points(
# o3d.core.Tensor.from_numpy(np.array(melting_mesh.vertices).astype(np.float32))
# )
triangle_ids_from_melting = ret_dict['primitive_ids'].cpu().numpy()
# keep the Gaussians sharing the same closest triangle with the melting mesh
matching_gaussians_mask = np.isin(triangle_ids_from_gaussians, triangle_ids_from_melting)
# create new Gaussians and merge the new Gaussians with the existing ones
new_gaussians = copy.deepcopy(orig_gaussians)
new_gaussians._xyz = orig_gaussians._xyz[matching_gaussians_mask]
new_gaussians._features_dc = orig_gaussians._features_dc[matching_gaussians_mask]
new_gaussians._features_rest = orig_gaussians._features_rest[matching_gaussians_mask]
new_gaussians._scaling = orig_gaussians._scaling[matching_gaussians_mask]
new_gaussians._rotation = orig_gaussians._rotation[matching_gaussians_mask]
new_gaussians._opacity = orig_gaussians._opacity[matching_gaussians_mask]
all_gaussians = merge_two_gaussians(all_gaussians, new_gaussians)
else:
all_gaussians = self.gaussians
result = render(view, all_gaussians, self.pipe, self.background)
# rgb image
rgba_img = result["render"]
torchvision.utils.save_image(rgba_img, os.path.join(render_path, view.image_name + ".png"))
# depth map
depth_raw = result["depth"].cpu().numpy()
depth_raw = depth_raw.squeeze()
np.save(os.path.join(depth_path, view.image_name + ".npy"), depth_raw.astype(np.float32))
depth_img = depth2img(depth_raw, scale=3.0)
cv2.imwrite(os.path.join(depth_path, view.image_name + ".png"), depth_img)
# normal map
normal = result["normal"].cpu().numpy()
normal = (normal + 1) / 2
normal = (normal * 255).astype(np.uint8)
cv2.imwrite(os.path.join(normal_path, view.image_name + ".png"), cv2.cvtColor(normal, cv2.COLOR_RGB2BGR))
# generate video from frames
if render_video:
rgb_frames_path = sorted(glob.glob(os.path.join(render_path, '*.png')))
generate_video_from_frames(rgb_frames_path, os.path.join(self.traj_results_dir, 'render_rgb.mp4'), fps=15)
depth_frames_path = sorted(glob.glob(os.path.join(depth_path, '*.png')))
generate_video_from_frames(depth_frames_path, os.path.join(self.traj_results_dir, 'render_depth.mp4'), fps=15)
normal_frames_path = sorted(glob.glob(os.path.join(normal_path, '*.png')))
generate_video_from_frames(normal_frames_path, os.path.join(self.traj_results_dir, 'render_normal.mp4'), fps=15)
def training_3DGS_for_inpainting(self, gaussians_path, image_dir, mask_dir, output_dir, transforms_path):
gaussians = GaussianModel(self.hparams.max_sh_degree - 1)
gaussians.load_ply(gaussians_path)
opt = OptimizationParams()
pipe = PipelineParams()
gaussians.training_setup(opt)
gaussians.max_radii2D = torch.zeros((gaussians.get_xyz.shape[0]), device="cuda")
# get training cameras
cameraList = []
with open(transforms_path, 'r') as f:
transforms = json.load(f)
fx, fy, cx, cy = transforms["fl_x"], transforms["fl_y"], transforms["cx"], transforms["cy"]
w, h = transforms["w"], transforms["h"]
for idx, info in tqdm(enumerate(transforms["frames"]), desc="Loading custom cameras"):
filename = info["filename"]
c2w = np.array(info["transform_matrix"])
w2c = np.linalg.inv(np.array(c2w))
R = np.transpose(w2c[:3,:3])
T = w2c[:3, 3]
FovY = focal2fov(fy, h)
FovX = focal2fov(fx, w)
image = Image.open(os.path.join(image_dir, filename))
image = PILtoTorch(image, (w, h))
view = GSCamera(
colmap_id=idx, R=R, T=T,
FoVx=FovX, FoVy=FovY, image=image, gt_alpha_mask=None,
image_name=filename, uid=idx,
image_height=h, image_width=w)
cameraList.append(view)
LPIPS = init_lpips_model()
viewpoint_stack = None
first_iter = 0
last_iter = 2000 # make it shorter to prevent overfit, original: 5000
for iteration in tqdm(range(first_iter, last_iter + 1), desc="Re-training for inpainting progress"):
gaussians.update_learning_rate(iteration)
# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = cameraList.copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
result = render(viewpoint_cam, gaussians, pipe, self.background)
image, viewspace_point_tensor, visibility_filter, radii = \
result["render"], \
result["viewspace_points"], \
result["visibility_filter"], \
result["radii"]
# get the boolean mask
mask2d = None
mask2d_path = os.path.join(mask_dir, viewpoint_cam.image_name)
if os.path.exists(mask2d_path):
mask2d = Image.open(mask2d_path)
mask2d = torch.from_numpy(np.array(mask2d) / 255.0).unsqueeze(0).cuda()
mask2d = mask2d.repeat(4, 1, 1)
mask2d = mask2d > 0.0
# RGB Loss and LPIPS Loss (adpted from gaussian grouping)
gt_image = viewpoint_cam.original_image.cuda()
if mask2d is None or not is_large_mask(mask2d): # use L1-loss if no mask provided or the mask is not large enough
loss_rgb = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * loss_rgb + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
else:
loss_rgb = l1_loss(image[~mask2d], gt_image[~mask2d])
loss_lpips = compute_lpips_loss(LPIPS, image[:3, ...], gt_image[:3, ...], mask2d[0, ...])
loss = (1.0 - opt.lambda_dssim) * loss_rgb + opt.lambda_dssim * loss_lpips
loss.backward()
with torch.no_grad():
# Densification
if iteration < 5000:
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration % 300 == 0:
size_threshold = 20
min_opacity = 0.1 # 0.005 would create floaters due to multi-view inconsistency in inpainting
gaussians.densify_and_prune(opt.densify_grad_threshold, min_opacity, 1.1, size_threshold) # 1.1 since we normalize the camera in sdf rendering (1 * 1.1)
# gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
# if iteration % opt.opacity_reset_interval == 0:
# gaussians.reset_opacity()
# Optimizer step
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)
# save the gaussians to .ply file
gaussians.save_ply(os.path.join(output_dir, 'inpaint_gaussians.ply'))
if __name__ == '__main__':
hparams = get_opts()
scene_representation = SceneRepresentation(hparams)
##### Test rendering from Blender #####
# scene_representation.render_scene()
# with open(scene_representation.cfg_path, 'r') as f:
# scene_representation.blender_cfg = json.load(f)
# scene_representation.rb_transform_info = scene_representation.blender_cfg['rb_transform']
# scene_representation.render_scene()
##### Test rendering from 3DGS #####
scene_representation.load_scene()
scene_representation.render_from_3DGS(render_video=True)
##### Pre-render all environment map #####
# scene_representation.render_global_env_map()
##### Estimate scene scale #####
# scene_representation.estimate_scene_scale()
##### Test mesh extraction #####
# TEXT_PROMPT = 'bulldozer'
# # scene_representation.render_from_3DGS()
# # from tracking.demo_with_text import run_deva
# # run_deva(os.path.join(scene_representation.traj_results_dir, 'images'), scene_representation.tracking_results_dir, TEXT_PROMPT, scene_representation.DINO_THRESHOLD)
# from extract.extract_object import extract_object_from_scene, inpaint_object
# id = str([x for x in os.listdir(os.path.join(scene_representation.tracking_results_dir, '_'.join(TEXT_PROMPT.split(' ')))) if x.isdigit()][0])
# # extract_object_from_scene(scene_representation, TEXT_PROMPT, id)
# inpaint_object(scene_representation, TEXT_PROMPT, id)
# save_dir = os.path.join(scene_representation.results_dir, 'object_instance', scene_representation.custom_traj_name, '_'.join(TEXT_PROMPT.split(' ')), id)
# scene_representation.training_3DGS_for_inpainting(
# os.path.join(save_dir, 'removal_gaussians.ply'),
# os.path.join(save_dir, 'render_inpaint_lama'),
# os.path.join(save_dir, 'render_inpaint_mask'),
# save_dir,
# os.path.join(save_dir, 'inpaint_camera_poses.json')
# )
# gaussians = GaussianModel(scene_representation.hparams.max_sh_degree - 1)
# gaussians.load_ply(os.path.join(save_dir, 'inpaint_gaussians.ply'))
# scene_representation.gaussians = gaussians
# scene_representation.render_from_3DGS(render_video=True)
##### Test rigid body simulation of existing objects in the scene #####
# object_gaussians_path = 'output/garden_norm_aniso_0.1_pseudo_normal_0.01_alpha_0.0/object_instance/vase_with_flowers/18040383/object_gaussians.ply'
# object_gaussians = load_gaussians(object_gaussians_path, scene_representation.hparams.max_sh_degree - 1)
# rb_transform_file_path = 'output/garden_norm_aniso_0.1_pseudo_normal_0.01_alpha_0.0/custom_camera_path/transforms_001/blend_results_vase_drop/rb_transform.json'
# with open(rb_transform_file_path, 'r') as f:
# rb_transform = json.load(f)
# scene_representation.rb_transform_info = rb_transform
# scene_representation.object_mesh_path = 'output/garden_norm_aniso_0.1_pseudo_normal_0.01_alpha_0.0/object_instance/vase_with_flowers/18040383/object_mesh/object_mesh.obj'
# scene_representation.object_gaussians = object_gaussians
# scene_representation.render_from_3DGS()