-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #163 from Rufaim/adding_ninapro_and_darcyflow
Adding ninapro and darcyflow datasets
- Loading branch information
Showing
11 changed files
with
294 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import json | ||
import argparse | ||
|
||
|
||
def main(vals_buildings, test_buildings): | ||
all_tasks = [] | ||
dirs = [ f.path for f in os.scandir(os.path.dirname(os.path.abspath(__file__))) if f.is_dir() ] | ||
for d in dirs: | ||
taskname = os.path.basename(d) | ||
templates = [ f"{taskname}/{{domain}}/"+os.path.basename(f.path).replace("_rgb.","_{domain}.") for f in os.scandir(os.path.join(d,"rgb")) if f.is_file() ] | ||
templates = sorted(templates) | ||
with open(d+".json", "w") as f: | ||
json.dump(templates, f) | ||
|
||
all_tasks.append(taskname) | ||
|
||
train_tasks = [] | ||
val_tasks = [] | ||
test_tasks = [] | ||
for task in all_tasks: | ||
if task in test_buildings: | ||
test_tasks.append(task) | ||
elif task in vals_buildings: | ||
val_tasks.append(task) | ||
else: | ||
train_tasks.append(task) | ||
|
||
foldername = os.path.dirname(d) | ||
for s,f in zip([train_tasks, val_tasks, test_tasks], ["train_split.json", "val_split.json", "test_split.json"]): | ||
with open(os.path.join(foldername, f), "w") as file: | ||
json.dump(s, file) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser("Taskonomy splits generator") | ||
parser.add_argument("--val", nargs="*", type=str, default=[]) | ||
parser.add_argument("--test", nargs="+", type=str, default=["uvalda", "merom", "stockman"]) | ||
args = parser.parse_args() | ||
|
||
main(args.val, args.test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import numpy as np | ||
import scipy.io | ||
import torch.utils.data | ||
import torchvision.transforms | ||
|
||
|
||
class UnitGaussianNormalizer(object): | ||
def __init__(self, x, eps=0.00001): | ||
super(UnitGaussianNormalizer, self).__init__() | ||
|
||
# x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T | ||
self.mean = torch.mean(x, 0) | ||
self.std = torch.std(x, 0) | ||
self.eps = eps | ||
|
||
def __call__(self, x): | ||
x = (x - self.mean) / (self.std + self.eps) | ||
return x | ||
|
||
# adapted from https://github.com/rtu715/NAS-Bench-360/blob/0d1af0ce37b5f656d6491beee724488c3fccf871/perceiver-io/perceiver/data/nb360/darcyflow.py#L73 | ||
def load_darcyflow_data(path): | ||
train_path = os.path.join(path, "piececonst_r421_N1024_smooth1.mat") | ||
test_path = os.path.join(path, "piececonst_r421_N1024_smooth2.mat") | ||
|
||
r = 5 | ||
s = int(((421 - 1) / r) + 1) | ||
|
||
x_train, y_train = read_mat(train_path, r, s) | ||
x_test, y_test = read_mat(test_path, r, s) | ||
|
||
x_normalizer = UnitGaussianNormalizer(x_train) | ||
x_train = x_normalizer(x_train) | ||
x_test = x_normalizer(x_test) | ||
|
||
y_normalizer = UnitGaussianNormalizer(y_train) | ||
y_train = y_normalizer(y_train) | ||
y_test = y_normalizer(y_test) | ||
|
||
x_train = x_train.reshape((-1, s, s, 1)) | ||
x_test = x_test.reshape((-1, s, s, 1)) | ||
|
||
trainset = torch.utils.data.TensorDataset(x_train, y_train) | ||
testset = torch.utils.data.TensorDataset(x_test, y_test) | ||
|
||
return trainset, testset | ||
|
||
|
||
def read_mat(file_path, r, s): | ||
data = scipy.io.loadmat(file_path) | ||
x = read_mat_field(data, "coeff")[:, ::r, ::r][:, :s, :s] | ||
y = read_mat_field(data, "sol")[:, ::r, ::r][:, :s, :s] | ||
del data | ||
return x, y | ||
|
||
|
||
def read_mat_field(mat, field): | ||
x = mat[field] | ||
x = x.astype(np.float32) | ||
return torch.from_numpy(x) | ||
|
||
|
||
def darcyflow_transform(args): | ||
transform_list = [] | ||
transform_list.append(torchvision.transforms.ToTensor()) | ||
return torchvision.transforms.Compose(transform_list), torchvision.transforms.Compose(transform_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import numpy as np | ||
import torch.utils.data | ||
import torchvision.transforms | ||
|
||
|
||
# adapted from https://github.com/rtu715/NAS-Bench-360/blob/0d1af0ce37b5f656d6491beee724488c3fccf871/perceiver-io/perceiver/data/nb360/ninapro.py#L64 | ||
class NinaPro(torch.utils.data.Dataset): | ||
def __init__(self, root, split="train", transform=None): | ||
self.root = root | ||
self.split = split | ||
self.transform = transform | ||
self.x = np.load(os.path.join(root, f"ninapro_{split}.npy")).astype(np.float32) | ||
self.x = self.x[:, np.newaxis, :, :].transpose(0, 2, 3, 1) | ||
self.y = np.load(os.path.join(root, f"label_{split}.npy")).astype(int) | ||
|
||
def __len__(self): | ||
return len(self.y) | ||
|
||
def __getitem__(self, idx): | ||
if torch.is_tensor(idx): | ||
idx = idx.tolist() | ||
|
||
x = self.x[idx, :] | ||
y = self.y[idx] | ||
|
||
if self.transform: | ||
x = self.transform(x) | ||
return x, y | ||
|
||
|
||
def ninapro_transform(args, channels_last: bool = True): | ||
transform_list = [] | ||
|
||
def channels_to_last(img: torch.Tensor): | ||
return img.permute(1, 2, 0).contiguous() | ||
|
||
transform_list.append(torchvision.transforms.ToTensor()) | ||
|
||
if channels_last: | ||
transform_list.append(channels_to_last) | ||
|
||
return torchvision.transforms.Compose(transform_list), torchvision.transforms.Compose(transform_list) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.