From b1dada9724e01584fb71906a6feaeb8f25d83c03 Mon Sep 17 00:00:00 2001 From: Denis Davydenko Date: Tue, 24 Dec 2019 17:45:23 -0800 Subject: [PATCH] Fix for MXNet integ test --- .../mxnet_gluon_integration_test.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/zero_code_change/mxnet_gluon_integration_test.py b/tests/zero_code_change/mxnet_gluon_integration_test.py index 8ec312cc1..7fbe11b0c 100644 --- a/tests/zero_code_change/mxnet_gluon_integration_test.py +++ b/tests/zero_code_change/mxnet_gluon_integration_test.py @@ -8,6 +8,9 @@ from mxnet import autograd, gluon from mxnet.gluon import nn +# First Party +from smdebug.core.utils import SagemakerSimulator + def parse_args(): parser = argparse.ArgumentParser( @@ -155,17 +158,32 @@ def main(): # Start the training. train_data, val_data = prepare_data(opt.batch_size) - train_model( - net=net, - epochs=opt.epochs, - ctx=context, - learning_rate=opt.learning_rate, - momentum=0.9, - train_data=train_data, - val_data=val_data, - ) - if opt.validate: - validate() + json_file_contents = """ + { + "S3OutputPath": "s3://sagemaker-test", + "LocalPath": "/tmp/mxnet_integ_test", + "CollectionConfigurations": [ + { + "CollectionName": "losses", + "CollectionParameters": { + "save_interval": 100 + } + } + ] + } + """ + with SagemakerSimulator(json_file_contents=json_file_contents) as sim: + train_model( + net=net, + epochs=opt.epochs, + ctx=context, + learning_rate=opt.learning_rate, + momentum=0.9, + train_data=train_data, + val_data=val_data, + ) + if opt.validate: + validate() if __name__ == "__main__":