Skip to content

Commit d166bd0

Browse files
authored
#fix train deepseek-distill-qwen-1.5b on bfloat16 causes error by np1.26 do not support bfloat16 (#2029)
1 parent 7a0885f commit d166bd0

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

mindnlp/core/ops/other.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .reduction import any
1111
from .comparison import eq
1212

13+
from mindspore._c_expression import typing
14+
from mindspore._c_expression.typing import Type
15+
1316
# atleast_2d
1417

1518

@@ -698,8 +701,89 @@ def masked_fill(input, mask, value):
698701
masked_fill_ = _get_cache_prim(ops.MaskedFill)()
699702
return masked_fill_(input, mask, mindspore.tensor(value, dtype=input.dtype))
700703

701-
def finfo(dtype):
702-
return np.finfo(mindspore.dtype_to_nptype(dtype))
704+
def is_complex(weight):
705+
return weight.dtype in (mindspore.complex64, mindspore.complex128)
706+
707+
dtype = Type
708+
float16 = typing.kFloat16
709+
float32 = typing.kFloat32
710+
bfloat16 = typing.kBFloat16
711+
712+
bits_map = {
713+
714+
}
715+
716+
min_map = {
717+
float32: -3.40282e+38,
718+
float16: -65504,
719+
bfloat16: -3.38953e+38
720+
}
721+
722+
max_map = {
723+
float32: 3.40282e+38,
724+
float16: 65504,
725+
bfloat16: 3.38953e+38
726+
}
727+
728+
eps_map = {
729+
float32: 1.19209e-07,
730+
float16: 0.000976562,
731+
bfloat16: 0.0078125
732+
}
733+
734+
tiny_map = {
735+
float32: 1.17549e-38,
736+
float16: 6.10352e-05,
737+
bfloat16: 1.17549e-38
738+
}
739+
740+
smallest_normal_map = {
741+
float32: 1.17549e-38,
742+
float16: 6.10352e-05,
743+
bfloat16: 1.17549e-38
744+
}
745+
746+
resolution_map = {
747+
float32: 1e-06,
748+
float16: 0.001,
749+
bfloat16: 0.01
750+
}
751+
752+
class finfo:
753+
def __init__(self, dtype):
754+
self._dtype = dtype
755+
756+
@property
757+
def bits(self):
758+
return bits_map[self._dtype]
759+
760+
@property
761+
def min(self):
762+
return min_map[self._dtype]
763+
764+
@property
765+
def max(self):
766+
return max_map[self._dtype]
767+
768+
@property
769+
def eps(self):
770+
return eps_map[self._dtype]
771+
772+
@property
773+
def tiny(self):
774+
return tiny_map[self._dtype]
775+
776+
@property
777+
def smallest_normal(self):
778+
return smallest_normal_map[self._dtype]
779+
780+
@property
781+
def resolution(self):
782+
return resolution_map[self._dtype]
783+
784+
@property
785+
def dtype(self):
786+
return str(self._dtype)
703787

704788
def iinfo(dtype):
705789
return np.iinfo(mindspore.dtype_to_nptype(dtype))

0 commit comments

Comments
 (0)