From 90517e53c716632188de2d8978b158f95ac9102e Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Thu, 22 Oct 2020 20:52:30 -0700 Subject: [PATCH] pytorch tmp --- tests/zero_code_change/pt_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/zero_code_change/pt_utils.py b/tests/zero_code_change/pt_utils.py index 3d6cb78de..9bb1c073e 100644 --- a/tests/zero_code_change/pt_utils.py +++ b/tests/zero_code_change/pt_utils.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import torchvision import torchvision.transforms as transforms +from packaging import version def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: @@ -14,15 +15,26 @@ def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Dat [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) + # Temporary Change to allow the test to run with pytorch 1.7 RC3 + # Smdebug breaks when num_workers>0 for Pytorch 1.7.0 + if version.parse(torch.__version__) >= version.parse("1.7.0"): + num_workers = 0 + else: + num_workers = 2 + trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) - trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=4, shuffle=True, num_workers=num_workers + ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) - testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) + testloader = torch.utils.data.DataLoader( + testset, batch_size=4, shuffle=False, num_workers=num_workers + ) classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") return trainloader, testloader