Skip to content

Commit

Permalink
Apply cuda fix
Browse files Browse the repository at this point in the history
  • Loading branch information
atalman committed Mar 1, 2023
1 parent ee468d6 commit 338dda9
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import torch
import torch.nn as nn
import torchvision
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
Expand All @@ -27,10 +28,9 @@ def smoke_test_torchvision_read_decode() -> None:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")

def smoke_test_compile() -> None:
import torch.nn as nn
model = resnet50().cuda()
model = torch.compile(model)
x = torch.randn(1, 3, 224, 224).cuda()
x = torch.randn(1, 3, 224, 224, device="cuda")
out = model(x)
print(out.shape)

Expand Down Expand Up @@ -66,12 +66,10 @@ def main() -> None:
smoke_test_torchvision_resnet50_classify()
if torch.cuda.is_available():
smoke_test_torchvision_resnet50_classify("cuda")
<<<<<<< HEAD
smoke_test_compile()
if torch.backends.mps.is_available():
smoke_test_torchvision_resnet50_classify("mps")
=======
smoke_test_compile()
>>>>>>> 2b8667d9a4 (Add smoke test Using a simple RN50 with torch.compile)



if __name__ == "__main__":
Expand Down

0 comments on commit 338dda9

Please sign in to comment.