Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fixed linker issues related to fp8 buffer_comparator functions #66

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -123,13 +124,17 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -145,6 +150,9 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

Expand Down
Loading