Skip to content

Commit

Permalink
[i1] Implement packed_storage layout encoding attribute (#19354)
Browse files Browse the repository at this point in the history
* make `packed_storage` as a type of `iree_encoding` attribute, and make
type converters accept it.
* `i1` tensors with `#iree_encoding.packed_storage` will be interpreted
as packed i1 type, same as specifying
`--iree-experimental-packed-i1-storage`. Other i1 tensors are treated as
non-packed datatype, and will be extended.
* `--iree-experimental-packed-i1-storage` are kept for testing purposes.
* We can drop this option after frontend enables emitting `i1` tensors
with attributes.

Signed-off-by: Alan Li <me@alanli.org>
  • Loading branch information
lialan authored Jan 10, 2025
1 parent 801e2c1 commit c793f90
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 35 deletions.
15 changes: 9 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -62,16 +63,18 @@ MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
// Mark any operation that has operands/results with encoding as
// illegal.
markUnknownOpDynamicallyLegal([](Operation *op) {
auto typeHasEncoding = [](Type t) -> bool {
auto typeHasDataTilingEncoding = [](Type t) -> bool {
auto tensorType = dyn_cast<RankedTensorType>(t);
return tensorType && tensorType.getEncoding();
if (!tensorType)
return false;
return getEncodingAttr(tensorType) != nullptr;
};
auto valueHasEncoding = [=](Value v) -> bool {
return typeHasEncoding(v.getType());
auto valueHasDataTilingEncoding = [=](Value v) -> bool {
return typeHasDataTilingEncoding(v.getType());
};
bool hasOperandOrResultsWithEncoding =
llvm::any_of(op->getOperands(), valueHasEncoding) ||
llvm::any_of(op->getResultTypes(), typeHasEncoding);
llvm::any_of(op->getOperands(), valueHasDataTilingEncoding) ||
llvm::any_of(op->getResultTypes(), typeHasDataTilingEncoding);
return !hasOperandOrResultsWithEncoding;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
def EncodingOpTypeAttr:
IREEEncoding_EnumAttr<EncodingOpType, "optype">;


def PackedStorageAttr : IREEEncoding_Attr<"PackedStorage"> {
let mnemonic = "packed_storage";
let summary = [{Indicates packed storage data type.}];
let description = [{
This attribute indicates this is a back-to-back packed storage in memory.
This attribute takes no arguments.
}];
let genVerifyDecl = 0;
}

def EncodingAttr :
IREEEncoding_Attr<"Encoding", [
DeclareAttrInterfaceMethods<IREEEncoding_EncodingLayoutAttrInterface, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ namespace mlir::iree_compiler::IREE::Encoding {
/// Otherwise, returns null.
EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns true if the type contains packed_storage attribute.
bool hasPackedStorageAttr(RankedTensorType type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_compiler_cc_library(
"Patterns.h",
],
deps = [
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/Stream/Conversion",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_cc_library(
MLIRIR
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Stream::Conversion
iree::compiler::Dialect::Stream::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
Expand Down Expand Up @@ -100,6 +101,13 @@ struct ConvertTensorImportOp
RankedTensorType tensorType,
ValueRange dynamicDims,
OpBuilder &builder) {
// If the encoding attr is about packed storage then we don't need
// assertion, because packed storage attribute is about memory layout and it
// doesn't affect the tensor shape.
if (IREE::Encoding::hasPackedStorageAttr(tensorType)) {
return success();
}

auto expectedElementType = builder.create<IREE::HAL::ElementTypeOp>(
loc, tensorType.getElementType());
auto expectedEncodingType = builder.create<IREE::HAL::EncodingTypeOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
ValueRange encodingDims,
PatternRewriter &rewriter) {
auto encoding = encodingType.getEncoding();
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr>(encoding)) {
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr,
IREE::Encoding::PackedStorageAttr>(encoding)) {
return rewriter.notifyMatchFailure(op, [=](Diagnostic &d) {
d << "unsupported tensor encoding: " << encodingType;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ iree_lit_test_suite(
"encode_device_tensors_packing.mlir",
"encode_host_tensors.mlir",
"encode_host_tensors_packing.mlir",
"encode_host_tensors_packing_i1.mlir",
"encode_host_tensors_packing_i1_attr.mlir",
"encode_host_tensors_packing_i1_experimental_clopt.mlir",
"fold_globals.mlir",
"fold_uniform_operands.mlir",
"fuse_dispatch_bindings.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ iree_lit_test_suite(
"encode_device_tensors_packing.mlir"
"encode_host_tensors.mlir"
"encode_host_tensors_packing.mlir"
"encode_host_tensors_packing_i1.mlir"
"encode_host_tensors_packing_i1_attr.mlir"
"encode_host_tensors_packing_i1_experimental_clopt.mlir"
"fold_globals.mlir"
"fold_uniform_operands.mlir"
"fuse_dispatch_bindings.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: iree-opt --split-input-file --iree-stream-encode-host-tensors %s | FileCheck %s

#packed = #iree_encoding.packed_storage
func.func @unaligned_i1_size() -> index {
%0 = stream.tensor.sizeof tensor<12xi1, #packed> : index
return %0 : index
}
// CHECK: func @unaligned_i1_size() -> index {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: return %[[C2]] : index

// -----

#packed = #iree_encoding.packed_storage
func.func @aligned_i1_size() -> index {
%0 = stream.tensor.sizeof tensor<24xi1, #packed> : index
return %0 : index
}

// CHECK: func @aligned_i1_size() -> index {
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: return %[[C3]] : index
61 changes: 45 additions & 16 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"

namespace mlir::iree_compiler {

// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
llvm::cl::opt<bool> clEnableI1Support(
"iree-experimental-packed-i1-storage",
llvm::cl::desc(
"Experimental feature: enable i1 data type support in codegen"),
"Experimental feature: force to use packed storage for i1 tensors."
"Turning on this option will see i1 tensors as if it has "
"#iree_encoding.packed_storage attribute."
"This is to allow an alternative way to test the packed storage "
"feature before frontend can emit packed i1 tensors."
"This option can be dropped once the frontend can emit packed i1 "
"tensors."),
llvm::cl::init(false));

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
namespace mlir::iree_compiler {

static bool needToPackSubByteElementBitWidthImpl(unsigned bitWidth,
bool isPackedStorage) {
// Enable i1 support if requested.
if (clEnableI1Support && bitWidth == 1) {
if (isPackedStorage && bitWidth == 1) {
return true;
}
// Require the original bit width to be some power of two for now to avoid
Expand All @@ -35,20 +43,31 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
}

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
return needToPackSubByteElementBitWidthImpl(
bitWidth, /*isPackedStorage=*/clEnableI1Support);
}

bool needToPackSubByteElements(RankedTensorType shapedType) {
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
return needToPackSubByteElementBitWidth(bitWidth);
// Two paths to enable packed storage for i1 tensors: the attribute or cl
// option. The cl option will be dropped once frontend supports emitting
// tensors with attributes.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
return needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage);
}

Type legalizeStorageElementType(Type elementType) {
static Type legalizeStorageElementTypeImpl(Type elementType,
bool isPackedStorage) {
// Only handle integers; floats in MLIR all have aligned widths (today).
auto intType = dyn_cast<IntegerType>(elementType);
if (!intType)
return elementType;

// For sub-byte elements, default to pack them into bytes.
unsigned bitWidth = intType.getWidth();
if (needToPackSubByteElementBitWidth(bitWidth))
if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage))
return elementType;

// Otherwise, extend them to the next power-of-two bit width.
Expand All @@ -60,6 +79,12 @@ Type legalizeStorageElementType(Type elementType) {
intType.getSignedness());
}

Type legalizeStorageElementType(Type elementType) {
// Consider packed storage for i1 tensors if cl opt is set.
return legalizeStorageElementTypeImpl(elementType,
/*isPackedStorage=*/clEnableI1Support);
}

Value calculateStorageElementCountInBytes(Location loc,
RankedTensorType shapedType,
ValueRange dynamicDims,
Expand All @@ -72,13 +97,15 @@ Value calculateStorageElementCountInBytes(Location loc,
loc, builder, shapedType, dynamicDims);
}

Type alignedElementType =
legalizeStorageElementType(shapedType.getElementType());
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
Type alignedElementType = legalizeStorageElementTypeImpl(
shapedType.getElementType(), isPackedStorage);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Calculate all static dims first, if any.
int64_t staticCount = 1;
if (!needToPackSubByteElementBitWidth(elementBits)) {
if (!needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
}

Expand All @@ -93,13 +120,13 @@ Value calculateStorageElementCountInBytes(Location loc,
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidth(elementBits)) {
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
assert(8 % elementBits == 0);
unsigned byteElements = 8 / elementBits;
// TODO(antiagainst): We may want to emit runtime check to make sure this is
// divisible.
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
if (!clEnableI1Support && dynamicDims.empty() &&
if (!isPackedStorage && dynamicDims.empty() &&
(staticCount * elementBits) % 8 != 0) {
return nullptr;
}
Expand All @@ -113,12 +140,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
RankedTensorType originalType,
Value linearizedIndex,
OpBuilder &builder) {
Type alignedElementType =
legalizeStorageElementType(originalType.getElementType());
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(originalType) || clEnableI1Support;
Type alignedElementType = legalizeStorageElementTypeImpl(
originalType.getElementType(), isPackedStorage);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidth(elementBits)) {
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
Value byteElements =
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace mlir::iree_compiler {
/// Returns true if the given |bitWidth|, if appearing at runtime-kernel
/// interface, is less than a byte that should be tightly packed together.
bool needToPackSubByteElementBitWidth(unsigned bitWidth);

/// Returns true if the given |shapedType|, if appearing at runtime-kernel
/// interface, has sub-byte element types that should be tightly packed
/// together.
Expand Down
16 changes: 6 additions & 10 deletions tests/e2e/subbyte_types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,14 @@ package(
licenses = ["notice"], # Apache 2.0
)

LLVM_SRCS = enforce_glob(
# keep sorted
[
"subbyte_types.mlir",
],
include = ["*.mlir"],
exclude = [],
)

iree_check_single_backend_test_suite(
name = "check_llvm-cpu_subbyte_emulation",
srcs = LLVM_SRCS,
srcs = enforce_glob(
[
"subbyte_types.mlir",
],
include = ["*.mlir"],
),
compiler_flags = [
"--iree-llvmcpu-target-cpu=generic",
"--iree-experimental-packed-i1-storage",
Expand Down

0 comments on commit c793f90

Please sign in to comment.