From 98b0f426dcd0dd3d2bc0c915745643a26ccc6fc5 Mon Sep 17 00:00:00 2001 From: Kirill Trapeznikov Date: Wed, 8 Jun 2022 00:30:51 +0000 Subject: [PATCH] minor fixes --- create_dataset.py | 5 +++++ gaia/data.py | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/create_dataset.py b/create_dataset.py index 9f1a020..76af6ab 100644 --- a/create_dataset.py +++ b/create_dataset.py @@ -6,4 +6,9 @@ if __name__=="__main__": NCDataConstructor.default_data(split="train", workers =32, prefix=cam4, train_years=3, save_location=".", cache = ".") + NCDataConstructor.default_data(split="test", workers =32, prefix=cam4, train_years=3, save_location=".", cache = ".") + NCDataConstructor.default_data(split="train", workers =32, prefix=spcam, train_years=2, save_location=".", cache = ".") + NCDataConstructor.default_data(split="test", workers =32, prefix=spcam, train_years=2, save_location=".", cache = ".") + + # NCDataConstructor.default_data(split="train") \ No newline at end of file diff --git a/gaia/data.py b/gaia/data.py index 2e38eec..e9a472a 100644 --- a/gaia/data.py +++ b/gaia/data.py @@ -494,6 +494,7 @@ def default_data( ), outputs="PRECT,PRECC,PTEQ,PTTEND".split(","), flatten=split == "train", + shuffle = split == "train", subsample_factor=4, compute_stats=True, cache = os.path.join(cache,split), @@ -689,7 +690,11 @@ def clean_up_file(self, dataset): def subsample_data(self, xi, yi, subsample_factor): size = xi.shape[0] new_size = size // subsample_factor - shuffled_index = torch.randperm(size)[:new_size] + if self.shuffle: + shuffled_index = torch.randperm(size)[:new_size] + else: + shuffled_index = torch.arange(0, size, subsample_factor) + xi = xi[shuffled_index, ...] yi = yi[shuffled_index, ...] return xi, yi, shuffled_index