Skip to content

Commit

Permalink
Update references to use the new Model Registration API (#6369)
Browse files Browse the repository at this point in the history
* Expose on Hub the public methods of the registration API

* Limit methods and update docs.

* Update references to use the new Model Registration API
  • Loading branch information
datumbox authored Aug 8, 2022
1 parent c72b284 commit 1d0786b
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def main(args):
)

print("Creating model")
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.to(device)

if args.distributed and args.sync_bn:
Expand Down
6 changes: 5 additions & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
prefix = "quantized_"
model_name = args.model
if not model_name.startswith(prefix):
model_name = prefix + model_name
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def main(args):
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
model = torchvision.models.get_model(
args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
)
model.to(device)
if args.distributed and args.sync_bn:
Expand Down
2 changes: 1 addition & 1 deletion references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def main(args):
else:
torch.backends.cudnn.benchmark = True

model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights)
model = torchvision.models.get_model(args.model, weights=args.weights)

if args.distributed:
model = model.to(args.local_rank)
Expand Down
8 changes: 6 additions & 2 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,12 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

model = torchvision.models.segmentation.__dict__[args.model](
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss
model = torchvision.models.get_model(
args.model,
weights=args.weights,
weights_backbone=args.weights_backbone,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
model.to(device)
if args.distributed:
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def main(args):
)

print("Creating model")
model = torchvision.models.video.__dict__[args.model](weights=args.weights)
model = torchvision.models.get_model(args.model, weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down

0 comments on commit 1d0786b

Please sign in to comment.