Skip to content

Commit

Permalink
check torch version for mixed precision example (horovod#3238)
Browse files Browse the repository at this point in the history
chongxiaoc authored Oct 23, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent e3328ce commit 43eaaa3
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/pytorch/pytorch_mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
from distutils.version import LooseVersion
from filelock import FileLock

import torch.multiprocessing as mp
@@ -158,6 +159,10 @@ def test():
if args.use_mixed_precision:
raise ValueError("Mixed precision is only supported with cuda enabled.")

if (args.use_mixed_precision and LooseVersion(torch.__version__)
< LooseVersion('1.6.0')):
raise ValueError("""Mixed precision is using torch.cuda.amp.autocast(),
which requires torch >= 1.6.0""")

# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(1)

0 comments on commit 43eaaa3

Please sign in to comment.