Skip to content

Commit

Permalink
torchvision tutorial: update deprecated pretrained=True to `weights…
Browse files Browse the repository at this point in the history
…="DEFAULT"` (#1998)
  • Loading branch information
YoniChechik authored Aug 10, 2022
1 parent d5f7a40 commit 770665f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions intermediate_source/torchvision_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ way of doing it:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
# replace the classifier with a new one, that has
# num_classes which is user-defined
Expand All @@ -242,7 +242,7 @@ way of doing it:
# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
Expand Down Expand Up @@ -291,7 +291,7 @@ be using Mask R-CNN:
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
Expand Down Expand Up @@ -344,7 +344,7 @@ expects during training and inference time on sample data.

.. code:: python
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
Expand Down

0 comments on commit 770665f

Please sign in to comment.