Skip to content

Commit e5135a9

Browse files
Alex4210987xinxyxiaoLeiWang1999
authored
fix bug&add amd examples (tile-ai#966)
* [Enhancement] Refactor buffer index handling for improved precision and clarity (tile-ai#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Enhance AMD example script and update CI workflows - Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability. * Remove redundant tool cache cleanup step in AMD CI workflow * Remove `torch` dependency from `requirements-rocm.txt` to streamline requirements. * Add new AMD FlashAttention example and test script - Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang. - Added `test.sh` script to facilitate running the new example with specified parameters. - Enhanced the overall structure and organization of the example for better clarity and usability. * Update configurations in `example_amd_flash_attn_fwd.py` for autotuner - Reduced the number of threads and `num_split_q` options for improved performance. - Adjusted `panel_size` options to streamline configuration settings. * Update submodule 'tvm' to commit 6ccc74f622c7ec4ac25d430d0f6546e7b9edb217 * Update submodule 'tvm' to commit 14ff70ab142b9e5a31bbf9c7923c8a697d41e86c * Add example for AMD Flash Attention backward pass implementation - Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang. - Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps. - Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters. - Included reference implementation for validation against PyTorch's attention mechanism. This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications. * Enhance AMD Flash Attention example with additional testing capabilities - Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation. - Improved the main function to allow for better parameter configuration and benchmarking. - Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example. This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications. * Update submodule TVM to commit a64a5926a6e59f5417ef2501f9d88b467337cf6a * Refactor HIP intrinsic rules to CUDA - Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules. - Adjusted include paths for better organization and clarity in the code structure. * Update AMD CI workflow to uninstall specific PyTorch packages before installation - Removed the installation of `flash_attn==2.5.8` to streamline the CI process. - Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts. * Remove unused shared memory allocations in AMD Flash Attention backward example - Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance. - This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead. * Remove unnecessary pip uninstall command from AMD CI workflow - Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions. - This change simplifies the CI process and reduces potential overhead during package management. * Refactor DispatchHIPWarpActiveMask function in HIP intrinsic rules - Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity. - Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues. * Refactor formatting of HIP intrinsic rule registrations - Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining. - No functional changes were made; this update focuses on code style improvements to enhance maintainability. * Update file name and documentation for HIP intrinsic rules - Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules. - Updated the file documentation to clarify its purpose as related to HIP rather than CUDA. * Enhance DispatchHIPShuffle function with clang-analyzer comments - Added NOLINTBEGIN and NOLINTEND comments to the DispatchHIPShuffle function to suppress clang-analyzer warnings related to inner pointer usage. - This change improves code clarity and maintains compliance with static analysis tools. * lint fix * fix * Enhance autotuner configurations in example_amd_flash_attn_fwd.py by adding new block sizes, stages, and panel sizes. Update test script to use relative Python path and adjust parameters for consistency. * Add backward attention example to test script - Extended the test.sh script to include a new backward attention example using example_amd_flash_attn_bwd.py. - Added parameters for batch size, context length, and head dimensions to ensure consistency with the forward example. - Updated the command for the backward tile example to match the new configuration. * Refactor FlashAttention implementation in example_amd_flash_attn_bwd.py and example_amd_flash_attn_fwd.py - Introduced new functions for forward and backward configurations to enhance autotuning capabilities. - Updated the FlashAttention forward and backward functions to improve performance and maintainability. - Adjusted test script parameters for consistency and clarity, including the addition of group handling. - Enhanced the autotuner configurations by refining block sizes and stages for better performance tuning. - Updated the main function to reflect changes in parameter names and types for better usability. * Enhance FlashAttention backward implementation in example_amd_flash_attn_bwd.py - Updated the backward function to return additional outputs, including log-sum-exp (LSE) values for improved gradient calculations. - Refined autotuner configurations by adding new block sizes and adjusting parameters for better performance tuning. - Improved shared memory usage in the backward pass to optimize memory access patterns and enhance computational efficiency. - Updated the main function to reflect changes in parameter handling and ensure consistency with the forward pass. - Enhanced correctness checks in the main function to include LSE validation alongside gradient checks. * Enhance FlashAttention backward implementation in example_amd_flash_attn_bwd.py - Introduced a scaling factor for improved numerical stability in gradient calculations. - Optimized shared memory usage by adding new shared buffers for intermediate calculations. - Refined the handling of tensor fragments to improve performance and maintainability. - Updated the main function to ensure compatibility with the new output parameters for backward operations. - Removed unnecessary parameters from the test script to streamline execution. * Refactor FlashAttention implementation in example_amd_flash_attn_bwd.py and example_mha_bwd.py - Updated the forward and backward functions to improve numerical stability and performance. - Enhanced shared memory usage by optimizing buffer allocations and reducing unnecessary parameters. - Adjusted autotuner configurations for better performance tuning and compatibility with new output parameters. - Added debugging and benchmarking functions for improved correctness verification and performance analysis. - Updated the main function to reflect changes in parameter handling and ensure consistency across examples. * Enhance FlashAttention backward implementation in example_amd_flash_attn_bwd.py - Updated scaling factor application for improved numerical stability in gradient calculations. - Refined tensor handling to ensure consistency with forward pass operations. - Optimized atomic operations for writing gradients to dK and dV using fp32 for better precision. - Adjusted comments for clarity and alignment with standard implementation practices. * Expand autotuner configurations in example_amd_flash_attn_bwd.py and update test.sh - Increased the range of block sizes and stages for forward and backward configurations to enhance performance tuning. - Adjusted the test script to include additional parameters for batch size and head dimensions, ensuring consistency with the forward example. - Improved comments for clarity and alignment with the updated configurations. * Enhance performance calculations and benchmarking in example_amd_flash_attn_bwd.py - Updated FLOPs calculation to account for both forward and backward passes, clarifying the total computational cost. - Modified benchmarking functions to evaluate the complete forward and backward performance of both reference and Tile-lang implementations. - Improved comments for better understanding of the performance metrics and implementation details. - Removed unnecessary parameter from test.sh to streamline execution. * Remove forward attention test commands from test.sh and retain backward attention execution for streamlined testing. * Refactor FlashAttention forward and backward implementations in example_amd_flash_attn_bwd.py and example_amd_flash_attn_fwd.py - Updated the forward function to return both output and log-sum-exp (LSE) values for improved gradient calculations. - Enhanced autotuner configurations for forward pass, including new parameters for better performance tuning. - Refined scaling factor calculations for numerical stability in both forward and backward passes. - Improved comments and documentation for clarity and consistency across implementations. - Adjusted main function to reflect changes in parameter handling and ensure compatibility with new output requirements. * Refactor FlashAttention implementation in example_amd_flash_attn_bwd.py - Removed outdated comments and improved clarity in the code. - Enhanced the forward function to consistently return output and log-sum-exp (LSE) values. - Updated autotuner configurations to include new parameters for better performance tuning. - Refined tensor handling and scaling factor calculations for improved numerical stability. - Adjusted the main function to ensure compatibility with updated output requirements and parameter handling. * Enhance FlashAttention backward implementation in example_amd_flash_attn_bwd.py - Updated configuration parameters for backward calculations, including new options for block sizes, threads, and rasterization. - Added new parameters (k_pack, qk_coalesced_width, v_coalesced_width) to improve performance tuning and memory access patterns. - Modified tensor copy operations to utilize coalesced widths for optimized memory loads. - Enhanced GEMM operations with k_pack for improved computational efficiency. - Refined the configuration generation logic to accommodate the new parameters, ensuring comprehensive coverage for backward pass scenarios. * Refactor configuration and tensor operations in example_amd_flash_attn_bwd.py - Updated backward configuration parameters to include larger block sizes and a wider range of threads for enhanced performance tuning. - Removed unnecessary parameters (k_pack, qk_coalesced_width, v_coalesced_width) from function signatures and tensor operations to simplify the implementation. - Optimized tensor copy operations by eliminating coalesced width specifications, streamlining memory access patterns. - Adjusted GEMM operations to improve computational efficiency without the use of k_pack. * Enhance HIP code generation and FP8 type support - Added support for additional FP8 types (e4m3, e4m3b11fnuz, e5m2fnuz, e8m0) in codegen_hip.cc to improve compatibility. - Updated error logging to include unsupported FP8 type details for better debugging. - Implemented handling for loop break and no-op register management in HIP within VisitExpr_ method. - Introduced new FP8 vector types (e5 and e8) in hip_fp8.h for enhanced functionality. - Added overloads for AtomicAdd in common.h to support both pointer and value arguments. * Enhance FP8 type support and clarify accumulator handling in HIP - Expanded FP8 type support in codegen_hip.cc to include additional float8 formats. - Updated gemm.h to clarify the handling of the accumulator when clear_accum is true. - Added comments in hip_fp8.h to indicate that E8M0 types are not supported in the current HIP version. * Remove deprecated files and update print statements for clarity in example_amd_flash_attn_bwd.py * Update print statement formatting for clarity in example_amd_flash_attn_bwd.py * Remove redundant verification results summary print statement in example_amd_flash_attn_bwd.py for cleaner output. * Fix formatting inconsistencies in example_amd_flash_attn_bwd.py and example_amd_flash_attn_fwd.py by adding spaces for improved readability in configuration parameters and print statements. * Refactor and enhance HIP code generation for improved FP8 support - Reorganized and cleaned up code in codegen_hip.cc for better readability and maintainability. - Enhanced handling of FP8 types, including additional formats and improved error logging for unsupported types. - Updated AtomicAdd function in common.h to streamline its implementation. - Refined the PrintVecElemLoadExpr method to handle volatile loads more effectively. - Added function to manage the addition of new functions in the code generation process. * Fix formatting issue in HIP code generation for MFMA call - Adjusted the indentation of the MFMA call code block in codegen_hip.cc for improved readability and consistency. * Refactor HIP code generation and enhance FP8 type handling - Reintroduced necessary includes and reorganized code in codegen_hip.cc for improved structure and readability. - Enhanced the GetFP8Type function to support additional FP8 formats and improved error handling for unsupported types. - Updated PrintType and PrintVecElemLoadExpr methods to better manage type conversions and vector element loading. - Refined the AddFunction method to streamline function addition in the code generation process. * Remove unnecessary blank line in example_amd_flash_attn_bwd.py for improved code cleanliness. * Refactor backward attention implementation in example_amd_flash_attn_bwd.py - Updated the GEMM operation to use shared memory for improved performance. - Adjusted parallelization parameters to enhance efficiency in the backward pass. * Fix formatting by removing an unnecessary blank line in example_amd_flash_attn_bwd.py for improved code cleanliness. * Add additional test cases for `assert_tl_matmul_correctness` with `float8_e4m3fnuz` and various configurations * Refactor test case formatting for `assert_tl_matmul_correctness` in `test_tilelang_gemm_mfma_intrinsic.py` --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent b90e784 commit e5135a9

File tree

9 files changed

+625
-325
lines changed

9 files changed

+625
-325
lines changed

examples/amd/example_amd_flash_attn_bwd.py

Lines changed: 525 additions & 285 deletions
Large diffs are not rendered by default.

examples/amd/example_amd_flash_attn_fwd.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_configs():
3434
block_N = [32, 64, 128, 256]
3535
threads = [128, 256, 512]
3636
num_split_q = [64, 128, 256]
37-
num_stages = [0]
37+
num_stages = [0, 1]
3838
enable_rasterization = [True]
3939
k_pack = [2]
4040
panel_size = [7, 8]
@@ -60,18 +60,6 @@ def get_configs():
6060
"qk_coalesced_width": qkw,
6161
"v_coalesced_width": vw,
6262
})
63-
valid_configs.append({
64-
'block_M': 64,
65-
'block_N': 64,
66-
'num_split_q': 64,
67-
'threads': 256,
68-
'num_stages': 1,
69-
'enable_rasterization': True,
70-
'k_pack': 2,
71-
'panel_size': 64,
72-
'qk_coalesced_width': 8,
73-
'v_coalesced_width': 8,
74-
})
7563
return valid_configs
7664

