@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16+ #include < array>
1617#include < cstddef>
1718#include < cstdint>
1819#include < memory>
@@ -39,6 +40,7 @@ limitations under the License.
3940#include " xla/stream_executor/stream_executor.h"
4041#include " xla/stream_executor/typed_kernel_factory.h"
4142#include " xla/tsl/lib/core/status_test_util.h"
43+ #include " xla/tsl/platform/logging.h"
4244#include " xla/tsl/platform/statusor.h"
4345
4446namespace stream_executor ::gpu {
@@ -164,7 +166,12 @@ TEST_F(GpuKernelTest, ArrayArgByValue) {
164166 )" ;
165167
166168 MultiKernelLoaderSpec spec (/* arity=*/ 2 );
167- spec.AddCudaPtxInMemory (copy_kernel, " copy_kernel" );
169+ if (executor_->GetPlatform ()->id () ==
170+ stream_executor::rocm::kROCmPlatformId ) {
171+ spec.AddInProcessSymbol (internal::GetCopyKernel (), " copy_kernel" );
172+ } else {
173+ spec.AddCudaPtxInMemory (copy_kernel, " copy_kernel" );
174+ }
168175
169176 TF_ASSERT_OK_AND_ASSIGN (auto stream, executor_->CreateStream ());
170177 TF_ASSERT_OK_AND_ASSIGN (auto kernel, executor_->LoadKernel (spec));
@@ -174,25 +181,22 @@ TEST_F(GpuKernelTest, ArrayArgByValue) {
174181 DeviceMemory<char > dst = executor_->AllocateArray <char >(kLength , 0 );
175182 TF_ASSERT_OK (stream->MemZero (&dst, kLength ));
176183
177- struct ByValArg {
178- std::byte storage[16 ];
179- };
180- ByValArg arg;
184+ std::array<std::byte, 16 > storage;
181185 int i = 0 ;
182- for (auto & element : arg. storage ) {
186+ for (auto & element : storage) {
183187 element = static_cast <std::byte>(i++);
184188 }
185189
186190 // Launch kernel.
187- auto args = stream_executor::PackKernelArgs (/* shmem_bytes=*/ 0 , dst, arg );
191+ auto args = stream_executor::PackKernelArgs (/* shmem_bytes=*/ 0 , dst, storage );
188192 TF_ASSERT_OK (kernel->Launch (ThreadDim (), BlockDim (), stream.get (), *args));
189193
190194 // Copy data back to host.
191195 std::byte dst_host[16 ] = {};
192196 TF_ASSERT_OK (stream->Memcpy (dst_host, dst, kLength ));
193197 TF_ASSERT_OK (stream->BlockHostUntilDone ());
194198
195- EXPECT_THAT (dst_host, ::testing::ElementsAreArray (arg. storage ));
199+ EXPECT_THAT (dst_host, ::testing::ElementsAreArray (storage));
196200}
197201} // namespace
198202} // namespace stream_executor::gpu
0 commit comments