diff --git a/mesh_tensorflow/tpu_variables.py b/mesh_tensorflow/tpu_variables.py index e5b1ba2d..01ee6d3f 100644 --- a/mesh_tensorflow/tpu_variables.py +++ b/mesh_tensorflow/tpu_variables.py @@ -24,7 +24,7 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_conversion_registry +from tensorflow.python.framework import tensor_conversion from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -224,7 +224,7 @@ def _tensor_conversion(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access -tensor_conversion_registry.register_tensor_conversion_function( +tensor_conversion.register_tensor_conversion_function( ReplicatedVariable, _tensor_conversion) if not TF_23: