diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index e8f3a955c1b35..0467f6ce99b57 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -685,6 +685,7 @@ cc_library( deps = [ ":thunk", "//xla:shape_util", + "//xla:status_macros", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -733,6 +734,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", diff --git a/xla/backends/cpu/runtime/dot_lib.cc b/xla/backends/cpu/runtime/dot_lib.cc index 067cdbef49811..05aaca671a474 100644 --- a/xla/backends/cpu/runtime/dot_lib.cc +++ b/xla/backends/cpu/runtime/dot_lib.cc @@ -25,10 +25,12 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/layout_util.h" #include "xla/runtime/buffer_use.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" namespace xla::cpu { @@ -39,7 +41,7 @@ absl::InlinedVector DotBufferUses(const DotSlices& slices) { BufferUse::Write(slices.out_buffer)}; } -absl::StatusOr GetDotShape(DotDimensionNumbers dot_dimensions, +absl::StatusOr GetDotShape(const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& out_shape) { @@ -95,4 +97,48 @@ absl::StatusOr GetDotShape(DotDimensionNumbers dot_dimensions, }; } +absl::StatusOr GetDotCanonicalDims( + const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape) { + // Copy from the original dot dimension numbers. + absl::InlinedVector lhs_contracting_dims; + absl::InlinedVector rhs_contracting_dims; + + lhs_contracting_dims.assign( + dot_dimensions.lhs_contracting_dimensions().begin(), + dot_dimensions.lhs_contracting_dimensions().end()); + rhs_contracting_dims.assign( + dot_dimensions.rhs_contracting_dimensions().begin(), + dot_dimensions.rhs_contracting_dimensions().end()); + + // Adjust contracting dimensions for leading batch dimensions. + for (int64_t& dim : lhs_contracting_dims) + dim -= dot_dimensions.lhs_batch_dimensions_size(); + for (int64_t& dim : rhs_contracting_dims) + dim -= dot_dimensions.rhs_batch_dimensions_size(); + + // Non-contracting dots should never make it here. + TF_RET_CHECK(lhs_contracting_dims.size() == 1); + TF_RET_CHECK(rhs_contracting_dims.size() == 1); + TF_RET_CHECK(lhs_contracting_dims[0] < 2); + TF_RET_CHECK(rhs_contracting_dims[0] < 2); + + auto is_column_major = [](const Shape& shape) { + return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; + }; + + return DotCanonicalDims{ + /*m=*/dot_shape.lhs_matmul_shape.rank() <= 1 + ? int64_t{1} + : dot_shape.lhs_matmul_shape.dimensions(1 - lhs_contracting_dims[0]), + /*k=*/dot_shape.lhs_matmul_shape.dimensions(lhs_contracting_dims[0]), + /*n=*/dot_shape.rhs_matmul_shape.rank() <= 1 + ? int64_t{1} + : dot_shape.rhs_matmul_shape.dimensions(1 - rhs_contracting_dims[0]), + /*lhs_column_major=*/is_column_major(dot_shape.lhs_matmul_shape), + /*lhs_canonical=*/dot_shape.lhs_matmul_shape.rank() <= 1 || + lhs_contracting_dims[0] == 1, + /*rhs_column_major=*/is_column_major(dot_shape.rhs_matmul_shape), + /*rhs_canonical=*/rhs_contracting_dims[0] == 0}; +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/dot_lib.h b/xla/backends/cpu/runtime/dot_lib.h index c269453336774..393a5b603fdb6 100644 --- a/xla/backends/cpu/runtime/dot_lib.h +++ b/xla/backends/cpu/runtime/dot_lib.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" @@ -38,6 +39,8 @@ struct DotSlices { Shape out_shape; }; +// TODO(ezhulenev): Merge DotCanonicalDims into DotShape. + // Shape of the batched dot operation supported by the XLA:CPU runtime. struct DotShape { // Product of batch dimensions. @@ -49,16 +52,48 @@ struct DotShape { Shape out_matmul_shape; }; +// Dot operation is implemented as a matrix-matrix multiply (row-major x +// rowm-major or col-major x col-major). For batched dot operations, it is +// implemented as multiple matrix multiplications repeated for each batch +// element. +struct DotCanonicalDims { + // The number of rows in the LHS. + int64_t m; + + // The number of columns in the LHS, which also must be equal to the + // number of rows in the RHS. + int64_t k; + + // The number of columns in the RHS. + int64_t n; + + // True if the LHS matrix is column major. + bool lhs_column_major; + + // True if the LHS contraction dimension is 1. + bool lhs_canonical; + + // True if the RHS matrix is column major. + bool rhs_column_major; + + // True if the RHS contraction dimension is 0. + bool rhs_canonical; +}; + // Returns buffer uses of the dot operation. absl::InlinedVector DotBufferUses(const DotSlices& slices); // Verifies dot dimensions and shapes and returns the shape of the dot operation // in a form that is convenient for the runtime implementation. -absl::StatusOr GetDotShape(DotDimensionNumbers dot_dimensions, +absl::StatusOr GetDotShape(const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& out_shape); +// Get canonical dot dimensions for the given dot shape. +absl::StatusOr GetDotCanonicalDims( + const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape); + } // namespace xla::cpu #endif // XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_ diff --git a/xla/backends/cpu/runtime/dot_thunk.cc b/xla/backends/cpu/runtime/dot_thunk.cc index cf3c10ed0efd0..00bcec6a2df83 100644 --- a/xla/backends/cpu/runtime/dot_thunk.cc +++ b/xla/backends/cpu/runtime/dot_thunk.cc @@ -33,75 +33,14 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/logging.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { -namespace { - -// Dot operation is implemented as a matrix-matrix multiply (row-major x -// rowm-major or col-major x col-major). For batched dot operations, it is -// implemented as multiple matrix multiplications repeated for each batch -// element. -// -// We rely on col-major Eigen contraction and figure out how to represent dot -// operation as a contraction based on the dot dimension numbers. -struct MatMulDims { - // The number of rows in the LHS. - int64_t m; - - // The number of columns in the LHS, which also must be equal to the - // number of rows in the RHS. - int64_t k; - - // The number of columns in the RHS. - int64_t n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is 1. - bool lhs_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is 0. - bool rhs_canonical; -}; - -} // namespace - -static MatMulDims GetMatMulDims( - const Shape& lhs_shape, absl::Span lhs_contracting_dims, - const Shape& rhs_shape, absl::Span rhs_contracting_dims) { - // Non-contracting dots should never make it here. - CHECK_EQ(lhs_contracting_dims.size(), 1); - CHECK_EQ(rhs_contracting_dims.size(), 1); - CHECK_LT(lhs_contracting_dims[0], 2); - CHECK_LT(rhs_contracting_dims[0], 2); - - auto is_column_major = [](const Shape& shape) { - return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; - }; - - return MatMulDims{ - /*m=*/lhs_shape.rank() <= 1 - ? 1LL - : lhs_shape.dimensions(1LL - lhs_contracting_dims[0]), - /*k=*/lhs_shape.dimensions(lhs_contracting_dims[0]), - /*n=*/rhs_shape.rank() <= 1 - ? 1LL - : rhs_shape.dimensions(1LL - rhs_contracting_dims[0]), - /*lhs_column_major=*/is_column_major(lhs_shape), - /*lhs_canonical=*/lhs_shape.rank() <= 1 || lhs_contracting_dims[0] == 1, - /*rhs_column_major=*/is_column_major(rhs_shape), - /*rhs_canonical=*/rhs_contracting_dims[0] == 0}; -} absl::StatusOr> DotThunk::Create( Info info, DotDimensionNumbers dot_dimensions, @@ -111,35 +50,26 @@ absl::StatusOr> DotThunk::Create( TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, rhs_shape, out_shape)); + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); + DotSlices dot_slices{lhs_buffer, std::move(lhs_shape), rhs_buffer, std::move(rhs_shape), out_buffer, std::move(out_shape)}; - return absl::WrapUnique(new DotThunk(info, std::move(dot_dimensions), - std::move(dot_slices), - std::move(dot_shape))); + return absl::WrapUnique( + new DotThunk(info, std::move(dot_dimensions), std::move(dot_slices), + std::move(dot_shape), std::move(dot_canonical_dims))); } DotThunk::DotThunk(Info info, DotDimensionNumbers dot_dimensions, - DotSlices dot_slices, DotShape dot_shape) + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims) : Thunk(Kind::kDot, info), dot_dimensions_(std::move(dot_dimensions)), dot_slices_(std::move(dot_slices)), - dot_shape_(std::move(dot_shape)) { - // Copy from the original dot dimension numbers. - lhs_matmul_contracting_dims_.assign( - dot_dimensions_.lhs_contracting_dimensions().begin(), - dot_dimensions_.lhs_contracting_dimensions().end()); - rhs_matmul_contracting_dims_.assign( - dot_dimensions_.rhs_contracting_dimensions().begin(), - dot_dimensions_.rhs_contracting_dimensions().end()); - - // Adjust contracting dimensions for leading batch dimensions. - for (int64_t& dim : lhs_matmul_contracting_dims_) - dim -= dot_dimensions_.lhs_batch_dimensions_size(); - for (int64_t& dim : rhs_matmul_contracting_dims_) - dim -= dot_dimensions_.rhs_batch_dimensions_size(); -} + dot_shape_(std::move(dot_shape)), + dot_canonical_dims_(std::move(dot_canonical_dims)) {} tsl::AsyncValueRef DotThunk::Execute( const ExecuteParams& params) { @@ -181,16 +111,12 @@ tsl::AsyncValueRef DotThunk::Execute( dot_shape_.rhs_matmul_shape.ToString(true), dot_shape_.out_matmul_shape.ToString(true)); - MatMulDims matmul_dims = - GetMatMulDims(dot_shape_.lhs_matmul_shape, lhs_matmul_contracting_dims_, - dot_shape_.rhs_matmul_shape, rhs_matmul_contracting_dims_); - VLOG(3) << absl::StreamFormat( " matmul dims: m=%d, k=%d, n=%d, lhs_column_major=%v, lhs_canonical=%v, " "rhs_column_major=%v, rhs_canonical=%v", - matmul_dims.m, matmul_dims.k, matmul_dims.n, matmul_dims.lhs_column_major, - matmul_dims.lhs_canonical, matmul_dims.rhs_column_major, - matmul_dims.rhs_canonical); + dot_canonical_dims_.m, dot_canonical_dims_.k, dot_canonical_dims_.n, + dot_canonical_dims_.lhs_column_major, dot_canonical_dims_.lhs_canonical, + dot_canonical_dims_.rhs_column_major, dot_canonical_dims_.rhs_canonical); if (params.intra_op_threadpool == nullptr) { return InvalidArgument("Intra-op threadpool must be provided for DotThunk"); @@ -211,12 +137,17 @@ tsl::AsyncValueRef DotThunk::Execute( void* lhs = lhs_data.opaque(); void* rhs = rhs_data.opaque(); - bool transpose_lhs = !matmul_dims.lhs_canonical; - bool transpose_rhs = !matmul_dims.rhs_canonical; + int64_t m = dot_canonical_dims_.m; + int64_t n = dot_canonical_dims_.n; + int64_t k = dot_canonical_dims_.k; + + bool transpose_lhs = !dot_canonical_dims_.lhs_canonical; + bool transpose_rhs = !dot_canonical_dims_.rhs_canonical; - CHECK_EQ(matmul_dims.lhs_column_major, matmul_dims.rhs_column_major); - if (!matmul_dims.lhs_column_major) { - std::swap(matmul_dims.m, matmul_dims.n); + CHECK_EQ(dot_canonical_dims_.lhs_column_major, + dot_canonical_dims_.rhs_column_major); + if (!dot_canonical_dims_.lhs_column_major) { + std::swap(m, n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } @@ -224,9 +155,9 @@ tsl::AsyncValueRef DotThunk::Execute( PrimitiveType element_type = dot_shape_.lhs_matmul_shape.element_type(); int64_t byte_width = primitive_util::ByteWidth(element_type); - int64_t lhs_stride = matmul_dims.m * matmul_dims.k * byte_width; - int64_t rhs_stride = matmul_dims.k * matmul_dims.n * byte_width; - int64_t out_stride = matmul_dims.m * matmul_dims.n * byte_width; + int64_t lhs_stride = m * k * byte_width; + int64_t rhs_stride = k * n * byte_width; + int64_t out_stride = m * n * byte_width; auto batch_ptr = [&](void* ptr, int64_t stride, int64_t index) -> void* { return static_cast(ptr) + stride * index; @@ -238,9 +169,9 @@ tsl::AsyncValueRef DotThunk::Execute( for (int64_t i = 0; i < dot_shape_.batch_size; ++i) { TypedMatMul( params.intra_op_threadpool, batch_ptr(out, out_stride, i), - batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i), - matmul_dims.m, matmul_dims.n, matmul_dims.k, transpose_lhs, - transpose_rhs, [state]() mutable { state.CountDown(); }); + batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i), m, n, k, + transpose_lhs, transpose_rhs, + [state]() mutable { state.CountDown(); }); } }; diff --git a/xla/backends/cpu/runtime/dot_thunk.h b/xla/backends/cpu/runtime/dot_thunk.h index fbce0b397f044..15b5b97fd33c2 100644 --- a/xla/backends/cpu/runtime/dot_thunk.h +++ b/xla/backends/cpu/runtime/dot_thunk.h @@ -52,7 +52,7 @@ class DotThunk final : public Thunk { private: DotThunk(Info info, DotDimensionNumbers dot_dimensions, DotSlices dot_slices, - DotShape dot_shape); + DotShape dot_shape, DotCanonicalDims dot_canonical_dims); using DoneCallback = absl::AnyInvocable; @@ -72,6 +72,7 @@ class DotThunk final : public Thunk { DotDimensionNumbers dot_dimensions_; DotSlices dot_slices_; DotShape dot_shape_; + DotCanonicalDims dot_canonical_dims_; // Contracting dimensions of the LHS and RHS matmul shapes. absl::InlinedVector lhs_matmul_contracting_dims_; diff --git a/xla/backends/cpu/runtime/xnnpack/BUILD b/xla/backends/cpu/runtime/xnnpack/BUILD index 0157f9067f6bd..9f6a7f4c612cb 100644 --- a/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/xla/backends/cpu/runtime/xnnpack/BUILD @@ -137,6 +137,8 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", "@XNNPACK", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -149,8 +151,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc index ba9122f59d0dd..8f9d89aceb44b 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc +++ b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc @@ -36,17 +36,18 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { -static absl::Status DefineXnnSubgraph(xnn_subgraph_t subgraph, - const DotDimensionNumbers& dot_dimensions, - const DotShape& dot_shape) { +static absl::Status DefineXnnSubgraph( + xnn_subgraph_t subgraph, const DotDimensionNumbers& dot_dimensions, + const DotSlices& dot_slices, const DotShape& dot_shape, + const DotCanonicalDims& dot_canonical_dims) { uint32_t lhs_id = XNN_INVALID_VALUE_ID; uint32_t rhs_id = XNN_INVALID_VALUE_ID; uint32_t out_id = XNN_INVALID_VALUE_ID; @@ -55,9 +56,9 @@ static absl::Status DefineXnnSubgraph(xnn_subgraph_t subgraph, return {dims.begin(), dims.end()}; }; - std::vector lhs_dims = dims(dot_shape.lhs_matmul_shape.dimensions()); - std::vector rhs_dims = dims(dot_shape.rhs_matmul_shape.dimensions()); - std::vector out_dims = dims(dot_shape.out_matmul_shape.dimensions()); + std::vector lhs_dims = dims(dot_slices.lhs_shape.dimensions()); + std::vector rhs_dims = dims(dot_slices.rhs_shape.dimensions()); + std::vector out_dims = dims(dot_slices.out_shape.dimensions()); XNN_RETURN_IF_ERROR(xnn_define_tensor_value( subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), nullptr, @@ -71,13 +72,34 @@ static absl::Status DefineXnnSubgraph(xnn_subgraph_t subgraph, subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), nullptr, /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id)); - XNN_RETURN_IF_ERROR(xnn_define_batch_matrix_multiply(subgraph, lhs_id, rhs_id, - out_id, - /*flags=*/0)); + XNN_RETURN_IF_ERROR(xnn_define_batch_matrix_multiply( + subgraph, lhs_id, rhs_id, out_id, + /*flags=*/dot_canonical_dims.rhs_canonical ? 0 : XNN_FLAG_TRANSPOSE_B)); return absl::OkStatus(); } +absl::StatusOr XnnDotThunk::IsSupported( + const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, + const Shape& rhs_shape, const Shape& out_shape) { + // TODO(ezhulenev): Support other element types. + if (lhs_shape.element_type() != F32 || rhs_shape.element_type() != F32 || + out_shape.element_type() != F32) { + return false; + } + + TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); + + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); + + // XNNPACK does not support transposing LHS or col-major layouts. + return dot_canonical_dims.lhs_canonical && + !dot_canonical_dims.lhs_column_major && + !dot_canonical_dims.rhs_column_major; +} + absl::StatusOr> XnnDotThunk::Create( Info info, DotDimensionNumbers dot_dimensions, BufferAllocation::Slice lhs_buffer, Shape lhs_shape, @@ -88,21 +110,26 @@ absl::StatusOr> XnnDotThunk::Create( TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, rhs_shape, out_shape)); + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); + DotSlices dot_slices{lhs_buffer, std::move(lhs_shape), rhs_buffer, std::move(rhs_shape), out_buffer, std::move(out_shape)}; - return absl::WrapUnique(new XnnDotThunk(info, std::move(dot_dimensions), - std::move(dot_slices), - std::move(dot_shape))); + return absl::WrapUnique( + new XnnDotThunk(info, std::move(dot_dimensions), std::move(dot_slices), + std::move(dot_shape), std::move(dot_canonical_dims))); } XnnDotThunk::XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions, - DotSlices dot_slices, DotShape dot_shape) + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims) : Thunk(Kind::kXnnDot, info), dot_dimensions_(std::move(dot_dimensions)), dot_slices_(std::move(dot_slices)), - dot_shape_(std::move(dot_shape)) {} + dot_shape_(std::move(dot_shape)), + dot_canonical_dims_(std::move(dot_canonical_dims)) {} tsl::AsyncValueRef XnnDotThunk::Execute( const ExecuteParams& params) { @@ -144,11 +171,19 @@ tsl::AsyncValueRef XnnDotThunk::Execute( dot_shape_.rhs_matmul_shape.ToString(true), dot_shape_.out_matmul_shape.ToString(true)); + VLOG(3) << absl::StreamFormat( + " matmul dims: m=%d, k=%d, n=%d, lhs_column_major=%v, lhs_canonical=%v, " + "rhs_column_major=%v, rhs_canonical=%v", + dot_canonical_dims_.m, dot_canonical_dims_.k, dot_canonical_dims_.n, + dot_canonical_dims_.lhs_column_major, dot_canonical_dims_.lhs_canonical, + dot_canonical_dims_.rhs_column_major, dot_canonical_dims_.rhs_canonical); + xnn_subgraph_t subgraph = nullptr; XNN_RETURN_IF_ERROR( xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph)); - TF_RETURN_IF_ERROR(DefineXnnSubgraph(subgraph, dot_dimensions_, dot_shape_)); + TF_RETURN_IF_ERROR(DefineXnnSubgraph(subgraph, dot_dimensions_, dot_slices_, + dot_shape_, dot_canonical_dims_)); xnn_workspace_t workspace = nullptr; XNN_RETURN_IF_ERROR(xnn_create_workspace(&workspace)); diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h index c12194e870297..27ea46585f353 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h +++ b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h @@ -30,6 +30,12 @@ namespace xla::cpu { // Dot operation implemented on top of XNNPACK. class XnnDotThunk : public Thunk { public: + // Returns true if the dot operation is supported by XNNPACK. Returns an error + // if the dot operation shape is invalid. + static absl::StatusOr IsSupported( + const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, + const Shape& rhs_shape, const Shape& out_shape); + static absl::StatusOr> Create( Info info, DotDimensionNumbers dot_dimensions, BufferAllocation::Slice lhs_buffer, Shape lhs_shape, @@ -42,11 +48,13 @@ class XnnDotThunk : public Thunk { private: XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions, - DotSlices dot_slices, DotShape dot_shape); + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims); DotDimensionNumbers dot_dimensions_; DotSlices dot_slices_; DotShape dot_shape_; + DotCanonicalDims dot_canonical_dims_; }; } // namespace xla::cpu diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 6302f89d5c404..b78b87c8a15df 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -88,6 +88,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_acl(true); #endif opts.set_xla_cpu_use_thunk_runtime(true); + opts.set_xla_cpu_use_xnnpack(false); opts.set_xla_cpu_parallel_codegen_split_count(32); opts.set_xla_cpu_copy_insertion_use_region_analysis(false); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); @@ -922,6 +923,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_use_thunk_runtime), debug_options->xla_cpu_use_thunk_runtime(), "Use Thunk-based runtime for the CPU backend.")); + flag_list->push_back( + tsl::Flag("xla_cpu_use_xnnpack", + bool_setter_for(&DebugOptions::set_xla_cpu_use_xnnpack), + debug_options->xla_cpu_use_xnnpack(), + "Use XNNPACK for supported operations.")); flag_list->push_back(tsl::Flag( "xla_cpu_parallel_codegen_split_count", int32_setter_for(&DebugOptions::set_xla_cpu_parallel_codegen_split_count), diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 080ddc2a6c05b..6fbf6037cde06 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -909,6 +909,7 @@ cc_library( "//xla/backends/cpu/runtime:thunk", "//xla/backends/cpu/runtime:topk_thunk", "//xla/backends/cpu/runtime:while_thunk", + "//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index bd9f650bfa347..c2c2198a15012 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/topk_thunk.h" #include "xla/backends/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h" #include "xla/comparison_util.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -813,9 +814,23 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice out_slice, GetAllocationSlice(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, - rhs->shape(), out_slice, instruction->shape()); + // Decide whether to use XNNPACK or Eigen. + bool use_xnn = hlo_module_config_.debug_options().xla_cpu_use_xnnpack(); + if (use_xnn) { + TF_ASSIGN_OR_RETURN( + use_xnn, XnnDotThunk::IsSupported(dnums, lhs->shape(), rhs->shape(), + instruction->shape())); + } + + if (use_xnn) { + return ThunkSequence::Of( + ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, + rhs->shape(), out_slice, instruction->shape()); + } else { + return ThunkSequence::Of( + ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, + rhs->shape(), out_slice, instruction->shape()); + } } } } diff --git a/xla/xla.proto b/xla/xla.proto index 1382558c1f7a4..448cc49c9d9e7 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -101,6 +101,9 @@ message DebugOptions { // When true, XLA:CPU uses the thunk runtime to execute compiled program. bool xla_cpu_use_thunk_runtime = 298; + // When true, XLA:CPU uses XNNPACK to execute supported operations. + bool xla_cpu_use_xnnpack = 359; + // Enabling this will enable optimizations that ignore the possibility of NaN. bool xla_enable_fast_math = 335; @@ -1098,7 +1101,7 @@ message DebugOptions { // be deterministic, although with additional overhead. bool xla_gpu_enable_scatter_determinism_expander = 345; - // Next id: 359 + // Next id: 360 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.