Skip to content

Commit

Permalink
Add python components of WaveNet multi-input example.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatkinson1000 committed Mar 18, 2024
1 parent 50667db commit 66fedb0
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 0 deletions.
136 changes: 136 additions & 0 deletions examples/3_Multiple_Inputs/pt2ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Load a pytorch model and convert it to TorchScript."""

from typing import Optional
import torch

# FPTLIB-TODO
# Add a module import with your model here:
import run_wavenet as rwn


def script_to_torchscript(
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt"
) -> None:
"""
Save pyTorch model to TorchScript using scripting.
Parameters
----------
model : torch.NN.Module
a pyTorch model
filename : str
name of file to save to
"""
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
scripted_model = torch.jit.script(model)
print(scripted_model.code)
scripted_model.save(filename)


def trace_to_torchscript(
model: torch.nn.Module,
dummy_input: torch.Tensor,
filename: Optional[str] = "traced_model.pt",
) -> None:
"""
Save pyTorch model to TorchScript using tracing.
Parameters
----------
model : torch.NN.Module
a pyTorch model
dummy_input : torch.Tensor
appropriate size Tensor to act as input to model
filename : str
name of file to save to
"""
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
traced_model = torch.jit.trace(model, dummy_input)
# traced_model.save(filename)
frozen_model = torch.jit.freeze(traced_model)
## print(frozen_model.graph)
## print(frozen_model.code)
frozen_model.save(filename)


def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module:
"""
Load a TorchScript from file.
Parameters
----------
filename : str
name of file containing TorchScript model
"""
model = torch.jit.load(filename)

return model


if __name__ == "__main__":
# FPTLIB-TODO
# Load a pre-trained PyTorch model
# Insert code here to load your model from file as `trained_model`:
trained_model = rwn.initialize()

# Switch-off some specific layers/parts of the model that behave
# differently during training and inference.
# This may have been done by the user already, so just make sure here.
trained_model.eval()

# FPTLIB-TODO
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size.
# trained_model_dummy_input = torch.ones((512, 42))
trained_model_dummy_input_u = torch.ones((512, 40), dtype=torch.float64)
trained_model_dummy_input_l = torch.ones((512, 1), dtype=torch.float64)
trained_model_dummy_input_p = torch.ones((512, 1), dtype=torch.float64)

# Run model over dummy input
# If something isn't working This will generate an error
trained_model_dummy_output = trained_model(
trained_model_dummy_input_u,
trained_model_dummy_input_l,
trained_model_dummy_input_p,
)

# FPTLIB-TODO
# If you want to save for inference on GPU uncomment the following 4 lines:
# device = torch.device('cuda')
# model = model.to(device)
# model.eval()
# dummy_input = dummy_input.to(device)

# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to
saved_ts_filename = "saved_model.pt"

# FPTLIB-TODO
# Save the pytorch model using either scripting (recommended where possible) or tracing
# -----------
# Scripting
# -----------
script_to_torchscript(trained_model, filename=saved_ts_filename)

# -----------
# Tracing
# -----------
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename)

# Load torchscript and run model as a test
testing_input_u = 2.0 * trained_model_dummy_input_u
testing_input_l = 2.0 * trained_model_dummy_input_l
testing_input_p = 2.0 * trained_model_dummy_input_p
trained_model_testing_output = trained_model(
testing_input_u, testing_input_l, testing_input_p
)
ts_model = load_torchscript(filename=saved_ts_filename)
ts_model_output = ts_model(testing_input_u, testing_input_l, testing_input_p)

if torch.all(ts_model_output.eq(trained_model_testing_output)):
print("Saved TorchScript model working as expected in a basic test.")
print("Users should perform further validation as appropriate.")
else:
raise RuntimeError(
"Saved Torchscript model is not performing as expected.\n"
"Consider using scripting if you used tracing, or investigate further."
)
2 changes: 2 additions & 0 deletions examples/3_Multiple_Inputs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
numpy
82 changes: 82 additions & 0 deletions examples/3_Multiple_Inputs/run_wavenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Contains all python commands MiMA will use.
It needs in the same directory as `wavenet.py` which describes the
model architecture, and `wavenet_weights.pkl` which contains the model weights.
"""

from torch import load, device, no_grad, reshape, zeros, tensor, float64
import wavenet as m


# Initialize everything
def initialize(path_weights_stats="wavenet_weights.pkl"):
"""
Initialize a WaveNet model and load weights.
Parameters
__________
path_weights_stats : pickled object that contains weights and statistics (means and stds).
"""

device_str = "cpu"
checkpoint = load(path_weights_stats, map_location=device(device_str))
model = m.WaveNet(checkpoint).to(device_str)

# Load weights and set to evaluation mode.
model.load_state_dict(checkpoint["weights"])
model.eval()
return model


