diff --git a/test/smoke_test.py b/test/smoke_test.py index e8ee178d95e..9ffc9117773 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -4,6 +4,7 @@ from pathlib import Path import torch +import torch.nn as nn import torchvision from torchvision.io import read_image from torchvision.models import resnet50, ResNet50_Weights @@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None: if img_png.ndim != 3 or img_png.numel() < 100: raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") +def smoke_test_compile() -> None: + model = resnet50().cuda() + model = torch.compile(model) + x = torch.randn(1, 3, 224, 224, device="cuda") + out = model(x) + print(f"torch.compile model output: {out.shape}") def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) @@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: def main() -> None: print(f"torchvision: {torchvision.__version__}") + print(f"torch.cuda.is_available: {torch.cuda.is_available()}") smoke_test_torchvision() smoke_test_torchvision_read_decode() smoke_test_torchvision_resnet50_classify() if torch.cuda.is_available(): smoke_test_torchvision_resnet50_classify("cuda") + smoke_test_compile() + if torch.backends.mps.is_available(): smoke_test_torchvision_resnet50_classify("mps") + if __name__ == "__main__": main()