Skip to content

Commit 0b7e62c

Browse files
[API compatibility] add dtype conversion method (#74416)
* add tensor.bool * add tensor.bool test * update tensor.bool test * update tensor.bool * delete useless test code * add dtype conversions method * skip the complex128 test on XPU * fix: Implement the byte function separately * fix: update unit testing
1 parent 11c7cdd commit 0b7e62c

File tree

3 files changed

+386
-0
lines changed

3 files changed

+386
-0
lines changed

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,36 @@
6565
_already_patch_eager_tensor = False
6666

6767

68+
_supported_dtype_conversions = {
69+
# float
70+
'float16': 'float16',
71+
'half': 'float16',
72+
'bfloat16': 'bfloat16',
73+
'float32': 'float32',
74+
'float': 'float32',
75+
'float64': 'float64',
76+
'double': 'float64',
77+
# int
78+
'int8': 'int8',
79+
'char': 'int8',
80+
# We handle uint8 conversion separately
81+
# 'uint8': 'uint8',
82+
# 'byte': 'uint8',
83+
'int16': 'int16',
84+
'short': 'int16',
85+
'int32': 'int32',
86+
'int': 'int32',
87+
'int64': 'int64',
88+
'long': 'int64',
89+
# other
90+
'bool': 'bool',
91+
'complex64': 'complex64',
92+
'complex128': 'complex128',
93+
'cfloat': 'complex64',
94+
'cdouble': 'complex128',
95+
}
96+
97+
6898
def monkey_patch_math_tensor():
6999
"""
70100
Similar to monkey_patch_variable.
@@ -104,6 +134,44 @@ def astype(self: Tensor, dtype: DTypeLike) -> Tensor:
104134

105135
return _C_ops.cast(self, dtype)
106136

137+
def byte(self: Tensor) -> Tensor:
138+
# since paddle don't support float to uint8, so we need to convert it to int8 first
139+
if self.is_floating_point():
140+
tensor = astype(self, 'int8')
141+
return astype(tensor, 'uint8')
142+
elif self.is_complex():
143+
real = astype(self.real(), 'int8')
144+
return astype(real, 'uint8')
145+
else:
146+
return astype(self, 'uint8')
147+
148+
def _create_dtype_conversion_methods():
149+
"""
150+
Batch create all data type conversion methods
151+
"""
152+
methods = []
153+
154+
for method_name, target_dtype in _supported_dtype_conversions.items():
155+
156+
def make_conversion_method(dtype):
157+
def conversion_method(self: Tensor) -> Tensor:
158+
return astype(self, dtype)
159+
160+
return conversion_method
161+
162+
method_impl = make_conversion_method(target_dtype)
163+
method_impl.__name__ = method_name
164+
method_impl.__doc__ = f"""
165+
Cast a Tensor to {target_dtype} data type if it differs from the current dtype;
166+
otherwise, return the original Tensor.
167+
Returns:
168+
Tensor: a new Tensor with {target_dtype} dtype
169+
"""
170+
171+
methods.append((method_name, method_impl))
172+
173+
return methods
174+
107175
def _scalar_elementwise_op_(
108176
var: Tensor, scale: float, bias: float
109177
) -> Tensor:
@@ -225,6 +293,8 @@ def _mT_(var: Tensor) -> Tensor:
225293
('__len__', _len_),
226294
('__index__', _index_),
227295
('astype', astype),
296+
('byte', byte),
297+
('uint8', byte),
228298
('dim', dim),
229299
('ndimension', ndimension),
230300
('ndim', _ndim),
@@ -235,6 +305,9 @@ def _mT_(var: Tensor) -> Tensor:
235305
('__array_ufunc__', None),
236306
]
237307

308+
dtype_conversion_methods = _create_dtype_conversion_methods()
309+
eager_methods.extend(dtype_conversion_methods)
310+
238311
eager_cpp_level_patch = [
239312
"__add__",
240313
"__radd__",

python/paddle/pir/math_op_patch.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,35 @@
3737
DataType.INT64,
3838
]
3939

40+
_supported_dtype_conversions = {
41+
# float
42+
'float16': 'float16',
43+
'half': 'float16',
44+
'bfloat16': 'bfloat16',
45+
'float32': 'float32',
46+
'float': 'float32',
47+
'float64': 'float64',
48+
'double': 'float64',
49+
# int
50+
'int8': 'int8',
51+
'char': 'int8',
52+
# We handle uint8 conversion separately
53+
# 'uint8': 'uint8',
54+
# 'byte': 'uint8',
55+
'int16': 'int16',
56+
'short': 'int16',
57+
'int32': 'int32',
58+
'int': 'int32',
59+
'int64': 'int64',
60+
'long': 'int64',
61+
# other
62+
'bool': 'bool',
63+
'complex64': 'complex64',
64+
'complex128': 'complex128',
65+
'cfloat': 'complex64',
66+
'cdouble': 'complex128',
67+
}
68+
4069
SUPPORT_PROMOTION_OPS = [
4170
"__add__",
4271
"__radd__",
@@ -370,6 +399,41 @@ def astype(self, dtype):
370399

371400
return _C_ops.cast(self, dtype)
372401

402+
def byte(self):
403+
# since paddle don't support float to uint8, so we need to convert it to int8 first
404+
if self.is_floating_point():
405+
tensor = astype(self, 'int8')
406+
return astype(tensor, 'uint8')
407+
elif self.is_complex():
408+
real = astype(self.real(), 'int8')
409+
return astype(real, 'uint8')
410+
else:
411+
return astype(self, 'uint8')
412+
413+
def _create_dtype_conversion_methods():
414+
"""
415+
Batch create all data type conversion methods
416+
"""
417+
methods = []
418+
for method_name, target_dtype in _supported_dtype_conversions.items():
419+
420+
def make_conversion_method(dtype):
421+
def conversion_method(self):
422+
return astype(self, dtype)
423+
424+
return conversion_method
425+
426+
method_impl = make_conversion_method(target_dtype)
427+
method_impl.__name__ = method_name
428+
method_impl.__doc__ = f"""
429+
Cast a Value to {target_dtype} data type if it differs from the current dtype;
430+
otherwise, return the original Value.
431+
Returns:
432+
Value: a new Value with {target_dtype} dtype
433+
"""
434+
methods.append((method_name, method_impl))
435+
return methods
436+
373437
def _scalar_add_(var, value):
374438
return paddle.scale(var, 1.0, value)
375439

@@ -1109,6 +1173,8 @@ def register_hook(self, hook):
11091173
('ndimension', ndimension),
11101174
('ndim', _ndim),
11111175
('astype', astype),
1176+
('byte', byte),
1177+
('uint8', byte),
11121178
('size', _size_),
11131179
('T', _T_),
11141180
('mT', _mT_),
@@ -1253,6 +1319,8 @@ def register_hook(self, hook):
12531319
('__bool__', _bool_),
12541320
('__complex__', _complex_),
12551321
]
1322+
dtype_conversion_methods = _create_dtype_conversion_methods()
1323+
value_methods.extend(dtype_conversion_methods)
12561324

12571325
global _already_patch_value
12581326
if not _already_patch_value:

0 commit comments

Comments
 (0)