Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Add a flag to enable XNNPACK operations in XLA and connect XnnDotThunk to ThunkEmitter #20648

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ cc_library(
deps = [
":thunk",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
Expand Down Expand Up @@ -733,6 +734,7 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/framework/contraction:eigen_contraction_kernel",
"//xla/tsl/platform:logging",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
48 changes: 47 additions & 1 deletion xla/backends/cpu/runtime/dot_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/layout_util.h"
#include "xla/runtime/buffer_use.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"

namespace xla::cpu {
Expand All @@ -39,7 +41,7 @@ absl::InlinedVector<BufferUse, 4> DotBufferUses(const DotSlices& slices) {
BufferUse::Write(slices.out_buffer)};
}

absl::StatusOr<DotShape> GetDotShape(DotDimensionNumbers dot_dimensions,
absl::StatusOr<DotShape> GetDotShape(const DotDimensionNumbers& dot_dimensions,
const Shape& lhs_shape,
const Shape& rhs_shape,
const Shape& out_shape) {
Expand Down Expand Up @@ -95,4 +97,48 @@ absl::StatusOr<DotShape> GetDotShape(DotDimensionNumbers dot_dimensions,
};
}

absl::StatusOr<DotCanonicalDims> GetDotCanonicalDims(
const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape) {
// Copy from the original dot dimension numbers.
absl::InlinedVector<int64_t, 2> lhs_contracting_dims;
absl::InlinedVector<int64_t, 2> rhs_contracting_dims;

lhs_contracting_dims.assign(
dot_dimensions.lhs_contracting_dimensions().begin(),
dot_dimensions.lhs_contracting_dimensions().end());
rhs_contracting_dims.assign(
dot_dimensions.rhs_contracting_dimensions().begin(),
dot_dimensions.rhs_contracting_dimensions().end());

// Adjust contracting dimensions for leading batch dimensions.
for (int64_t& dim : lhs_contracting_dims)
dim -= dot_dimensions.lhs_batch_dimensions_size();
for (int64_t& dim : rhs_contracting_dims)
dim -= dot_dimensions.rhs_batch_dimensions_size();

// Non-contracting dots should never make it here.
TF_RET_CHECK(lhs_contracting_dims.size() == 1);
TF_RET_CHECK(rhs_contracting_dims.size() == 1);
TF_RET_CHECK(lhs_contracting_dims[0] < 2);
TF_RET_CHECK(rhs_contracting_dims[0] < 2);

auto is_column_major = [](const Shape& shape) {
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
};

return DotCanonicalDims{
/*m=*/dot_shape.lhs_matmul_shape.rank() <= 1
? int64_t{1}
: dot_shape.lhs_matmul_shape.dimensions(1 - lhs_contracting_dims[0]),
/*k=*/dot_shape.lhs_matmul_shape.dimensions(lhs_contracting_dims[0]),
/*n=*/dot_shape.rhs_matmul_shape.rank() <= 1
? int64_t{1}
: dot_shape.rhs_matmul_shape.dimensions(1 - rhs_contracting_dims[0]),
/*lhs_column_major=*/is_column_major(dot_shape.lhs_matmul_shape),
/*lhs_canonical=*/dot_shape.lhs_matmul_shape.rank() <= 1 ||
lhs_contracting_dims[0] == 1,
/*rhs_column_major=*/is_column_major(dot_shape.rhs_matmul_shape),
/*rhs_canonical=*/rhs_contracting_dims[0] == 0};
}

} // namespace xla::cpu
37 changes: 36 additions & 1 deletion xla/backends/cpu/runtime/dot_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "absl/container/inlined_vector.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/shape.h"
Expand All @@ -38,6 +39,8 @@ struct DotSlices {
Shape out_shape;
};

// TODO(ezhulenev): Merge DotCanonicalDims into DotShape.

// Shape of the batched dot operation supported by the XLA:CPU runtime.
struct DotShape {
// Product of batch dimensions.
Expand All @@ -49,16 +52,48 @@ struct DotShape {
Shape out_matmul_shape;
};

// Dot operation is implemented as a matrix-matrix multiply (row-major x
// rowm-major or col-major x col-major). For batched dot operations, it is
// implemented as multiple matrix multiplications repeated for each batch
// element.
struct DotCanonicalDims {
// The number of rows in the LHS.
int64_t m;

// The number of columns in the LHS, which also must be equal to the
// number of rows in the RHS.
int64_t k;

// The number of columns in the RHS.
int64_t n;

// True if the LHS matrix is column major.
bool lhs_column_major;

// True if the LHS contraction dimension is 1.
bool lhs_canonical;

// True if the RHS matrix is column major.
bool rhs_column_major;

// True if the RHS contraction dimension is 0.
bool rhs_canonical;
};

// Returns buffer uses of the dot operation.
absl::InlinedVector<BufferUse, 4> DotBufferUses(const DotSlices& slices);

// Verifies dot dimensions and shapes and returns the shape of the dot operation
// in a form that is convenient for the runtime implementation.
absl::StatusOr<DotShape> GetDotShape(DotDimensionNumbers dot_dimensions,
absl::StatusOr<DotShape> GetDotShape(const DotDimensionNumbers& dot_dimensions,
const Shape& lhs_shape,
const Shape& rhs_shape,
const Shape& out_shape);

// Get canonical dot dimensions for the given dot shape.
absl::StatusOr<DotCanonicalDims> GetDotCanonicalDims(
const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape);

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_
129 changes: 30 additions & 99 deletions xla/backends/cpu/runtime/dot_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,75 +33,14 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/logging.h"
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
namespace {

// Dot operation is implemented as a matrix-matrix multiply (row-major x
// rowm-major or col-major x col-major). For batched dot operations, it is
// implemented as multiple matrix multiplications repeated for each batch
// element.
//
// We rely on col-major Eigen contraction and figure out how to represent dot
// operation as a contraction based on the dot dimension numbers.
struct MatMulDims {
// The number of rows in the LHS.
int64_t m;

// The number of columns in the LHS, which also must be equal to the
// number of rows in the RHS.
int64_t k;

// The number of columns in the RHS.
int64_t n;

// True if the LHS matrix is column major.
bool lhs_column_major;

// True if the LHS contraction dimension is 1.
bool lhs_canonical;

// True if the RHS matrix is column major.
bool rhs_column_major;

// True if the RHS contraction dimension is 0.
bool rhs_canonical;
};

} // namespace

static MatMulDims GetMatMulDims(
const Shape& lhs_shape, absl::Span<const int64_t> lhs_contracting_dims,
const Shape& rhs_shape, absl::Span<const int64_t> rhs_contracting_dims) {
// Non-contracting dots should never make it here.
CHECK_EQ(lhs_contracting_dims.size(), 1);
CHECK_EQ(rhs_contracting_dims.size(), 1);
CHECK_LT(lhs_contracting_dims[0], 2);
CHECK_LT(rhs_contracting_dims[0], 2);

auto is_column_major = [](const Shape& shape) {
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
};

return MatMulDims{
/*m=*/lhs_shape.rank() <= 1
? 1LL
: lhs_shape.dimensions(1LL - lhs_contracting_dims[0]),
/*k=*/lhs_shape.dimensions(lhs_contracting_dims[0]),
/*n=*/rhs_shape.rank() <= 1
? 1LL
: rhs_shape.dimensions(1LL - rhs_contracting_dims[0]),
/*lhs_column_major=*/is_column_major(lhs_shape),
/*lhs_canonical=*/lhs_shape.rank() <= 1 || lhs_contracting_dims[0] == 1,
/*rhs_column_major=*/is_column_major(rhs_shape),
/*rhs_canonical=*/rhs_contracting_dims[0] == 0};
}

absl::StatusOr<std::unique_ptr<DotThunk>> DotThunk::Create(
Info info, DotDimensionNumbers dot_dimensions,
Expand All @@ -111,35 +50,26 @@ absl::StatusOr<std::unique_ptr<DotThunk>> DotThunk::Create(
TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape,
rhs_shape, out_shape));

TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims,
GetDotCanonicalDims(dot_dimensions, dot_shape));

DotSlices dot_slices{lhs_buffer, std::move(lhs_shape),
rhs_buffer, std::move(rhs_shape),
out_buffer, std::move(out_shape)};

return absl::WrapUnique(new DotThunk(info, std::move(dot_dimensions),
std::move(dot_slices),
std::move(dot_shape)));
return absl::WrapUnique(
new DotThunk(info, std::move(dot_dimensions), std::move(dot_slices),
std::move(dot_shape), std::move(dot_canonical_dims)));
}

DotThunk::DotThunk(Info info, DotDimensionNumbers dot_dimensions,
DotSlices dot_slices, DotShape dot_shape)
DotSlices dot_slices, DotShape dot_shape,
DotCanonicalDims dot_canonical_dims)
: Thunk(Kind::kDot, info),
dot_dimensions_(std::move(dot_dimensions)),
dot_slices_(std::move(dot_slices)),
dot_shape_(std::move(dot_shape)) {
// Copy from the original dot dimension numbers.
lhs_matmul_contracting_dims_.assign(
dot_dimensions_.lhs_contracting_dimensions().begin(),
dot_dimensions_.lhs_contracting_dimensions().end());
rhs_matmul_contracting_dims_.assign(
dot_dimensions_.rhs_contracting_dimensions().begin(),
dot_dimensions_.rhs_contracting_dimensions().end());

// Adjust contracting dimensions for leading batch dimensions.
for (int64_t& dim : lhs_matmul_contracting_dims_)
dim -= dot_dimensions_.lhs_batch_dimensions_size();
for (int64_t& dim : rhs_matmul_contracting_dims_)
dim -= dot_dimensions_.rhs_batch_dimensions_size();
}
dot_shape_(std::move(dot_shape)),
dot_canonical_dims_(std::move(dot_canonical_dims)) {}

tsl::AsyncValueRef<DotThunk::ExecuteEvent> DotThunk::Execute(
const ExecuteParams& params) {
Expand Down Expand Up @@ -181,16 +111,12 @@ tsl::AsyncValueRef<DotThunk::ExecuteEvent> DotThunk::Execute(
dot_shape_.rhs_matmul_shape.ToString(true),
dot_shape_.out_matmul_shape.ToString(true));

MatMulDims matmul_dims =
GetMatMulDims(dot_shape_.lhs_matmul_shape, lhs_matmul_contracting_dims_,
dot_shape_.rhs_matmul_shape, rhs_matmul_contracting_dims_);

VLOG(3) << absl::StreamFormat(
" matmul dims: m=%d, k=%d, n=%d, lhs_column_major=%v, lhs_canonical=%v, "
"rhs_column_major=%v, rhs_canonical=%v",
matmul_dims.m, matmul_dims.k, matmul_dims.n, matmul_dims.lhs_column_major,
matmul_dims.lhs_canonical, matmul_dims.rhs_column_major,
matmul_dims.rhs_canonical);
dot_canonical_dims_.m, dot_canonical_dims_.k, dot_canonical_dims_.n,
dot_canonical_dims_.lhs_column_major, dot_canonical_dims_.lhs_canonical,
dot_canonical_dims_.rhs_column_major, dot_canonical_dims_.rhs_canonical);

if (params.intra_op_threadpool == nullptr) {
return InvalidArgument("Intra-op threadpool must be provided for DotThunk");
Expand All @@ -211,22 +137,27 @@ tsl::AsyncValueRef<DotThunk::ExecuteEvent> DotThunk::Execute(
void* lhs = lhs_data.opaque();
void* rhs = rhs_data.opaque();

bool transpose_lhs = !matmul_dims.lhs_canonical;
bool transpose_rhs = !matmul_dims.rhs_canonical;
int64_t m = dot_canonical_dims_.m;
int64_t n = dot_canonical_dims_.n;
int64_t k = dot_canonical_dims_.k;

bool transpose_lhs = !dot_canonical_dims_.lhs_canonical;
bool transpose_rhs = !dot_canonical_dims_.rhs_canonical;

CHECK_EQ(matmul_dims.lhs_column_major, matmul_dims.rhs_column_major);
if (!matmul_dims.lhs_column_major) {
std::swap(matmul_dims.m, matmul_dims.n);
CHECK_EQ(dot_canonical_dims_.lhs_column_major,
dot_canonical_dims_.rhs_column_major);
if (!dot_canonical_dims_.lhs_column_major) {
std::swap(m, n);
std::swap(lhs, rhs);
std::swap(transpose_lhs, transpose_rhs);
}

PrimitiveType element_type = dot_shape_.lhs_matmul_shape.element_type();
int64_t byte_width = primitive_util::ByteWidth(element_type);

int64_t lhs_stride = matmul_dims.m * matmul_dims.k * byte_width;
int64_t rhs_stride = matmul_dims.k * matmul_dims.n * byte_width;
int64_t out_stride = matmul_dims.m * matmul_dims.n * byte_width;
int64_t lhs_stride = m * k * byte_width;
int64_t rhs_stride = k * n * byte_width;
int64_t out_stride = m * n * byte_width;

auto batch_ptr = [&](void* ptr, int64_t stride, int64_t index) -> void* {
return static_cast<uint8_t*>(ptr) + stride * index;
Expand All @@ -238,9 +169,9 @@ tsl::AsyncValueRef<DotThunk::ExecuteEvent> DotThunk::Execute(
for (int64_t i = 0; i < dot_shape_.batch_size; ++i) {
TypedMatMul<decltype(type_tag)>(
params.intra_op_threadpool, batch_ptr(out, out_stride, i),
batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i),
matmul_dims.m, matmul_dims.n, matmul_dims.k, transpose_lhs,
transpose_rhs, [state]() mutable { state.CountDown(); });
batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i), m, n, k,
transpose_lhs, transpose_rhs,
[state]() mutable { state.CountDown(); });
}
};

Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/dot_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DotThunk final : public Thunk {

private:
DotThunk(Info info, DotDimensionNumbers dot_dimensions, DotSlices dot_slices,
DotShape dot_shape);
DotShape dot_shape, DotCanonicalDims dot_canonical_dims);

using DoneCallback = absl::AnyInvocable<void()>;

Expand All @@ -72,6 +72,7 @@ class DotThunk final : public Thunk {
DotDimensionNumbers dot_dimensions_;
DotSlices dot_slices_;
DotShape dot_shape_;
DotCanonicalDims dot_canonical_dims_;

// Contracting dimensions of the LHS and RHS matmul shapes.
absl::InlinedVector<int64_t, 2> lhs_matmul_contracting_dims_;
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/framework/contraction:eigen_contraction_kernel",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"@XNNPACK",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
Expand All @@ -149,8 +151,6 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
Expand Down
Loading
Loading