Skip to content

Commit

Permalink
Fixes to passes for custom dispatch to work with bf16 type (iree-org#…
Browse files Browse the repository at this point in the history
…17242)

Custom dispatch for bf16 type stopped working because of a couple of
problems:

1. `--iree-codegen-expand-strided-metadata` doesn't work with `iree-opt`
when given an `iree_codegen.extract_strided_metadata` op due to a lack
of a dependency on the `memref` dialect.
2. `--iree-convert-bf16-to-uint16-buffers` doesn't convert types in an
`iree_codegen.extract_strided_metadata` op due to a lack of a dependency
on the `iree_codegen` dialect. As a result, type casts are introduced
between the `iree_codegen.extract_strided_metadata` and its
`hal.interface.binding.subspan` predecessor, which confuses
`--iree-codegen-expand-strided-metadata`, which fails to do its job of
eliminating the `iree_codegen.extract_strided_metadata`. This second
problem was described in iree-org#17177.

The fix for both problems is to add the missing dialect dependencies.

fixes: iree-org#17177
  • Loading branch information
daveliddell authored May 1, 2024
1 parent f54a861 commit 6233f4f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Utils/ConversionUtils.h"
Expand Down Expand Up @@ -268,7 +269,8 @@ struct ConvertBf16ToUInt16BuffersPass final
target.addLegalOp<arith::TruncFOp, arith::ExtFOp, ModuleOp>();
target.addDynamicallyLegalDialect<arith::ArithDialect, func::FuncDialect,
IREE::HAL::HALDialect,
memref::MemRefDialect, scf::SCFDialect>(
memref::MemRefDialect, scf::SCFDialect,
IREE::Codegen::IREECodegenDialect>(
[&typeConverter](Operation *op) {
bool legal = typeConverter.isLegal(op);
LLVM_DEBUG(if (!legal) llvm::dbgs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -243,7 +244,7 @@ struct IREEExpandStridedMetadataPass
: public IREEExpandStridedMetadataBase<IREEExpandStridedMetadataPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
IREE::Codegen::IREECodegenDialect>();
IREE::Codegen::IREECodegenDialect, memref::MemRefDialect>();
}

void runOnOperation() override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,25 @@ func.func @load_trunc_f32_bf16(%arg0 : memref<32xf32>, %arg1 : memref<32xbf16>)
vector.store %trunc, %arg1[%c0] : memref<32xbf16>, vector<4xbf16>
return
}

// -----

// Test that iree_codegen.extract_strided_metadata (or any other op from iree_codegen)
// is rewritten correctly, along with any following ops.
// See issue https://github.com/iree-org/iree/issues/17177
// CHECK-LABEL: module @extract_strided_metadata
module @extract_strided_metadata {
func.func private @external_func(memref<bf16>, index) attributes {llvm.bareptr = [true]}
// CHECK: func.func private @external_func(memref<i16>, index)
func.func @external_func_entry_point() attributes {translation_info = #iree_codegen.translation_info<CPUDefault>} {
%0 = hal.interface.constant.load[0] : i32
%1 = arith.index_castui %0 : i32 to index
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%1) flags(ReadOnly) : memref<1x8x768xbf16, strided<[6144, 768, 1], offset: ?>>
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan {{.*}} : memref<1x8x768xi16,
%base_buffer, %offset, %sizes:3, %strides:3 = iree_codegen.extract_strided_metadata %2 : memref<1x8x768xbf16, strided<[6144, 768, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index
// CHECK: {{.+}} = iree_codegen.extract_strided_metadata %[[SUBSPAN]] : memref<1x8x768xi16,
call @external_func(%base_buffer, %offset) : (memref<bf16>, index) -> ()
// CHECK: call @external_func({{.*}}) : (memref<i16>, index)
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,18 @@ func.func @resolve_global_memref() -> (memref<f32>, index, index, index, index,
// CHECK-DAG: %[[GET_GLOBAL:.+]] = memref.get_global @[[CONSTANT]]
// CHECK: %[[CAST:.+]] = memref.reinterpret_cast %[[GET_GLOBAL]] to offset: [0], sizes: [], strides: []
// CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]

// -----

// Test for the part of the pass that converts iree_codegen to memref
func.func @external_func_entry_point() -> (memref<bf16>, index) {
%0 = hal.interface.constant.load[0] : i32
%1 = arith.index_castui %0 : i32 to index
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%1) flags(ReadOnly) : memref<1x8x768xbf16, strided<[6144, 768, 1], offset: ?>>
%base_buffer, %offset, %sizes:3, %strides:3 = iree_codegen.extract_strided_metadata %2 : memref<1x8x768xbf16, strided<[6144, 768, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index
return %base_buffer, %offset : memref<bf16>, index
}

// CHECK-LABEL: func.func @external_func_entry_point()
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK: %{{.+}} = memref.reinterpret_cast %[[SUBSPAN]]

0 comments on commit 6233f4f

Please sign in to comment.