From 5c3c7cf4044b755094cc1ce5f65b1648f3c42c99 Mon Sep 17 00:00:00 2001 From: ktrapeznikov Date: Thu, 28 Jul 2022 16:05:49 -0400 Subject: [PATCH] add support for lat/lon cond (#38) Co-authored-by: Kirill --- gaia/config.py | 29 ++++++++- gaia/data.py | 95 ++++++++++++++++++++++++++-- gaia/layers.py | 100 +++++++++++++++++++++++++---- gaia/models.py | 166 +++++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 347 insertions(+), 43 deletions(-) diff --git a/gaia/config.py b/gaia/config.py index 016bb20..be448de 100644 --- a/gaia/config.py +++ b/gaia/config.py @@ -107,7 +107,13 @@ def set_dataset_params(cli_args=dict()): raise ValueError(f"unknown dataset {dataset}") var_index_file = base + "_var_index.pt" + + #possibly shared params + batch_size = cli_args.get('dataset_params',{}).get("batch_size",24 * 96 * 144) + include_index = cli_args.get('dataset_params',{}).get("include_index",False) + subsample = cli_args.get('dataset_params',{}).get("subsample",1) + space_filter = cli_args.get('dataset_params',{}).get("space_filter",None) dataset_params = dict( train=dict( @@ -115,21 +121,30 @@ def set_dataset_params(cli_args=dict()): batch_size=batch_size, shuffle=True, flatten=False, # already flattened - var_index_file=var_index_file + var_index_file=var_index_file, + include_index = include_index, + subsample = subsample, + space_filter =space_filter ), val=dict( dataset_file=base + "_val.pt", batch_size=batch_size, shuffle=False, flatten=False, # already flattened - var_index_file=var_index_file + var_index_file=var_index_file, + include_index = include_index, + subsample = subsample, + space_filter =space_filter ), test=dict( dataset_file=base+'_test.pt', batch_size=batch_size, shuffle=False, flatten=True, # already flattened - var_index_file=var_index_file + var_index_file=var_index_file, + include_index = include_index, + subsample = subsample, + space_filter =space_filter ), mean_thres=mean_thres ) @@ -212,6 +227,14 @@ def model_type_lookup(model_type): "num_layers": 3, "hidden_size": 128, } + elif model_type == "fcn_with_index": + model_config = { + "model_type": "fcn_with_index", + "num_layers": 7, + "hidden_size": 512, + "dropout": 0.01, + "leaky_relu": 0.15 + } else: raise ValueError diff --git a/gaia/data.py b/gaia/data.py index 2060361..0863770 100644 --- a/gaia/data.py +++ b/gaia/data.py @@ -932,6 +932,28 @@ def get_dataset_v1( return dataset_dict, data_loader +def unravel_index(flat_index, shape): + # flat_index = operator.index(flat_index) + res = [] + + # Short-circuits on zero dim tensors + if shape == torch.Size([]): + return 0 + + for size in shape[::-1]: + res.append(flat_index % size) + flat_index = flat_index // size + + # return torch.cat(res + + if len(res) == 1: + return res[0] + + return res[::-1] + + + + def get_dataset( dataset_file, batch_size=1024, @@ -939,6 +961,8 @@ def get_dataset( shuffle = False, var_index_file = None, include_index = False, + subsample = 1, + space_filter = None ): dataset_dict = torch.load(dataset_file) @@ -956,13 +980,76 @@ def get_dataset( dataset_dict[v] = flatten_tensor(dataset_dict[v]) - if not include_index: - del dataset_dict["index"] + tensor_list = [dataset_dict["x"], dataset_dict["y"]] - tensor_list = [dataset_dict["x"], dataset_dict["y"]] + if include_index or (space_filter is not None): + #TODO dont hard code this + num_ts,num_lats,num_lons = 8,96,144 + logger.warning(f"using hardcoded expected shape for unraveling the index: {num_ts,num_lats,num_lons}") + + if flatten: + + # index is flattened + # shape = dataset_dict["x"].shape + # if len(dataset_dict["x"].shape) == 3: + # samples, timesteps, channels = dataset_dict["x"].shape + # else: + # samples, channels = dataset_dict["x"].shape + + num_samples = dataset_dict["index"].shape[0] + + lats = torch.ones(num_samples,1,num_lons)*torch.arange(num_lats)[None,:,None] + lons = torch.ones(num_samples,num_lats,1)*torch.arange(num_lons)[None,None,:] + index = torch.cat([lats.ravel().long()[:,None], lons.ravel().long()[:,None]],dim=-1) + else: + index = unravel_index(dataset_dict["index"], shape=[num_ts, num_lats, num_lons]) + index = torch.cat([i[:,None] for i in index[1:]],dim=-1) #just want lats, lons + + else: - pass + del dataset_dict["index"] + + + if include_index: + tensor_list += [index] + + + if space_filter is not None: + # filter out dataset + + logger.info(f"applying space filter {space_filter}") + + from gaia.plot import lats as lats_vals + from gaia.plot import lons as lons_vals + + lats_vals = torch.tensor(lats_vals) + lons_vals = torch.tensor(lons_vals) + + mask = torch.ones(len(tensor_list[0])).bool() + + if "lat_bounds" in space_filter: + lat_min,lat_max = space_filter["lat_bounds"] + temp = lats_vals[index[:,0]] + mask = mask & (temp <= lat_max) & (temp>=lat_min) + + if "lon_bounds" in space_filter: + lon_min,lon_max = space_filter["lon_bounds"] + temp = lons_vals[index[:,1]] + mask = mask & (temp <= lon_max) & (temp>=lon_min) + + assert mask.any() + + tensor_list = [t[mask,...] for t in tensor_list] + + + if subsample>1: + tensor_list = [t[::subsample,...] for t in tensor_list] + logger.info(f"subsampling by factor of {subsample}") + + + logger.info(f"data size {len(tensor_list[0])}") + data_loader = DataLoader( FastTensorDataset( diff --git a/gaia/layers.py b/gaia/layers.py index 3af8693..45ebb4d 100644 --- a/gaia/layers.py +++ b/gaia/layers.py @@ -1,7 +1,10 @@ +from turtle import forward import torch from gaia import get_logger + logger = get_logger(__name__) + class Normalization(torch.nn.Module): def __init__(self, mean, std): super().__init__() @@ -9,7 +12,7 @@ def __init__(self, mean, std): if z.any(): logger.warn(f"found zero {z.sum()} std values, replacing with ones") - std[z] = 1. + std[z] = 1.0 self.register_buffer("mean", mean[None, :, None, None]) self.register_buffer("std", std[None, :, None, None]) @@ -40,13 +43,15 @@ def __init__( input_grid=None, output_grid=None, input_grid_index=None, - output_grid_index=None + output_grid_index=None, ): super().__init__() self.linear = torch.nn.Linear(len(input_grid), len(output_grid), bias=False) output_grid, input_grid = torch.tensor(output_grid), torch.tensor(input_grid) with torch.no_grad(): - self.linear.weight.data = make_interpolation_weights(output_grid, input_grid) + self.linear.weight.data = make_interpolation_weights( + output_grid, input_grid + ) self.input_grid_index = input_grid_index self.output_grid_index = output_grid_index @@ -100,12 +105,85 @@ def forward(self, x): class Conv2dDS(torch.nn.Module): - def __init__(self, nin, nout, kernel_size=3, kernels_per_layer=1, bias = True, padding = "same" ): - super().__init__() - self.depthwise = torch.nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=padding, groups=nin, bias = bias) - self.pointwise = torch.nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1, bias = bias) - - def forward(self, x): - out = self.depthwise(x) - out = self.pointwise(out) + def __init__( + self, nin, nout, kernel_size=3, kernels_per_layer=1, bias=True, padding="same" + ): + super().__init__() + self.depthwise = torch.nn.Conv2d( + nin, + nin * kernels_per_layer, + kernel_size=kernel_size, + padding=padding, + groups=nin, + bias=bias, + ) + self.pointwise = torch.nn.Conv2d( + nin * kernels_per_layer, nout, kernel_size=1, bias=bias + ) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) return out + + +class MultiIndexEmbedding(torch.nn.Module): + def __init__(self, hidden_dim, index_shape, init_value=None): + super().__init__() + self.hidden_dim = hidden_dim + self.index_shape = index_shape + self.embeddings = torch.nn.ModuleList( + [torch.nn.Embedding(num_emb, hidden_dim) for num_emb in index_shape] + ) + if init_value is not None: + for e in self.embeddings: + with torch.no_grad(): + e.weight.data.fill_(init_value) + + def forward(self, x): + out = torch.zeros(x.shape[0], self.hidden_dim, device=x.device) + for i, emb in enumerate(self.embeddings): + out += emb(x[:, i]) + + return out / len(self.embeddings) + + +class FCLayer(torch.nn.Module): + def __init__( + self, + ins, + outs, + batch_norm=False, + dropout=0., + leaky_relu=-1, + index_shape=None, + ): + super().__init__() + + self.linear = torch.nn.Linear(ins, outs) + self.batch_norm = ( + torch.nn.BatchNorm1d(outs) if batch_norm else torch.nn.Identity() + ) + self.drop_out = ( + torch.nn.Dropout(dropout) if dropout > 0 else torch.nn.Identity() + ) + self.relu = ( + torch.nn.LeakyReLU(leaky_relu) if leaky_relu > -1 else torch.nn.Identity() + ) + + if index_shape is not None: + self.scale = MultiIndexEmbedding(outs, index_shape, init_value=1.0) + self.bias = MultiIndexEmbedding(outs, index_shape, init_value=0.0) + + def forward(self, x, index=None): + + x = self.linear(x) + x = self.batch_norm(x) + + if index is not None: + x = x * self.scale(index) + self.bias(index) + + x = self.drop_out(x) + x = self.relu(x) + + return x diff --git a/gaia/models.py b/gaia/models.py index fda0b96..977b3a1 100644 --- a/gaia/models.py +++ b/gaia/models.py @@ -9,7 +9,9 @@ from gaia.data import SCALING_FACTORS, flatten_tensor, unflatten_tensor from gaia.layers import ( Conv2dDS, + FCLayer, InterpolateGrid1D, + MultiIndexEmbedding, Normalization, ResDNNLayer, make_interpolation_weights, @@ -20,8 +22,10 @@ import torch.nn.functional as F from gaia.optim import get_cosine_schedule_with_warmup from gaia import get_logger + logger = get_logger(__name__) + class TrainingModel(LightningModule): def __init__( self, @@ -38,7 +42,7 @@ def __init__( ignore_input_variables=None, interpolate=None, predict_hidden_states=False, - lr_schedule = None, + lr_schedule=None, **kwargs, ): super().__init__() @@ -82,7 +86,9 @@ def __init__( elif model_type == "fcn_history": self.model = FcnHistory(**model_config) elif model_type == "conv1d": - self.model = ConvNet1D(input_index=input_index, output_index=output_index, **model_config) + self.model = ConvNet1D( + input_index=input_index, output_index=output_index, **model_config + ) elif model_type == "resdnn": self.model = ResDNN(**model_config) elif model_type == "encoderdecoder": @@ -91,6 +97,8 @@ def __init__( self.model = TransformerModel( input_index=input_index, output_index=output_index, **model_config ) + elif model_type == "fcn_with_index": + self.model = FcnWithIndex(**model_config) else: raise ValueError("unknown model_type") @@ -231,15 +239,18 @@ def configure_optimizers(self): int(0.1 * self.trainer.estimated_stepping_batches), self.trainer.estimated_stepping_batches, ), - "interval": "step" + "interval": "step", } else: raise ValueError(f"unknown lr scheduler {self.hparams.lr_schedule}") return out - def forward(self, x): - return self.model(x) + def forward(self, x, index=None): + if index is not None: + return self.model(x, index=index) + else: + return self.model(x) def select_input_variables(self, x): if self.hparams.ignore_input_variables is None: @@ -248,7 +259,14 @@ def select_input_variables(self, x): return x[:, self.input_variable_index, ...] def handle_batch(self, batch): - x, y = batch + + x, y = batch[:2] + + index = None + + if len(batch) > 2: + index = batch[2] + num_dims = len(x.shape) if num_dims == 3 or num_dims == 5: @@ -272,11 +290,11 @@ def handle_batch(self, batch): if self.hparams.memory_variables is not None: # not using all variables for history y1 = y1[:, self.memory_variable_index, ...] - return [x, y1], y2 + res = [x, y1], y2 else: # dont use history - return x, y2 + res = x, y2 else: @@ -289,14 +307,16 @@ def handle_batch(self, batch): x = self.select_input_variables(x) - return x, y + res = x, y + + return res + (index,) def step(self, batch, mode="train"): # x, y = batch # x = self.input_normalize(x) # y = self.output_normalize(y) - x, y = self.handle_batch(batch) + x, y, index = self.handle_batch(batch) if len(y.shape) == 2: reduce_dims = [0] @@ -305,13 +325,11 @@ def step(self, batch, mode="train"): else: raise ValueError("wrong size of x") - yhat = self(x) + yhat = self(x, index=index) if self.hparams.interpolate is not None: yhat = self.interpolate_model_to_data_output(yhat) - - loss = OrderedDict() mse = F.mse_loss(y, yhat, reduction="none") @@ -342,17 +360,15 @@ def step(self, batch, mode="train"): loss_name = f"skill_ave_trunc_{k}_{i:02}" loss[loss_name] = skill[i] - if self.hparams.loss_output_weights is not None: num_dims = len(mse.shape) if num_dims == 4: - mse = mse * self.loss_output_weights[None,:, None, None] + mse = mse * self.loss_output_weights[None, :, None, None] elif num_dims == 2: - mse = mse * self.loss_output_weights[None,:] + mse = mse * self.loss_output_weights[None, :] else: raise ValueError("wrong number of dims in mse") - loss["mse"] = mse.mean() for k, v in loss.items(): @@ -394,9 +410,18 @@ def __init__( output_size: int = 26 * 2, dropout: float = 0.01, leaky_relu: float = 0.15, + use_index = False, model_type=None, ): super().__init__() + + if use_index: + # add lon/lat as an additional input + from gaia.plot import lats, lons + self.register_buffer("lats",torch.tensor(lats)/90.) + self.register_buffer("lons",torch.tensor(lons)/180.) + input_size = input_size+2 + self.hidden_size = hidden_size self.input_size = input_size self.output_size = output_size @@ -404,6 +429,7 @@ def __init__( self.leaky_relu = leaky_relu self.num_layers = num_layers self.model = self.make_model() + def make_model(self): if self.num_layers == 1: @@ -427,8 +453,75 @@ def make_layer(ins, outs): layers = [input_layer] + intermediate_layers + [output_layer] return torch.nn.Sequential(*layers) - def forward(self, x): - return self.model(x) + def forward(self, x, index = None): + if index is None: + return self.model(x) + else: + lats = self.lats[index[:,0]] + lons = self.lons[index[:,1]] + x = torch.cat([x,lats[:,None], lons[:,None]],dim = 1) + return self.model(x) + + +class FcnWithIndex(torch.nn.Module): + def __init__( + self, + input_size: int = 26 * 2, + num_layers: int = 7, + hidden_size: int = 512, + output_size: int = 26 * 2, + dropout: float = 0.01, + leaky_relu: float = 0.15, + index_shape=None, + model_type=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.input_size = input_size + self.output_size = output_size + self.dropout = dropout + self.leaky_relu = leaky_relu + self.num_layers = num_layers + self.index_shape = index_shape + + self.layers = torch.nn.ModuleList(self.make_model()) + + def make_model(self): + if self.num_layers == 1: + return [FCLayer( + self.input_size, self.output_size, index_shape=self.index_shape + )] + + input_layer = FCLayer( + self.input_size, + self.hidden_size, + batch_norm=True, + dropout=self.dropout, + leaky_relu=self.leaky_relu, + index_shape=self.index_shape, + ) + + intermediate_layers = [ + FCLayer( + self.hidden_size, + self.hidden_size, + batch_norm=True, + dropout=self.dropout, + leaky_relu=self.leaky_relu, + index_shape=self.index_shape, + ) + for _ in range(self.num_layers - 2) + ] + output_layer = FCLayer( + self.hidden_size, self.output_size, index_shape=self.index_shape + ) + layers = [input_layer] + intermediate_layers + [output_layer] + return layers + + def forward(self, x, index = None): + for layer in self.layers: + x = layer(x,index = index) + return x class EncoderDecoder(torch.nn.Module): @@ -441,6 +534,8 @@ def __init__( dropout: float = 0.01, leaky_relu: float = 0.15, bottleneck_dim: int = 32, + encoder_layers=None, + index_shape=None, model_type=None, ): super().__init__() @@ -453,8 +548,9 @@ def __init__( self.num_layers = num_layers self.bottleneck_dim = bottleneck_dim self.scale = 1 + if encoder_layers is None: + encoder_layers = ceil(self.num_layers / 2) - encoder_layers = ceil(self.num_layers / 2) decoder_layers = self.num_layers - encoder_layers self.encoder = FcnBaseline( @@ -475,9 +571,22 @@ def __init__( leaky_relu=leaky_relu, ) - def forward(self, x, return_hidden_state=False): + if index_shape is not None: + self.index_embedding_scale = MultiIndexEmbedding( + bottleneck_dim, index_shape, init_value=1.0 + ) + self.index_embedding_bias = MultiIndexEmbedding( + bottleneck_dim, index_shape, init_value=0.0 + ) + + def forward(self, x, return_hidden_state=False, index=None): b = self.encoder(x) + if index is not None: + scale = self.index_embedding_scale(index) + bias = self.index_embedding_bias(index) + b = b * scale + bias + if self.scale > 1 and not self.training: b = unflatten_tensor(b) s = self.scale @@ -865,16 +974,19 @@ def __init__( self.model = self.make_model() def make_model(self): - if self.conv_type == 'conv2d': + if self.conv_type == "conv2d": conv = torch.nn.Conv2d elif self.conv_type == "conv2d_ds": conv = Conv2dDS else: raise ValueError(f"unknown {self.conv_typ}") + def make_layer(ins, outs): - + layer = torch.nn.Sequential( - conv(ins, outs, kernel_size=self.kernel_size, bias=True, padding = "same"), + conv( + ins, outs, kernel_size=self.kernel_size, bias=True, padding="same" + ), torch.nn.BatchNorm2d(outs), torch.nn.Dropout(self.dropout), torch.nn.LeakyReLU(self.leaky_relu), @@ -887,7 +999,11 @@ def make_layer(ins, outs): for _ in range(self.num_layers - 2) ] output_layer = conv( - self.hidden_size, self.output_size, kernel_size=self.kernel_size, padding = "same", bias=True + self.hidden_size, + self.output_size, + kernel_size=self.kernel_size, + padding="same", + bias=True, ) layers = [input_layer] + intermediate_layers + [output_layer]