Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-input example #94

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions examples/3_Multiple_Inputs/pt2ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Load a pytorch model and convert it to TorchScript."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could mention the methods included for torchscript conversion here (tracing and scripting), and that scripting is the recommended method


from typing import Optional
import torch

# FPTLIB-TODO
# Add a module import with your model here:
import run_wavenet as rwn


def script_to_torchscript(
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt"
) -> None:
"""
Save pyTorch model to TorchScript using scripting.

Parameters
----------
model : torch.NN.Module
a pyTorch model
filename : str
name of file to save to
"""
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
scripted_model = torch.jit.script(model)
print(scripted_model.code)
scripted_model.save(filename)


def trace_to_torchscript(
model: torch.nn.Module,
dummy_input: torch.Tensor,
filename: Optional[str] = "traced_model.pt",
) -> None:
"""
Save pyTorch model to TorchScript using tracing.

Parameters
----------
model : torch.NN.Module
a pyTorch model
dummy_input : torch.Tensor
appropriate size Tensor to act as input to model
filename : str
name of file to save to
"""
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
traced_model = torch.jit.trace(model, dummy_input)
# traced_model.save(filename)
frozen_model = torch.jit.freeze(traced_model)
## print(frozen_model.graph)
## print(frozen_model.code)
frozen_model.save(filename)


def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module:
"""
Load a TorchScript from file.

Parameters
----------
filename : str
name of file containing TorchScript model
"""
model = torch.jit.load(filename)

return model


if __name__ == "__main__":
# FPTLIB-TODO
# Load a pre-trained PyTorch model
# Insert code here to load your model from file as `trained_model`:
trained_model = rwn.initialize()

# Switch-off some specific layers/parts of the model that behave
# differently during training and inference.
# This may have been done by the user already, so just make sure here.
trained_model.eval()

# FPTLIB-TODO
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size.
# trained_model_dummy_input = torch.ones((512, 42))
trained_model_dummy_input_u = torch.ones((512, 40), dtype=torch.float64)
trained_model_dummy_input_l = torch.ones((512, 1), dtype=torch.float64)
trained_model_dummy_input_p = torch.ones((512, 1), dtype=torch.float64)

# Run model over dummy input
# If something isn't working This will generate an error
trained_model_dummy_output = trained_model(
trained_model_dummy_input_u,
trained_model_dummy_input_l,
trained_model_dummy_input_p,
)

# FPTLIB-TODO
# If you want to save for inference on GPU uncomment the following 4 lines:
# device = torch.device('cuda')
# model = model.to(device)
# model.eval()
# dummy_input = dummy_input.to(device)

# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to
saved_ts_filename = "saved_model.pt"

# FPTLIB-TODO
# Save the pytorch model using either scripting (recommended where possible) or tracing
# -----------
# Scripting
# -----------
script_to_torchscript(trained_model, filename=saved_ts_filename)

# -----------
# Tracing
# -----------
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename)

# Load torchscript and run model as a test
testing_input_u = 2.0 * trained_model_dummy_input_u
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a line to why the input is scaled by a factor of 2

testing_input_l = 2.0 * trained_model_dummy_input_l
testing_input_p = 2.0 * trained_model_dummy_input_p
trained_model_testing_output = trained_model(
testing_input_u, testing_input_l, testing_input_p
)
ts_model = load_torchscript(filename=saved_ts_filename)
ts_model_output = ts_model(testing_input_u, testing_input_l, testing_input_p)

if torch.all(ts_model_output.eq(trained_model_testing_output)):
print("Saved TorchScript model working as expected in a basic test.")
print("Users should perform further validation as appropriate.")
else:
raise RuntimeError(
"Saved Torchscript model is not performing as expected.\n"
"Consider using scripting if you used tracing, or investigate further."
)
2 changes: 2 additions & 0 deletions examples/3_Multiple_Inputs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding minimum required versions would be good?

numpy
82 changes: 82 additions & 0 deletions examples/3_Multiple_Inputs/run_wavenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Contains all python commands MiMA will use.

It needs in the same directory as `wavenet.py` which describes the
model architecture, and `wavenet_weights.pkl` which contains the model weights.
"""

from torch import load, device, no_grad, reshape, zeros, tensor, float64
import wavenet as m


# Initialize everything
def initialize(path_weights_stats="wavenet_weights.pkl"):
"""
Initialize a WaveNet model and load weights.

Parameters
----------
path_weights_stats : str
path to pickled object that contains weights and statistics (means and stds).

"""
device_str = "cpu"
checkpoint = load(path_weights_stats, map_location=device(device_str))
model = m.WaveNet(checkpoint).to(device_str)

# Load weights and set to evaluation mode.
model.load_state_dict(checkpoint["weights"])
model.eval()
return model


# Compute drag
def compute_reshape_drag(*args):
"""
Compute the drag from inputs using a neural net.

