Skip to content

Commit

Permalink
add support for lat/lon cond (#38)
Browse files Browse the repository at this point in the history
Co-authored-by: Kirill <kirill.trapeznikov@str.us>
  • Loading branch information
ktrapeznikov and ktrapeznikov authored Jul 28, 2022
1 parent 1ac4994 commit 5c3c7cf
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 43 deletions.
29 changes: 26 additions & 3 deletions gaia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,44 @@ def set_dataset_params(cli_args=dict()):
raise ValueError(f"unknown dataset {dataset}")

var_index_file = base + "_var_index.pt"

#possibly shared params

batch_size = cli_args.get('dataset_params',{}).get("batch_size",24 * 96 * 144)
include_index = cli_args.get('dataset_params',{}).get("include_index",False)
subsample = cli_args.get('dataset_params',{}).get("subsample",1)
space_filter = cli_args.get('dataset_params',{}).get("space_filter",None)

dataset_params = dict(
train=dict(
dataset_file=base + "_train.pt",
batch_size=batch_size,
shuffle=True,
flatten=False, # already flattened
var_index_file=var_index_file
var_index_file=var_index_file,
include_index = include_index,
subsample = subsample,
space_filter =space_filter
),
val=dict(
dataset_file=base + "_val.pt",
batch_size=batch_size,
shuffle=False,
flatten=False, # already flattened
var_index_file=var_index_file
var_index_file=var_index_file,
include_index = include_index,
subsample = subsample,
space_filter =space_filter
),
test=dict(
dataset_file=base+'_test.pt',
batch_size=batch_size,
shuffle=False,
flatten=True, # already flattened
var_index_file=var_index_file
var_index_file=var_index_file,
include_index = include_index,
subsample = subsample,
space_filter =space_filter
),
mean_thres=mean_thres
)
Expand Down Expand Up @@ -212,6 +227,14 @@ def model_type_lookup(model_type):
"num_layers": 3,
"hidden_size": 128,
}
elif model_type == "fcn_with_index":
model_config = {
"model_type": "fcn_with_index",
"num_layers": 7,
"hidden_size": 512,
"dropout": 0.01,
"leaky_relu": 0.15
}
else:
raise ValueError

