Skip to content

Commit dc12f68

Browse files
NihalHarishsophiayue1116
authored andcommitted
Cache TF Versions (awslabs#421)
1 parent fef3357 commit dc12f68

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

smdebug/pytorch/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
88
from smdebug.core.reductions import get_numpy_reduction
99

10+
# Cached Pytorch Version
11+
PT_VERSION = version.parse(torch.__version__)
12+
1013

1114
def get_reduction_of_data(reduction_name, tensor_data, tensor_name, abs=False):
1215
if isinstance(tensor_data, np.ndarray):
@@ -42,20 +45,20 @@ def is_pt_1_5():
4245
Determine whether the version of torch is 1.5.x
4346
:return: bool
4447
"""
45-
return version.parse("1.5.0") <= version.parse(torch.__version__) < version.parse("1.6.0")
48+
return version.parse("1.5.0") <= PT_VERSION < version.parse("1.6.0")
4649

4750

4851
def is_pt_1_6():
4952
"""
5053
Determine whether the version of torch is 1.6.x
5154
:return: bool
5255
"""
53-
return version.parse("1.6.0") <= version.parse(torch.__version__) < version.parse("1.7.0")
56+
return version.parse("1.6.0") <= PT_VERSION < version.parse("1.7.0")
5457

5558

5659
def is_pt_1_7():
5760
"""
5861
Determine whether the version of torch is 1.7.x
5962
:return: bool
6063
"""
61-
return version.parse("1.7.0") <= version.parse(torch.__version__) < version.parse("1.8.0")
64+
return version.parse("1.7.0") <= PT_VERSION < version.parse("1.8.0")

smdebug/tensorflow/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .base_hook import TensorflowBaseHook
1414
from .tensor_ref import TensorType
1515
from .utils import (
16+
TF_VERSION,
1617
TFDistributionStrategy,
1718
build_fetches_tuple,
1819
extract_graph_summary,
@@ -241,7 +242,7 @@ def _is_not_supported(self):
241242
if self.distribution_strategy == TFDistributionStrategy.MIRRORED:
242243
from packaging import version
243244

244-
if version.parse(tf.__version__) < version.parse("1.14.0"):
245+
if TF_VERSION < version.parse("1.14.0"):
245246
self._hook_supported = False
246247
# in tf 1.13, we can't support mirrored strategy as
247248
# MirroredVariable does not have _values attribute

smdebug/tensorflow/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# First Party
1414
from smdebug.core.modes import ModeKeys
1515

16+
# Cached TF Version
17+
TF_VERSION = version.parse(tf.__version__)
18+
1619

1720
def does_tf_support_mixed_precision_training():
1821
# The Keras mixed precision API is first available in TensorFlow 2.1.0
1922
# See: https://www.tensorflow.org/guide/mixed_precision
20-
return version.parse(tf.__version__) >= version.parse("2.1.0")
23+
return TF_VERSION >= version.parse("2.1.0")
2124

2225

2326
def supported_tf_variables():
@@ -405,23 +408,23 @@ def get_keras_mode(mode):
405408

406409

407410
def is_tf_version_2x():
408-
return version.parse(tf.__version__) >= version.parse("2.0.0")
411+
return TF_VERSION >= version.parse("2.0.0")
409412

410413

411414
def is_tf_version_2_2_x():
412-
return version.parse("2.2.0") <= version.parse(tf.__version__) < version.parse("2.3.0")
415+
return version.parse("2.2.0") <= TF_VERSION < version.parse("2.3.0")
413416

414417

415418
def is_tf_version_2_3_x():
416-
return version.parse("2.3.0") <= version.parse(tf.__version__) < version.parse("2.4.0")
419+
return version.parse("2.3.0") <= TF_VERSION < version.parse("2.4.0")
417420

418421

419422
def is_tf_version_2_4_x():
420-
return version.parse("2.4.0") <= version.parse(tf.__version__) < version.parse("2.5.0")
423+
return version.parse("2.4.0") <= TF_VERSION < version.parse("2.5.0")
421424

422425

423426
def is_tf_version_greater_than_2_4_x():
424-
return version.parse("2.4.0") <= version.parse(tf.__version__)
427+
return version.parse("2.4.0") <= TF_VERSION
425428

426429

427430
def is_profiler_supported_for_tf_version():

tests/tensorflow2/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import tensorflow.compat.v2 as tf
44
from packaging import version
55

6+
# Cached TF Version
7+
TF_VERSION = version.parse(tf.__version__)
8+
69

710
def is_tf_2_2():
811
"""
@@ -12,12 +15,12 @@ def is_tf_2_2():
1215
number of tensor_names emitted by 1.
1316
:return: bool
1417
"""
15-
if version.parse(tf.__version__) >= version.parse("2.2.0"):
18+
if TF_VERSION >= version.parse("2.2.0"):
1619
return True
1720
return False
1821

1922

2023
def is_tf_2_3():
21-
if version.parse(tf.__version__) == version.parse("2.3.0"):
24+
if TF_VERSION == version.parse("2.3.0"):
2225
return True
2326
return False

0 commit comments

Comments
 (0)