diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index c36f5bdc1..666ffa4fb 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -480,7 +480,7 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t, os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.is_bfloat16()) { - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" + os << "((bfloat16x2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index 4449bac57..25b30cc1b 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -67,7 +67,7 @@ using half_t = float16_t; using bfloat16_t = hip_bfloat16; struct bfloat16x2 { - bfloat16_t data[2]; + bfloat16_t x, y; }; struct bfloat16x4 { diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 556642bb2..b8690ce08 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -56,6 +56,7 @@ def tl_matmul( A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) C_shared_shape = ( + block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y,