Skip to content

Commit

Permalink
Update old pretrained TorchVision API in ao tutorials (pytorch#313)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#313

For TorchVision models, pretrained parameters have been deprecated in favor of "Multi-weight support API" - see https://pytorch.org/vision/0.15/models.html

Differential Revision: D58117114
  • Loading branch information
kit1980 authored and facebook-github-bot committed Jun 4, 2024
1 parent d75f450 commit 308af03
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions tutorials/quantize_vit/run_vit_b.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
import torchvision.models.vision_transformer as models

from torchao.utils import benchmark_model, profiler_runner
from torchvision import models

torch.set_float32_matmul_precision("high")
# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)
model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)
Expand Down
5 changes: 3 additions & 2 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import torchao
import torchvision.models.vision_transformer as models

from torchao.utils import benchmark_model, profiler_runner
from torchvision import models

torch.set_float32_matmul_precision("high")
# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)
model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)
Expand Down

0 comments on commit 308af03

Please sign in to comment.