Skip to content

Commit 2c127b7

Browse files
committed
Fix gpu_kernel_test
1 parent 051b967 commit 2c127b7

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

xla/stream_executor/gpu/gpu_kernel_test.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <array>
1617
#include <cstddef>
1718
#include <cstdint>
1819
#include <memory>
@@ -39,6 +40,7 @@ limitations under the License.
3940
#include "xla/stream_executor/stream_executor.h"
4041
#include "xla/stream_executor/typed_kernel_factory.h"
4142
#include "xla/tsl/lib/core/status_test_util.h"
43+
#include "xla/tsl/platform/logging.h"
4244
#include "xla/tsl/platform/statusor.h"
4345

4446
namespace stream_executor::gpu {
@@ -164,7 +166,12 @@ TEST_F(GpuKernelTest, ArrayArgByValue) {
164166
)";
165167

166168
MultiKernelLoaderSpec spec(/*arity=*/2);
167-
spec.AddCudaPtxInMemory(copy_kernel, "copy_kernel");
169+
if (executor_->GetPlatform()->id() ==
170+
stream_executor::rocm::kROCmPlatformId) {
171+
spec.AddInProcessSymbol(internal::GetCopyKernel(), "copy_kernel");
172+
} else {
173+
spec.AddCudaPtxInMemory(copy_kernel, "copy_kernel");
174+
}
168175

169176
TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream());
170177
TF_ASSERT_OK_AND_ASSIGN(auto kernel, executor_->LoadKernel(spec));
@@ -174,25 +181,22 @@ TEST_F(GpuKernelTest, ArrayArgByValue) {
174181
DeviceMemory<char> dst = executor_->AllocateArray<char>(kLength, 0);
175182
TF_ASSERT_OK(stream->MemZero(&dst, kLength));
176183

177-
struct ByValArg {
178-
std::byte storage[16];
179-
};
180-
ByValArg arg;
184+
std::array<std::byte, 16> storage;
181185
int i = 0;
182-
for (auto& element : arg.storage) {
186+
for (auto& element : storage) {
183187
element = static_cast<std::byte>(i++);
184188
}
185189

186190
// Launch kernel.
187-
auto args = stream_executor::PackKernelArgs(/*shmem_bytes=*/0, dst, arg);
191+
auto args = stream_executor::PackKernelArgs(/*shmem_bytes=*/0, dst, storage);
188192
TF_ASSERT_OK(kernel->Launch(ThreadDim(), BlockDim(), stream.get(), *args));
189193

190194
// Copy data back to host.
191195
std::byte dst_host[16] = {};
192196
TF_ASSERT_OK(stream->Memcpy(dst_host, dst, kLength));
193197
TF_ASSERT_OK(stream->BlockHostUntilDone());
194198

195-
EXPECT_THAT(dst_host, ::testing::ElementsAreArray(arg.storage));
199+
EXPECT_THAT(dst_host, ::testing::ElementsAreArray(storage));
196200
}
197201
} // namespace
198202
} // namespace stream_executor::gpu

xla/stream_executor/gpu/gpu_test_kernels.cu.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
1717

18+
#include <array>
1819
#include <cstdint>
1920

2021
#include "xla/stream_executor/kernel_spec.h"
@@ -46,6 +47,14 @@ __global__ void AddI32Ptrs3(Ptrs3<int32_t> ptrs) {
4647
int index = threadIdx.x + blockIdx.x * blockDim.x;
4748
ptrs.c[index] = ptrs.a[index] + ptrs.b[index];
4849
}
50+
51+
__global__ void CopyKernel(std::byte* dst, std::array<std::byte, 16> byval) {
52+
if (threadIdx.x == 0) {
53+
for (int i = 0; i < 16; i++) {
54+
dst[i] = byval[i];
55+
}
56+
}
57+
}
4958
}
5059

5160
void* GetAddI32Kernel() { return reinterpret_cast<void*>(&AddI32); }
@@ -56,6 +65,8 @@ void* GetIncAndCmpKernel() { return reinterpret_cast<void*>(&IncAndCmp); }
5665

5766
void* GetAddI32Ptrs3Kernel() { return reinterpret_cast<void*>(&AddI32Ptrs3); }
5867

68+
void* GetCopyKernel() { return reinterpret_cast<void*>(&CopyKernel); }
69+
5970
} // namespace internal
6071

6172
MultiKernelLoaderSpec GetAddI32KernelSpec() {

xla/stream_executor/gpu/gpu_test_kernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ void* GetIncAndCmpKernel();
9898
// StreamExecutor arguments packing for custom C++ types.
9999
void* GetAddI32Ptrs3Kernel();
100100

101+
void* GetCopyKernel();
102+
101103
} // namespace internal
102104

103105
// Returns an in-process kernel loader spec for the `AddI32` kernel above.

0 commit comments

Comments
 (0)