7765

@@ -95,7 +83,7 @@ def fast_flashattn(
9583
qk_coalesced_width: int,
9684
v_coalesced_width: int,
9785
):
98-
scale = (1.0 / dim)**0.5 * 1.44269504
86+
scale = (1.0 / dim)**0.5
9987
head_kv = heads // groups
10088
q_shape = [batch, seq_len, heads, dim]
10189
kv_shape = [batch, seq_len, head_kv, dim]
@@ -185,15 +173,15 @@ def main(
185173
T.reduce_max(acc_s, m_i, dim=1, clear=False)
186174

187175
for i in T.Parallel(block_M):
188-
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
176+
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
189177
l_i[i] *= sf
190178
scale_factor[i] = sf
191179

192180
for i, j in T.Parallel(block_M, dim):
193181
acc_o[i, j] *= scale_factor[i]
194182

195183
for i, j in T.Parallel(block_M, block_N):
196-
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
184+
acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)
197185

198186
T.reduce_sum(acc_s, row_sum, dim=1)
199187
for i in T.Parallel(block_M):

examples/amd/test.sh

Lines changed: 0 additions & 10 deletions
This file was deleted.

examples/flash_attention/example_mha_bwd.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,10 @@ def flash_fwd(
3838
scores_sum = T.alloc_fragment([block_M], accum_dtype)
3939
logsum = T.alloc_fragment([block_M], accum_dtype)
4040

41-
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
4241
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
4342
T.fill(acc_o, 0)
4443
T.fill(logsum, 0)
4544
T.fill(scores_max, -T.infinity(accum_dtype))
46-
# T.copy(Q_shared, Q_local)
47-
# for i, j in T.Parallel(block_M, dim):
48-
# Q_local[i, j] *= scale
4945
loop_range = (
5046
T.ceildiv(
5147
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
@@ -192,9 +188,6 @@ def flash_bwd(
192188

193189
T.annotate_layout({
194190
dQ: make_dq_layout(dQ),
195-
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
196-
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
197-
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
198191
})
199192
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
200193
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)

src/target/codegen_hip.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,18 @@ static std::string GetFP8Type(DataType type) {
4141
stream << "fp8_e4" << vec << "_t";
4242
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
4343
stream << "fp8_e4" << vec << "_t";
44+
} else if (type.code() == DataType::kFloat8_e4m3) {
45+
stream << "fp8_e4" << vec << "_t";
46+
} else if (type.code() == DataType::kFloat8_e4m3b11fnuz) {
47+
stream << "fp8_e4" << vec << "_t";
4448
} else if (type.code() == DataType::kFloat8_e5m2) {
4549
stream << "fp8_e5" << vec << "_t";
50+
} else if (type.code() == DataType::kFloat8_e5m2fnuz) {
51+
stream << "fp8_e5" << vec << "_t";
52+
} else if (type.code() == DataType::kFloat8_e8m0fnu) {
53+
stream << "fp8_e8" << vec << "_t";
4654
} else {
47-
LOG(FATAL) << "Unsupported FP8 type in HIP codegen";
55+
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
4856
}
4957
return stream.str();
5058
}
@@ -926,10 +934,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
926934
{"float8_e4m3fnuzx8", "long"},
927935
{"float32x16", "float32x16"}};
928936
std::string call_mfma_code = R"({
929-
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
930-
*((({B_dtype}*){b_ref}) + {b_bias}),
931-
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
932-
})";
937+
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
938+
*((({B_dtype}*){b_ref}) + {b_bias}),
939+
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
940+
})";
933941
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
934942
Replacer replacer;
935943

