Skip to content

Commit fbfb80b

Browse files
[XLA:GPU] update nvjitlink and compilation_provider tests to support cuda 12.8
Update test expectations due to cuda 12.8 changes nvJitLinkCreate behavior to fail when an invalid sm architecture is provided Update nvJitLinkDestroy usage as asan detects a leakage if it is not run after nvJitLinkCreate failure. PiperOrigin-RevId: 743222408
1 parent f191703 commit fbfb80b

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

third_party/xla/xla/stream_executor/cuda/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ xla_cc_test(
930930
":nvjitlink_support",
931931
"//xla/stream_executor:device_description",
932932
"//xla/stream_executor/gpu:gpu_asm_opts",
933+
"//xla/tsl/platform:status_matchers",
933934
"@com_google_absl//absl/status",
934935
"@com_google_absl//absl/strings",
935936
"@com_google_absl//absl/types:span",

third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,25 @@ absl::StatusOr<std::vector<uint8_t>> CompileAndLinkUsingLibNvJitLink(
163163
absl::c_transform(cli_args, std::back_inserter(cli_args_ptrs),
164164
[](const std::string& s) { return s.c_str(); });
165165

166-
nvJitLinkHandle link_handle{};
167-
RETURN_IF_NVJITLINK_ERROR(nvJitLinkCreate(&link_handle, cli_args_ptrs.size(),
168-
cli_args_ptrs.data()));
166+
nvJitLinkHandle link_handle = nullptr;
167+
nvJitLinkResult create_result =
168+
nvJitLinkCreate(&link_handle, cli_args_ptrs.size(), cli_args_ptrs.data());
169+
169170
absl::Cleanup link_handle_cleaner = [&link_handle] {
170-
CHECK_EQ(nvJitLinkDestroy(&link_handle), NVJITLINK_SUCCESS);
171+
if (link_handle != nullptr) {
172+
CHECK_EQ(nvJitLinkDestroy(&link_handle), NVJITLINK_SUCCESS);
173+
}
171174
};
172175

176+
if (create_result != NVJITLINK_SUCCESS) {
177+
TF_ASSIGN_OR_RETURN(std::string error_log,
178+
nvJitLinkGetErrorLog(link_handle));
179+
180+
VLOG(3) << "libnvjitlink error log output: " << error_log;
181+
182+
return ToStatus(create_result, error_log);
183+
}
184+
173185
for (auto& image : inputs) {
174186
nvJitLinkInputType input_type = image.type == NvJitLinkInput::Type::kPtx
175187
? NVJITLINK_INPUT_PTX

third_party/xla/xla/stream_executor/cuda/nvjitlink_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "xla/stream_executor/cuda/nvjitlink_support.h"
3131
#include "xla/stream_executor/device_description.h"
3232
#include "xla/stream_executor/gpu/gpu_asm_opts.h"
33+
#include "xla/tsl/platform/status_matchers.h"
3334
#include "tsl/platform/status_matchers.h"
3435
#include "tsl/platform/test.h"
3536

@@ -165,7 +166,8 @@ TEST_F(NvJitLinkTest, IdentifiesUnsupportedArchitecture) {
165166
EXPECT_THAT(
166167
CompileAndLinkHelper(stream_executor::CudaComputeCapability{100, 0},
167168
{kStandalonePtx}),
168-
tsl::testing::StatusIs(absl::StatusCode::kUnimplemented));
169+
tsl::testing::StatusIs(testing::AnyOf(absl::StatusCode::kUnknown,
170+
absl::StatusCode::kUnimplemented)));
169171
}
170172

171173
TEST_F(NvJitLinkTest, LinkingTwoCompilationUnitsSucceeds) {

0 commit comments

Comments
 (0)