Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Clean up remnants of the early C++ dialect implementation. #25610

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions jax/experimental/mosaic/gpu/layout_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,14 @@ def inference_step(op: ir.Operation):
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
if ir.VectorType.isinstance(ty):
layout = WGStridedFragLayout.from_shaped_type(ty)
elif ir.RankedTensorType.isinstance(ty):
layout = WGStridedFragLayout.from_shaped_type(ty)
else:
return None
return to_strided_fragmented_layout_attr(layout)

def set_default_layout(op: ir.OpView):
if should_have_layout(op) and not has_any_layout_set(op):
# TODO(bchetioui): consistently set layouts only for supported argument
# types (i.e. skip non-vector/tensor typed arguments.)
# types (i.e. skip non-vector typed arguments.)
in_layouts = []
for operand in op.operands:
if (layout := to_default_layout(operand.type)) is not None:
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/mosaic/gpu/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def to_splat_fragmented_layout_attr(layout: WGSplatFragLayout) -> ir.Attribute:
def should_have_layout(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be assigned a layout."""

def is_array(v: ir.Value):
ty = v.type
return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty)

is_array = lambda v: ir.VectorType.isinstance(v.type)
return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore


Expand Down
24 changes: 0 additions & 24 deletions jaxlib/mosaic/dialect/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,30 +131,6 @@ cc_library(
],
)

cc_test(
name = "mosaic_gpu_test",
srcs = ["mosaic_gpu_test.cc"],
deps = [
":mosaic_gpu",
"//testing/base/public:gunit_main",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BufferizationInterfaces",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:errors",
],
)

gentbl_filegroup(
name = "mosaic_gpu_python_gen_raw",
tbl_outs = [
Expand Down
179 changes: 0 additions & 179 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"

#include <cstdint>
#include <vector>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
Expand All @@ -35,23 +30,16 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "tsl/platform/statusor.h"

// Generated definitions.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc"
Expand All @@ -66,186 +54,19 @@ limitations under the License.
namespace mosaic_gpu {
namespace {

using ::mlir::FloatType;
using ::mlir::ImplicitLocOpBuilder;
using ::mlir::IntegerType;
using ::mlir::MLIRContext;
using ::mlir::Type;
using ::mlir::TypeRange;
using ::mlir::Value;
using ::mlir::ValueRange;

using Index = ::mlir::TypedValue<::mlir::IndexType>;
using Integer = ::mlir::TypedValue<::mlir::IntegerType>;

Integer ToI64(ImplicitLocOpBuilder& b, Index index) {
return llvm::cast<Integer>(
b.create<mlir::arith::IndexCastOp>(b.getI64Type(), index).getResult());
}

template <typename T>
Value Constant(ImplicitLocOpBuilder& b, T scalar, IntegerType type) {
return b.create<mlir::arith::ConstantOp>(
type, mlir::IntegerAttr::get(type, scalar));
}

template <typename T>
Value Constant(ImplicitLocOpBuilder& b, T scalar, FloatType type) {
return b.create<mlir::arith::ConstantOp>(type,
mlir::FloatAttr::get(type, scalar));
}

// Given a range of values of the same type, produces a LLVM array that contains
// all of them in order. Returns a pointer to the start of the newly created
// array.
absl::StatusOr<Pointer> ToLLVMArray(ImplicitLocOpBuilder& b,
ValueRange values) {
if (values.empty()) {
return absl::InvalidArgumentError("Can not pack an empty array of values.");
}

Type element_type = values.front().getType();

MLIRContext* ctx = b.getContext();
mlir::LLVM::LLVMPointerType pointer_type =
mlir::LLVM::LLVMPointerType::get(ctx);
Pointer array_pointer = b.create<mlir::LLVM::AllocaOp>(
pointer_type, element_type, Constant(b, values.size(), b.getI64Type()));

for (auto [i, value] : llvm::enumerate(values)) {
if (value.getType() != element_type) {
return absl::InvalidArgumentError(absl::StrCat(
"Expected all values to have the same type, but got ",
MlirToString(value.getType()), " and ", MlirToString(element_type)));
}

auto element_pointer = llvm::cast<Pointer>(
b.create<mlir::LLVM::GEPOp>(
pointer_type, element_type, array_pointer,
mlir::ArrayRef<mlir::LLVM::GEPArg>(mlir::LLVM::GEPArg(i)))
.getResult());
b.create<mlir::LLVM::StoreOp>(value, element_pointer);
}

return array_pointer;
}

// Extracts a pointer to the start of the parameter memref.
Pointer FromMemref(ImplicitLocOpBuilder& b, Memref memref) {
Index aligned_pointer_as_index =
b.create<mlir::memref::ExtractAlignedPointerAsIndexOp>(memref);

mlir::LLVM::LLVMPointerType pointer_type =
mlir::LLVM::LLVMPointerType::get(b.getContext());

Value alloc_pointer = b.create<mlir::LLVM::IntToPtrOp>(
pointer_type, ToI64(b, aligned_pointer_as_index));

Type tensor_element_type = memref.getType().getElementType();

return mlir::cast<Pointer>(
b.create<mlir::LLVM::GEPOp>(
pointer_type, tensor_element_type, alloc_pointer,
mlir::ArrayRef<mlir::LLVM::GEPArg>(
mlir::LLVM::GEPArg(ToI64(b, aligned_pointer_as_index))))
.getResult());
}

} // anonymous namespace

// TODO(bchetioui): add swizzling.
absl::Status InitTmaDescriptor(mlir::OpBuilder& builder,
Pointer host_pointer_to_descriptor,
Memref gmem_ref,
mlir::ArrayRef<int64_t> slice_shape) {
ImplicitLocOpBuilder b(
mlir::NameLoc::get(builder.getStringAttr("InitTmaDescriptor")), builder);

mlir::memref::ExtractStridedMetadataOp extract_strided_metadata_op =
b.create<mlir::memref::ExtractStridedMetadataOp>(gmem_ref);

Type tensor_element_type = gmem_ref.getType().getElementType();

Pointer tensor_base_pointer = FromMemref(b, gmem_ref);

int64_t tensor_rank = gmem_ref.getType().getRank();
ValueRange sizes = extract_strided_metadata_op.getSizes();
ValueRange strides = extract_strided_metadata_op.getStrides();

if (tensor_rank != slice_shape.size()) {
return absl::FailedPreconditionError(absl::StrCat(
"Slice shape should have the same rank as the target tensor "
"but got ",
slice_shape.size(), " != ", tensor_rank));
}

std::vector<Value> sizes_as_i64;
std::vector<Value> strides_as_i64;
std::vector<Value> slice_as_i64;
sizes_as_i64.reserve(tensor_rank);
strides_as_i64.reserve(tensor_rank);
slice_as_i64.reserve(tensor_rank);
for (auto [size, stride, slice_dim] :
llvm::zip(sizes, strides, slice_shape)) {
sizes_as_i64.push_back(ToI64(b, llvm::cast<Index>(size)));
strides_as_i64.push_back(ToI64(b, llvm::cast<Index>(stride)));
slice_as_i64.push_back(Constant(b, slice_dim, b.getI64Type()));
}

TF_ASSIGN_OR_RETURN(Pointer sizes_array, ToLLVMArray(b, sizes_as_i64));
TF_ASSIGN_OR_RETURN(Pointer strides_array, ToLLVMArray(b, strides_as_i64));
TF_ASSIGN_OR_RETURN(Pointer slice_array, ToLLVMArray(b, slice_as_i64));

IntegerType i64 = b.getI64Type();

int64_t elem_bitwidth = tensor_element_type.getIntOrFloatBitWidth();

if (elem_bitwidth < 8) {
return absl::UnimplementedError("Sub-byte types are not yet supported.");
}

// TODO(bchetioui): connect this to runtime.
b.create<mlir::func::CallOp>(
kRuntimeTmaDescriptorInitializerName, TypeRange{},
ValueRange{/*tma_desc=*/host_pointer_to_descriptor,
/*base_addr=*/tensor_base_pointer,
/*elem_bytewidth=*/Constant(b, elem_bitwidth / 8, i64),
/*rank=*/Constant(b, tensor_rank, i64),
/*sizes=*/sizes_array,
/*strides=*/strides_array,
// TODO(bchetioui): implement swizzling.
/*swizzle_bytes=*/Constant(b, 0, i64),
/*window_shape=*/slice_array});

return absl::OkStatus();
}

void DeclareRuntimeFunctions(mlir::OpBuilder& builder) {
MLIRContext* ctx = builder.getContext();
mlir::LLVM::LLVMPointerType ptr = mlir::LLVM::LLVMPointerType::get(ctx);
IntegerType i64 = builder.getI64Type();

builder
.create<mlir::func::FuncOp>(
builder.getUnknownLoc(), kRuntimeTmaDescriptorInitializerName,
builder.getFunctionType(
TypeRange{ptr, ptr, i64, i64, ptr, ptr, i64, ptr}, TypeRange{}))
.setVisibility(mlir::func::FuncOp::Visibility::Private);

builder
.create<mlir::func::FuncOp>(
builder.getUnknownLoc(), kRuntimeMemcpyAsyncH2DName,
builder.getFunctionType(TypeRange{ptr, ptr, i64, ptr}, TypeRange{}))
.setVisibility(mlir::func::FuncOp::Visibility::Private);
}

bool IsContiguous(mlir::MemRefType type) {
return type.getLayout().isIdentity() ||
(type.hasStaticShape() && type.getNumElements() > 0 &&
mlir::memref::isStaticShapeAndContiguousRowMajor(type));
}

namespace {
llvm::LogicalResult VerifyCommonLoadStoreOp(
mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name,
mlir::MemRefType smem_type, absl::string_view smem_name,
Expand Down
41 changes: 1 addition & 40 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,10 @@ limitations under the License.
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_

#include <cstdint>
#include <string>

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // IWYU pragma: keep
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"

// Generated definitions.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep
Expand All @@ -43,39 +33,10 @@ limitations under the License.

namespace mosaic_gpu {

using Memref = ::mlir::TypedValue<::mlir::MemRefType>;
using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>;

struct GlobalMemory : public mlir::SideEffects::Resource::Base<GlobalMemory> {
llvm::StringRef getName() final { return "<GlobalMemory>"; }
};

constexpr absl::string_view kRuntimeTmaDescriptorInitializerName =
"mosaic_gpu_init_tma_desc";
constexpr absl::string_view kRuntimeMemcpyAsyncH2DName =
"mosaic_gpu_memcpy_async_h2d";

template <typename T>
std::string MlirToString(T&& value) {
std::string result;
llvm::raw_string_ostream os(result);
value.print(os);
return result;
}

// Declares the runtime functions that can be called from the generated code.
void DeclareRuntimeFunctions(mlir::OpBuilder& builder);

// Given a target host pointer, a memref corresponding to the tensor we intend
// to describe, and the shape of the slice we intend to load using the resulting
// TMA descriptor, `InitTmaDescriptor` generates the TMA descriptor
// initialization logic on the host. The resulting TMA descriptor will be
// stored at `host_pointer_to_descriptor`.
absl::Status InitTmaDescriptor(mlir::OpBuilder& builder,
Pointer host_pointer_to_descriptor,
Memref gmem_ref,
mlir::ArrayRef<int64_t> slice_shape);

} // namespace mosaic_gpu

#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_
6 changes: 3 additions & 3 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
memrefs. `a` and `b` must have the same element type and when `a` is in
registers only F16 or BF16 are supported.

The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA
The `accumulator` must be a vector with a FragmentedLayout. The WGMMA
operation will be executed in the async proxy and any inputs in
registers need to be synchronized with a memory fence.

Expand All @@ -338,10 +338,10 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
}];

let arguments = (ins
TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
AnyTypeOf<[
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,

// Attributes
Expand Down
Loading
Loading