Takes in input arguments passed from MiMA and outputs drag in shape desired by MiMA.
Reshaping & porting to torch.tensor type, and applying model.forward is performed.

Parameters
----------
model : nn.Module
WaveNet model ready to be deployed.
wind :
U or V (128, num_col, 40)
lat :
latitudes (num_col)
p_surf :
surface pressure (128, num_col)
Y_out :
output prellocated in MiMA (128, num_col, 40)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spelling of prelocated

num_col :
# of latitudes on this proc

Returns
-------
Y_out :
Results to be returned to MiMA
"""
model, wind, lat, p_surf, Y_out, num_col = args

# Reshape and put all input variables together [wind, lat, p_surf]
wind_T = tensor(wind)

# lat_T = zeros((imax * num_col, 1), dtype=float64)
lat_T = tensor(lat)

# pressure_T = zeros((imax * num_col, 1), dtype=float64)
pressure_T = tensor(p_surf)

# Apply model.
with no_grad():
# Ensure evaluation mode (leave training mode and stop using current batch stats)
# model.eval() # Set during initialisation
assert model.training is False
temp = model(wind_T, lat_T, pressure_T)

# Place in output array for MiMA.
Y_out[:, :] = temp

return Y_out
112 changes: 112 additions & 0 deletions examples/3_Multiple_Inputs/wavenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Module defining the pytorch WaveNet architecture for coupling to MiMA."""

import torch
from torch import nn


class WaveNet(nn.Module):
"""Neural network architecture following Espinosa et al. (2022)."""

def __init__(
self,
checkpoint,
n_in: int = 42,
n_out: int = 40,
branch_dims=None,
) -> None:
"""
Initialize a WaveNet model.

Parameters
----------
checkpoint: dict
dictionary containing weights & statistics.
n_in : int
Number of input features.
n_out : int
Number of output features.
branch_dims : Union[list, None]
List of dimensions of the layers to include in each of the level-specific branches.

"""
if branch_dims is None:
branch_dims = [64, 32]

super().__init__()

shared = [nn.BatchNorm1d(n_in), nn.Linear(n_in, 256), nn.ReLU()]
for _ in range(4):
shared.extend([nn.Linear(256, 256)])
shared.extend([nn.ReLU()])

shared.extend([nn.Linear(256, branch_dims[0])])
shared.extend([nn.ReLU()])

# All data gets fed through shared, then extra layers defined in branches for each z-level
branches = []
for _ in range(n_out):
args: list[nn.Module] = []
for in_features, out_features in zip(branch_dims[:-1], branch_dims[1:]):
args.extend([nn.Linear(in_features, out_features)])
args.extend([nn.ReLU()])

args.extend([nn.Linear(branch_dims[-1], 1)])
branches.append(nn.Sequential(*args))

self.shared = nn.Sequential(*shared)
self.branches = nn.ModuleList(branches)

self.shared.apply(_xavier_init)
for branch in self.branches:
branch.apply(_xavier_init)

self.double()
self.means = checkpoint["means"]
self.stds = checkpoint["stds"]
del checkpoint

def forward(
self, wind: torch.Tensor, lat: torch.Tensor, pressure: torch.Tensor
) -> torch.Tensor:
"""
Apply the network to a `Tensor` of input features.

Parameters
----------
wind : torch.Tensor
Tensor of of input wind flattened to (n_lat*n_lon, 40).
lat : torch.Tensor
Tensor of of input features flattened to (n_lat*n_lon, 1).
pressure : torch.Tensor
Tensor of of input features flattened to (n_lat*n_lon, 1).

Returns
-------
output : torch.Tensor
Tensor of predicted outputs.

"""
Z, levels = self.shared(torch.cat((wind, lat, pressure), 1)), []

for branch in self.branches:
levels.append(branch(Z).squeeze())
Y = torch.vstack(levels).T

# Un-standardize
Y *= self.stds
Y += self.means
return Y


def _xavier_init(layer: nn.Module) -> None:
"""
Apply Xavier initialization to a layer if it is an `nn.Linear`.

Parameters
----------
layer : nn.Module
Linear to potentially initialize.

"""
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
18 changes: 18 additions & 0 deletions examples/3_Multiple_Inputs/wavenet_infer_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Script to test WaveNet NN."""

import numpy as np
import run_wavenet as rwn


IMAX = 128
NUM_COL = 4

# Generate the four input tensors and populate with random data
wind = np.random.randn(IMAX * NUM_COL, 40)
lat = np.random.randn(IMAX * NUM_COL, 1)
ps = np.random.randn(IMAX * NUM_COL, 1)
Y_out = np.zeros((IMAX * NUM_COL, 40))

# Initialise and run the model
model = rwn.initialize()
Y_out = rwn.compute_reshape_drag(model, wind, lat, ps, Y_out, NUM_COL)
Binary file added examples/3_Multiple_Inputs/wavenet_weights.pkl
Binary file not shown.
Loading