From 3d01da3d5b561dbfcd93d753e6e60dc68905e8c3 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 26 May 2023 12:32:31 +0800 Subject: [PATCH] [METAL] Fix int8 vectorized cast Current codegen output `(half4)*(device uint*)A` tries to create a `int32` number and then cast it to `half4`, which is not the expected behavior. As Metal supports `uchar4` and `char4` types, we can direct use them to solve that problem. --- src/target/source/codegen_metal.cc | 5 ---- .../unittest/test_target_codegen_metal.py | 30 +++++++++++++++---- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index bd2b93016686..b7105e4bcdfc 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -220,11 +220,6 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (t.is_uint()) { os << 'u'; } - if (t.bits() == 8 && t.lanes() == 4) { - // directly 4 8 bit int in integer. - os << "int"; - return; - } switch (t.bits()) { case 8: os << "char"; diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index 3b1cdb4422c5..dcbbba8c9c9f 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np -import tvm.testing +import tvm import tvm.script +import tvm.testing +from tvm import te from tvm.script import tir as T @@ -149,7 +149,25 @@ def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")): np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5) +@tvm.testing.requires_gpu +@tvm.testing.requires_metal +def test_vectorized_uint8(): + @T.prim_func + def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): + for i in T.thread_binding(4, thread="threadIdx.x"): + for j in T.vectorized(4): + with T.block("block"): + vi = T.axis.spatial(16, i * 4 + j) + B[vi] = T.Cast("float32", A[vi]) + + dev = tvm.metal() + a = np.arange(16).astype("uint8") + a_nd = tvm.nd.array(a, dev) + b_nd = tvm.nd.empty((16,), "float32", dev) + f = tvm.build(func, target="metal") + f(a_nd, b_nd) + np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": - test_ramp() - test_metal_inf_nan() - test_metal_erf() + tvm.testing.main()