Skip to content

Commit

Permalink
fix special case
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Jan 25, 2021
1 parent c940239 commit 7e656f3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
38 changes: 16 additions & 22 deletions python/paddle/fluid/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,25 @@
from .framework import _cpu_num, _cuda_ids
__all__ = ['DataFeeder']

_PADDLE_DTYPE_2_NUMPY_DTYPE = {
core.VarDesc.VarType.BOOL: 'bool',
core.VarDesc.VarType.FP16: 'float16',
core.VarDesc.VarType.FP32: 'float32',
core.VarDesc.VarType.FP64: 'float64',
core.VarDesc.VarType.INT8: 'int8',
core.VarDesc.VarType.INT16: 'int16',
core.VarDesc.VarType.INT32: 'int32',
core.VarDesc.VarType.INT64: 'int64',
core.VarDesc.VarType.UINT8: 'uint8',
core.VarDesc.VarType.COMPLEX64: 'complex64',
core.VarDesc.VarType.COMPLEX128: 'complex128',
}


def convert_dtype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.INT8:
return 'int8'
elif dtype == core.VarDesc.VarType.INT16:
return 'int16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
elif dtype == core.VarDesc.VarType.COMPLEX64:
return 'complex64'
elif dtype == core.VarDesc.VarType.COMPLEX128:
return 'complex128'
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
Expand Down
18 changes: 12 additions & 6 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE


def monkey_patch_varbase():
Expand Down Expand Up @@ -320,13 +320,19 @@ def __bool__(self):
("__name__", "Tensor")):
setattr(core.VarBase, method_name, method)

def dtype_str(dtype):
prefix = 'paddle.'
return prefix + convert_dtype(dtype)

# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
# So, we need to overwrite it to custom one.
# So, we need to overwrite it to a more readable one.
# See details in https://github.com/pybind/pybind11/issues/2537.
origin = getattr(core.VarDesc.VarType, "__repr__")

def dtype_str(dtype):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
prefix = 'paddle.'
return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
else:
# for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
return origin(dtype)

setattr(core.VarDesc.VarType, "__repr__", dtype_str)

# patch math methods for varbase
Expand Down

1 comment on commit 7e656f3

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.