Skip to content

Commit

Permalink
[METAL] Fix int8 vectorized cast (#14962)
Browse files Browse the repository at this point in the history
Current codegen output `(half4)*(device uint*)A` tries to create a `int32`
number and then cast it to `half4`, which is not the expected behavior.

As Metal supports `uchar4` and `char4` types, we can direct use them to
solve that problem.
  • Loading branch information
Hzfengsy authored May 26, 2023
1 parent 1aeb34a commit 6198c7f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
5 changes: 0 additions & 5 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,6 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int";
return;
}
switch (t.bits()) {
case 8:
os << "char";
Expand Down
30 changes: 24 additions & 6 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np

import tvm.testing
import tvm
import tvm.script
import tvm.testing
from tvm import te
from tvm.script import tir as T


Expand Down Expand Up @@ -149,7 +149,25 @@ def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")):
np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_vectorized_uint8():
@T.prim_func
def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
for i in T.thread_binding(4, thread="threadIdx.x"):
for j in T.vectorized(4):
with T.block("block"):
vi = T.axis.spatial(16, i * 4 + j)
B[vi] = T.Cast("float32", A[vi])

dev = tvm.metal()
a = np.arange(16).astype("uint8")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((16,), "float32", dev)
f = tvm.build(func, target="metal")
f(a_nd, b_nd)
np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_ramp()
test_metal_inf_nan()
test_metal_erf()
tvm.testing.main()

0 comments on commit 6198c7f

Please sign in to comment.