-
Notifications
You must be signed in to change notification settings - Fork 14
/
data_loader.py
95 lines (90 loc) · 3.92 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import numpy as np
from nerf_utils.nerf import cumprod_exclusive, get_minibatches, get_ray_bundle, positional_encoding
from nerf_utils.tiny_nerf import VeryTinyNerfModel
from torchvision.datasets import mnist
from torchvision import transforms
import Lenet5
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from copy import deepcopy
def wrapper_dataset(config, args, device):
if args.datatype == 'tinynerf':
data = np.load(args.data_train_path)
images = data["images"]
# Camera extrinsics (poses)
tform_cam2world = data["poses"]
tform_cam2world = torch.from_numpy(tform_cam2world).to(device)
# Focal length (intrinsics)
focal_length = data["focal"]
focal_length = torch.from_numpy(focal_length).to(device)
# Height and width of each image
height, width = images.shape[1:3]
# Near and far clipping thresholds for depth values.
near_thresh = 2.0
far_thresh = 6.0
# Hold one image out (for test).
testimg, testpose = images[101], tform_cam2world[101]
testimg = torch.from_numpy(testimg).to(device)
# Map images to device
images = torch.from_numpy(images[:100, ..., :3]).to(device)
num_encoding_functions = 10
# Specify encoding function.
encode = positional_encoding
# Number of depth samples along each ray.
depth_samples_per_ray = 32
model = VeryTinyNerfModel(num_encoding_functions=num_encoding_functions)
# Chunksize (Note: this isn't batchsize in the conventional sense. This only
# specifies the number of rays to be queried in one go. Backprop still happens
# only after all rays from the current "bundle" are queried and rendered).
# Use chunksize of about 4096 to fit in ~1.4 GB of GPU memory (when using 8
# samples per ray).
chunksize = 4096
batch = {}
batch['height'] = height
batch['width'] = width
batch['focal_length'] = focal_length
batch['testpose'] = testpose
batch['near_thresh'] = near_thresh
batch['far_thresh'] = far_thresh
batch['depth_samples_per_ray'] = depth_samples_per_ray
batch['encode'] = encode
batch['get_minibatches'] =get_minibatches
batch['chunksize'] =chunksize
batch['num_encoding_functions'] = num_encoding_functions
train_ds, test_ds = [],[]
for img,tfrom in zip(images,tform_cam2world):
batch['input'] = tfrom
batch['output'] = img
train_ds.append(deepcopy(batch))
batch['input'] = testpose
batch['output'] = testimg
test_ds = [batch]
elif args.datatype == 'mnist':
model = Lenet5.NetOriginal()
train_transform = transforms.Compose(
[
transforms.ToTensor()
])
train_dataset = mnist.MNIST(
"\data\mnist", train=True, download=True, transform=ToTensor())
test_dataset = mnist.MNIST(
"\data\mnist", train=False, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)
train_ds, test_ds = [],[]
for idx, data in enumerate(train_loader):
train_x, train_label = data[0], data[1]
train_x = train_x[:,0,:,:].unsqueeze(1)
batch = {'input':train_x,'output':train_label}
train_ds.append(deepcopy(batch))
for idx, data in enumerate(test_loader):
train_x, train_label = data[0], data[1]
train_x = train_x[:,0,:,:].unsqueeze(1)
batch = {'input':train_x,'output':train_label}
test_ds.append(deepcopy(batch))
else:
"implement on your own"
pass
return train_ds,test_ds,model