From 5d16bca0e2bd0e18bd907930de4e0d928f9a85e4 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Thu, 18 Mar 2021 15:22:30 -0700 Subject: [PATCH] Change default checkpoint saving dtype to float32 instead of bfloat16. Saving PiperOrigin-RevId: 363758862 --- mesh_tensorflow/transformer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index ee929080..d1cd220f 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -87,7 +87,7 @@ def parse_gin_defaults_and_flags(skip_unknown=False, finalize_config=True): # this stupid VariableDtype class and stop passing it all over creation. @gin.configurable def get_variable_dtype( - master_dtype=tf.bfloat16, + master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32): """Datatypes to use for the run.