Skip to content

Commit

Permalink
Format tensor on cpu (#8548)
Browse files Browse the repository at this point in the history
* Format tensor on cpu

* use tensor.detach
  • Loading branch information
liujuncheng authored Jul 4, 2022
1 parent 66027d0 commit d2e40b4
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/oneflow/framework/tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ def get_summarized_data(self):
)


def _format_tensor_on_cpu(tensor):
if tensor.is_global:
device = tensor.placement.type
else:
device = tensor.device.type
return device != "cpu" and device != "cuda"


def _gen_tensor_str_template(tensor, is_meta):
is_meta = is_meta or tensor.is_lazy
prefix = "tensor("
Expand All @@ -349,10 +357,8 @@ def _gen_tensor_str_template(tensor, is_meta):
if tensor.is_global:
suffixes.append(f"placement={str(tensor.placement)}")
suffixes.append(f"sbp={str(tensor.sbp)}")
elif tensor.device.type == "cuda":
suffixes.append("device='" + str(tensor.device) + "'")
elif tensor.device.type != "cpu":
raise RunTimeError("unknow device type")
suffixes.append("device='" + str(tensor.device) + "'")
if tensor.is_lazy:
suffixes.append("is_lazy='True'")

Expand All @@ -366,7 +372,10 @@ def _gen_tensor_str_template(tensor, is_meta):
tensor_str = "..."
suffixes.append("size=" + str(tuple(tensor.shape)))
else:
tensor_str = _tensor_str(tensor, indent)
if _format_tensor_on_cpu(tensor):
tensor_str = _tensor_str(tensor.detach().to("cpu"), indent)
else:
tensor_str = _tensor_str(tensor, indent)

suffixes.append("dtype=" + str(tensor.dtype))
if tensor.grad_fn is not None:
Expand Down

0 comments on commit d2e40b4

Please sign in to comment.