Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CodeGen Bug #10105

Merged
merged 17 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/ir/include/OneFlow/Conversion/OneFlowToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace oneflow {
std::unique_ptr<mlir::Pass> createLowerOneFlowToTosaPass();
std::unique_ptr<mlir::Pass> createLowerOneFlowToLinalgPass();
std::unique_ptr<mlir::Pass> createConvertToSignlessForTosaPass();
std::unique_ptr<mlir::Pass> createCastOneFlowInputToSignlessPass();
std::unique_ptr<mlir::Pass> createCastOneFlowOpsToSignlessPass();

} // namespace oneflow

Expand Down
2 changes: 1 addition & 1 deletion oneflow/ir/include/OneFlow/OneFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def OneFlow_ReturnOp : Op<OneFlow_Dialect, "return", [NoMemoryEffect, HasParent<
let hasVerifier = 1;
}

def OneFlow_NormalizationInferenceOp : OneFlow_NormalizationBaseOp<"normalization_infer", [DeclareOpInterfaceMethods<AlternativeOpTypeNameInterface>]> {
def OneFlow_NormalizationInferenceOp : OneFlow_NormalizationBaseOp<"normalization_infer", [DeclareOpInterfaceMethods<AlternativeOpTypeNameInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {
let output = (outs
OneFlow_Tensor:$y
);
Expand Down
6 changes: 3 additions & 3 deletions oneflow/ir/include/OneFlow/OneFlowPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def OneFlowJobToFuncPass : Pass<"ofjob-to-func", "ModuleOp"> {
let dependentDialects = ["mlir::func::FuncDialect"];
}

def CastOneFlowInputToSignlessPass : Pass<"cast-ofinput-to-signless", "ModuleOp"> {
let summary = "cast oneflow input to singless";
let constructor = "mlir::oneflow::createCastOneFlowInputToSignlessPass()";
def CastOneFlowOpsToSignlessPass : Pass<"cast-ofops-to-signless", "ModuleOp"> {
let summary = "cast oneflow ops to singless";
let constructor = "mlir::oneflow::createCastOneFlowOpsToSignlessPass()";
let dependentDialects = ["mlir::func::FuncDialect", "mlir::BuiltinDialect"];
}

Expand Down
183 changes: 110 additions & 73 deletions oneflow/ir/lib/OneFlow/Conversion/OneFlowToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ RankedTensorType CreateTransposeType(ShapedType output, ArrayRef<int32_t> perms)
return RankedTensorType::get(ranked_type, output.getElementType());
};

Value CreateBNOp(Location loc, ConversionPatternRewriter& rewriter, Value output, Value x,
Value CreateBNOp(Location loc, ConversionPatternRewriter& rewriter, Type output_type, Value x,
Value mean, Value variance, Value epsilon, Value gamma, Value beta) {
const auto output_type = output.getType();
// sub_op = sub(input, mean)
auto sub_op0 = rewriter.create<tosa::SubOp>(loc, output_type, x, mean);
// add_op0 = add(var, epsilon)
Expand All @@ -134,7 +133,7 @@ Value CreateBNOp(Location loc, ConversionPatternRewriter& rewriter, Value output
// op5 = mul(mul_op0, gamma)
auto mul_op1 = rewriter.create<tosa::MulOp>(loc, output_type, mul_op0, gamma, 0);
// op6 = add(mul_op1, beta)
auto batch_norm = rewriter.create<tosa::AddOp>(loc, output_type, mul_op1, beta);
Value batch_norm = rewriter.create<tosa::AddOp>(loc, output_type, mul_op1, beta);
return batch_norm;
};

Expand Down Expand Up @@ -365,18 +364,27 @@ struct MaxPool2DOpLowering final : public OpConversionPattern<MaxPool2DOp> {
auto pad_pairs = get_pair_int64_from_array(op.getPadding());

auto loc = op.getLoc();
auto perms = {0, 2, 3, 1};

const auto kernel = rewriter.getDenseI64ArrayAttr({kernel_pairs.first, kernel_pairs.second});
const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second});
const auto pad = rewriter.getDenseI64ArrayAttr(
{pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second});

auto input = CreateTransposeValue(loc, rewriter, op.getX(), perms);
auto output = CreateTransposeType(op.getY().getType().cast<ShapedType>(), perms);

auto max_pool2d = rewriter.create<tosa::MaxPool2dOp>(loc, output, input, kernel, stride, pad);
auto y = CreateTransposeValue(loc, rewriter, max_pool2d, {0, 3, 1, 2});
auto input = op.getX();
auto out_type = op.getY().getType().cast<ShapedType>();

Value y;
if (op.IsNCHW()) {
auto perms = {0, 2, 3, 1};
auto reverse_perms = {0, 3, 1, 2};
input = CreateTransposeValue(loc, rewriter, input, perms);
out_type = CreateTransposeType(out_type, perms);
auto max_pool2d =
rewriter.create<tosa::MaxPool2dOp>(loc, out_type, input, kernel, stride, pad);
y = CreateTransposeValue(loc, rewriter, max_pool2d, reverse_perms);
} else {
y = rewriter.create<tosa::MaxPool2dOp>(loc, out_type, input, kernel, stride, pad);
}

auto indice_output = convertToSignless(op->getContext(), op.getIndice().getType());
auto value = DenseElementsAttr::get(indice_output, rewriter.getZeroAttr(rewriter.getI64Type()));
Expand Down Expand Up @@ -473,31 +481,31 @@ struct NormalizationInferenceOpLowering final
using OpConversionPattern<NormalizationInferenceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(NormalizationInferenceOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto reshape_dim = [&](Type type, Value value) -> Value {
RankedTensorType in_type = value.getType().dyn_cast<RankedTensorType>();
RankedTensorType out_type = type.cast<RankedTensorType>();
SmallVector<int64_t> new_shape = {in_type.getShape()[0]};
for (auto i = 2; i < out_type.getRank(); ++i) new_shape.push_back(1);
auto new_type = RankedTensorType::get(new_shape, out_type.getElementType());
return rewriter.create<tosa::ReshapeOp>(op->getLoc(), new_type, value,
rewriter.getDenseI64ArrayAttr(new_shape));
};

auto loc = op->getLoc();
const auto out_type = op.getY().getType();

const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type());
auto epsilon = rewriter.create<tosa::ConstOp>(
loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon()));
auto mean = reshape_dim(out_type, op.getMovingMean());
auto variance = reshape_dim(out_type, op.getMovingVariance());
auto gamma = reshape_dim(out_type, op.getGamma());
auto beta = reshape_dim(out_type, op.getBeta());
auto output = op.getY();
auto mean = op.getMovingMean();
auto variance = op.getMovingVariance();
auto gamma = op.getGamma();
auto beta = op.getBeta();
auto output_type = op.getY().getType();
auto x = op.getX();

if (op.IsNCHW()) {
const auto perms = {0, 2, 3, 1};
x = CreateTransposeValue(loc, rewriter, x, perms);
output_type = CreateTransposeType(output_type, perms);
}

auto batch_norm =
oneflow::CreateBNOp(loc, rewriter, output, x, mean, variance, epsilon, gamma, beta);
oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta);

if (op.IsNCHW()) {
const auto reverse_perms = {0, 3, 1, 2};
batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms);
}
rewriter.replaceOp(op, {batch_norm});
return success();
}
Expand All @@ -508,36 +516,31 @@ struct NormalizationOpLowering final : public OpConversionPattern<NormalizationO
using OpConversionPattern<NormalizationOp>::OpConversionPattern;
LogicalResult matchAndRewrite(NormalizationOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto reshape_dim = [&](Type type, Value value) -> Value {
const RankedTensorType in_type = value.getType().dyn_cast<RankedTensorType>();
const RankedTensorType out_type = type.cast<RankedTensorType>();
SmallVector<int64_t> new_shape = {in_type.getShape()[0]};
for (auto i = 2; i < out_type.getRank(); ++i) new_shape.push_back(1);
const auto new_type = RankedTensorType::get(new_shape, out_type.getElementType());
return rewriter.create<tosa::ReshapeOp>(op->getLoc(), new_type, value,
rewriter.getDenseI64ArrayAttr(new_shape));
};

auto loc = op->getLoc();
const auto out_type = op.getY().getType();

const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type());
// epsilon = reshape(epsilon, shape_1)
auto epsilon = rewriter.create<tosa::ConstOp>(
loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon()));
// mean = reshape(mean, shape_0)
auto mean = reshape_dim(out_type, op.getMovingMean());
// variance= reshape(variance, shape_0)
auto variance = reshape_dim(out_type, op.getMovingVariance());
// scale = reshape(scale, shape_0)
auto gamma = reshape_dim(out_type, op.getGamma());
// beta = reshape(beta, shape_0)
auto beta = reshape_dim(out_type, op.getBeta());
auto output = op.getY();
auto mean = op.getMovingMean();
auto variance = op.getMovingVariance();
auto gamma = op.getGamma();
auto beta = op.getBeta();
auto output_type = op.getY().getType();
auto x = op.getX();

