Skip to content

Commit

Permalink
[torch-frontend] update torch-mlir (#139)
Browse files Browse the repository at this point in the history
* update torch-mlir to latest
* enable `torch_frontend.compile` to emit stablehlo versioned bytecode.

Co-authored-by: Jiawei Wu <wujiawei.aml@bytedance.com>
  • Loading branch information
qingyunqu and Vremold authored Mar 26, 2024
1 parent 156c573 commit 9651131
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 61 deletions.
2 changes: 1 addition & 1 deletion frontends/torch-frontend/scripts/envsetup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ PROJ_DIR="$ROOT_PROJ_DIR/frontends/torch-frontend"
TORCH_MLIR_ROOT="$PROJ_DIR/third_party/torch-mlir"

function load_pytorch_llvm_prebuilt() {
TORCH_FRONTEND_LLVM_INSTALL_DIR="/data00/llvm_libraries/0cb024b357aff294b1ba0f9d3de8f48ab684962b/llvm_build"
TORCH_FRONTEND_LLVM_INSTALL_DIR="/data00/llvm_libraries/e5ed7b6e2fd368b722b6359556cd0125881e7638/llvm_build"
}

function apply_patches() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@ diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/
index c09900ce..f080990b 100644
--- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
+++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
@@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
+ let hasFolder = 1;
}

def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [
@@ -15185,3 +15186,90 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
}];
}
Expand Down
24 changes: 0 additions & 24 deletions frontends/torch-frontend/third_party/patches/torch_ops_cpp.patch

This file was deleted.

9 changes: 3 additions & 6 deletions frontends/torch-frontend/third_party/patches/tuple.patch
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
index 30cc4db4..96d04fd1 100644
index 2891a22e..096635c2 100644
--- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
+++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
@@ -194,9 +194,8 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeConverter typeConverter;
@@ -195,7 +195,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
- [](Torch::TupleType type,
- SmallVectorImpl<Type> &types) -> LogicalResult {
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
- llvm::append_range(types, type.getContainedTypes());
+ [](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
+ // llvm::append_range(types, type.getContainedTypes());
return success();
});
Expand Down
2 changes: 1 addition & 1 deletion frontends/torch-frontend/third_party/torch-mlir
Submodule torch-mlir updated 252 files
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include "./PassDetail.h"

#include <unordered_set>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
Expand Down Expand Up @@ -109,8 +111,17 @@ stablehlo::ReduceOp createSingleOpReduce(PatternRewriter &rewriter,
auto inputType = input.getType().cast<RankedTensorType>();
stablehlo::ConstantOp initValue = createInitialValueForReduceOp<OP>(
rewriter, loc, inputType.getElementType());

std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> outputShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
outputShape.push_back(inputType.getDimSize(i));
}
}
stablehlo::ReduceOp reduceOp = rewriter.create<stablehlo::ReduceOp>(
loc, input, initValue.getOutput(), rewriter.getI64TensorAttr(dims));
loc, RankedTensorType::get(outputShape, inputType.getElementType()),
input, initValue.getOutput(), rewriter.getDenseI64ArrayAttr(dims));

Block &block = reduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputType.getElementType());
Expand Down Expand Up @@ -329,13 +340,13 @@ class ConvertAtenGroupNormOp : public OpConversionPattern<AtenOpT> {
// group_norm weight/bias
if (!weight.getType().template isa<Torch::NoneType>()) {
Value bcastWeight = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, weight, rewriter.getI64TensorAttr({1}));
op->getLoc(), outType, weight, rewriter.getDenseI64ArrayAttr({1}));
result =
rewriter.create<stablehlo::MulOp>(op->getLoc(), result, bcastWeight);
}
if (!bias.getType().template isa<Torch::NoneType>()) {
Value bcastBias = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, bias, rewriter.getI64TensorAttr({1}));
op->getLoc(), outType, bias, rewriter.getDenseI64ArrayAttr({1}));
result =
rewriter.create<stablehlo::AddOp>(op->getLoc(), result, bcastBias);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,10 @@ struct ConvertAtenMaxPool2dWithIndicesBackwardOp
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseI64ArrayAttr windowDimensions =
rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
DenseI64ArrayAttr windowStrides =
rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -315,7 +310,7 @@ struct ConvertAtenPowScalarOp : public OpConversionPattern<AtenPowScalarOp> {
if (!lhsType) {
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ add_mlir_library(TorchFrontendPipelines
TorchMLIRTorchConversionPasses
TorchFrontendConversion
TorchFrontendTransforms
StablehloPasses
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "torch-frontend/Pipelines/Pipelines.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "stablehlo/transforms/Passes.h"
#include "torch-frontend/Conversion/Passes.h"
#include "torch-frontend/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
Expand All @@ -33,6 +34,8 @@ void mlir::torch_frontend::createTorchToMhloPipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloExt());
pm.addNestedPass<func::FuncOp>(
createConvertTorchToStablehloPass(false, false));
pm.addNestedPass<func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

// Clean up any non-canonical code introduced above..
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
//===----------------------------------------------------------------------===//

#include "torch-frontend/Transforms/CanonicalizeExt.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand Down Expand Up @@ -61,6 +62,18 @@ LogicalResult foldConstantConvertOp(stablehlo::ConvertOp op,
return success();
}

LogicalResult replaceArithConstantOpWithMhlo(arith::ConstantOp op,
PatternRewriter &rewriter) {
if (llvm::isa<ElementsAttr>(op.getValue())) {
stablehlo::ConstantOp newConstantOp =
rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), op.getValue().cast<ElementsAttr>());
rewriter.replaceOp(op, newConstantOp.getOutput());
return success();
}
return failure();
}

