From 338dda954d12e76a094d5724533680ddcaa47d7a Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 1 Mar 2023 06:10:34 -0800 Subject: [PATCH] Apply cuda fix --- test/smoke_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/smoke_test.py b/test/smoke_test.py index 3701f661add..31af998487e 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -4,6 +4,7 @@ from pathlib import Path import torch +import torch.nn as nn import torchvision from torchvision.io import read_image from torchvision.models import resnet50, ResNet50_Weights @@ -27,10 +28,9 @@ def smoke_test_torchvision_read_decode() -> None: raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") def smoke_test_compile() -> None: - import torch.nn as nn model = resnet50().cuda() model = torch.compile(model) - x = torch.randn(1, 3, 224, 224).cuda() + x = torch.randn(1, 3, 224, 224, device="cuda") out = model(x) print(out.shape) @@ -66,12 +66,10 @@ def main() -> None: smoke_test_torchvision_resnet50_classify() if torch.cuda.is_available(): smoke_test_torchvision_resnet50_classify("cuda") -<<<<<<< HEAD + smoke_test_compile() if torch.backends.mps.is_available(): smoke_test_torchvision_resnet50_classify("mps") -======= - smoke_test_compile() ->>>>>>> 2b8667d9a4 (Add smoke test Using a simple RN50 with torch.compile) + if __name__ == "__main__":