# Compute drag
def compute_reshape_drag(*args):
"""
Compute the drag from inputs using a neural net.
Takes in input arguments passed from MiMA and outputs drag in shape desired by MiMA.
Reshaping & porting to torch.tensor type, and applying model.forward is performed.
Parameters
__________
model : nn.Module
WaveNet model ready to be deployed.
wind :
U or V (128, num_col, 40)
lat :
latitudes (num_col)
p_surf :
surface pressure (128, num_col)
Y_out :
output prellocated in MiMA (128, num_col, 40)
num_col :
# of latitudes on this proc
Returns
-------
Y_out :
Results to be returned to MiMA
"""
model, wind, lat, p_surf, Y_out, num_col = args

# Reshape and put all input variables together [wind, lat, p_surf]
wind_T = tensor(wind)

# lat_T = zeros((imax * num_col, 1), dtype=float64)
lat_T = tensor(lat)

# pressure_T = zeros((imax * num_col, 1), dtype=float64)
pressure_T = tensor(p_surf)

# Apply model.
with no_grad():
# Ensure evaluation mode (leave training mode and stop using current batch stats)
# model.eval() # Set during initialisation
assert model.training is False
temp = model(wind_T, lat_T, pressure_T)

# Place in output array for MiMA.
Y_out[:, :] = temp

return Y_out
115 changes: 115 additions & 0 deletions examples/3_Multiple_Inputs/wavenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Module defining the pytorch WaveNet architecture for coupling to MiMA. """

import torch
from torch import nn


class WaveNet(nn.Module):
"""Neural network architecture following Espinosa et al. (2022)."""

def __init__(
self,
checkpoint,
n_in: int = 42,
n_out: int = 40,
branch_dims=None,
) -> None:
"""
Initialize a WaveNet model.
Parameters
----------
checkpoint: dict
dictionary containing weights & statistics.
n_in : int
Number of input features.
n_out : int
Number of output features.
branch_dims : Union[list, None]
List of dimensions of the layers to include in each of the level-specific branches.
"""

if branch_dims is None:
branch_dims = [64, 32]

super().__init__()

shared = [nn.BatchNorm1d(n_in), nn.Linear(n_in, 256), nn.ReLU()]
for _ in range(4):
shared.extend([nn.Linear(256, 256)])
shared.extend([nn.ReLU()])

shared.extend([nn.Linear(256, branch_dims[0])])
shared.extend([nn.ReLU()])

# All data gets fed through shared, then extra layers defined in branches for each z-level
branches = []
for _ in range(n_out):
args: list[nn.Module] = []
for in_features, out_features in zip(branch_dims[:-1], branch_dims[1:]):
args.extend([nn.Linear(in_features, out_features)])
args.extend([nn.ReLU()])

args.extend([nn.Linear(branch_dims[-1], 1)])
branches.append(nn.Sequential(*args))

self.shared = nn.Sequential(*shared)
self.branches = nn.ModuleList(branches)

self.shared.apply(_xavier_init)
for branch in self.branches:
branch.apply(_xavier_init)

self.double()
self.means = checkpoint["means"]
self.stds = checkpoint["stds"]
del checkpoint

def forward(
self, wind: torch.Tensor, lat: torch.Tensor, pressure: torch.Tensor
) -> torch.Tensor:
"""
Apply the network to a `Tensor` of input features.
Parameters
----------
wind : torch.Tensor
Tensor of of input wind flattened to (n_lat*n_lon, 40).
lat : torch.Tensor
Tensor of of input features flattened to (n_lat*n_lon, 1).
pressure : torch.Tensor
Tensor of of input features flattened to (n_lat*n_lon, 1).
Returns
-------
output : torch.Tensor
Tensor of predicted outputs.
"""

Z, levels = self.shared(torch.cat((wind, lat, pressure), 1)), []

for branch in self.branches:
levels.append(branch(Z).squeeze())
Y = torch.vstack(levels).T

# Un-standardize
Y *= self.stds
Y += self.means
return Y


def _xavier_init(layer: nn.Module) -> None:
"""
Apply Xavier initialization to a layer if it is an `nn.Linear`.
Parameters
----------
layer : nn.Module
Linear to potentially initialize.
"""

if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
18 changes: 18 additions & 0 deletions examples/3_Multiple_Inputs/wavenet_infer_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Script to test WaveNet NN"""

import numpy as np
import run_wavenet as rwn


IMAX = 128
NUM_COL = 4

# Generate the four input tensors and populate with random data
wind = np.random.randn(IMAX * NUM_COL, 40)
lat = np.random.randn(IMAX * NUM_COL, 1)
ps = np.random.randn(IMAX * NUM_COL, 1)
Y_out = np.zeros((IMAX * NUM_COL, 40))

# Initialise and run the model
model = rwn.initialize()
Y_out = rwn.compute_reshape_drag(model, wind, lat, ps, Y_out, NUM_COL)
Binary file added examples/3_Multiple_Inputs/wavenet_weights.pkl
Binary file not shown.

0 comments on commit 66fedb0

Please sign in to comment.