From 4e07d5e7186626dbc56f5a6d63c5dc259f9eb9d8 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Sun, 21 Mar 2021 13:51:06 -0700 Subject: [PATCH] Change default checkpoint saving dtype to float32 instead of bfloat16. Saving PiperOrigin-RevId: 364201025 --- mesh_tensorflow/transformer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index d1cd220f..ee929080 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -87,7 +87,7 @@ def parse_gin_defaults_and_flags(skip_unknown=False, finalize_config=True): # this stupid VariableDtype class and stop passing it all over creation. @gin.configurable def get_variable_dtype( - master_dtype=tf.float32, + master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.float32): """Datatypes to use for the run.