Skip to content

Commit

Permalink
Add ClientLibraryTestRunnerMixin.
Browse files Browse the repository at this point in the history
`ClientLibraryTestRunnerMixin` is a sort-of replacement for
`ClientLibraryTestBase` to run tests on top of `HloTestBase` and friends (e.g.
`HloRunnerAgnosticTestBase`).  This is to enable a future migration to PjRt and
TFRT.

Due to `ClientLibraryTestBase` containing many client-specific calls, moving
tests is not as trivial as simply dropping in a new base class. The idea with
this class is just to make that migration simpler and to reduce (but not
eliminate) the amount of code changes required in tests.

Migration timeline for `ClientLibraryTestBase` tests:

1. `class XYZ: ClientLibraryTestBase` (starting point)
2. `class XYZ: ClientLibraryTestRunnerMixin<HloTestBase>` (intermediate state)
3. `class XYZ: ClientLibraryTestRunnerMixin<HloPjRtReferenceMixin<HloPjRtTestBase>>` (end state)

PiperOrigin-RevId: 715040700
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 13, 2025
1 parent 7fcea58 commit a49c585
Show file tree
Hide file tree
Showing 3 changed files with 391 additions and 0 deletions.
31 changes: 31 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,37 @@ cc_library(
],
)

cc_library(
name = "client_library_test_runner_mixin",
testonly = True,
hdrs = ["client_library_test_runner_mixin.h"],
deps = [
":hlo_runner_agnostic_test_base",
":literal_test_util",
"//xla:array2d",
"//xla:array3d",
"//xla:array4d",
"//xla:error_spec",
"//xla:execution_options_util",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_module_config",
"//xla/tsl/lib/core:bitmap",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "llvm_irgen_test_base",
testonly = True,
Expand Down
354 changes: 354 additions & 0 deletions xla/tests/client_library_test_runner_mixin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_MIXIN_H_
#define XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_MIXIN_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/types/span.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/error_spec.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/bitmap.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/types.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"

namespace xla {

template <typename T>
constexpr inline bool is_floating_or_complex_v =
std::disjunction_v<is_specialized_floating_point<T>, is_complex<T>>;

// This class is designed to be used as a mixin for tests that formerly extended
// ClientLibraryTestBase. It is a partial re-implementation of
// ClientLibraryTestBase, but explicitly backed by an implementation of
// HloRunnerAgnosticTestBase. It also requires the use of
// HloRunnerAgnosticReferenceMixin, as it relies on RunAndCompare functionality.
//
// This class serves as a crucial bridging mechanism during the migration
// towards a single test base class and a migration away from stream executor.
//
// The reliance on templates / implementation as a mixin lets us switch out the
// underlying test base class and reference runner implementations incrementally
// and on a per-test basis instead of all at once.
template <typename T>
class ClientLibraryTestRunnerMixin : public T {
static_assert(
std::is_base_of_v<HloRunnerAgnosticTestBase, T> &&
T::has_reference_runner_mixin::value,
"Mixin must be used with a subclass of HloRunnerAgnosticTestBase and "
"HloRunnerAgnosticReferenceMixin.");

protected:
template <typename... BaseArgs>
explicit ClientLibraryTestRunnerMixin(BaseArgs&&... base_args)
: T(std::forward<BaseArgs>(base_args)...) {}
~ClientLibraryTestRunnerMixin() override = default;

// The float type used in this test.
PrimitiveType FloatType() const { return test_type_; }
void set_float_type(PrimitiveType type) { test_type_ = type; }

absl::StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
const absl::Span<Literal* const> arguments,
const Shape* const shape_with_output_layout = nullptr) {
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
shape_with_output_layout->ToProto();
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
BuildAndVerifyHloModule(computation, &execution_options));
return this->Execute(std::move(module), arguments);
}

absl::StatusOr<Literal> ExecuteAndTransfer(
XlaBuilder* const builder, const absl::Span<Literal* const> arguments,
const Shape* shape_with_output_layout = nullptr) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder->Build());
return ExecuteAndTransfer(std::move(computation), arguments,
shape_with_output_layout);
}

// Run a computation and return its value as a string. If an error
// occurs, then instead return the error as a string.
std::string ExecuteToString(XlaBuilder* const builder,
const absl::Span<Literal* const> arguments) {
const absl::StatusOr<Literal> result =
ExecuteAndTransfer(builder, arguments);
if (!result.ok()) {
return result.status().ToString();
} else {
return result.value().ToString();
}
}

// Compare with reference.
// Side effect: EXPECT_OK
void ComputeAndCompare(XlaBuilder* const builder,
const absl::Span<Literal* const> arguments,
const std::optional<ErrorSpec> error = std::nullopt) {
ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder->Build());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
BuildAndVerifyHloModule(computation));
EXPECT_TRUE(this->RunAndCompare(std::move(module), arguments, error));
}

