From 770665f504accac0e7244abd39f5066f9cdc77ab Mon Sep 17 00:00:00 2001 From: Yoni Chechik Date: Wed, 10 Aug 2022 20:11:22 +0300 Subject: [PATCH] torchvision tutorial: update deprecated `pretrained=True` to `weights="DEFAULT"` (#1998) --- intermediate_source/torchvision_tutorial.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/intermediate_source/torchvision_tutorial.rst b/intermediate_source/torchvision_tutorial.rst index d58d88daec..9e3d1b9655 100644 --- a/intermediate_source/torchvision_tutorial.rst +++ b/intermediate_source/torchvision_tutorial.rst @@ -221,7 +221,7 @@ way of doing it: from torchvision.models.detection.faster_rcnn import FastRCNNPredictor # load a model pre-trained on COCO - model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT") # replace the classifier with a new one, that has # num_classes which is user-defined @@ -242,7 +242,7 @@ way of doing it: # load a pre-trained model for classification and return # only the features - backbone = torchvision.models.mobilenet_v2(pretrained=True).features + backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features # FasterRCNN needs to know the number of # output channels in a backbone. For mobilenet_v2, it's 1280 # so we need to add it here @@ -291,7 +291,7 @@ be using Mask R-CNN: def get_model_instance_segmentation(num_classes): # load an instance segmentation model pre-trained on COCO - model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features @@ -344,7 +344,7 @@ expects during training and inference time on sample data. .. code:: python - model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT") dataset = PennFudanDataset('PennFudanPed', get_transform(train=True)) data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=True, num_workers=4,