Skip to content

Commit

Permalink
Fix linting on SLSA for models
Browse files Browse the repository at this point in the history
Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>
  • Loading branch information
mihaimaruseac committed Oct 24, 2023
1 parent c44908a commit 8b3d9cc
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 244 deletions.
36 changes: 18 additions & 18 deletions slsa_for_models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,29 @@


def readOptions():
parser = argparse.ArgumentParser('Train CIFAR10 models with TF/PT')
model_formats = list(tf.supported_models().keys())
model_formats += list(pt.supported_models().keys())
parser.add_argument('model', choices=model_formats,
help='Model to generate (name implies framework)')
return parser.parse_args()
parser = argparse.ArgumentParser('Train CIFAR10 models with TF/PT')
model_formats = list(tf.supported_models().keys())
model_formats += list(pt.supported_models().keys())
parser.add_argument('model', choices=model_formats,
help='Model to generate (name implies framework)')
return parser.parse_args()


def main(args):
model_formats = list(tf.supported_models().keys())
for model_format in model_formats:
if args.model == model_format:
return tf.model_pipeline(args.model)
model_formats = list(tf.supported_models().keys())
for model_format in model_formats:
if args.model == model_format:
return tf.model_pipeline(args.model)

model_formats = list(pt.supported_models().keys())
for model_format in model_formats:
if args.model == model_format:
return pt.model_pipeline(args.model)
model_formats = list(pt.supported_models().keys())
for model_format in model_formats:
if args.model == model_format:
return pt.model_pipeline(args.model)

# we should not reach this case in the normal flow, but cover all corners
raise ValueError("Model format not supported")
# we should not reach this case in the normal flow, but cover all corners
raise ValueError("Model format not supported")


if __name__ == '__main__':
args = readOptions()
main(args)
args = readOptions()
main(args)
248 changes: 124 additions & 124 deletions slsa_for_models/pytorch_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,159 +25,159 @@


def pretraining():
"""Perform setup required before training.
Does the lazy loading of TensorFlow too, to prevent compatibility issues with
mixing TensorFlow and PyTorch imports.
"""
global torch
global nn
global F
global optim
global torchvision
global transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
"""Perform setup required before training.
Does the lazy loading of TensorFlow too, to prevent compatibility issues
with mixing TensorFlow and PyTorch imports.
"""
global torch
global nn
global F
global optim
global torchvision
global transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


def load_data():
"""Load the CIFAR10 data.
"""Load the CIFAR10 data.
Obtains both the train and the test splits. According to
https://www.cs.toronto.edu/~kriz/cifar.html, there should be 50000 training
images and 10000 test ones. Each image is 32x32 RGB.
Obtains both the train and the test splits. According to
https://www.cs.toronto.edu/~kriz/cifar.html, there should be 50000 training
images and 10000 test ones. Each image is 32x32 RGB.
Data is normalized to be in range [-1, 1].
Data is normalized to be in range [-1, 1].
Returns iterators to train and test sets.
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Returns iterators to train and test sets.
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 4
num_workers = 2
batch_size = 4
num_workers = 2

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True,
num_workers=num_workers)

return trainloader, testloader
return trainloader, testloader


def create_model():
"""Create a Torch NN model.
The model is taken from the tutorial at
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html.
Returns the model.
"""
# Train a model based on tutorial from
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html.
# We inline the class to be able to use lazy loading of PyTorch modules.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)


def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

return MyModel()
"""Create a Torch NN model.
The model is taken from the tutorial at
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html.
Returns the model.
"""
# Train a model based on tutorial from
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html.
# We inline the class to be able to use lazy loading of PyTorch modules.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

return MyModel()


def prepare_model(model):
"""Prepare model for training with loss and optimizer."""
# We only need to return loss and optimizer
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
return loss, optimizer
"""Prepare model for training with loss and optimizer."""
# We only need to return loss and optimizer
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
return loss, optimizer


def train_model(model, loss, optimizer, train):
"""Train a model on the training set."""
num_epochs = 2
batch_size = 2000
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train, 1):
x, y = data
optimizer.zero_grad()
outputs = model(x)
loss_score = loss(outputs, y)
loss_score.backward()
optimizer.step()
running_loss += loss_score.item()
if i % batch_size == 0:
print(f'[{epoch}, {i:5d}], loss: {running_loss / batch_size :.3f}')
"""Train a model on the training set."""
num_epochs = 2
batch_size = 2000
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train, 1):
x, y = data
optimizer.zero_grad()
outputs = model(x)
loss_score = loss(outputs, y)
loss_score.backward()
optimizer.step()
running_loss += loss_score.item()
if i % batch_size == 0:
print(f'[{epoch}, {i:5d}], '
f'loss: {running_loss / batch_size :.3f}')
running_loss = 0.0


def score_model(model, test):
"""Score a trained model on the test set."""
correct = 0
total = 0
with torch.no_grad():
for data in test:
x, y = data
outputs = model(x)
_, predicted = torch.max(outputs.data, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
print(f'Test accuracy: {correct / total}')
"""Score a trained model on the test set."""
correct = 0
total = 0
with torch.no_grad():
for data in test:
x, y = data
outputs = model(x)
_, predicted = torch.max(outputs.data, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
print(f'Test accuracy: {correct / total}')


def supported_models():
"""Returns supported model types paired with method to save them."""
return {
'pytorch_model.pth': lambda m, p: torch.save(m.state_dict(), p),
'pytorch_full_model.pth': lambda m, p: torch.save(m, p),
'pytorch_jitted_model.pt': lambda m, p: torch.jit.script(m).save(p),
}
"""Returns supported model types paired with method to save them."""
return {
'pytorch_model.pth': lambda m, p: torch.save(m.state_dict(), p),
'pytorch_full_model.pth': lambda m, p: torch.save(m, p),
'pytorch_jitted_model.pt': lambda m, p: torch.jit.script(m).save(p),
}


def save_model(model, model_format):
"""Save the model after training to be transferred to production.
"""Save the model after training to be transferred to production.
Saves in the requested format, if supported by PyTorch.
"""
saver = supported_models().get(model_format, None)
if not saver:
raise ValueError('Requested a model format not supported by PyTorch')
saver(model, './' + model_format)
Saves in the requested format, if supported by PyTorch.
"""
saver = supported_models().get(model_format, None)
if not saver:
raise ValueError('Requested a model format not supported by PyTorch')
saver(model, './' + model_format)


def model_pipeline(model_format):
"""Train a model and save it in the requested format."""
pretraining()
data = load_data()
model = create_model()
loss, optimizer = prepare_model(model)
train_model(model, loss, optimizer, data[0])
score_model(model, data[1])
save_model(model, model_format)
"""Train a model and save it in the requested format."""
pretraining()
data = load_data()
model = create_model()
loss, optimizer = prepare_model(model)
train_model(model, loss, optimizer, data[0])
score_model(model, data[1])
save_model(model, model_format)
Loading

0 comments on commit 8b3d9cc

Please sign in to comment.