diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index f7173047b77c..8be12435809c 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -164,8 +164,6 @@ def inference_step(op: ir.Operation): def to_default_layout(ty: ir.Type) -> ir.Attribute | None: if ir.VectorType.isinstance(ty): layout = WGStridedFragLayout.from_shaped_type(ty) - elif ir.RankedTensorType.isinstance(ty): - layout = WGStridedFragLayout.from_shaped_type(ty) else: return None return to_strided_fragmented_layout_attr(layout) @@ -173,7 +171,7 @@ def to_default_layout(ty: ir.Type) -> ir.Attribute | None: def set_default_layout(op: ir.OpView): if should_have_layout(op) and not has_any_layout_set(op): # TODO(bchetioui): consistently set layouts only for supported argument - # types (i.e. skip non-vector/tensor typed arguments.) + # types (i.e. skip non-vector typed arguments.) in_layouts = [] for operand in op.operands: if (layout := to_default_layout(operand.type)) is not None: diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index ccae1ede1e88..2dc4d31298a4 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -72,10 +72,7 @@ def to_splat_fragmented_layout_attr(layout: WGSplatFragLayout) -> ir.Attribute: def should_have_layout(op: ir.OpView) -> bool: """Returns 'true' if the operation should be assigned a layout.""" - def is_array(v: ir.Value): - ty = v.type - return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty) - + is_array = lambda v: ir.VectorType.isinstance(v.type) return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 681ee708edd8..5db2d9c34274 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -131,30 +131,6 @@ cc_library( ], ) -cc_test( - name = "mosaic_gpu_test", - srcs = ["mosaic_gpu_test.cc"], - deps = [ - ":mosaic_gpu", - "//testing/base/public:gunit_main", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:BufferizationInterfaces", - "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", - ], -) - gentbl_filegroup( name = "mosaic_gpu_python_gen_raw", tbl_outs = [ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index b21f56327457..89eca3fe7ec8 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -16,17 +16,12 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include -#include -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -35,23 +30,16 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Diagnostics.h" -#include "tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" @@ -66,178 +54,12 @@ limitations under the License. namespace mosaic_gpu { namespace { -using ::mlir::FloatType; -using ::mlir::ImplicitLocOpBuilder; using ::mlir::IntegerType; -using ::mlir::MLIRContext; using ::mlir::Type; -using ::mlir::TypeRange; -using ::mlir::Value; -using ::mlir::ValueRange; using Index = ::mlir::TypedValue<::mlir::IndexType>; using Integer = ::mlir::TypedValue<::mlir::IntegerType>; -Integer ToI64(ImplicitLocOpBuilder& b, Index index) { - return llvm::cast( - b.create(b.getI64Type(), index).getResult()); -} - -template -Value Constant(ImplicitLocOpBuilder& b, T scalar, IntegerType type) { - return b.create( - type, mlir::IntegerAttr::get(type, scalar)); -} - -template -Value Constant(ImplicitLocOpBuilder& b, T scalar, FloatType type) { - return b.create(type, - mlir::FloatAttr::get(type, scalar)); -} - -// Given a range of values of the same type, produces a LLVM array that contains -// all of them in order. Returns a pointer to the start of the newly created -// array. -absl::StatusOr ToLLVMArray(ImplicitLocOpBuilder& b, - ValueRange values) { - if (values.empty()) { - return absl::InvalidArgumentError("Can not pack an empty array of values."); - } - - Type element_type = values.front().getType(); - - MLIRContext* ctx = b.getContext(); - mlir::LLVM::LLVMPointerType pointer_type = - mlir::LLVM::LLVMPointerType::get(ctx); - Pointer array_pointer = b.create( - pointer_type, element_type, Constant(b, values.size(), b.getI64Type())); - - for (auto [i, value] : llvm::enumerate(values)) { - if (value.getType() != element_type) { - return absl::InvalidArgumentError(absl::StrCat( - "Expected all values to have the same type, but got ", - MlirToString(value.getType()), " and ", MlirToString(element_type))); - } - - auto element_pointer = llvm::cast( - b.create( - pointer_type, element_type, array_pointer, - mlir::ArrayRef(mlir::LLVM::GEPArg(i))) - .getResult()); - b.create(value, element_pointer); - } - - return array_pointer; -} - -// Extracts a pointer to the start of the parameter memref. -Pointer FromMemref(ImplicitLocOpBuilder& b, Memref memref) { - Index aligned_pointer_as_index = - b.create(memref); - - mlir::LLVM::LLVMPointerType pointer_type = - mlir::LLVM::LLVMPointerType::get(b.getContext()); - - Value alloc_pointer = b.create( - pointer_type, ToI64(b, aligned_pointer_as_index)); - - Type tensor_element_type = memref.getType().getElementType(); - - return mlir::cast( - b.create( - pointer_type, tensor_element_type, alloc_pointer, - mlir::ArrayRef( - mlir::LLVM::GEPArg(ToI64(b, aligned_pointer_as_index)))) - .getResult()); -} - -} // anonymous namespace - -// TODO(bchetioui): add swizzling. -absl::Status InitTmaDescriptor(mlir::OpBuilder& builder, - Pointer host_pointer_to_descriptor, - Memref gmem_ref, - mlir::ArrayRef slice_shape) { - ImplicitLocOpBuilder b( - mlir::NameLoc::get(builder.getStringAttr("InitTmaDescriptor")), builder); - - mlir::memref::ExtractStridedMetadataOp extract_strided_metadata_op = - b.create(gmem_ref); - - Type tensor_element_type = gmem_ref.getType().getElementType(); - - Pointer tensor_base_pointer = FromMemref(b, gmem_ref); - - int64_t tensor_rank = gmem_ref.getType().getRank(); - ValueRange sizes = extract_strided_metadata_op.getSizes(); - ValueRange strides = extract_strided_metadata_op.getStrides(); - - if (tensor_rank != slice_shape.size()) { - return absl::FailedPreconditionError(absl::StrCat( - "Slice shape should have the same rank as the target tensor " - "but got ", - slice_shape.size(), " != ", tensor_rank)); - } - - std::vector sizes_as_i64; - std::vector strides_as_i64; - std::vector slice_as_i64; - sizes_as_i64.reserve(tensor_rank); - strides_as_i64.reserve(tensor_rank); - slice_as_i64.reserve(tensor_rank); - for (auto [size, stride, slice_dim] : - llvm::zip(sizes, strides, slice_shape)) { - sizes_as_i64.push_back(ToI64(b, llvm::cast(size))); - strides_as_i64.push_back(ToI64(b, llvm::cast(stride))); - slice_as_i64.push_back(Constant(b, slice_dim, b.getI64Type())); - } - - TF_ASSIGN_OR_RETURN(Pointer sizes_array, ToLLVMArray(b, sizes_as_i64)); - TF_ASSIGN_OR_RETURN(Pointer strides_array, ToLLVMArray(b, strides_as_i64)); - TF_ASSIGN_OR_RETURN(Pointer slice_array, ToLLVMArray(b, slice_as_i64)); - - IntegerType i64 = b.getI64Type(); - - int64_t elem_bitwidth = tensor_element_type.getIntOrFloatBitWidth(); - - if (elem_bitwidth < 8) { - return absl::UnimplementedError("Sub-byte types are not yet supported."); - } - - // TODO(bchetioui): connect this to runtime. - b.create( - kRuntimeTmaDescriptorInitializerName, TypeRange{}, - ValueRange{/*tma_desc=*/host_pointer_to_descriptor, - /*base_addr=*/tensor_base_pointer, - /*elem_bytewidth=*/Constant(b, elem_bitwidth / 8, i64), - /*rank=*/Constant(b, tensor_rank, i64), - /*sizes=*/sizes_array, - /*strides=*/strides_array, - // TODO(bchetioui): implement swizzling. - /*swizzle_bytes=*/Constant(b, 0, i64), - /*window_shape=*/slice_array}); - - return absl::OkStatus(); -} - -void DeclareRuntimeFunctions(mlir::OpBuilder& builder) { - MLIRContext* ctx = builder.getContext(); - mlir::LLVM::LLVMPointerType ptr = mlir::LLVM::LLVMPointerType::get(ctx); - IntegerType i64 = builder.getI64Type(); - - builder - .create( - builder.getUnknownLoc(), kRuntimeTmaDescriptorInitializerName, - builder.getFunctionType( - TypeRange{ptr, ptr, i64, i64, ptr, ptr, i64, ptr}, TypeRange{})) - .setVisibility(mlir::func::FuncOp::Visibility::Private); - - builder - .create( - builder.getUnknownLoc(), kRuntimeMemcpyAsyncH2DName, - builder.getFunctionType(TypeRange{ptr, ptr, i64, ptr}, TypeRange{})) - .setVisibility(mlir::func::FuncOp::Visibility::Private); -} bool IsContiguous(mlir::MemRefType type) { return type.getLayout().isIdentity() || @@ -245,7 +67,6 @@ bool IsContiguous(mlir::MemRefType type) { mlir::memref::isStaticShapeAndContiguousRowMajor(type)); } -namespace { llvm::LogicalResult VerifyCommonLoadStoreOp( mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name, mlir::MemRefType smem_type, absl::string_view smem_name, diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 14c0d0295a8f..cae2edba462c 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -16,20 +16,10 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_ -#include -#include - #include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep @@ -43,39 +33,10 @@ limitations under the License. namespace mosaic_gpu { -using Memref = ::mlir::TypedValue<::mlir::MemRefType>; -using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>; - struct GlobalMemory : public mlir::SideEffects::Resource::Base { llvm::StringRef getName() final { return ""; } }; -constexpr absl::string_view kRuntimeTmaDescriptorInitializerName = - "mosaic_gpu_init_tma_desc"; -constexpr absl::string_view kRuntimeMemcpyAsyncH2DName = - "mosaic_gpu_memcpy_async_h2d"; - -template -std::string MlirToString(T&& value) { - std::string result; - llvm::raw_string_ostream os(result); - value.print(os); - return result; -} - -// Declares the runtime functions that can be called from the generated code. -void DeclareRuntimeFunctions(mlir::OpBuilder& builder); - -// Given a target host pointer, a memref corresponding to the tensor we intend -// to describe, and the shape of the slice we intend to load using the resulting -// TMA descriptor, `InitTmaDescriptor` generates the TMA descriptor -// initialization logic on the host. The resulting TMA descriptor will be -// stored at `host_pointer_to_descriptor`. -absl::Status InitTmaDescriptor(mlir::OpBuilder& builder, - Pointer host_pointer_to_descriptor, - Memref gmem_ref, - mlir::ArrayRef slice_shape); - } // namespace mosaic_gpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_H_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 278c5bee71c0..89fe4abc2d55 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -327,7 +327,7 @@ def MosaicGPU_WGMMAOp : Op { memrefs. `a` and `b` must have the same element type and when `a` is in registers only F16 or BF16 are supported. - The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA + The `accumulator` must be a vector with a FragmentedLayout. The WGMMA operation will be executed in the async proxy and any inputs in registers need to be synchronized with a memory fence. @@ -338,10 +338,10 @@ def MosaicGPU_WGMMAOp : Op { }]; let arguments = (ins - TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator, + VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>:$accumulator, AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, - TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a, + VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, // Attributes diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc deleted file mode 100644 index 34f6241661d5..000000000000 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" - -#include -#include -#include - -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" - -namespace mosaic_gpu { -namespace { - -using ::testing::HasSubstr; -using ::testing::UnorderedElementsAre; -using ::testing::status::StatusIs; - -template -absl::StatusOr FromCppFunc( - mlir::ModuleOp module, - absl::Status (*f)(mlir::OpBuilder&, mlir::TypedValue, - mlir::TypedValue, Ts...), - T1 type1, T2 type2, Ts... varargs) { - mlir::MLIRContext* context = module->getContext(); - mlir::OpBuilder b(context); - b.setInsertionPointToEnd(module.getBody()); - - auto fn = b.create( - b.getUnknownLoc(), "function_wrapper", - b.getFunctionType({type1, type2}, std::nullopt)); - fn.addEntryBlock(); - b.setInsertionPointToStart(&fn.front()); - - TF_RETURN_IF_ERROR(f(b, mlir::cast>(fn.getArgument(0)), - mlir::cast>(fn.getArgument(1)), - varargs...)); - - b.create(b.getUnknownLoc()); - - if (mlir::failed(mlir::verify(module))) { - return absl::InternalError("Failed to verify generated module"); - } - - return fn; -} - -class MosaicGpuTest : public ::testing::Test { - public: - MosaicGpuTest() - : builder_(&context_), - module_(mlir::OwningOpRef( - mlir::ModuleOp::create(builder_.getUnknownLoc(), "module"))) { - RegisterErrorRecordingHandler(); - context_.loadDialect(); - builder_.setInsertionPointToEnd(module_->getBody()); - mosaic_gpu::DeclareRuntimeFunctions(builder_); - } - - void ExpectLastErrorContains(absl::string_view substring) { - EXPECT_THAT(last_error_message_, HasSubstr(substring)); - } - - protected: - mlir::MLIRContext context_; - mlir::OpBuilder builder_; - mlir::OwningOpRef module_; - std::string last_error_message_; - - private: - void RegisterErrorRecordingHandler() { - // Make sure to make the context single-threaded to avoid race conditions - // when recording the last error message. - context_.disableMultithreading(); - mlir::DiagnosticEngine& diagnostic_engine = context_.getDiagEngine(); - diagnostic_engine.registerHandler([&](mlir::Diagnostic& diagnostic) { - last_error_message_ = diagnostic.str(); - }); - } -}; - -TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { - std::vector shape{1, 2, 3}; - std::vector slice_shape{1, 2}; - - mlir::LLVM::LLVMPointerType pointer_type = - mlir::LLVM::LLVMPointerType::get(&context_); - mlir::MemRefType memref_type = - mlir::MemRefType::get(shape, builder_.getF32Type()); - - EXPECT_THAT( - FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, mlir::ArrayRef(slice_shape)), - StatusIs( - absl::StatusCode::kFailedPrecondition, - HasSubstr( - "Slice shape should have the same rank as the target tensor"))); -} - -TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) { - std::vector shape{1, 2, 3}; - std::vector slice_shape{1, 2, 3}; - - mlir::LLVM::LLVMPointerType pointer_type = - mlir::LLVM::LLVMPointerType::get(&context_); - mlir::MemRefType memref_type = - mlir::MemRefType::get(shape, builder_.getI4Type()); - - EXPECT_THAT(FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, mlir::ArrayRef(slice_shape)), - StatusIs(absl::StatusCode::kUnimplemented, - HasSubstr("Sub-byte types are not yet supported"))); -} - -TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) { - std::vector shape{1, 2, 3}; - std::vector slice_shape{1, 2, 3}; - - mlir::LLVM::LLVMPointerType pointer_type = - mlir::LLVM::LLVMPointerType::get(&context_); - mlir::MemRefType memref_type = - mlir::MemRefType::get(shape, builder_.getF32Type()); - - absl::StatusOr fn_or = - FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, mlir::ArrayRef(slice_shape)); - ASSERT_OK(fn_or); - - llvm::SmallVector call_ops = - llvm::to_vector(fn_or->getBlocks().front().getOps()); - EXPECT_EQ(call_ops.size(), 1); - EXPECT_EQ(call_ops.front().getCallee().str(), - mosaic_gpu::kRuntimeTmaDescriptorInitializerName); -} - -TEST_F(MosaicGpuTest, RuntimeFunctionsAreRegistered) { - // Deliberately introduce a new module to explicitly register the runtime - // functions. - mlir::OwningOpRef module_op = - mlir::ModuleOp::create(builder_.getUnknownLoc(), "new_module"); - builder_.setInsertionPointToEnd(module_op->getBody()); - mosaic_gpu::DeclareRuntimeFunctions(builder_); - - llvm::SmallVector func_ops = - llvm::to_vector(module_op->getBody()->getOps()); - EXPECT_EQ(func_ops.size(), 2); - - absl::flat_hash_set func_names; - for (mlir::func::FuncOp& func_op : func_ops) { - func_names.insert(func_op.getSymName().str()); - } - - EXPECT_THAT(func_names, UnorderedElementsAre( - mosaic_gpu::kRuntimeTmaDescriptorInitializerName, - mosaic_gpu::kRuntimeMemcpyAsyncH2DName)); -} - - -} // anonymous namespace -} // namespace mosaic_gpu diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 188239ca3629..708e35bc522c 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -487,7 +487,7 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self): def test_wgmma_types_match(self): with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func( - ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.VectorType.get([128, 160], ir.BF16Type.get()), ir.MemRefType.get([2, 4, 64, 32], ir.F16Type.get()), ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()), name="wgmma", @@ -509,7 +509,7 @@ def test_wgmma_types_match(self): def test_wgmma_b_rank_is_4(self): with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func( - ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.VectorType.get([128, 160], ir.BF16Type.get()), ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()), name="wgmma", @@ -531,7 +531,7 @@ def test_wgmma_b_rank_is_4(self): def test_wgmma_b_shape_dim_3(self): with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func( - ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.VectorType.get([128, 160], ir.BF16Type.get()), ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()), name="wgmma", @@ -554,7 +554,7 @@ def test_wgmma_b_shape_dim_3(self): def test_wgmma_b_shape_dim_2(self): with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func( - ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.VectorType.get([128, 160], ir.BF16Type.get()), ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()), name="wgmma", diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index e1d7f1a0d39a..ecc316b83a70 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -45,13 +45,12 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() - @parameterized.parameters(ir.RankedTensorType, ir.VectorType) - def test_infer_layout_default(self, type_constructor): + def test_infer_layout_default(self): shape = (16, 8) elt_type = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): - ab_type = type_constructor.get(shape, elt_type) + ab_type = ir.VectorType.get(shape, elt_type) const_zero = ir.FloatAttr.get(elt_type, 0) const_one = ir.FloatAttr.get(elt_type, 1) a = arith.ConstantOp( @@ -80,13 +79,12 @@ def test_infer_layout_default(self, type_constructor): op.attributes["out_layouts"], [layout] * len(op.results) ) - @parameterized.parameters(ir.RankedTensorType, ir.VectorType) - def test_infer_layout_for_pointwise_op(self, type_constructor): + def test_infer_layout_for_pointwise_op(self): shape = (4, 8) elt_type = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): - ab_type = type_constructor.get(shape, elt_type) + ab_type = ir.VectorType.get(shape, elt_type) const_zero = ir.FloatAttr.get(elt_type, 0) const_one = ir.FloatAttr.get(elt_type, 1) a = arith.ConstantOp(