From 9aad148615b7519eadfa1a60356116a50561f192 Mon Sep 17 00:00:00 2001 From: Steven Liu Date: Thu, 27 Oct 2022 14:31:20 -0700 Subject: [PATCH] add mnist_rnn to test script for CI (#1086) * fix device mismatch issue #1071 * fix device mismatch issue #1071 * add mnist_rnn to test script for CI * support dry_run in test() --- mnist_rnn/main.py | 14 +++++++++++--- run_python_examples.sh | 6 ++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mnist_rnn/main.py b/mnist_rnn/main.py index 4778574b09..57e86cd57d 100644 --- a/mnist_rnn/main.py +++ b/mnist_rnn/main.py @@ -1,11 +1,13 @@ from __future__ import print_function + import argparse + import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms class Net(nn.Module): @@ -51,9 +53,11 @@ def train(args, model, device, train_loader, optimizer, epoch): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break -def test(model, device, test_loader): +def test(args, model, device, test_loader): model.eval() test_loss = 0 correct = 0 @@ -64,6 +68,8 @@ def test(model, device, test_loader): test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() + if args.dry_run: + break test_loss /= len(test_loader.dataset) @@ -87,6 +93,8 @@ def main(): help='learning rate step gamma (default: 0.7)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', @@ -121,7 +129,7 @@ def main(): scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) - test(model, device, test_loader) + test(args, model, device, test_loader) scheduler.step() if args.save_model: diff --git a/run_python_examples.sh b/run_python_examples.sh index 8f0777b27e..cc7c67b7ef 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -99,6 +99,11 @@ function mnist_hogwild() { python main.py --epochs 1 --dry-run $CUDA_FLAG || error "mnist hogwild failed" } +function mnist_rnn() { + start + python main.py --epochs 1 --dry-run || error "mnist rnn example failed" +} + function regression() { start python main.py --epochs 1 $CUDA_FLAG || error "regression failed" @@ -190,6 +195,7 @@ function run_all() { imagenet mnist mnist_hogwild + mnist_rnn regression reinforcement_learning siamese_network