Skip to content

Commit 20cf5d2

Browse files
authored
support mixed precision training (#96)
1 parent 4512671 commit 20cf5d2

File tree

4 files changed

+107
-6
lines changed

4 files changed

+107
-6
lines changed

smdebug/tensorflow/keras.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_keras_optimizer,
4444
is_tf_version_2_3_x,
4545
is_tf_version_2x,
46+
supported_tf_variables,
4647
)
4748

4849
python_profiler = None
@@ -263,7 +264,7 @@ def _create_tensors_for_matching_collections(
263264
tensor_refs = []
264265
for coll in colls_with_tensor:
265266
if not tensor_refs:
266-
if isinstance(tensor, tf.Variable):
267+
if isinstance(tensor, supported_tf_variables()):
267268
tensor_refs.append(
268269
coll.add_variable(tensor, export_name=export_name, mode=mode)
269270
)
@@ -481,7 +482,7 @@ def save_gradients_from_logs(self, gradients):
481482
if isinstance(v, tf.Tensor):
482483
# Tensor.name is meaningless with eager execution
483484
layer_name = str(v.numpy(), "utf-8")
484-
elif isinstance(v, tf.Variable):
485+
elif isinstance(v, supported_tf_variables()):
485486
layer_name = v.name
486487
elif isinstance(v, bytes):
487488
layer_name = str(v, "utf-8")
@@ -861,6 +862,13 @@ def _save_layer_values(self, logs):
861862
for layer_name, layer_input, layer_output in logs:
862863
# Cast layer_name to str since it can also be of type bytes
863864
# when run with mirrored strategy
865+
if isinstance(layer_name, tf.Tensor):
866+
# Tensor.name is meaningless with eager execution
867+
layer_name = str(layer_name.numpy(), "utf-8")
868+
elif isinstance(layer_name, supported_tf_variables()):
869+
layer_name = layer_name.name
870+
elif isinstance(layer_name, bytes):
871+
layer_name = str(layer_name, "utf-8")
864872
if len(layer_input) == 1:
865873
# Layer Inputs are flattened and passed as a list into
866874
# the next layer. Unpacking it speeds up the _make_numpy fn.
@@ -1102,7 +1110,12 @@ def run(*args, **kwargs):
11021110
if (
11031111
(not grads or not vars)
11041112
or (not isinstance(grads, list) or not isinstance(vars, list))
1105-
or (not ((isinstance(vars[0], tf.Variable)) and hasattr(vars[0], "numpy")))
1113+
or (
1114+
not (
1115+
(isinstance(vars[0], supported_tf_variables()))
1116+
and hasattr(vars[0], "numpy")
1117+
)
1118+
)
11061119
or (not ((isinstance(grads[0], tf.Tensor)) and hasattr(grads[0], "numpy")))
11071120
):
11081121
return grads

smdebug/tensorflow/tensor_ref.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from smdebug.core.logger import get_logger
1010

1111
# Local
12-
from .utils import is_tf_version_2x
12+
from .utils import is_tf_version_2x, supported_tf_variables
1313

1414
logger = get_logger()
1515

1616

1717
def get_tf_names(arg):
18-
if isinstance(arg, tf.Variable):
18+
if isinstance(arg, supported_tf_variables()):
1919
tf_names = [arg.name]
2020
elif isinstance(arg, tf.Tensor):
2121
tf_names = [arg.name]
@@ -101,7 +101,11 @@ def from_variable(cls, variable, export_name=None, mode=None, original_tensor=No
101101
# for mirrored variable value this will be the mirrored variable
102102
original_tensor = variable
103103

104-
if is_tf_version_2x() and tf.executing_eagerly() and isinstance(variable, tf.Variable):
104+
if (
105+
is_tf_version_2x()
106+
and tf.executing_eagerly()
107+
and isinstance(variable, supported_tf_variables())
108+
):
105109
# In TF 2.X eager mode, TF throws an error if you try to access a tensor's name.
106110
# We need to pass it in as a variable, not a tensor, to maintain the name.
107111
tf_obj = variable

smdebug/tensorflow/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@
1313
from smdebug.core.modes import ModeKeys
1414

1515

16+
def does_tf_support_mixed_precision_training():
17+
# The Keras mixed precision API is first available in TensorFlow 2.1.0
18+
# See: https://www.tensorflow.org/guide/mixed_precision
19+
return version.parse(tf.__version__) >= version.parse("2.1.0")
20+
21+
22+
def supported_tf_variables():
23+
if does_tf_support_mixed_precision_training():
24+
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
25+
26+
return tf.Variable, autocast_variable.AutoCastVariable
27+
else:
28+
return tf.Variable
29+
30+
1631
class ModelOutput:
1732
LABELS = "smdebug_y"
1833
PREDICTIONS = "smdebug_y_pred"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Third Party
2+
import pytest
3+
import tensorflow as tf
4+
from tensorflow import keras
5+
from tensorflow.keras import layers
6+
7+
# First Party
8+
import smdebug.tensorflow as smd
9+
from smdebug.tensorflow.utils import does_tf_support_mixed_precision_training
10+
from smdebug.trials import create_trial
11+
12+
# Test Reference: https://github.com/tensorflow/docs/blob/master/site/en/guide/mixed_precision.ipynb
13+
14+
15+
@pytest.mark.skipif(
16+
does_tf_support_mixed_precision_training() is False,
17+
reason="The Keras mixed precision API is first available in TensorFlow 2.1.0",
18+
)
19+
def test_mixed_precision_training(out_dir):
20+
21+
from tensorflow.keras.mixed_precision import experimental as mixed_precision
22+
23+
hook = smd.KerasHook(out_dir=out_dir, save_all=True)
24+
policy = mixed_precision.Policy("mixed_float16")
25+
mixed_precision.set_policy(policy)
26+
27+
inputs = keras.Input(shape=(784,), name="digits")
28+
if tf.config.list_physical_devices("GPU"):
29+
# The model will run with 4096 units on a GPU
30+
num_units = 4096
31+
else:
32+
# Use fewer units on CPUs so the model finishes in a reasonable amount of time
33+
# The model will run with 64 units on a CPU
34+
num_units = 64
35+
dense1 = layers.Dense(num_units, activation="relu", name="dense_1")
36+
x = dense1(inputs)
37+
dense2 = layers.Dense(num_units, activation="relu", name="dense_2")
38+
x = dense2(x)
39+
40+
# CORRECT: softmax and model output are float32
41+
x = layers.Dense(10, name="dense_logits")(x)
42+
outputs = layers.Activation("softmax", dtype="float32", name="predictions")(x)
43+
44+
# The linear activation is an identity function. So this simply casts 'outputs'
45+
# to float32. In this particular case, 'outputs' is already float32 so this is a
46+
# no-op.
47+
outputs = layers.Activation("linear", dtype="float32")(outputs)
48+
49+
model = keras.Model(inputs=inputs, outputs=outputs)
50+
model.compile(
51+
loss="sparse_categorical_crossentropy",
52+
optimizer=keras.optimizers.RMSprop(),
53+
metrics=["accuracy"],
54+
)
55+
56+
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
57+
x_train = x_train.reshape(60000, 784).astype("float32") / 255
58+
x_test = x_test.reshape(10000, 784).astype("float32") / 255
59+
60+
initial_weights = model.get_weights()
61+
62+
hooks = [hook]
63+
history = model.fit(
64+
x_train, y_train, batch_size=8192, epochs=5, callbacks=hooks, validation_split=0.2
65+
)
66+
test_scores = model.evaluate(x_test, y_test, verbose=2)
67+
68+
trial = create_trial(out_dir)
69+
assert len(trial.tensor_names()) == 30

0 commit comments

Comments
 (0)