if (op.IsNCHW()) {
const auto perms = {0, 2, 3, 1};
x = CreateTransposeValue(loc, rewriter, x, perms);
output_type = CreateTransposeType(output_type, perms);
}

auto batch_norm =
oneflow::CreateBNOp(loc, rewriter, output, x, mean, variance, epsilon, gamma, beta);
oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta);

if (op.IsNCHW()) {
const auto reverse_perms = {0, 3, 1, 2};
batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms);
}
auto moving_mean = op.getMovingMean();
auto moving_variance = op.getMovingVariance();

Expand Down Expand Up @@ -570,23 +573,37 @@ struct Conv2DOpLowering final : public OpConversionPattern<Conv2DOp> {
auto loc = op.getLoc();
if (!bias) {
const auto output_shape = op.getOut().getType().cast<ShapedType>();
const auto output_channels = output_shape.getDimSize(1);
// support nhwc
const auto output_channels = output_shape.getDimSize(op.IsNCHW() ? 1 : 3);
const auto bias_elem_type = output_shape.getElementType();
const auto type = RankedTensorType::get(output_channels, bias_elem_type);
bias = rewriter.create<tosa::ConstOp>(
op.getLoc(), type, DenseElementsAttr::get(type, rewriter.getZeroAttr(bias_elem_type)));
}

auto perms = {0, 2, 3, 1};
auto in = CreateTransposeValue(loc, rewriter, op.getIn(), perms);
auto weight = CreateTransposeValue(loc, rewriter, op.getWeight(), perms);
const auto output = CreateTransposeType(op.getOut().getType().cast<ShapedType>(), perms);

auto conv2d =
rewriter.create<tosa::Conv2DOp>(loc, output, in, weight, bias, pad, stride, dilation);

auto res = CreateTransposeValue(loc, rewriter, conv2d, {0, 3, 1, 2});
rewriter.replaceOp(op, {res});
auto in = op.getIn();
auto weight = op.getWeight();
auto out_type = op.getOut().getType().cast<ShapedType>();
if (out_type.getRank() != 4) {
LOG(FATAL) << "Failed to lowering oneflow op";
op->dump();
}
// support nhwc
if (op.IsNCHW()) {
const auto perms = {0, 2, 3, 1};
const auto reverse_perms = {0, 3, 1, 2};
in = CreateTransposeValue(loc, rewriter, in, perms);
weight = CreateTransposeValue(loc, rewriter, weight, perms);
out_type = CreateTransposeType(out_type, perms);
auto conv2d =
rewriter.create<tosa::Conv2DOp>(loc, out_type, in, weight, bias, pad, stride, dilation);

auto res = CreateTransposeValue(loc, rewriter, conv2d, reverse_perms);
rewriter.replaceOp(op, {res});
} else {
rewriter.replaceOpWithNewOp<tosa::Conv2DOp>(op, out_type, in, weight, bias, pad, stride,
dilation);
}
return success();
}
};
Expand Down Expand Up @@ -623,29 +640,48 @@ struct CastInputConversion final : public OpRewritePattern<InputOp> {
if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); }
}
}
LOG(ERROR) << "ok4";
InputOp cloned = rewriter.create<InputOp>(op->getLoc(), op.getResultTypes(), op->getOperands(),
op->getAttrs());
auto m = op->getParentOp();
m->dump();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput());
m->dump();
return success();
}
};

