-
Notifications
You must be signed in to change notification settings - Fork 83
Zcc mxnet test #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Zcc mxnet test #31
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3502ce4
Adding integration test for zero code change.
leleamol bd0a5a5
Merge branch 'master' of https://github.com/awslabs/sagemaker-debugge…
leleamol 8a36d9b
Ran the pre-commit
leleamol f1ef582
Addressed the review comments
leleamol 468bd9c
Merge branch 'master' of https://github.com/awslabs/sagemaker-debugge…
leleamol b5ae41b
Merge branch 'master' of https://github.com/awslabs/sagemaker-debugge…
leleamol d6fa284
Added the assert and optional validation which should be used only wh…
leleamol bd6843d
updated the code to use tensor_names
leleamol a0f0278
Merge branch 'master' of https://github.com/awslabs/sagemaker-debugge…
leleamol 4910fbb
Updated the assert to compare the loss values.
leleamol File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Standard Library | ||
import argparse | ||
import random | ||
|
||
# Third Party | ||
import mxnet as mx | ||
import numpy as np | ||
from mxnet import autograd, gluon | ||
from mxnet.gluon import nn | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Train a mxnet gluon model for FashonMNIST dataset" | ||
) | ||
parser.add_argument("--batch-size", type=int, default=256, help="Batch size") | ||
parser.add_argument("--epochs", type=int, default=1, help="Number of Epochs") | ||
parser.add_argument("--learning_rate", type=float, default=0.1) | ||
parser.add_argument( | ||
"--context", type=str, default="cpu", help="Context can be either cpu or gpu" | ||
) | ||
parser.add_argument( | ||
"--validate", type=bool, default=True, help="Run validation if running with smdebug" | ||
) | ||
|
||
opt = parser.parse_args() | ||
return opt | ||
|
||
|
||
def test(ctx, net, val_data): | ||
metric = mx.metric.Accuracy() | ||
for i, (data, label) in enumerate(val_data): | ||
data = data.as_in_context(ctx) | ||
label = label.as_in_context(ctx) | ||
output = net(data) | ||
metric.update([label], [output]) | ||
|
||
return metric.get() | ||
|
||
|
||
def train_model(net, epochs, ctx, learning_rate, momentum, train_data, val_data): | ||
# Collect all parameters from net and its children, then initialize them. | ||
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) | ||
# Trainer is for updating parameters with gradient. | ||
trainer = gluon.Trainer( | ||
net.collect_params(), "sgd", {"learning_rate": learning_rate, "momentum": momentum} | ||
) | ||
metric = mx.metric.Accuracy() | ||
loss = gluon.loss.SoftmaxCrossEntropyLoss() | ||
|
||
for epoch in range(epochs): | ||
# reset data iterator and metric at begining of epoch. | ||
metric.reset() | ||
for i, (data, label) in enumerate(train_data): | ||
# Copy data to ctx if necessary | ||
data = data.as_in_context(ctx) | ||
label = label.as_in_context(ctx) | ||
# Start recording computation graph with record() section. | ||
# Recorded graphs can then be differentiated with backward. | ||
with autograd.record(): | ||
output = net(data) | ||
L = loss(output, label) | ||
L.backward() | ||
# take a gradient step with batch_size equal to data.shape[0] | ||
trainer.step(data.shape[0]) | ||
# update metric at last. | ||
metric.update([label], [output]) | ||
|
||
if i % 100 == 0 and i > 0: | ||
name, acc = metric.get() | ||
print("[Epoch %d Batch %d] Training: %s=%f" % (epoch, i, name, acc)) | ||
|
||
name, acc = metric.get() | ||
print("[Epoch %d] Training: %s=%f" % (epoch, name, acc)) | ||
name, val_acc = test(ctx, net, val_data) | ||
print("[Epoch %d] Validation: %s=%f" % (epoch, name, val_acc)) | ||
|
||
|
||
def transformer(data, label): | ||
data = data.reshape((-1,)).astype(np.float32) / 255 | ||
return data, label | ||
|
||
|
||
def prepare_data(batch_size): | ||
train_data = gluon.data.DataLoader( | ||
gluon.data.vision.MNIST("/tmp", train=True, transform=transformer), | ||
batch_size=batch_size, | ||
shuffle=True, | ||
last_batch="discard", | ||
) | ||
|
||
val_data = gluon.data.DataLoader( | ||
gluon.data.vision.MNIST("/tmp", train=False, transform=transformer), | ||
batch_size=batch_size, | ||
shuffle=False, | ||
) | ||
return train_data, val_data | ||
|
||
|
||
# Create a model using gluon API. The hook is currently | ||
# supports MXNet gluon models only. | ||
def create_gluon_model(): | ||
net = nn.Sequential() | ||
with net.name_scope(): | ||
net.add(nn.Dense(128, activation="relu")) | ||
net.add(nn.Dense(64, activation="relu")) | ||
net.add(nn.Dense(10)) | ||
return net | ||
|
||
|
||
def validate(): | ||
try: | ||
from smdebug.trials import create_trial | ||
from smdebug.mxnet import get_hook | ||
|
||
hook = get_hook() | ||
out_dir = hook.out_dir | ||
print("Created the trial with out_dir {0}".format(out_dir)) | ||
tr = create_trial(out_dir) | ||
global_steps = tr.steps() | ||
print("Global steps: " + str(global_steps)) | ||
|
||
loss_tensor_name = tr.tensor_names(regex="softmaxcrossentropyloss._output_.")[0] | ||
print("Obtained the loss tensor " + loss_tensor_name) | ||
assert loss_tensor_name == "softmaxcrossentropyloss0_output_0" | ||
|
||
mean_loss_tensor_value_first_step = tr.tensor(loss_tensor_name).reduction_value( | ||
step_num=global_steps[0], reduction_name="mean", abs=False | ||
) | ||
|
||
mean_loss_tensor_value_last_step = tr.tensor(loss_tensor_name).reduction_value( | ||
step_num=global_steps[-1], reduction_name="mean", abs=False | ||
) | ||
|
||
print("Mean validation loss first step = " + str(mean_loss_tensor_value_first_step)) | ||
print("Mean validation loss last step = " + str(mean_loss_tensor_value_last_step)) | ||
assert mean_loss_tensor_value_first_step >= mean_loss_tensor_value_last_step | ||
|
||
except ImportError: | ||
print("smdebug libraries do not exist. Skipped Validation.") | ||
|
||
print("Validation Complete") | ||
|
||
|
||
def main(): | ||
opt = parse_args() | ||
mx.random.seed(128) | ||
random.seed(12) | ||
np.random.seed(2) | ||
|
||
context = mx.cpu() if opt.context.lower() == "cpu" else mx.gpu() | ||
# Create a Gluon Model. | ||
net = create_gluon_model() | ||
|
||
# Start the training. | ||
train_data, val_data = prepare_data(opt.batch_size) | ||
|
||
train_model( | ||
net=net, | ||
epochs=opt.epochs, | ||
ctx=context, | ||
learning_rate=opt.learning_rate, | ||
momentum=0.9, | ||
train_data=train_data, | ||
val_data=val_data, | ||
) | ||
if opt.validate: | ||
validate() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.