Skip to content

Commit 4483ada

Browse files
authored
Merge pull request #2681 from ROCm/r1.15-rocm61-bufcomparator-fix
buffer comparator fix
2 parents fb3cab0 + 88d22d2 commit 4483ada

File tree

5 files changed

+72
-14
lines changed

5 files changed

+72
-14
lines changed

tensorflow/compiler/xla/service/gpu/BUILD

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ load(
2323
"//tensorflow/core/platform:default/cuda_build_defs.bzl",
2424
"if_cuda_is_configured",
2525
)
26-
load(
27-
"@local_config_rocm//rocm:build_defs.bzl",
28-
"if_rocm_is_configured",
29-
)
26+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", "rocm_copts")
3027

3128
package(
3229
default_visibility = [":friends"],
@@ -1416,6 +1413,19 @@ cc_library(
14161413
],
14171414
)
14181415

1416+
cc_library(
1417+
name = "stream_executor_util_kernel",
1418+
srcs = ["stream_executor_util_kernel.cu.cc"],
1419+
tags = ["no_rocm"],
1420+
copts = rocm_copts(),
1421+
deps = if_rocm_is_configured([
1422+
"@local_config_rocm//rocm:rocm_headers",
1423+
]) +
1424+
if_cuda_is_configured([
1425+
"@local_config_cuda//cuda:cuda_headers",
1426+
])
1427+
)
1428+
14191429
cc_library(
14201430
name = "stream_executor_util",
14211431
srcs = ["stream_executor_util.cc"],
@@ -1439,6 +1449,7 @@ cc_library(
14391449
"//tensorflow/core/profiler/lib:traceme",
14401450
"//tensorflow/stream_executor:kernel_spec",
14411451
"//tensorflow/stream_executor:gpu_asm_opts",
1452+
":stream_executor_util_kernel",
14421453
"@com_google_absl//absl/memory",
14431454
"@com_google_absl//absl/strings",
14441455
"@com_google_absl//absl/types:span",

tensorflow/compiler/xla/service/gpu/stream_executor_util.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -414,28 +414,28 @@ static void InitializeTypedBuffer(se::Stream* stream,
414414
// Nothing more to do
415415
return;
416416
}
417-
#ifdef GOOGLE_CUDA
418417
// Repeat the host_buffer_size elements at the start of `buf` to the end
419418
CHECK_EQ(elements_to_fill, buffer.size() / sizeof(T) - host_buffer_size);
420419
se::StreamExecutor* executor = stream->parent();
421-
auto kernel =
422-
se::TypedKernelFactory<se::DeviceMemoryBase, int64, int64>::Create(
423-
executor, "RepeatBufferKernel", repeat_buffer_kernel::kernel());
420+
421+
auto kernel =
422+
executor->CreateTypedKernel<se::DeviceMemoryBase, int64_t, int64_t>(
423+
"RepeatBufferKernel", repeat_buffer_kernel::kernel());
424424
if (!kernel.ok()) {
425425
LOG(FATAL) << "Could not create RepeatBufferKernel: " << kernel.status();
426426
}
427427
// Launch the kernel with at least host_buffer_bytes threads. Each thread
428428
// will read one byte of `host_buffer` from the start of `buffer`, where the
429429
// Memcpy call(s) above put it, and scatter it through the rest of `buffer`.
430-
constexpr int64 host_buffer_bytes = host_buffer_size * sizeof(T);
430+
constexpr int64_t host_buffer_bytes = host_buffer_size * sizeof(T);
431431
constexpr int threads_per_block = 256;
432432
constexpr int blocks_per_grid =
433433
(host_buffer_bytes + threads_per_block - 1) / threads_per_block;
434-
TF_CHECK_OK(stream->ThenLaunch(se::ThreadDim(threads_per_block, 1, 1),
435-
se::BlockDim(blocks_per_grid, 1, 1), *kernel,
434+
stream->ThenLaunch(se::ThreadDim(threads_per_block, 1, 1),
435+
se::BlockDim(blocks_per_grid, 1, 1),
436+
*kernel.ValueOrDie(),
436437
buffer, host_buffer_bytes,
437-
static_cast<int64>(buffer.size())));
438-
#endif
438+
static_cast<int64_t>(buffer.size()));
439439
}
440440

441441
void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,

tensorflow/compiler/xla/service/gpu/stream_executor_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ limitations under the License.
3333
namespace xla {
3434
namespace gpu {
3535

36+
namespace repeat_buffer_kernel {
37+
void* kernel();
38+
} // namespace repeat_buffer_kernel
39+
3640
// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU.
3741
bool IsVoltaOrLater(const se::StreamExecutor& stream_exec);
3842

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
18+
namespace xla {
19+
namespace gpu {
20+
namespace repeat_buffer_kernel {
21+
22+
namespace {
23+
// Populate the last `buffer_size - repeat_size` bytes of `buffer` by repeating
24+
// the first `repeat_size` bytes. This should be launched with at least
25+
// `repeat_size` threads in total.
26+
__global__ void RepeatBufferKernel(char* buffer, int64_t repeat_size,
27+
int64_t buffer_size) {
28+
int64_t global_index = blockDim.x * blockIdx.x + threadIdx.x;
29+
if (global_index >= repeat_size) {
30+
return;
31+
}
32+
const char src_value = buffer[global_index];
33+
for (int64_t dst_index = global_index + repeat_size; dst_index < buffer_size;
34+
dst_index += repeat_size) {
35+
buffer[dst_index] = src_value;
36+
}
37+
}
38+
} // namespace
39+
void* kernel() { return reinterpret_cast<void*>(RepeatBufferKernel); }
40+
} // namespace repeat_buffer_kernel
41+
} // namespace gpu
42+
} // namespace xla

third_party/gpus/rocm_configure.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def _rocm_include_path(repository_ctx, rocm_config):
191191
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include")
192192
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include")
193193
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include")
194-
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include")
194+
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include")
195+
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/17/include")
195196
inc_dirs.append(rocm_config.rocm_toolkit_path + "/lib/llvm/lib/clang/17/include")
196197
inc_dirs.append(rocm_config.rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
197198
inc_dirs.append(rocm_config.rocm_toolkit_path + "/lib/llvm/lib/clang/19/include")

0 commit comments

Comments
 (0)