Expand Down
95 changes: 91 additions & 4 deletions gaia/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,13 +932,37 @@ def get_dataset_v1(
return dataset_dict, data_loader


def unravel_index(flat_index, shape):
# flat_index = operator.index(flat_index)
res = []

# Short-circuits on zero dim tensors
if shape == torch.Size([]):
return 0

for size in shape[::-1]:
res.append(flat_index % size)
flat_index = flat_index // size

# return torch.cat(res

if len(res) == 1:
return res[0]

return res[::-1]




def get_dataset(
dataset_file,
batch_size=1024,
flatten=False,
shuffle = False,
var_index_file = None,
include_index = False,
subsample = 1,
space_filter = None
):

dataset_dict = torch.load(dataset_file)
Expand All @@ -956,13 +980,76 @@ def get_dataset(
dataset_dict[v] = flatten_tensor(dataset_dict[v])


if not include_index:
del dataset_dict["index"]
tensor_list = [dataset_dict["x"], dataset_dict["y"]]

tensor_list = [dataset_dict["x"], dataset_dict["y"]]

if include_index or (space_filter is not None):
#TODO dont hard code this
num_ts,num_lats,num_lons = 8,96,144
logger.warning(f"using hardcoded expected shape for unraveling the index: {num_ts,num_lats,num_lons}")

if flatten:

# index is flattened
# shape = dataset_dict["x"].shape
# if len(dataset_dict["x"].shape) == 3:
# samples, timesteps, channels = dataset_dict["x"].shape
# else:
# samples, channels = dataset_dict["x"].shape

num_samples = dataset_dict["index"].shape[0]

lats = torch.ones(num_samples,1,num_lons)*torch.arange(num_lats)[None,:,None]
lons = torch.ones(num_samples,num_lats,1)*torch.arange(num_lons)[None,None,:]
index = torch.cat([lats.ravel().long()[:,None], lons.ravel().long()[:,None]],dim=-1)
else:
index = unravel_index(dataset_dict["index"], shape=[num_ts, num_lats, num_lons])
index = torch.cat([i[:,None] for i in index[1:]],dim=-1) #just want lats, lons


else:
pass
del dataset_dict["index"]


if include_index:
tensor_list += [index]


if space_filter is not None:
# filter out dataset

logger.info(f"applying space filter {space_filter}")

from gaia.plot import lats as lats_vals
from gaia.plot import lons as lons_vals

lats_vals = torch.tensor(lats_vals)
lons_vals = torch.tensor(lons_vals)

mask = torch.ones(len(tensor_list[0])).bool()

if "lat_bounds" in space_filter:
lat_min,lat_max = space_filter["lat_bounds"]
temp = lats_vals[index[:,0]]
mask = mask & (temp <= lat_max) & (temp>=lat_min)

if "lon_bounds" in space_filter:
lon_min,lon_max = space_filter["lon_bounds"]
temp = lons_vals[index[:,1]]
mask = mask & (temp <= lon_max) & (temp>=lon_min)

assert mask.any()

tensor_list = [t[mask,...] for t in tensor_list]


if subsample>1:
tensor_list = [t[::subsample,...] for t in tensor_list]
logger.info(f"subsampling by factor of {subsample}")


logger.info(f"data size {len(tensor_list[0])}")


data_loader = DataLoader(
FastTensorDataset(
Expand Down
100 changes: 89 additions & 11 deletions gaia/layers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from turtle import forward
import torch
from gaia import get_logger

logger = get_logger(__name__)


class Normalization(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
z = std == 0

if z.any():
logger.warn(f"found zero {z.sum()} std values, replacing with ones")
std[z] = 1.
std[z] = 1.0

self.register_buffer("mean", mean[None, :, None, None])
self.register_buffer("std", std[None, :, None, None])
Expand Down Expand Up @@ -40,13 +43,15 @@ def __init__(
input_grid=None,
output_grid=None,
input_grid_index=None,
output_grid_index=None
output_grid_index=None,
):
super().__init__()
self.linear = torch.nn.Linear(len(input_grid), len(output_grid), bias=False)
output_grid, input_grid = torch.tensor(output_grid), torch.tensor(input_grid)
with torch.no_grad():
self.linear.weight.data = make_interpolation_weights(output_grid, input_grid)
self.linear.weight.data = make_interpolation_weights(
output_grid, input_grid
)
self.input_grid_index = input_grid_index
self.output_grid_index = output_grid_index

Expand Down Expand Up @@ -100,12 +105,85 @@ def forward(self, x):


class Conv2dDS(torch.nn.Module):
def __init__(self, nin, nout, kernel_size=3, kernels_per_layer=1, bias = True, padding = "same" ):
super().__init__()
self.depthwise = torch.nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=padding, groups=nin, bias = bias)
self.pointwise = torch.nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1, bias = bias)

def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
def __init__(
self, nin, nout, kernel_size=3, kernels_per_layer=1, bias=True, padding="same"
):
super().__init__()
self.depthwise = torch.nn.Conv2d(
nin,
nin * kernels_per_layer,
kernel_size=kernel_size,
padding=padding,
groups=nin,
bias=bias,
)
self.pointwise = torch.nn.Conv2d(
nin * kernels_per_layer, nout, kernel_size=1, bias=bias
)

def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out


class MultiIndexEmbedding(torch.nn.Module):
def __init__(self, hidden_dim, index_shape, init_value=None):
super().__init__()
self.hidden_dim = hidden_dim
self.index_shape = index_shape
self.embeddings = torch.nn.ModuleList(
[torch.nn.Embedding(num_emb, hidden_dim) for num_emb in index_shape]
)
if init_value is not None:
for e in self.embeddings:
with torch.no_grad():
e.weight.data.fill_(init_value)

def forward(self, x):
out = torch.zeros(x.shape[0], self.hidden_dim, device=x.device)
for i, emb in enumerate(self.embeddings):
out += emb(x[:, i])

return out / len(self.embeddings)


class FCLayer(torch.nn.Module):
def __init__(
self,
ins,
outs,
batch_norm=False,
dropout=0.,
leaky_relu=-1,
index_shape=None,
):
super().__init__()

self.linear = torch.nn.Linear(ins, outs)
self.batch_norm = (
torch.nn.BatchNorm1d(outs) if batch_norm else torch.nn.Identity()
)
self.drop_out = (
torch.nn.Dropout(dropout) if dropout > 0 else torch.nn.Identity()
)
self.relu = (
torch.nn.LeakyReLU(leaky_relu) if leaky_relu > -1 else torch.nn.Identity()
)

if index_shape is not None:
self.scale = MultiIndexEmbedding(outs, index_shape, init_value=1.0)
self.bias = MultiIndexEmbedding(outs, index_shape, init_value=0.0)

def forward(self, x, index=None):

x = self.linear(x)
x = self.batch_norm(x)

if index is not None:
x = x * self.scale(index) + self.bias(index)

x = self.drop_out(x)
x = self.relu(x)

return x
Loading

0 comments on commit 5c3c7cf

Please sign in to comment.