Skip to content

Commit

Permalink
[OPENCL] Always use convert_T for type conversion (#14972)
Browse files Browse the repository at this point in the history
This PR changes the Cast in OpenCL to always relying on convert_T to get closer to the spec and more reliable.
  • Loading branch information
tqchen authored Jun 1, 2023
1 parent ca30b13 commit 7f02606
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 67 deletions.
10 changes: 6 additions & 4 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,17 +370,19 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType

std::string CodeGenOpenCL::CastTo(std::string value, DataType target) {
std::ostringstream os;
if (target.lanes() == 1) {
os << "((";
if (target == DataType::Bool()) {
os << "(";
os << "(";
this->PrintType(target, os);
os << ")" << value << ")";
} else { // convert vector type
return os.str();
} else {
os << "(";
os << "convert_";
this->PrintType(target, os);
os << "(" << value << "))";
return os.str();
}
return os.str();
}

void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) {
Expand Down
80 changes: 36 additions & 44 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def check_if_then_else(dev, n, dtype):
max_lhs = tvm.tir.const(2, dtype=dtype)
max_rhs = tvm.tir.if_then_else(A[0] > 0, true_value, false_value)
C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand All @@ -48,9 +50,11 @@ def check_select(dev, n, dtype):
max_lhs = tvm.tir.const(2, dtype=dtype)
max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)

a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
Expand All @@ -76,9 +80,11 @@ def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand All @@ -102,9 +108,11 @@ def check_max(dev, n, dtype):
max_lhs = A[0] + tvm.tir.const(1, dtype=dtype)
max_rhs = tvm.tir.const(0, dtype=dtype)
C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)

a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
Expand Down Expand Up @@ -150,50 +158,34 @@ def check_type_casting(ctx, n, dtype):
tvm.tir.all(
*[
i // block_size == tvm.tir.const(3, "int32"),
i % block_size == tvm.tir.const(3, "int32"),
i % 3 == tvm.tir.const(1, "int32"),
]
),
tvm.tir.const(1, dtype),
tvm.tir.const(0, dtype),
),
name="C",
)
s = te.create_schedule(C.op)
(tx, vx) = s[C].split(s[C].op.axis[0], factor=block_size)
s[C].vectorize(vx)
thrx = te.thread_axis("threadIdx.x")

s[C].bind(tx, thrx)
fun = tvm.build(s, [C], target)

# NOTE: test simple convert pattern
func = te.create_prim_func([C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
tx, vx = sch.split(x, factors=[None, block_size])
sch.bind(tx, "threadIdx.x")
sch.vectorize(vx)

fun = tvm.build(sch.mod, target=target)
c = tvm.nd.empty((n,), dtype, ctx)
assembly = fun.imported_modules[0].get_source()

if dtype == "float32":
false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
lcond = "convert_int4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
cond = "({} && {})".format(lcond, rcond)
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
count = assembly.count(select)
assert count == 1
fun(c)

elif dtype == "float16":
false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))"
true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))"
lcond = "convert_short4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))))"
cond = "({} && {})".format(lcond, rcond)
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
count = assembly.count(select)
assert count == 1
fun(c)
lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))"
rcond = "(convert_uint4(((((int4)(((convert_int(get_local_id(0))))+(1*0), ((convert_int(get_local_id(0))))+(1*1), ((convert_int(get_local_id(0))))+(1*2), ((convert_int(get_local_id(0))))+(1*3))) % ((int4)(3, 3, 3, 3))) == ((int4)(1, 1, 1, 1))))))))"
pattern_cond = "({} && {})".format(lcond, rcond)
assert assembly.count(pattern_cond) != 0
fun(c)

dev = tvm.device(target, 0)

check_type_casting(dev, 16, "float32")
check_type_casting(dev, 32, "float32")
# fp16 is not yet supported in ci
# check_type_casting(dev, 16, "float16")

