Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix mnist-pytorch example (#1596)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Oct 10, 2019
1 parent f60bf1d commit b869dd4
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/trials/mnist-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
https://github.com/pytorch/examples/blob/master/mnist/main.py
"""

import os
import argparse
import logging
import nni
Expand Down Expand Up @@ -84,15 +85,18 @@ def main(args):
device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

data_dir = os.path.join(args['data_dir'], nni.get_trial_id())

train_loader = torch.utils.data.DataLoader(
datasets.MNIST(args['data_dir'], train=True, download=True,
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(args['data_dir'], train=False, transform=transforms.Compose([
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
Expand Down

0 comments on commit b869dd4

Please sign in to comment.