Skip to content

Commit

Permalink
add floatimm range check for fp16 and fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Aug 24, 2022
1 parent 5e5d84d commit bcf0ac2
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 25 deletions.
7 changes: 6 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (static_cast<int64_t>(value) < 0) {
if (value < static_cast<ValueType>(0)) {
LOG(FATAL) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
Expand All @@ -934,6 +934,11 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return PrimExpr();
}

template <>
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
return MakeConstScalar(t, static_cast<int>(value), span);
}

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ def _scalar_type_inference(value):
elif isinstance(value, bool):
dtype = "bool"
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = "float32"
# We intentionally prefer convert the float to float32 since it's more common in DL.
if -3.40282347e38 <= value <= 3.40282347e38:
dtype = "float32"
else:
dtype = "float64"
elif isinstance(value, int):
# We intentionally prefer convert the python int to int32 since it's more common in DL.
if -2147483648 <= value < 2147483648:
if -2147483648 <= value <= 2147483647:
dtype = "int32"
else:
dtype = "int64"
Expand Down
17 changes: 17 additions & 0 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>

#include "../support/scalars.h"

namespace tvm {

PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {}
Expand Down Expand Up @@ -116,6 +118,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

// check range for float32 and float16 since they have specified range.
if (!std::isinf(value) && !std::isnan(value)) {
if (dtype.bits() == 32) {
ICHECK_GE(value, std::numeric_limits<float>::lowest())
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, std::numeric_limits<float>::max())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
} else if (dtype.is_float16()) {
ICHECK_GE(value, -support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
Expand Down
4 changes: 0 additions & 4 deletions src/support/scalars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ IntImm ValueToIntImm(int64_t value, int width) {
}
}

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

FloatImm ValueToFloatImm(double value, int width) {
if (width == 16) {
if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) {
Expand Down
4 changes: 4 additions & 0 deletions src/support/scalars.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ std::string FloatImmToString(const FloatImm& float_imm);
IntImm ValueToIntImm(int64_t value, int width);
FloatImm ValueToFloatImm(double value, int width);

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

} // namespace support
} // namespace tvm

Expand Down
70 changes: 53 additions & 17 deletions tests/python/unittest/test_tir_imm_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,35 @@ def compare_float_value(value, expect):
assert math.isinf(expect)


@pytest.mark.parametrize(
"dtype, literals",
[
["float16", [-65504.0, 3.14, 65504.0, np.inf, np.nan]],
["bfloat16", [-3.38953139e38, 3.38953139e38, 3.14]],
["float32", [np.finfo("float32").min, 3.14, np.finfo("float32").max, np.inf, np.nan]],
["float64", [np.finfo("float64").min, 3.14, np.finfo("float64").max, np.inf, np.nan]],
],
)
def test_tir_make_floatimm(dtype, literals):
for l in literals:
imm = tir.const(l, dtype)
compare_float_value(imm.value, l)


@pytest.mark.parametrize(
"dtype, literals",
[
["float16", [-65505.0, 65505.0]],
["float32", [-3.402e39, 3.402e39]],
],
)
def test_tir_invalid_floatimm(dtype, literals):
"""Currently only fp16 and fp32 have range check."""
for l in literals:
with pytest.raises(tvm.TVMError):
tir.const(l, dtype)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
@pytest.mark.parametrize("literal", [3.14, np.nan, np.inf])
def test_tir_special_floatimms(dtype, literal):
Expand All @@ -111,23 +140,9 @@ def test_tir_special_floatimms(dtype, literal):


@tvm.testing.requires_llvm()
def test_tir_floatimm_overflow():
# Behavior check: if literal value is out of dtype range, the
def test_tir_too_large_literal_f64():
# Behavior check: if literal f64 value is out of dtype range, the
# object is still constructed, and eval to infinity.
@T.prim_func
def imm_overflow_fp16() -> T.float16:
T.evaluate(T.ret(T.float16(65536), dtype="float16"))

f = tvm.build(imm_overflow_fp16, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp32() -> T.float32:
T.evaluate(T.ret(T.float32(3.4028e39), dtype="float32"))

f = tvm.build(imm_overflow_fp32, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp64() -> T.float64:
T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64"))
Expand All @@ -136,6 +151,27 @@ def imm_overflow_fp64() -> T.float64:
assert math.isinf(f())


@pytest.mark.parametrize(
"literal, expect_dtype",
[
(256, "int32"),
(2147483647, "int32"),
(-2147483648, "int32"),
(2147483648, "int64"),
(-2147483649, "int64"),
(3.14159, "float32"),
(np.finfo("float32").min, "float32"),
(np.finfo("float32").max, "float32"),
(-3.402e39, "float64"),
(3.402e39, "float64"),
],
)
def test_tir_const_auto_dtype(literal, expect_dtype):
x = tir.const(literal, dtype=None)
assert x.dtype == expect_dtype
assert x.value == literal


@tvm.testing.requires_llvm()
def test_tir_floatimm_const_fold():
# Behavior check: folding fp32 match platform f32 arithmetic
Expand All @@ -149,7 +185,7 @@ def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32:
for x, y in [(3.14e30, 3.14e30), (-3.14e30, 3.14e30)]:
assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y)

seed = random.randrange(sys.maxsize)
seed = random.randint(0, 2147483648)
print(
"\nThis test is intentionally non-deterministic, "
"if it fails please report it in github issue together with this seed {}\n".format(seed)
Expand Down

0 comments on commit bcf0ac2

Please sign in to comment.