Expand Down
38 changes: 19 additions & 19 deletions tests/python/unittest/test_target_texture_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,25 +1466,25 @@ class TestSimpleTextureToScalarFP16:
["global.texture", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));",
"out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]);",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((convert_int(get_local_id(0))) % 40), ((((convert_int(get_group_id(0))) & 1) * 20) + ((convert_int(get_local_id(0))) / 40)))));",
"out[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = (convert_half(((float*)&v_)[((convert_int(get_group_id(0))) >> 1)]));",
],
),
# 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW)
(
["", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]);"
"out[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = (convert_half(p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))]));"
],
),
# 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c)
(
["global.texture", (1, 1, 40, 40, 4)],
["global.texture", (1, 1, 40, 40, 4)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));",
"write_imageh(out, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), (convert_half4(v_)));",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((((convert_int(get_group_id(0))) * 24) + (convert_int(get_local_id(0)))) % 40), ((((convert_int(get_group_id(0))) * 8) + ((convert_int(get_local_id(0))) >> 3)) / 5))));",
"write_imageh(out, (int2)(((((convert_int(get_group_id(0))) * 24) + (convert_int(get_local_id(0)))) % 40), ((((convert_int(get_group_id(0))) * 8) + ((convert_int(get_local_id(0))) >> 3)) / 5)), (convert_half4(v_)));",
],
),
)
Expand All @@ -1507,16 +1507,16 @@ class TestSimpleTextureToScalarFP32:
["global.texture", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));",
"out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((float*)&v_)[(((int)get_group_id(0)) >> 1)];",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((convert_int(get_local_id(0))) % 40), ((((convert_int(get_group_id(0))) & 1) * 20) + ((convert_int(get_local_id(0))) / 40)))));",
"out[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = ((float*)&v_)[((convert_int(get_group_id(0))) >> 1)];",
],
),
# 2. Buffer (NCHW4c) -> Buffer (NCHW)
(
["", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))];"
"out[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))];"
],
),
)
Expand Down Expand Up @@ -1619,25 +1619,25 @@ class TestTextureToScalarReuseSSAFP16:
["global.texture", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));",
"out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)])));",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((convert_int(get_local_id(0))) % 40), ((((convert_int(get_group_id(0))) & 1) * 20) + ((convert_int(get_local_id(0))) / 40)))));",
"out_sum[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = ((convert_half(((float*)&v_)[((convert_int(get_group_id(0))) >> 1)])) + ((convert_half(((float*)&v_)[((convert_int(get_group_id(0))) >> 1)])) + (convert_half(((float*)&v_)[((convert_int(get_group_id(0))) >> 1)]))));",
],
),
# 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW)
(
["", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))])));"
" out_sum[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = ((convert_half(p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))])) + ((convert_half(p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))])) + (convert_half(p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))]))));"
],
),
# 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c)
(
["global.texture", (1, 1, 40, 40, 4)],
["global.texture", (1, 1, 40, 40, 4)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));",
"write_imageh(out_sum, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), ((convert_half4(v_)) + ((convert_half4(v_)) + (convert_half4(v_)))));",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((((convert_int(get_group_id(0))) * 24) + (convert_int(get_local_id(0)))) % 40), ((((convert_int(get_group_id(0))) * 8) + ((convert_int(get_local_id(0))) >> 3)) / 5))));",
"write_imageh(out_sum, (int2)(((((convert_int(get_group_id(0))) * 24) + (convert_int(get_local_id(0)))) % 40), ((((convert_int(get_group_id(0))) * 8) + ((convert_int(get_local_id(0))) >> 3)) / 5)), ((convert_half4(v_)) + ((convert_half4(v_)) + (convert_half4(v_)))));",
],
),
)
Expand All @@ -1660,16 +1660,16 @@ class TestTextureToScalarReuseSSAFP32:
["global.texture", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));",
"out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + ((float*)&v_)[(((int)get_group_id(0)) >> 1)]));",
"float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)(((convert_int(get_local_id(0))) % 40), ((((convert_int(get_group_id(0))) & 1) * 20) + ((convert_int(get_local_id(0))) / 40)))));",
"out_sum[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = (((float*)&v_)[((convert_int(get_group_id(0))) >> 1)] + (((float*)&v_)[((convert_int(get_group_id(0))) >> 1)] + ((float*)&v_)[((convert_int(get_group_id(0))) >> 1)]));",
],
),
# 2. Buffer (NCHW4c) -> Buffer (NCHW)
(
["", (1, 1, 40, 40, 4)],
["", (1, 4, 40, 40)],
[
"out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]));"
"out_sum[(((convert_int(get_group_id(0))) * 800) + (convert_int(get_local_id(0))))] = (p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))] + (p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))] + p0_comp[(((((convert_int(get_group_id(0))) & 1) * 3200) + ((convert_int(get_local_id(0))) * 4)) + ((convert_int(get_group_id(0))) >> 1))]));"
],
),
)
Expand All @@ -1693,10 +1693,10 @@ class TestLocalArrayToTexture:
(1, 2, 38, 38, 4),
[
"float out_local[4];",
"float4 v_ = READ_IMAGEF(p1_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), ((((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) % 722) / 19))));",
"float4 v__1 = READ_IMAGEF(p2_comp, image_sampler, ((int2)(rw, ((((((((int)get_group_id(0)) * 32) + (((int)get_local_id(0)) >> 2)) / 361) * 12) + (rcb * 3)) + rh))));",
"float4 v_ = READ_IMAGEF(p1_comp, image_sampler, ((int2)(((((convert_int(get_group_id(0))) * 14) + (convert_int(get_local_id(0)))) % 38), (((((convert_int(get_group_id(0))) * 64) + ((convert_int(get_local_id(0))) >> 1)) % 722) / 19))));",
"float4 v__1 = READ_IMAGEF(p2_comp, image_sampler, ((int2)(rw, (((((((convert_int(get_group_id(0))) * 32) + ((convert_int(get_local_id(0))) >> 2)) / 361) * 12) + (rcb * 3)) + rh))));",
"out_local[cb_c] = (out_local[cb_c] + (((float*)&v_)[rcb] * ((float*)&v__1)[cb_c]));",
"write_imagef(out, (int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), (((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) / 19)), vload4(0, out_local + 0));",
"write_imagef(out, (int2)(((((convert_int(get_group_id(0))) * 14) + (convert_int(get_local_id(0)))) % 38), ((((convert_int(get_group_id(0))) * 64) + ((convert_int(get_local_id(0))) >> 1)) / 19)), vload4(0, out_local + 0));",
],
),
)
Expand Down

0 comments on commit 7f02606

Please sign in to comment.