@@ -1713,41 +1713,45 @@ TEST(StreamExecutorGpuClientTest, NvshmemMemoryTest) {
17131713 client_options.allowed_devices = {0 };
17141714 client_options.num_nodes = 1 ;
17151715 client_options.kv_store = std::make_shared<InMemoryKeyValueStore>();
1716- TF_ASSERT_OK_AND_ASSIGN (auto client,
1716+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<PjRtClient> client,
17171717 GetStreamExecutorGpuClient (client_options));
17181718 xla::CompileOptions options;
17191719 options.executable_build_options .mutable_debug_options ()
17201720 ->set_xla_gpu_experimental_enable_nvshmem (true );
17211721
1722- TF_ASSERT_OK_AND_ASSIGN (auto executable,
1722+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<xla::PjRtLoadedExecutable> executable,
17231723 CompileExecutable (kProgram , *client, options));
17241724 std::vector<int32_t > data{1 , 2 , 3 , 4 };
17251725 Shape shape = ShapeUtil::MakeShapeWithDenseLayout (S32, {1 , 4 },
17261726 /* minor_to_major=*/ {1 , 0 });
17271727 shape.mutable_layout ()->set_memory_space (Layout::kDefaultMemorySpace );
17281728
1729- auto device = client->addressable_devices ()[0 ];
1729+ PjRtDevice* const device = client->addressable_devices ()[0 ];
17301730 TF_EXPECT_OK (device->default_memory_space ());
17311731 TF_ASSERT_OK_AND_ASSIGN (
1732- auto input, client->BufferFromHostBuffer (
1733- data.data (), shape.element_type (), shape.dimensions (),
1734- /* byte_strides=*/ std::nullopt ,
1735- PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall ,
1736- /* on_done_with_host_buffer=*/ nullptr , device));
1732+ std::unique_ptr<PjRtBuffer> input,
1733+ client->BufferFromHostBuffer (
1734+ data.data (), shape.element_type (), shape.dimensions (),
1735+ /* byte_strides=*/ std::nullopt ,
1736+ PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall ,
1737+ /* on_done_with_host_buffer=*/ nullptr , *device->default_memory_space (),
1738+ /* device_layout=*/ nullptr ));
17371739 EXPECT_EQ (input->memory_space ()->kind (), " device" );
17381740
1739- TF_ASSERT_OK_AND_ASSIGN (auto memory_kinds,
1740- executable->GetOutputMemoryKinds ());
1741+ TF_ASSERT_OK_AND_ASSIGN (
1742+ std::vector<std::vector<absl::string_view>> memory_kinds,
1743+ executable->GetOutputMemoryKinds ());
17411744 EXPECT_EQ (memory_kinds.size (), 1 );
17421745 EXPECT_EQ (memory_kinds[0 ].size (), 1 );
17431746 EXPECT_EQ (memory_kinds[0 ][0 ], " device" );
17441747
17451748 TF_ASSERT_OK_AND_ASSIGN (
1746- auto result, executable->Execute ({{input.get ()}}, ExecuteOptions ()));
1749+ std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> result,
1750+ executable->Execute ({{input.get ()}}, ExecuteOptions ()));
17471751 std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0 ];
17481752 EXPECT_EQ (result_buffers[0 ]->memory_space ()->kind (), " device" );
17491753 Shape result_shape = result_buffers[0 ]->on_device_shape ();
1750- auto memory_space = result_shape.layout ().memory_space ();
1754+ int64_t memory_space = result_shape.layout ().memory_space ();
17511755 EXPECT_EQ (memory_space, 1 );
17521756}
17531757
0 commit comments