Skip to content

Commit 07635a9

Browse files
committed
fix issue when using cpu mode for testing
1 parent 8267567 commit 07635a9

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

joint_main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(config):
6161
parser.add_argument('--n_color', type=int, default=3)
6262
parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4
6363
parser.add_argument('--wd', type=float, default=0.0005) # Weight decay
64-
parser.add_argument('--cuda', type=bool, default=True)
64+
parser.add_argument('--no-cuda', dest='cuda', action='store_false')
6565

6666
# Training settings
6767
parser.add_argument('--arch', type=str, default='resnet') # resnet or vgg

joint_solver.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def __init__(self, train_loader, test_loader, config):
2525
self.build_model()
2626
if config.mode == 'test':
2727
print('Loading pre-trained model from %s...' % self.config.model)
28-
self.net.load_state_dict(torch.load(self.config.model))
28+
if self.config.cuda:
29+
self.net.load_state_dict(torch.load(self.config.model))
30+
else:
31+
self.net.load_state_dict(torch.load(self.config.model, map_location='cpu'))
2932
self.net.eval()
3033

3134
# print the network information and parameter numbers

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main(config):
5959
parser.add_argument('--n_color', type=int, default=3)
6060
parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4
6161
parser.add_argument('--wd', type=float, default=0.0005) # Weight decay
62-
parser.add_argument('--cuda', type=bool, default=True)
62+
parser.add_argument('--no-cuda', dest='cuda', action='store_false')
6363

6464
# Training settings
6565
parser.add_argument('--arch', type=str, default='resnet') # resnet or vgg

solver.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def __init__(self, train_loader, test_loader, config):
2525
self.build_model()
2626
if config.mode == 'test':
2727
print('Loading pre-trained model from %s...' % self.config.model)
28-
self.net.load_state_dict(torch.load(self.config.model))
28+
if self.config.cuda:
29+
self.net.load_state_dict(torch.load(self.config.model))
30+
else:
31+
self.net.load_state_dict(torch.load(self.config.model, map_location='cpu'))
2932
self.net.eval()
3033

3134
# print the network information and parameter numbers

0 commit comments

Comments
 (0)