Skip to content

Commit

Permalink
add mnist_rnn to test script for CI (pytorch#1086)
Browse files Browse the repository at this point in the history
* fix device mismatch issue pytorch#1071

* fix device mismatch issue pytorch#1071

* add mnist_rnn to test script for CI

* support dry_run in test()
  • Loading branch information
hudeven authored Oct 27, 2022
1 parent 5d4b584 commit 9aad148
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 11 additions & 3 deletions mnist_rnn/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import print_function

import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms


class Net(nn.Module):
Expand Down Expand Up @@ -51,9 +53,11 @@ def train(args, model, device, train_loader, optimizer, epoch):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break


def test(model, device, test_loader):
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
Expand All @@ -64,6 +68,8 @@ def test(model, device, test_loader):
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
if args.dry_run:
break

test_loss /= len(test_loader.dataset)

Expand All @@ -87,6 +93,8 @@ def main():
help='learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
Expand Down Expand Up @@ -121,7 +129,7 @@ def main():
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
test(args, model, device, test_loader)
scheduler.step()

if args.save_model:
Expand Down
6 changes: 6 additions & 0 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ function mnist_hogwild() {
python main.py --epochs 1 --dry-run $CUDA_FLAG || error "mnist hogwild failed"
}

function mnist_rnn() {
start
python main.py --epochs 1 --dry-run || error "mnist rnn example failed"
}

function regression() {
start
python main.py --epochs 1 $CUDA_FLAG || error "regression failed"
Expand Down Expand Up @@ -190,6 +195,7 @@ function run_all() {
imagenet
mnist
mnist_hogwild
mnist_rnn
regression
reinforcement_learning
siamese_network
Expand Down

0 comments on commit 9aad148

Please sign in to comment.