@@ -955,6 +963,13 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
955963
op->args, true, os);
956964
} else if (op->op.same_as(tl::tl_gemm_sp())) {
957965
LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
966+
} else if (op->op.same_as(tl::loop_break())) {
967+
this->PrintIndent();
968+
this->stream << "break;\n";
969+
} else if (op->op.same_as(tl::no_set_max_nreg())) {
970+
// HIP doesn't need explicit register management like CUDA
971+
// This is a no-op for HIP
972+
return;
958973
} else {
959974
CodeGenC::VisitExpr_(op, os);
960975
}
@@ -1160,7 +1175,8 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
11601175
os << "bfloat16_t";
11611176
os << '(' << std::scientific << op->value << 'f' << ')';
11621177
return;
1163-
} else if (op->dtype.is_float8_e4m3fnuz()) {
1178+
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
1179+
op->dtype.is_float8_e4m3fn()) {
11641180
os << "fp8_e4_t";
11651181
os << '(' << std::scientific << op->value << 'f' << ')';
11661182
return;

src/tl_templates/hip/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,13 @@ template <typename T1, typename T2>
109109
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
110110
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
111111
}
112+
113+
// Overload for when the first argument is a value instead of a pointer
114+
template <typename T1, typename T2>
115+
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
116+
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
117+
}
118+
119+
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
120+
return atomicAdd(&ref, static_cast<T1>(val));
121+
}

