-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add python components of WaveNet multi-input example.
- Loading branch information
1 parent
50667db
commit 66fedb0
Showing
6 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Load a pytorch model and convert it to TorchScript.""" | ||
|
||
from typing import Optional | ||
import torch | ||
|
||
# FPTLIB-TODO | ||
# Add a module import with your model here: | ||
import run_wavenet as rwn | ||
|
||
|
||
def script_to_torchscript( | ||
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt" | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using scripting. | ||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
filename : str | ||
name of file to save to | ||
""" | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
scripted_model = torch.jit.script(model) | ||
print(scripted_model.code) | ||
scripted_model.save(filename) | ||
|
||
|
||
def trace_to_torchscript( | ||
model: torch.nn.Module, | ||
dummy_input: torch.Tensor, | ||
filename: Optional[str] = "traced_model.pt", | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using tracing. | ||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
dummy_input : torch.Tensor | ||
appropriate size Tensor to act as input to model | ||
filename : str | ||
name of file to save to | ||
""" | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
traced_model = torch.jit.trace(model, dummy_input) | ||
# traced_model.save(filename) | ||
frozen_model = torch.jit.freeze(traced_model) | ||
## print(frozen_model.graph) | ||
## print(frozen_model.code) | ||
frozen_model.save(filename) | ||
|
||
|
||
def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module: | ||
""" | ||
Load a TorchScript from file. | ||
Parameters | ||
---------- | ||
filename : str | ||
name of file containing TorchScript model | ||
""" | ||
model = torch.jit.load(filename) | ||
|
||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
# FPTLIB-TODO | ||
# Load a pre-trained PyTorch model | ||
# Insert code here to load your model from file as `trained_model`: | ||
trained_model = rwn.initialize() | ||
|
||
# Switch-off some specific layers/parts of the model that behave | ||
# differently during training and inference. | ||
# This may have been done by the user already, so just make sure here. | ||
trained_model.eval() | ||
|
||
# FPTLIB-TODO | ||
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size. | ||
# trained_model_dummy_input = torch.ones((512, 42)) | ||
trained_model_dummy_input_u = torch.ones((512, 40), dtype=torch.float64) | ||
trained_model_dummy_input_l = torch.ones((512, 1), dtype=torch.float64) | ||
trained_model_dummy_input_p = torch.ones((512, 1), dtype=torch.float64) | ||
|
||
# Run model over dummy input | ||
# If something isn't working This will generate an error | ||
trained_model_dummy_output = trained_model( | ||
trained_model_dummy_input_u, | ||
trained_model_dummy_input_l, | ||
trained_model_dummy_input_p, | ||
) | ||
|
||
# FPTLIB-TODO | ||
# If you want to save for inference on GPU uncomment the following 4 lines: | ||
# device = torch.device('cuda') | ||
# model = model.to(device) | ||
# model.eval() | ||
# dummy_input = dummy_input.to(device) | ||
|
||
# FPTLIB-TODO | ||
# Set the name of the file you want to save the torchscript model to | ||
saved_ts_filename = "saved_model.pt" | ||
|
||
# FPTLIB-TODO | ||
# Save the pytorch model using either scripting (recommended where possible) or tracing | ||
# ----------- | ||
# Scripting | ||
# ----------- | ||
script_to_torchscript(trained_model, filename=saved_ts_filename) | ||
|
||
# ----------- | ||
# Tracing | ||
# ----------- | ||
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename) | ||
|
||
# Load torchscript and run model as a test | ||
testing_input_u = 2.0 * trained_model_dummy_input_u | ||
testing_input_l = 2.0 * trained_model_dummy_input_l | ||
testing_input_p = 2.0 * trained_model_dummy_input_p | ||
trained_model_testing_output = trained_model( | ||
testing_input_u, testing_input_l, testing_input_p | ||
) | ||
ts_model = load_torchscript(filename=saved_ts_filename) | ||
ts_model_output = ts_model(testing_input_u, testing_input_l, testing_input_p) | ||
|
||
if torch.all(ts_model_output.eq(trained_model_testing_output)): | ||
print("Saved TorchScript model working as expected in a basic test.") | ||
print("Users should perform further validation as appropriate.") | ||
else: | ||
raise RuntimeError( | ||
"Saved Torchscript model is not performing as expected.\n" | ||
"Consider using scripting if you used tracing, or investigate further." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
Contains all python commands MiMA will use. | ||
It needs in the same directory as `wavenet.py` which describes the | ||
model architecture, and `wavenet_weights.pkl` which contains the model weights. | ||
""" | ||
|
||
from torch import load, device, no_grad, reshape, zeros, tensor, float64 | ||
import wavenet as m | ||
|
||
|
||
# Initialize everything | ||
def initialize(path_weights_stats="wavenet_weights.pkl"): | ||
""" | ||
Initialize a WaveNet model and load weights. | ||
Parameters | ||
__________ | ||
path_weights_stats : pickled object that contains weights and statistics (means and stds). | ||
""" | ||
|
||
device_str = "cpu" | ||
checkpoint = load(path_weights_stats, map_location=device(device_str)) | ||
model = m.WaveNet(checkpoint).to(device_str) | ||
|
||
# Load weights and set to evaluation mode. | ||
model.load_state_dict(checkpoint["weights"]) | ||
model.eval() | ||
return model | ||
|
||
|
||
# Compute drag | ||
def compute_reshape_drag(*args): | ||
""" | ||
Compute the drag from inputs using a neural net. | ||
Takes in input arguments passed from MiMA and outputs drag in shape desired by MiMA. | ||
Reshaping & porting to torch.tensor type, and applying model.forward is performed. | ||
Parameters | ||
__________ | ||
model : nn.Module | ||
WaveNet model ready to be deployed. | ||
wind : | ||
U or V (128, num_col, 40) | ||
lat : | ||
latitudes (num_col) | ||
p_surf : | ||
surface pressure (128, num_col) | ||
Y_out : | ||
output prellocated in MiMA (128, num_col, 40) | ||
num_col : | ||
# of latitudes on this proc | ||
Returns | ||
------- | ||
Y_out : | ||
Results to be returned to MiMA | ||
""" | ||
model, wind, lat, p_surf, Y_out, num_col = args | ||
|
||
# Reshape and put all input variables together [wind, lat, p_surf] | ||
wind_T = tensor(wind) | ||
|
||
# lat_T = zeros((imax * num_col, 1), dtype=float64) | ||
lat_T = tensor(lat) | ||
|
||
# pressure_T = zeros((imax * num_col, 1), dtype=float64) | ||
pressure_T = tensor(p_surf) | ||
|
||
# Apply model. | ||
with no_grad(): | ||
# Ensure evaluation mode (leave training mode and stop using current batch stats) | ||
# model.eval() # Set during initialisation | ||
assert model.training is False | ||
temp = model(wind_T, lat_T, pressure_T) | ||
|
||
# Place in output array for MiMA. | ||
Y_out[:, :] = temp | ||
|
||
return Y_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Module defining the pytorch WaveNet architecture for coupling to MiMA. """ | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class WaveNet(nn.Module): | ||
"""Neural network architecture following Espinosa et al. (2022).""" | ||
|
||
def __init__( | ||
self, | ||
checkpoint, | ||
n_in: int = 42, | ||
n_out: int = 40, | ||
branch_dims=None, | ||
) -> None: | ||
""" | ||
Initialize a WaveNet model. | ||
Parameters | ||
---------- | ||
checkpoint: dict | ||
dictionary containing weights & statistics. | ||
n_in : int | ||
Number of input features. | ||
n_out : int | ||
Number of output features. | ||
branch_dims : Union[list, None] | ||
List of dimensions of the layers to include in each of the level-specific branches. | ||
""" | ||
|
||
if branch_dims is None: | ||
branch_dims = [64, 32] | ||
|
||
super().__init__() | ||
|
||
shared = [nn.BatchNorm1d(n_in), nn.Linear(n_in, 256), nn.ReLU()] | ||
for _ in range(4): | ||
shared.extend([nn.Linear(256, 256)]) | ||
shared.extend([nn.ReLU()]) | ||
|
||
shared.extend([nn.Linear(256, branch_dims[0])]) | ||
shared.extend([nn.ReLU()]) | ||
|
||
# All data gets fed through shared, then extra layers defined in branches for each z-level | ||
branches = [] | ||
for _ in range(n_out): | ||
args: list[nn.Module] = [] | ||
for in_features, out_features in zip(branch_dims[:-1], branch_dims[1:]): | ||
args.extend([nn.Linear(in_features, out_features)]) | ||
args.extend([nn.ReLU()]) | ||
|
||
args.extend([nn.Linear(branch_dims[-1], 1)]) | ||
branches.append(nn.Sequential(*args)) | ||
|
||
self.shared = nn.Sequential(*shared) | ||
self.branches = nn.ModuleList(branches) | ||
|
||
self.shared.apply(_xavier_init) | ||
for branch in self.branches: | ||
branch.apply(_xavier_init) | ||
|
||
self.double() | ||
self.means = checkpoint["means"] | ||
self.stds = checkpoint["stds"] | ||
del checkpoint | ||
|
||
def forward( | ||
self, wind: torch.Tensor, lat: torch.Tensor, pressure: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Apply the network to a `Tensor` of input features. | ||
Parameters | ||
---------- | ||
wind : torch.Tensor | ||
Tensor of of input wind flattened to (n_lat*n_lon, 40). | ||
lat : torch.Tensor | ||
Tensor of of input features flattened to (n_lat*n_lon, 1). | ||
pressure : torch.Tensor | ||
Tensor of of input features flattened to (n_lat*n_lon, 1). | ||
Returns | ||
------- | ||
output : torch.Tensor | ||
Tensor of predicted outputs. | ||
""" | ||
|
||
Z, levels = self.shared(torch.cat((wind, lat, pressure), 1)), [] | ||
|
||
for branch in self.branches: | ||
levels.append(branch(Z).squeeze()) | ||
Y = torch.vstack(levels).T | ||
|
||
# Un-standardize | ||
Y *= self.stds | ||
Y += self.means | ||
return Y | ||
|
||
|
||
def _xavier_init(layer: nn.Module) -> None: | ||
""" | ||
Apply Xavier initialization to a layer if it is an `nn.Linear`. | ||
Parameters | ||
---------- | ||
layer : nn.Module | ||
Linear to potentially initialize. | ||
""" | ||
|
||
if isinstance(layer, nn.Linear): | ||
nn.init.xavier_uniform_(layer.weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Script to test WaveNet NN""" | ||
|
||
import numpy as np | ||
import run_wavenet as rwn | ||
|
||
|
||
IMAX = 128 | ||
NUM_COL = 4 | ||
|
||
# Generate the four input tensors and populate with random data | ||
wind = np.random.randn(IMAX * NUM_COL, 40) | ||
lat = np.random.randn(IMAX * NUM_COL, 1) | ||
ps = np.random.randn(IMAX * NUM_COL, 1) | ||
Y_out = np.zeros((IMAX * NUM_COL, 40)) | ||
|
||
# Initialise and run the model | ||
model = rwn.initialize() | ||
Y_out = rwn.compute_reshape_drag(model, wind, lat, ps, Y_out, NUM_COL) |
Binary file not shown.