// Compare with literal.
// Side effect: EXPECT_OK
void ComputeAndCompareLiteral(
XlaBuilder* const builder, const Literal& expected,
const absl::Span<Literal* const> arguments,
const std::optional<ErrorSpec> error = std::nullopt) {
ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder->Build());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
BuildAndVerifyHloModule(computation));
ASSERT_OK_AND_ASSIGN(Literal actual,
this->Execute(std::move(module), arguments));
if (!error.has_value()) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
} else {
EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, *error));
}
}

// Compare with literal.
// Side effect: EXPECT_OK
void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
absl::Span<Literal* const> arguments,
std::optional<ErrorSpec> error = std::nullopt) {
return ComputeAndCompareLiteral(builder, expected, arguments, error);
}

template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
absl::Span<Literal* const> arguments,
ErrorSpec error) {
static_assert(
is_floating_or_complex_v<NativeT>,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments, error);
}

template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
absl::Span<const NativeT> expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
absl::Span<const NativeT> expected,
absl::Span<Literal* const> arguments,
ErrorSpec error) {
static_assert(
is_floating_or_complex_v<NativeT>,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments, error);
}

void ComputeAndCompareR1(XlaBuilder* builder,
const tsl::core::Bitmap& expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal = LiteralUtil::CreateR1(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
absl::Span<Literal* const> arguments,
ErrorSpec error) {
static_assert(
is_floating_or_complex_v<NativeT>,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments, error);
}

template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
absl::Span<Literal* const> arguments,
ErrorSpec error) {
static_assert(
is_floating_or_complex_v<NativeT>,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments, error);
}

template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
absl::Span<Literal* const> arguments) {
Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments);
}

template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
absl::Span<Literal* const> arguments,
ErrorSpec error) {
static_assert(
is_floating_or_complex_v<NativeT>,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ComputeAndCompareLiteral(builder, expected_literal, arguments, error);
}

Literal CreateParameterAndTransferLiteral(const int64_t parameter_number,
const Literal& literal,
const std::string& name,
XlaBuilder* const builder,
XlaOp* const data_handle) {
Literal param_literal = MaybeConvertLiteralToTestType(literal);
*data_handle =
Parameter(builder, parameter_number, param_literal.shape(), name);
return param_literal;
}

template <typename NativeT>
Literal CreateR0Parameter(NativeT value, int64_t parameter_number,
const std::string& name, XlaBuilder* builder,
XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR0(value);
literal = MaybeConvertLiteralToTestType(literal);
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return literal;
}

template <typename NativeT>
Literal CreateR1Parameter(absl::Span<const NativeT> values,
int64_t parameter_number, const std::string& name,
XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR1(values);
literal = MaybeConvertLiteralToTestType(literal);
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return literal;
}

Literal MaybeConvertLiteralToTestType(const Literal& literal) const {
switch (test_type_) {
case BF16:
return LiteralUtil::ConvertF32ToBF16(literal);
case F32:
return literal.Clone();
case F8E5M2:
return LiteralUtil::ConvertF32ToF8E5M2(literal);
case F8E4M3FN:
return LiteralUtil::ConvertF32ToF8E4M3FN(literal);
default:
LOG(FATAL) << "Unsupported test type: " << test_type_;
}
}

void SetFastMathDisabled(const bool disabled) {
auto* opts = execution_options_.mutable_debug_options();
opts->set_xla_cpu_enable_fast_math(!disabled);
opts->set_xla_cpu_enable_fast_min_max(!disabled);
opts->set_xla_gpu_enable_fast_min_max(!disabled);
}

// Provides mutable access to the execution DebugOptions field; this lets
// tests tweak the options that will be used to compile/run the graph.
DebugOptions* mutable_debug_options() {
return execution_options_.mutable_debug_options();
}

private:
absl::StatusOr<std::unique_ptr<HloModule>> BuildAndVerifyHloModule(
const XlaComputation& computation,
const ExecutionOptions* execution_options = nullptr) const {
if (execution_options == nullptr) {
execution_options = &execution_options_;
}
TF_ASSIGN_OR_RETURN(
HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(
computation.proto(), execution_options->debug_options(),
execution_options));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(computation.proto(), module_config));
TF_RETURN_IF_ERROR(this->verifier().Run(module.get()).status());
return module;
}

PrimitiveType test_type_ = F32;
ExecutionOptions execution_options_ = CreateDefaultExecutionOptions();
};

} // namespace xla

#endif // XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_MIXIN_H_
Loading

0 comments on commit a49c585

Please sign in to comment.