-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
138 lines (104 loc) · 5.22 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torchvision
from torch.utils.data import Dataset
import torch.nn.functional as F
import utils
import pathlib
import trimesh
import numpy as np
from mesh_to_sdf import sample_sdf_near_surface
import minexr
class SDFDataset(Dataset):
"""
Obtained from https://github.com/mrharicot
Reads files generated by https://github.com/christopherbatty/SDFGen
"""
def __init__(self, sdf_path, tsdf_min=-1.0, tsdf_max=1.0):
super(SDFDataset, self).__init__()
with open(sdf_path, "r") as f:
lines = f.read().splitlines()
volume_size, origin, dx, *sdfs = lines
# Read sdf format
volume_size = [int(l) for l in volume_size.split()]
origin = torch.tensor([float(l) for l in origin.split()])
dx = float(dx)
sdfs_hwd1 = torch.tensor([float(l) for l in sdfs]).reshape(volume_size[::-1])[..., None]
# clamp TSDF
sdfs_hwd1.clamp_(tsdf_min, tsdf_max)
# generate voxel coords
zz, yy, xx = torch.meshgrid(torch.arange(sdfs_hwd1.shape[0]),
torch.arange(sdfs_hwd1.shape[1]),
torch.arange(sdfs_hwd1.shape[2]),
indexing="ij")
voxel_int_coords_hwd3 = torch.stack([xx, yy, zz]).permute(1, 2, 3, 0).float()
voxel_int_coords_N3 = voxel_int_coords_hwd3.view(-1, 3)
self.voxel_coords_hwd3 = origin[None, None, None] + dx * voxel_int_coords_hwd3
self.sdfs_hwd1 = sdfs_hwd1
# Find the voxels close to the surface and their indices
surface_voxel_ids_N = torch.where(sdfs_hwd1.flatten().abs() < 0.003)[0]
surface_voxel_coords_N3 = voxel_int_coords_N3[surface_voxel_ids_N]
# Generate random samples by taking 10X surface points and adding noise
random_surface_coords_N3 = surface_voxel_coords_N3.repeat(10, 1)
random_ids_N = torch.randint(voxel_int_coords_N3.shape[0], (surface_voxel_coords_N3.shape[0],))
random_coords_N3 = voxel_int_coords_N3[random_ids_N]
random_coords_N3 = torch.cat([random_surface_coords_N3, random_coords_N3], dim=0)
random_coords_N3 += torch.randn_like(random_coords_N3) * 1.0
random_uv_coords_N3 = 2 * random_coords_N3 / torch.tensor(sdfs_hwd1.shape[:3][::-1]).view(1, 3) - 1
self.random_sdf_samples_N1 = F.grid_sample(sdfs_hwd1.squeeze()[None, None],
random_uv_coords_N3[None, None, None],
align_corners=False,
padding_mode="border").view(-1, 1)
self.random_coord_samples_N3 = origin[None] + dx * random_coords_N3
def __len__(self):
return self.random_sdf_samples_N1.shape[0]
def __getitem__(self, idx):
return self.random_coord_samples_N3[idx], self.random_sdf_samples_N1[idx]
class OBJDataset(Dataset):
def __init__(self, obj_filename, voxel_resolution=256, num_samples=2**18, tsdf_min=-1.0, tsdf_max=1.0):
obj_path = pathlib.Path(obj_filename)
mesh = trimesh.load(str(obj_path))
# samples are inside the unit sphere
points, sdf_values = sample_sdf_near_surface(mesh, number_of_points=num_samples)
zz, yy, xx = torch.meshgrid(torch.linspace(-1.0, 1.0, voxel_resolution),
torch.linspace(-1.0, 1.0, voxel_resolution),
torch.linspace(-1.0, 1.0, voxel_resolution),
indexing="ij")
self.voxel_coords_hwd3 = torch.stack([xx, yy, zz], dim=3)
self.points_N3 = torch.from_numpy(points)
self.sdf_values_N = torch.from_numpy(sdf_values).clamp_(tsdf_min, tsdf_max)
def __len__(self):
return self.points_N3.shape[0]
def __getitem__(self, idx):
return self.points_N3[idx], self.sdf_values_N[idx]
class EXRDataset(Dataset):
def __init__(self, exr_filename):
with open(exr_filename, 'rb') as exr_file:
exr_img = minexr.load(exr_file)
rgb_hw3 = exr_img.select(['Color.R', 'Color.G', 'Color.B'])
self.rgb_hw3 = torch.from_numpy(rgb_hw3)
yy, xx = torch.meshgrid(
torch.linspace(0.0, 1.0, self.rgb_hw3.shape[0]),
torch.linspace(0.0, 1.0, self.rgb_hw3.shape[1]),
indexing="ij"
)
self.rgb_indices_N2 = torch.stack([xx, yy], dim=0).reshape(-1, 2)
def __len__(self):
return self.rgb_indices_N2.shape[0]
def __getitem__(self, idx):
indices = self.rgb_indices_N2[idx]
return indices, self.rgb_hw3[indices]
class ImageDataset(Dataset):
def __init__(self, img_filename):
self.rgb_hw3 = torchvision.io.read_image(img_filename).permute(1, 2, 0).float() / 255.0
yy, xx = torch.meshgrid(
torch.linspace(0.0, 1.0, self.rgb_hw3.shape[0]),
torch.linspace(0.0, 1.0, self.rgb_hw3.shape[1]),
indexing="ij"
)
self.rgb_indices_N2 = torch.stack([yy, xx], dim=2).reshape(-1, 2)
self.rgb_N3 = self.rgb_hw3.view(-1, 3)
pass
def __len__(self):
return self.rgb_indices_N2.shape[0]
def __getitem__(self, idx):
return self.rgb_indices_N2[idx], self.rgb_N3[idx]