struct CastVariableConversion final : public OpRewritePattern<VariableOp> {
public:
explicit CastVariableConversion(mlir::MLIRContext* context)
: OpRewritePattern<VariableOp>(context, /*benefit=*/0) {}
mlir::LogicalResult matchAndRewrite(VariableOp op,
mlir::PatternRewriter& rewriter) const override {
auto outType = op.getOutput().getType();
if (isSignLessTensorOrOther(outType)) { return failure(); }
if (op->hasOneUse()) {
if (auto cast =
llvm::dyn_cast<UnrealizedConversionCastOp>(op.getOutput().use_begin()->getOwner())) {
if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); }
}
}
if (op.getOutput().getUses().empty()) { return failure(); }
VariableOp cloned = rewriter.create<VariableOp>(op->getLoc(), op.getResultTypes(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里叫casted会不会语意更明确点

op->getOperands(), op->getAttrs());
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput());
return success();
}
};

namespace {

class CastOneFlowInputToSignlessPass
: public CastOneFlowInputToSignlessPassBase<CastOneFlowInputToSignlessPass> {
class CastOneFlowOpsToSignlessPass
: public CastOneFlowOpsToSignlessPassBase<CastOneFlowOpsToSignlessPass> {
void getDependentDialects(::mlir::DialectRegistry& registry) const override {
registry.insert<oneflow::OneFlowDialect>();
}
void runOnOperation() override {
Operation* op = getOperation();
RewritePatternSet patterns(&getContext());
patterns.add<oneflow::CastInputConversion>(op->getContext());
patterns.add<oneflow::CastInputConversion, oneflow::CastVariableConversion>(op->getContext());

(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
Expand Down Expand Up @@ -691,12 +727,13 @@ void OneFlowLoweringToTosaPass::runOnOperation() {
});
RewritePatternSet patterns(context);

const auto mgr = ::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get();
// check if the pass is triggered by python based on the presence of variable tensor manger
if (mgr) {
patterns.add<VariableOpLowering>(typeConverter, context);
} else {
patterns.add<VariableOpToConstLowering>(typeConverter, context, this->variableAsConstant);
if (fullyConvert) {
if (::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()) {
patterns.add<VariableOpLowering>(typeConverter, context);
} else {
patterns.add<VariableOpToConstLowering>(typeConverter, context, this->variableAsConstant);
}
}
patterns.add<CastOpLowering, ScalarMulByTensorOpLowering, ReluOpLowering, Conv2DOpLowering,
AvgPool2DOpLowering, ReshapeOpLowering, Add2OpLowering, MaxPool2DOpLowering,
Expand Down Expand Up @@ -768,8 +805,8 @@ void ConvertToSignlessForTosaPass::runOnOperation() {
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}

std::unique_ptr<Pass> createCastOneFlowInputToSignlessPass() {
return std::make_unique<CastOneFlowInputToSignlessPass>();
std::unique_ptr<Pass> createCastOneFlowOpsToSignlessPass() {
return std::make_unique<CastOneFlowOpsToSignlessPass>();
}

} // namespace oneflow
Expand Down
2 changes: 1 addition & 1 deletion oneflow/ir/lib/OneFlow/OneFlowInferReturnTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ ::mlir::LogicalResult inferReturnTypesWithOpTypeName(
};
::oneflow::ParallelConf parallel_conf = user_op::getParallelConfFromAttrDictionary(attributes);
::oneflow::ParallelDesc parallel_desc{parallel_conf};
op->FillOpParallelDesc(parallel_desc);
CHECK_JUST(op->FillOpParallelDesc(parallel_desc));
CHECK_JUST(op->InferLogicalOutBlobDescs(GetLogicalBlobDesc4BnInOp, parallel_desc));
for (const auto& result_id : result_ids) {
const auto& arg_name = result_id.first;
Expand Down
31 changes: 31 additions & 0 deletions oneflow/ir/lib/OneFlow/Transform/AutoNHWCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,18 @@ llvm::SmallVector<Value, 4> BroadcastAddOp::NchwToNhwc(llvm::SmallVector<Value,

bool NormalizationOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; }

bool NormalizationInferenceOp::IsNCHW() {
return this->getAxisAttr().getValue().getSExtValue() == 1;
}

llvm::DenseSet<Value> NormalizationOp::OperandsToTranspose() { return {this->getX()}; }

llvm::DenseSet<Value> NormalizationInferenceOp::OperandsToTranspose() { return {this->getX()}; }

llvm::DenseSet<Value> NormalizationOp::ResultsToTranspose() { return {this->getY()}; }

llvm::DenseSet<Value> NormalizationInferenceOp::ResultsToTranspose() { return {this->getY()}; }

llvm::SmallVector<Value, 4> NormalizationOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,
PatternRewriter& rewriter) {
auto normalization_op = *this;
Expand All @@ -122,6 +130,29 @@ llvm::SmallVector<Value, 4> NormalizationOp::NchwToNhwc(llvm::SmallVector<Value,
return results;
}

llvm::SmallVector<Value, 4> NormalizationInferenceOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,
PatternRewriter& rewriter) {
auto normalization_op = *this;
SmallVector<Value, 4> operands;
operands.push_back(value[0]);
if (normalization_op.getMovingMean()) operands.push_back(normalization_op.getMovingMean());
if (normalization_op.getMovingVariance())
operands.push_back(normalization_op.getMovingVariance());
operands.push_back(normalization_op.getGamma());
operands.push_back(normalization_op.getBeta());
if (normalization_op.get_addToOutput()) operands.push_back(normalization_op.get_addToOutput());
NamedAttrList attributes = normalization_op->getAttrs();
attributes.set(normalization_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3));
auto res =
rewriter
.create<oneflow::NormalizationInferenceOp>(
normalization_op.getLoc(), getNHWCResultTypes(normalization_op), operands, attributes)
->getResults();
llvm::SmallVector<Value, 4> results;
results.push_back(res[0]);
return results;
}

bool MaxPool2DOp::IsNCHW() { return this->getDataFormat().str() == "channels_first"; }

llvm::DenseSet<Value> MaxPool2DOp::OperandsToTranspose() { return {this->getX()}; }
Expand Down
3 changes: 2 additions & 1 deletion oneflow/ir/oneflow-opt/oneflow-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ int32_t main(int32_t argc, char** argv) {
mlir::registerBufferHostRegisterPassPass();
mlir::registerGpuCopyArgPassPass();
mlir::registerOneFlowJobToFuncPassPass();
mlir::registerCastOneFlowInputToSignlessPassPass();
mlir::registerCastOneFlowOpsToSignlessPassPass();
mlir::registerFuncToOneFlowJobPassPass();
mlir::registerAutoNhwcPass();
#ifdef WITH_MLIR_CUDA_CODEGEN
mlir::oneflow::registerGpuSerializeToCubinPass();
#endif // WITH_MLIR_CUDA_CODEGEN
Expand Down
Loading