From 7e3c8a77855d466ec33708d44eba11c7a4097786 Mon Sep 17 00:00:00 2001 From: sheng Date: Fri, 22 Dec 2023 18:29:35 +0800 Subject: [PATCH] fix camera intrisics --- scene/cameras.py | 8 ++++++-- scene/dataset_readers.py | 11 ++++++++++- utils/camera_utils.py | 3 ++- utils/graphics_utils.py | 6 +++--- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/scene/cameras.py b/scene/cameras.py index abf6e5242..88c40a602 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -15,7 +15,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): - def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + def __init__(self, colmap_id, R, T, FoVx, FoVy, cx, cy, image, gt_alpha_mask, image_name, uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" ): @@ -27,6 +27,8 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.T = T self.FoVx = FoVx self.FoVy = FoVy + self.cx = cx + self.cy = cy self.image_name = image_name try: @@ -52,7 +54,9 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.scale = scale self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() - self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, + fovX=self.FoVx, fovY=self.FoVy, + cx=self.cx, cy=self.cy).transpose(0, 1).cuda() self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 2a6f904a9..3ae967302 100644 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -29,6 +29,8 @@ class CameraInfo(NamedTuple): T: np.array FovY: np.array FovX: np.array + cx: np.array + cy: np.array image: np.array image_path: str image_name: str @@ -84,9 +86,13 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): if intr.model=="SIMPLE_PINHOLE": focal_length_x = intr.params[0] + cx = intr.params[1] + cy = intr.params[2] FovY = focal2fov(focal_length_x, height) FovX = focal2fov(focal_length_x, width) elif intr.model=="PINHOLE": + cx = intr.params[2] + cy = intr.params[3] focal_length_x = intr.params[0] focal_length_y = intr.params[1] FovY = focal2fov(focal_length_y, height) @@ -94,11 +100,14 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): else: assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" + cx = (cx - width / 2) / width * 2 + cy = (cy - height / 2) / height * 2 + image_path = os.path.join(images_folder, os.path.basename(extr.name)) image_name = os.path.basename(image_path).split(".")[0] image = Image.open(image_path) - cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, cx=cx, cy=cy, image=image, image_path=image_path, image_name=image_name, width=width, height=height) cam_infos.append(cam_info) sys.stdout.write('\n') diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0ada..d80b084f3 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -47,7 +47,8 @@ def loadCam(args, id, cam_info, resolution_scale): loaded_mask = resized_image_rgb[3:4, ...] return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, - FoVx=cam_info.FovX, FoVy=cam_info.FovY, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + cx=cam_info.cx, cy=cam_info.cy, image=gt_image, gt_alpha_mask=loaded_mask, image_name=cam_info.image_name, uid=id, data_device=args.data_device) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py index b4627d837..7ac05e04a 100644 --- a/utils/graphics_utils.py +++ b/utils/graphics_utils.py @@ -48,7 +48,7 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.linalg.inv(C2W) return np.float32(Rt) -def getProjectionMatrix(znear, zfar, fovX, fovY): +def getProjectionMatrix(znear, zfar, fovX, fovY, cx, cy): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) @@ -63,8 +63,8 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): P[0, 0] = 2.0 * znear / (right - left) P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) + P[0, 2] = cx + P[1, 2] = cy P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear)