Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] add interReturnType function for byteir.reshape and byteir.one_hot #451

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ void registerDotInferReturnTypeComponents();
void registerDotGeneralInferReturnTypeComponents();
void registerDynamicBroadcastInDimInferReturnTypeComponents();
void registerDynamicReshapeInferReturnTypeComponents();
void registerReshapeInferReturnTypeComponents();
void registerRealDynamicSliceInferReturnTypeComponents();
void registerReduceInferReturnTypeComponents();
void registerSoftmaxInferReturnTypeComponents();
void registerAddNInferReturnTypeComponents();
void registerOneHotInferReturnTypeComponents();
void registerTorchIndexSelectInferReturnTypeComponents();
void registerGeLUInferReturnTypeComponents();
void registerLayerNormInferReturnTypeComponents();
Expand All @@ -45,10 +47,12 @@ inline void registerAllMhloInferReturnTypeComponents() {
registerDotGeneralInferReturnTypeComponents();
registerDynamicBroadcastInDimInferReturnTypeComponents();
registerDynamicReshapeInferReturnTypeComponents();
registerReshapeInferReturnTypeComponents();
registerRealDynamicSliceInferReturnTypeComponents();
registerReduceInferReturnTypeComponents();
registerSoftmaxInferReturnTypeComponents();
registerAddNInferReturnTypeComponents();
registerOneHotInferReturnTypeComponents();
registerTorchIndexSelectInferReturnTypeComponents();
registerGeLUInferReturnTypeComponents();
registerLayerNormInferReturnTypeComponents();
Expand Down
5 changes: 4 additions & 1 deletion compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ constexpr llvm::StringRef getOneHotName() {
return CUSTOM_CALL_NAME_PREFIX "one_hot";
}

constexpr llvm::StringRef getReshapeName() {
return CUSTOM_CALL_NAME_PREFIX "reshape";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to add byteir.reshape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are four reasons for switching to byteir.reshape instead of retaining the original tf.Reshape form:

  1. If possible, convert regular op (such as reshape op) to custom_call@byteir., rather than custom_call@tf.
  2. custom_call@byteir.*** indicates that it cannot be converted to regular mhlo op, but we have identified this case. But the form custom_call@tf.*** does not convey this meaning
  3. when other front-end reshape op appears and cannot be converted into regular mhlo.reshape op and mhlo.dynamic_reshape op, it can also be converted into byteir.reshape, a unified representation
  4. getReshapeName function will be used in shape/bounded shape inference

}

constexpr llvm::StringRef getAddNName() {
return CUSTOM_CALL_NAME_PREFIX "addn";
}
Expand Down Expand Up @@ -130,7 +134,6 @@ constexpr llvm::StringRef getStridedSliceName() {
constexpr llvm::StringRef getBatchMatMulName() {
return TF_NAME_PREFIX "BatchMatMul";
}

} // namespace mlir

#undef TF_NAME_PREFIX
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Dialect/mhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_byteir_dialect_library(ByteIRMhloDynamicShapeOpRegister
DynamicShapeOpRegister/ReshapeLike.cpp
DynamicShapeOpRegister/Softmax.cpp
DynamicShapeOpRegister/AddN.cpp
DynamicShapeOpRegister/OneHot.cpp
DynamicShapeOpRegister/TorchIndexSelect.cpp
DynamicShapeOpRegister/ScatterNd.cpp
DynamicShapeOpRegister/StridedSlice.cpp
Expand Down
76 changes: 76 additions & 0 deletions compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/OneHot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//===- OneHot.cpp ---------------------------------------------*--- C++ -*-===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "dynamic-shape-op-register"

using namespace mlir;

void mlir::registerOneHotInferReturnTypeComponents() {
static InferReturnTypeComponentsRegistration shapeRegister(
getOneHotName(),
[](MLIRContext *context, std::optional<Location> loc,
ValueShapeRange operands, DictionaryAttr attr,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnTypes) {
ShapedType dataType = dyn_cast<ShapedType>(operands[0].getType());
if (!dataType) {
LLVM_DEBUG(llvm::dbgs() << loc << ": get dataType failed\n");
return failure();
}
int64_t axis = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<IntegerAttr>("axis")
.getInt();
int64_t depth = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<IntegerAttr>("depth")
.getInt();
Attribute onValue = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<Attribute>("on_value");
Type onValueType;
if (dyn_cast<IntegerAttr>(onValue)) {
onValueType = dyn_cast<IntegerAttr>(onValue).getType();
} else if (dyn_cast<FloatAttr>(onValue)) {
onValueType = dyn_cast<FloatAttr>(onValue).getType();
} else {
LLVM_DEBUG(llvm::dbgs()
<< loc << ": get output element type failed\n");
return failure();
}

auto dataShape = dataType.getShape();
llvm::SmallVector<int64_t> outShape;
for (int64_t i = 0; i < dataShape.size(); ++i) {
if (axis == i) {
outShape.push_back(depth);
}
outShape.push_back(dataShape[i]);
}
if (-1 == axis || axis >= dataShape.size()) {
outShape.push_back(depth);
}
inferredReturnTypes.emplace_back(outShape, onValueType);
return success();
});
}
50 changes: 50 additions & 0 deletions compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/ReshapeLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "byteir/Dialect/Shape/IR/ShapeExtOps.h"
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -129,3 +130,52 @@ void mlir::registerDynamicReshapeInferReturnTypeComponents() {
return success();
});
}

void mlir::registerReshapeInferReturnTypeComponents() {
static InferReturnTypeComponentsRegistration shapeRegister(
getReshapeName(),
[](MLIRContext *context, std::optional<Location> loc,
ValueShapeRange operands, DictionaryAttr, OpaqueProperties properties,
RegionRange,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnTypes) {
auto input = operands[0];
ShapedType inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType) {
LLVM_DEBUG(llvm::dbgs() << loc << ": get inputType failed\n");
return failure();
}
mlir::ShapeAdaptor shapeAdaptor = operands.getValueAsShape(1);
if (!shapeAdaptor) {
return failure();
}
if (!inputType.hasStaticShape() && !shapeAdaptor.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs() << loc << ": shape is dynamic\n");
return failure();
}
llvm::SmallVector<int64_t> shape;
shapeAdaptor.getDims(shape);
int negativeNum = std::count_if(shape.begin(), shape.end(),
[](int64_t i) { return i < 0; });
if (negativeNum > 1) {
LLVM_DEBUG(llvm::dbgs() << loc << ": shape is dynamic\n");
return failure();
}
if (negativeNum == 1) {
int64_t product = inputType.getNumElements();
int64_t dynamicDim = product;
for (auto dim : shape) {
if (dim > 0) {
dynamicDim /= dim;
}
}
for (int64_t i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
shape[i] = dynamicDim;
}
}
}
inferredReturnTypes.emplace_back(shape, inputType.getElementType());

return success();
});
}
2 changes: 1 addition & 1 deletion frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func.func @replace_where_3D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24x8xf16
}
// CHECK-LABEL: func.func @replace_where_3D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24x8xf16>) -> tensor<?x8xf16> {
// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[-9223372036854775808, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[-1, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<28800> : tensor<i64>}> : () -> tensor<i64>
// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() <{value = dense<86400> : tensor<i64>}> : () -> tensor<i64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: tf-ext-opt
// -rewrite-to-custom-call="ops=softmax,log_softmax,gelu,erf,arg_max,arg_min,top_k,layer_norm,l2_norm,addn,one_hot,DynamicMaskStitch,DynamicPartition,DynamicStitch"
// -rewrite-to-custom-call="ops=softmax,log_softmax,gelu,erf,arg_max,arg_min,top_k,layer_norm,l2_norm,addn,one_hot,reshape,DynamicMaskStitch,DynamicPartition,DynamicStitch"
// %s | FileCheck %s

func.func @softmax_case0(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32> {
Expand Down Expand Up @@ -731,3 +731,12 @@ func.func @onehot_case0(%arg0: tensor<150xi32>) -> tensor<150x16xf32> {
// tensor<150x16xf32> { CHECK: mhlo.custom_call CHECK-SAME: @byteir.one_hot
// CHECK-SAME: byteir_attrs = {axis = 1 : i64, depth = 16 : i64, off_value =
// 0.000000e+00 : f32, on_value = 1.000000e+00 : f32}

func.func @reshape_case0(%arg0: tensor<?x24xf16>) -> tensor<?x24x1xf16> {
%cst = "tf.Const"() <{value = dense<[-1, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
%0 = "tf.Reshape"(%arg0, %cst) : (tensor<?x24xf16>, tensor<3xi64>) -> tensor<?x24x1xf16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

=> compute_reshape_shape + mhlo.dynamic_reshape

return %0 : tensor<?x24x1xf16>
}
// CHECK-LABEL: func.func @reshape_case0
// CHECK: mhlo.custom_call
// CHECK-SAME: @byteir.reshape
6 changes: 5 additions & 1 deletion frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ Value replaceWhereStatic(PatternRewriter &rewriter, Location loc, Value input,
}
auto shapeType =
RankedTensorType::get({inputShape.size()}, rewriter.getIntegerType(64));
auto shapeAttr = DenseIntElementsAttr::get(shapeType, oneHotShape);
SmallVector<int64_t> shapeVec;
for (auto s : oneHotShape) {
shapeVec.push_back((s < 0) ? -1 : s);
}
auto shapeAttr = DenseIntElementsAttr::get(shapeType, shapeVec);
Value shape = rewriter.create<TF::ConstOp>(loc, shapeAttr);
oneHotOutputType = oneHotOutputType.clone(oneHotShape);
oneHotOutput = rewriter.create<TF::ReshapeOp>(loc, oneHotOutputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ namespace {
cb(one_hot, OneHot, CALL_TARGET_NAME_PREFIX) \
cb(repeat, Repeat, CALL_TARGET_NAME_PREFIX) \
cb(non_zero, Where, CALL_TARGET_NAME_PREFIX) \
cb(reshape, Reshape, CALL_TARGET_NAME_PREFIX) \
cb(DynamicMaskStitch, DynamicMaskStitch, CALL_TF_TARGET_NAME_PREFIX) \
cb(DynamicPartition, DynamicPartition, CALL_TF_TARGET_NAME_PREFIX) \
cb(DynamicStitch, DynamicStitch, CALL_TF_TARGET_NAME_PREFIX)
Expand All @@ -87,7 +88,8 @@ namespace {
cb(TF::ArgMaxOp, ArgMax) \
cb(TF::ErfOp, Erf) \
cb(TF::AddNOp, AddN) \
cb(TF::OneHotOp, OneHot)
cb(TF::OneHotOp, OneHot) \
cb(TF::ReshapeOp, Reshape)
// clang-format on

VALID_CUSTOM_CALL_OP(GEN_FUNCNAME) template <typename TF_OP> struct WrapName;
Expand Down Expand Up @@ -593,6 +595,39 @@ struct RewriteOneHot : public OpRewritePattern<TF::OneHotOp> {
}
};

//===----------------------------------------------------------------------===//
// Reshape Pattern
//===----------------------------------------------------------------------===//
struct RewriteReshape : public OpRewritePattern<TF::ReshapeOp> {
using OpRewritePattern<TF::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::ReshapeOp op,
PatternRewriter &rewriter) const override {
auto tensor = op.getTensor();
auto shape = op.getShape();
auto output = op.getOutput();
auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!tensorType || !outputType) {
return failure();
}
if (tensorType.hasStaticShape() || outputType.hasStaticShape()) {
return failure();
}
mhlo::CustomCallOp customCallOp = rewriter.create<mlir::mhlo::CustomCallOp>(
op->getLoc(), op->getResults().getTypes(), op->getOperands(),
getReshapeNameWithPrefix(), false, rewriter.getStringAttr(""),
mhlo::CustomCallApiVersion{
mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL},
rewriter.getArrayAttr(ArrayRef<Attribute>{}),
mhlo::CustomCallSchedule{mhlo::CustomCallSchedule::NONE}, nullptr,
nullptr, rewriter.getArrayAttr(ArrayRef<Attribute>{}));
customCallOp->setAttr(getByteIRAttrs(), getCleanAttr(op));
rewriter.replaceOp(op.getOperation(), customCallOp->getResults());

return success();
}
};

//===----------------------------------------------------------------------===//
// Repeat Pattern
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -882,6 +917,8 @@ struct RewriteToCustomCallOpsPass
// patterns with c++
validCustomCallOpSet[getOneHotName()].emplace_back(
std::make_unique<RewriteOneHot>(context, 1));
validCustomCallOpSet[getReshapeName()].emplace_back(
std::make_unique<RewriteReshape>(context, 1));
validCustomCallOpSet[getAddNName()].emplace_back(
std::make_unique<RewriteSimpleReplace<TF::AddNOp, false>>(context,
1));
Expand Down
Loading