Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ktrapeznikov committed Jun 8, 2022
1 parent 59be221 commit 98b0f42
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 5 additions & 0 deletions create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@

if __name__=="__main__":
NCDataConstructor.default_data(split="train", workers =32, prefix=cam4, train_years=3, save_location=".", cache = ".")
NCDataConstructor.default_data(split="test", workers =32, prefix=cam4, train_years=3, save_location=".", cache = ".")
NCDataConstructor.default_data(split="train", workers =32, prefix=spcam, train_years=2, save_location=".", cache = ".")
NCDataConstructor.default_data(split="test", workers =32, prefix=spcam, train_years=2, save_location=".", cache = ".")


# NCDataConstructor.default_data(split="train")
7 changes: 6 additions & 1 deletion gaia/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def default_data(
),
outputs="PRECT,PRECC,PTEQ,PTTEND".split(","),
flatten=split == "train",
shuffle = split == "train",
subsample_factor=4,
compute_stats=True,
cache = os.path.join(cache,split),
Expand Down Expand Up @@ -689,7 +690,11 @@ def clean_up_file(self, dataset):
def subsample_data(self, xi, yi, subsample_factor):
size = xi.shape[0]
new_size = size // subsample_factor
shuffled_index = torch.randperm(size)[:new_size]
if self.shuffle:
shuffled_index = torch.randperm(size)[:new_size]
else:
shuffled_index = torch.arange(0, size, subsample_factor)

xi = xi[shuffled_index, ...]
yi = yi[shuffled_index, ...]
return xi, yi, shuffled_index
Expand Down

0 comments on commit 98b0f42

Please sign in to comment.