-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-input example #94
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Load a pytorch model and convert it to TorchScript.""" | ||
|
||
from typing import Optional | ||
import torch | ||
|
||
# FPTLIB-TODO | ||
# Add a module import with your model here: | ||
import run_wavenet as rwn | ||
|
||
|
||
def script_to_torchscript( | ||
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt" | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using scripting. | ||
|
||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
filename : str | ||
name of file to save to | ||
""" | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
scripted_model = torch.jit.script(model) | ||
print(scripted_model.code) | ||
scripted_model.save(filename) | ||
|
||
|
||
def trace_to_torchscript( | ||
model: torch.nn.Module, | ||
dummy_input: torch.Tensor, | ||
filename: Optional[str] = "traced_model.pt", | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using tracing. | ||
|
||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
dummy_input : torch.Tensor | ||
appropriate size Tensor to act as input to model | ||
filename : str | ||
name of file to save to | ||
""" | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
traced_model = torch.jit.trace(model, dummy_input) | ||
# traced_model.save(filename) | ||
frozen_model = torch.jit.freeze(traced_model) | ||
## print(frozen_model.graph) | ||
## print(frozen_model.code) | ||
frozen_model.save(filename) | ||
|
||
|
||
def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module: | ||
""" | ||
Load a TorchScript from file. | ||
|
||
Parameters | ||
---------- | ||
filename : str | ||
name of file containing TorchScript model | ||
""" | ||
model = torch.jit.load(filename) | ||
|
||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
# FPTLIB-TODO | ||
# Load a pre-trained PyTorch model | ||
# Insert code here to load your model from file as `trained_model`: | ||
trained_model = rwn.initialize() | ||
|
||
# Switch-off some specific layers/parts of the model that behave | ||
# differently during training and inference. | ||
# This may have been done by the user already, so just make sure here. | ||
trained_model.eval() | ||
|
||
# FPTLIB-TODO | ||
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size. | ||
# trained_model_dummy_input = torch.ones((512, 42)) | ||
trained_model_dummy_input_u = torch.ones((512, 40), dtype=torch.float64) | ||
trained_model_dummy_input_l = torch.ones((512, 1), dtype=torch.float64) | ||
trained_model_dummy_input_p = torch.ones((512, 1), dtype=torch.float64) | ||
|
||
# Run model over dummy input | ||
# If something isn't working This will generate an error | ||
trained_model_dummy_output = trained_model( | ||
trained_model_dummy_input_u, | ||
trained_model_dummy_input_l, | ||
trained_model_dummy_input_p, | ||
) | ||
|
||
# FPTLIB-TODO | ||
# If you want to save for inference on GPU uncomment the following 4 lines: | ||
# device = torch.device('cuda') | ||
# model = model.to(device) | ||
# model.eval() | ||
# dummy_input = dummy_input.to(device) | ||
|
||
# FPTLIB-TODO | ||
# Set the name of the file you want to save the torchscript model to | ||
saved_ts_filename = "saved_model.pt" | ||
|
||
# FPTLIB-TODO | ||
# Save the pytorch model using either scripting (recommended where possible) or tracing | ||
# ----------- | ||
# Scripting | ||
# ----------- | ||
script_to_torchscript(trained_model, filename=saved_ts_filename) | ||
|
||
# ----------- | ||
# Tracing | ||
# ----------- | ||
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename) | ||
|
||
# Load torchscript and run model as a test | ||
testing_input_u = 2.0 * trained_model_dummy_input_u | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a line to why the input is scaled by a factor of 2 |
||
testing_input_l = 2.0 * trained_model_dummy_input_l | ||
testing_input_p = 2.0 * trained_model_dummy_input_p | ||
trained_model_testing_output = trained_model( | ||
testing_input_u, testing_input_l, testing_input_p | ||
) | ||
ts_model = load_torchscript(filename=saved_ts_filename) | ||
ts_model_output = ts_model(testing_input_u, testing_input_l, testing_input_p) | ||
|
||
if torch.all(ts_model_output.eq(trained_model_testing_output)): | ||
print("Saved TorchScript model working as expected in a basic test.") | ||
print("Users should perform further validation as appropriate.") | ||
else: | ||
raise RuntimeError( | ||
"Saved Torchscript model is not performing as expected.\n" | ||
"Consider using scripting if you used tracing, or investigate further." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding minimum required versions would be good? |
||
numpy |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
Contains all python commands MiMA will use. | ||
|
||
It needs in the same directory as `wavenet.py` which describes the | ||
model architecture, and `wavenet_weights.pkl` which contains the model weights. | ||
""" | ||
|
||
from torch import load, device, no_grad, reshape, zeros, tensor, float64 | ||
import wavenet as m | ||
|
||
|
||
# Initialize everything | ||
def initialize(path_weights_stats="wavenet_weights.pkl"): | ||
""" | ||
Initialize a WaveNet model and load weights. | ||
|
||
Parameters | ||
---------- | ||
path_weights_stats : str | ||
path to pickled object that contains weights and statistics (means and stds). | ||
|
||
""" | ||
device_str = "cpu" | ||
checkpoint = load(path_weights_stats, map_location=device(device_str)) | ||
model = m.WaveNet(checkpoint).to(device_str) | ||
|
||
# Load weights and set to evaluation mode. | ||
model.load_state_dict(checkpoint["weights"]) | ||
model.eval() | ||
return model | ||
|
||
|
||
# Compute drag | ||
def compute_reshape_drag(*args): | ||
""" | ||
Compute the drag from inputs using a neural net. | ||
|
||
Takes in input arguments passed from MiMA and outputs drag in shape desired by MiMA. | ||
Reshaping & porting to torch.tensor type, and applying model.forward is performed. | ||
|
||
Parameters | ||
---------- | ||
model : nn.Module | ||
WaveNet model ready to be deployed. | ||
wind : | ||
U or V (128, num_col, 40) | ||
lat : | ||
latitudes (num_col) | ||
p_surf : | ||
surface pressure (128, num_col) | ||
Y_out : | ||
output prellocated in MiMA (128, num_col, 40) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spelling of prelocated |
||
num_col : | ||
# of latitudes on this proc | ||
|
||
Returns | ||
------- | ||
Y_out : | ||
Results to be returned to MiMA | ||
""" | ||
model, wind, lat, p_surf, Y_out, num_col = args | ||
|
||
# Reshape and put all input variables together [wind, lat, p_surf] | ||
wind_T = tensor(wind) | ||
|
||
# lat_T = zeros((imax * num_col, 1), dtype=float64) | ||
lat_T = tensor(lat) | ||
|
||
# pressure_T = zeros((imax * num_col, 1), dtype=float64) | ||
pressure_T = tensor(p_surf) | ||
|
||
# Apply model. | ||
with no_grad(): | ||
# Ensure evaluation mode (leave training mode and stop using current batch stats) | ||
# model.eval() # Set during initialisation | ||
assert model.training is False | ||
temp = model(wind_T, lat_T, pressure_T) | ||
|
||
# Place in output array for MiMA. | ||
Y_out[:, :] = temp | ||
|
||
return Y_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
"""Module defining the pytorch WaveNet architecture for coupling to MiMA.""" | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class WaveNet(nn.Module): | ||
"""Neural network architecture following Espinosa et al. (2022).""" | ||
|
||
def __init__( | ||
self, | ||
checkpoint, | ||
n_in: int = 42, | ||
n_out: int = 40, | ||
branch_dims=None, | ||
) -> None: | ||
""" | ||
Initialize a WaveNet model. | ||
|
||
Parameters | ||
---------- | ||
checkpoint: dict | ||
dictionary containing weights & statistics. | ||
n_in : int | ||
Number of input features. | ||
n_out : int | ||
Number of output features. | ||
branch_dims : Union[list, None] | ||
List of dimensions of the layers to include in each of the level-specific branches. | ||
|
||
""" | ||
if branch_dims is None: | ||
branch_dims = [64, 32] | ||
|
||
super().__init__() | ||
|
||
shared = [nn.BatchNorm1d(n_in), nn.Linear(n_in, 256), nn.ReLU()] | ||
for _ in range(4): | ||
shared.extend([nn.Linear(256, 256)]) | ||
shared.extend([nn.ReLU()]) | ||
|
||
shared.extend([nn.Linear(256, branch_dims[0])]) | ||
shared.extend([nn.ReLU()]) | ||
|
||
# All data gets fed through shared, then extra layers defined in branches for each z-level | ||
branches = [] | ||
for _ in range(n_out): | ||
args: list[nn.Module] = [] | ||
for in_features, out_features in zip(branch_dims[:-1], branch_dims[1:]): | ||
args.extend([nn.Linear(in_features, out_features)]) | ||
args.extend([nn.ReLU()]) | ||
|
||
args.extend([nn.Linear(branch_dims[-1], 1)]) | ||
branches.append(nn.Sequential(*args)) | ||
|
||
self.shared = nn.Sequential(*shared) | ||
self.branches = nn.ModuleList(branches) | ||
|
||
self.shared.apply(_xavier_init) | ||
for branch in self.branches: | ||
branch.apply(_xavier_init) | ||
|
||
self.double() | ||
self.means = checkpoint["means"] | ||
self.stds = checkpoint["stds"] | ||
del checkpoint | ||
|
||
def forward( | ||
self, wind: torch.Tensor, lat: torch.Tensor, pressure: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Apply the network to a `Tensor` of input features. | ||
|
||
Parameters | ||
---------- | ||
wind : torch.Tensor | ||
Tensor of of input wind flattened to (n_lat*n_lon, 40). | ||
lat : torch.Tensor | ||
Tensor of of input features flattened to (n_lat*n_lon, 1). | ||
pressure : torch.Tensor | ||
Tensor of of input features flattened to (n_lat*n_lon, 1). | ||
|
||
Returns | ||
------- | ||
output : torch.Tensor | ||
Tensor of predicted outputs. | ||
|
||
""" | ||
Z, levels = self.shared(torch.cat((wind, lat, pressure), 1)), [] | ||
|
||
for branch in self.branches: | ||
levels.append(branch(Z).squeeze()) | ||
Y = torch.vstack(levels).T | ||
|
||
# Un-standardize | ||
Y *= self.stds | ||
Y += self.means | ||
return Y | ||
|
||
|
||
def _xavier_init(layer: nn.Module) -> None: | ||
""" | ||
Apply Xavier initialization to a layer if it is an `nn.Linear`. | ||
|
||
Parameters | ||
---------- | ||
layer : nn.Module | ||
Linear to potentially initialize. | ||
|
||
""" | ||
if isinstance(layer, nn.Linear): | ||
nn.init.xavier_uniform_(layer.weight) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Script to test WaveNet NN.""" | ||
|
||
import numpy as np | ||
import run_wavenet as rwn | ||
|
||
|
||
IMAX = 128 | ||
NUM_COL = 4 | ||
|
||
# Generate the four input tensors and populate with random data | ||
wind = np.random.randn(IMAX * NUM_COL, 40) | ||
lat = np.random.randn(IMAX * NUM_COL, 1) | ||
ps = np.random.randn(IMAX * NUM_COL, 1) | ||
Y_out = np.zeros((IMAX * NUM_COL, 40)) | ||
|
||
# Initialise and run the model | ||
model = rwn.initialize() | ||
Y_out = rwn.compute_reshape_drag(model, wind, lat, ps, Y_out, NUM_COL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could mention the methods included for torchscript conversion here (tracing and scripting), and that
scripting
is the recommended method