From 12cf2f17df40741ebd8eb610c2ab117c5ad64386 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Tue, 15 Oct 2019 15:21:37 +0100 Subject: [PATCH] Fix loading for .pth checkpoints --- examples/imagenet_val.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/imagenet_val.py b/examples/imagenet_val.py index 1643b1a26..b8036f6f9 100644 --- a/examples/imagenet_val.py +++ b/examples/imagenet_val.py @@ -51,7 +51,7 @@ def main(): # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) - model.load_state_dict(checkpoint['state_dict'], strict=False) + model.load_state_dict(checkpoint, strict=False) valdir = os.path.join(args.imagenet_dir, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])