diff --git a/python/oneflow/framework/tensor_str.py b/python/oneflow/framework/tensor_str.py index eaba9a96dbb..d4923df937a 100644 --- a/python/oneflow/framework/tensor_str.py +++ b/python/oneflow/framework/tensor_str.py @@ -339,6 +339,14 @@ def get_summarized_data(self): ) +def _format_tensor_on_cpu(tensor): + if tensor.is_global: + device = tensor.placement.type + else: + device = tensor.device.type + return device != "cpu" and device != "cuda" + + def _gen_tensor_str_template(tensor, is_meta): is_meta = is_meta or tensor.is_lazy prefix = "tensor(" @@ -349,10 +357,8 @@ def _gen_tensor_str_template(tensor, is_meta): if tensor.is_global: suffixes.append(f"placement={str(tensor.placement)}") suffixes.append(f"sbp={str(tensor.sbp)}") - elif tensor.device.type == "cuda": - suffixes.append("device='" + str(tensor.device) + "'") elif tensor.device.type != "cpu": - raise RunTimeError("unknow device type") + suffixes.append("device='" + str(tensor.device) + "'") if tensor.is_lazy: suffixes.append("is_lazy='True'") @@ -366,7 +372,10 @@ def _gen_tensor_str_template(tensor, is_meta): tensor_str = "..." suffixes.append("size=" + str(tuple(tensor.shape))) else: - tensor_str = _tensor_str(tensor, indent) + if _format_tensor_on_cpu(tensor): + tensor_str = _tensor_str(tensor.detach().to("cpu"), indent) + else: + tensor_str = _tensor_str(tensor, indent) suffixes.append("dtype=" + str(tensor.dtype)) if tensor.grad_fn is not None: