From 43eaaa352d7109616672f80299052fa90ea08f3c Mon Sep 17 00:00:00 2001 From: chongxiaoc <74630762+chongxiaoc@users.noreply.github.com> Date: Sat, 23 Oct 2021 13:28:50 -0700 Subject: [PATCH] check torch version for mixed precision example (#3238) --- examples/pytorch/pytorch_mnist.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/pytorch/pytorch_mnist.py b/examples/pytorch/pytorch_mnist.py index ae31a2f426..0cece73bd1 100644 --- a/examples/pytorch/pytorch_mnist.py +++ b/examples/pytorch/pytorch_mnist.py @@ -1,5 +1,6 @@ import argparse import os +from distutils.version import LooseVersion from filelock import FileLock import torch.multiprocessing as mp @@ -158,6 +159,10 @@ def test(): if args.use_mixed_precision: raise ValueError("Mixed precision is only supported with cuda enabled.") + if (args.use_mixed_precision and LooseVersion(torch.__version__) + < LooseVersion('1.6.0')): + raise ValueError("""Mixed precision is using torch.cuda.amp.autocast(), + which requires torch >= 1.6.0""") # Horovod: limit # of CPU threads to be used per worker. torch.set_num_threads(1)