src/tl_templates/hip/gemm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
7070
typename B_type, typename C_type, typename AccDataType = float>
7171
class GemmTensorOp {
7272
public:
73-
static_assert(!clear_accum, "clear_accum=true is not supported yet");
73+
// Note: clear_accum=true is not fully supported in HIP implementation
74+
// but we'll handle it by manually clearing the accumulator
75+
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
7476

7577
static constexpr int micro_size_x = 16;
7678
static constexpr int micro_size_y = 16;

src/tl_templates/hip/hip_fp8.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
66
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
77

8+
// Additional FP8 types for compatibility
9+
using fp8_e5_t = __hip_fp8_e5m2_fnuz;
10+
using fp8_e5_2_t = __hip_fp8x2_e5m2_fnuz;
11+
// Note: E8M0 types are not supported in current HIP version
12+
// using fp8_e8_t = __hip_fp8_e8m0_fnuz;
13+
// using fp8_e8_2_t = __hip_fp8x2_e8m0_fnuz;
14+
815
// Simple wrapper that provides member access for generated code
916
struct fp8_e4_4_t {
1017
union {
@@ -43,6 +50,54 @@ struct __align__(16) fp8_e4_16_t {
4350
fp8_e4_8_t y;
4451
};
4552

53+
// FP8 E5M2 vector types
54+
struct fp8_e5_4_t {
55+
union {
56+
__hip_fp8x4_e5m2_fnuz data;
57+
struct {
58+
fp8_e5_t x, y, z, w;
59+
};
60+
};
61+
__device__ fp8_e5_4_t() = default;
62+
__device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {}
63+
__device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; }
64+
};
65+
66+
struct __align__(8) fp8_e5_8_t {
67+
fp8_e5_4_t x;
68+
fp8_e5_4_t y;
69+
};
70+
71+
struct __align__(16) fp8_e5_16_t {
72+
fp8_e5_8_t x;
73+
fp8_e5_8_t y;
74+
};
75+
76+
// FP8 E8M0 vector types - not supported in current HIP version
77+
/*
78+
struct fp8_e8_4_t {
79+
union {
80+
__hip_fp8x4_e8m0_fnuz data;
81+
struct {
82+
fp8_e8_t x, y, z, w;
83+
};
84+
};
85+
__device__ fp8_e8_4_t() = default;
86+
__device__ fp8_e8_4_t(const __hip_fp8x4_e8m0_fnuz &val) : data(val) {}
87+
__device__ operator __hip_fp8x4_e8m0_fnuz() const { return data; }
88+
};
89+
90+
struct __align__(8) fp8_e8_8_t {
91+
fp8_e8_4_t x;
92+
fp8_e8_4_t y;
93+
};
94+
95+
struct __align__(16) fp8_e8_16_t {
96+
fp8_e8_8_t x;
97+
fp8_e8_8_t y;
98+
};
99+
*/
100+
46101
__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
47102
fp8_e4_t w) {
48103
// reinterpret the 4 fp8_e4_t values to signed char value and shift

testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ def test_assert_tl_matmul():
238238
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
239239
assert_tl_matmul_correctness(
240240
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
241+
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
242+
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
243+
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
244+
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
245+
assert_tl_matmul_correctness(
246+
128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
241247

242248

243249
if __name__ == "__main__":

0 commit comments

Comments
 (0)