namespace {

struct CanonicalizeExtPass : public CanonicalizeExtBase<CanonicalizeExtPass> {
Expand All @@ -85,6 +98,8 @@ struct CanonicalizeExtPass : public CanonicalizeExtBase<CanonicalizeExtPass> {

// Add conditional canonicalizer too
owningPatterns.add(foldConstantConvertOp);
// remove it if torch-to-stablehlo doesn't involve arith dialect
owningPatterns.add(replaceArithConstantOpWithMhlo);

patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch as tu

import torch_frontend
from torch_frontend import compile
from torch_frontend._mlir_libs._stablehlo import deserialize_portable_artifact

def serialize_helper(module, inputs):
stablehlo_bytecode = compile(module, inputs, "stablehlo+0.16.2")
deserialize_str = deserialize_portable_artifact(stablehlo_bytecode)
print(deserialize_str)

# ==============================================================================
class SoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten._softmax(x, dim=1, half_to_float=False)

def test_softmax():
inputs = [tu.rand(3, 4)]
serialize_helper(SoftmaxModule(), inputs)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_frontend import ir
from torch_frontend.passmanager import PassManager
from torch_frontend.dialects.builtin import ModuleOp
from torch_frontend._mlir_libs._stablehlo import serialize_portable_artifact

_CUSTOM_OPS_IN_TORCH = [
"aten._softmax",
Expand Down Expand Up @@ -91,22 +92,28 @@ def compile(
) -> ModuleOp:
"""
Args:
output_type: str type
`raw`
`torch`
`stablehlo`
`stablehlo+0.16.2`(stablehlo version)
debug: int type, one of
`0: no debug message`,
`1: print after failure`,
`2: print after pass only on change`
"""
if output_type not in ["raw", "torch", "stablehlo"]:
if output_type not in ["raw", "torch"] and "stablehlo" not in output_type:
raise NotImplementedError(f"unsupported output type {output_type}")
if debug not in [0, 1, 2]:
raise NotImplementedError(f"unsupported debug option {debug}")
if backend_legal_ops is None:
backend_legal_ops = _CUSTOM_OPS_IN_TORCH

module = torch_mlir.compile(
from torch_mlir import torchscript
module = torchscript.compile(
model,
example_inputs,
output_type=torch_mlir.OutputType.RAW,
output_type=torchscript.OutputType.RAW,
use_tracing=False,
verbose=False,
)
Expand All @@ -126,7 +133,7 @@ def compile(
if debug:
module.context.enable_multithreading(False)

extra_library_file_name = torch_mlir._canon_extra_library(extra_library)
extra_library_file_name = torchscript._canon_extra_library(extra_library)
if verbose:
cmdline_option_string = (
"backend-legal-ops=" + ",".join(backend_legal_ops) + " extra-library=" + extra_library_file_name
Expand Down Expand Up @@ -178,7 +185,9 @@ def compile(
large_elements_limit=10,
)
pm.run(module.operation)
return module
if output_type == "stablehlo":
return module
return serialize_portable_artifact(module.operation.get_asm(), output_type.split('+')[1])


def compile_dynamo_model(
Expand Down Expand Up @@ -208,7 +217,7 @@ def compile_dynamo_model(
fx_importer = FxImporter(context=torch_mlir_context)
# for torch.export
if isinstance(model, torch.export.ExportedProgram):
fx_importer.import_frozen_exported_program(model)
fx_importer.import_frozen_program(model)
# for torch.compile
elif isinstance(model, torch.fx.GraphModule):
fx_importer.import_graph_module(model)
Expand Down Expand Up @@ -290,10 +299,11 @@ def convert_to_mhlo_via_torch_mlir(
if backend_legal_ops is None:
backend_legal_ops = _CUSTOM_OPS_IN_TORCH
# torch_mlir.BACKEND_LEGAL_OPS[torch_mlir.OutputType.TORCH] = backend_legal_ops
module = torch_mlir.compile(
from torch_mlir import torchscript
module = torchscript.compile(
model,
example_inputs,
output_type=torch_mlir.OutputType.RAW,
output_type=torchscript.OutputType.RAW,
use_tracing=use_tracing,
verbose=False,
)
Expand Down

0 comments on commit 9651131

Please sign in to comment.