diff --git a/references/classification/train.py b/references/classification/train.py index 96703bfdf85..14360b042ed 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -221,7 +221,7 @@ def main(args): ) print("Creating model") - model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index a66a47f8674..ed36e13a028 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -46,7 +46,11 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + prefix = "quantized_" + model_name = args.model + if not model_name.startswith(prefix): + model_name = prefix + model_name + model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): diff --git a/references/detection/train.py b/references/detection/train.py index 178f7460417..dea483c5f75 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -216,8 +216,8 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - model = torchvision.models.detection.__dict__[args.model]( - weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs + model = torchvision.models.get_model( + args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs ) model.to(device) if args.distributed and args.sync_bn: diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 0327d92bdf9..be6ffe4ccef 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -215,7 +215,7 @@ def main(args): else: torch.backends.cudnn.benchmark = True - model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) + model = torchvision.models.get_model(args.model, weights=args.weights) if args.distributed: model = model.to(args.local_rank) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 95dfedb5e9a..0169a6ab43e 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -156,8 +156,12 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - model = torchvision.models.segmentation.__dict__[args.model]( - weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss + model = torchvision.models.get_model( + args.model, + weights=args.weights, + weights_backbone=args.weights_backbone, + num_classes=num_classes, + aux_loss=args.aux_loss, ) model.to(device) if args.distributed: diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 4da8331a1c6..9dff282d4f1 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -246,7 +246,7 @@ def main(args): ) print("Creating model") - model = torchvision.models.video.__dict__[args.model](weights=args.weights) + model = torchvision.models.get_model(args.model, weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)