Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stablehlo] enable Stablehlo refbackend with Interpreter #3292

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,119 @@
}

STABLEHLO_PASS_SET = {
"AddCDivModule_basic",
"AddCMulModule_basic",
"Add_MixPModule_basic",
"Add_Module_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dWithoutPadModule_basic",
"BmmFloatModule_basic",
"BmmIntModule_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"CollapseFullDynamicModule_basic",
"CollapsePartialDynamicModule_basic",
"CollapseRank1DynamicModule_basic",
"CopyModule_basic",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
"CopyWithDifferentSizesModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAtan2FloatIntModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseOrModule_basic",
"ElementwiseBitwiseXorModule_basic",
"ElementwiseDivTensorFloatModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseMaxOtherIntModule_basic",
"ElementwiseMaxOtherModule_basic",
"ElementwiseMaximumIntModule_basic",
"ElementwiseMaximumModule_basic",
"ElementwiseMinOtherIntModule_basic",
"ElementwiseMinOtherModule_basic",
"ElementwiseMinimumIntModule_basic",
"ElementwiseMinimumModule_basic",
"ElementwiseMulScalarModule_int",
"ElementwiseMulTensorFloatModule_basic",
"ElementwiseMulTensorIntModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseSignModule_basic",
"ElementwiseSubTensorInt8Module_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ExpandAsFloatModule_basic",
"ExpandModule_basic",
"FlipModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"LogSoftmaxBackwardModule_basic",
"LogSoftmaxIntModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulSingleDynamicBatchDim_basic",
"Matmul_3d",
"Matmul_4d",
"MaxPool1dModule_basic",
"MaxPool2dModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MseLossNoReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"OneHotModule_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"ReduceAmaxKeepDim_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
"ReduceSumDimIntListKeepDimIntModule_basic",
"RsubIntModule_basic",
"ScatterSrcStaticModule_basic",
"SiluModule_basic",
"SliceCopy_Module_basic",
"SoftmaxBackwardModule_basic",
"SoftmaxIntArgTypeF64Module_basic",
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SoftmaxIntNonNoneDtypeModule_basic",
"SquareModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
"SqueezeModule_broadcast",
"TanhBackward_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackSingleElementListModule_basic",
"ToCopyModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"TypePromotionSameCategoryDifferentWidthModule_basic",
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"ReduceAminmaxSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAmaxEmptyDim_basic",
Expand Down Expand Up @@ -1513,6 +1626,11 @@
"IndexPutWithNoneAndBroadcastModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
# stablehlo intrepreter crash
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from torch_mlir import ir
from torch_mlir.ir import *
from torch_mlir.dialects.func import FuncOp
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report

Expand All @@ -13,10 +15,60 @@

from .abc import StablehloBackend

from torch_mlir._mlir_libs._stablehlo import eval_module
import numpy as np

__all__ = [
"LinalgOnTensorsStablehloBackend",
]

element_type_to_np_dtype = {
"i1": np.bool_,
"i8": np.int8,
"ui8": np.uint8,
"i16": np.int16,
"i32": np.int32,
"i64": np.int64,
"f16": np.float16,
"f32": np.float32,
"f64": np.float64,
}


def convert_dense_elements_attr_to_numpy(attr):
assert isinstance(attr, ir.DenseElementsAttr)
dense_attr = ir.DenseElementsAttr(attr)
for DenseElementsAttrCls in [ir.DenseIntElementsAttr, ir.DenseFPElementsAttr]:
if DenseElementsAttrCls.isinstance(attr):
dense_attr = DenseElementsAttrCls(attr)
assert ir.ShapedType.isinstance(dense_attr.type)
dense_attr_type = ir.ShapedType(dense_attr.type)
return np.array(
[i for i in dense_attr],
dtype=element_type_to_np_dtype[str(dense_attr_type.element_type)],
).reshape(dense_attr_type.shape)
raise NotImplementedError("unsupported attribute {}".format(attr))


class RefBackendInvoker:
def __init__(self, module):
self.module = module

def __getattr__(self, function_name: str):
def invoke(*args):
mlir_args = [
ir.DenseElementsAttr.get(arg, context=self.module.context)
for arg in args
]
rets = eval_module(self.module, mlir_args)
rets = [convert_dense_elements_attr_to_numpy(i) for i in rets]
if len(rets) == 1:
return rets[0]
return rets

return invoke


# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
Expand All @@ -28,6 +80,39 @@
]
)

SHAPE_LEGALIZE_TO_STABLEHLO_PIPELINE = ",".join(
[
"func.func(remove-shape-constraints)",
"canonicalize",
"func.func(shape-legalize-to-stablehlo)",
"canonicalize",
]
)


def raise_if_not_supported_by_interpreter(module: Module):
for func in module.body.operations:
assert isinstance(func, FuncOp)
for arg in func.arguments:
assert isinstance(arg.type, ir.ShapedType)
if str(ir.ShapedType(arg.type).element_type) == "i1":
raise RuntimeError("i1")
for ret in list(func.entry_block.operations)[-1].operands:
assert isinstance(ret.type, ir.ShapedType)
if str(ir.ShapedType(ret.type).element_type) == "i1":
raise RuntimeError("i1")
for op in func.entry_block.operations:
if op.operation.name == "func.return":
continue
if not op.operation.name.startswith("stablehlo."):
raise RuntimeError(
f"stablehlo interpreter doesn't support {op.operation.name}"
)
if op.operation.name == "stablehlo.batch_norm_inference":
raise RuntimeError(
f"stablehlo interpreter doesn't support {op.operation.name}"
)


class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""Main entry-point for the linalg-on-tensors based Stablehlo backend.
Expand All @@ -48,15 +133,29 @@ def compile(self, imported_module: Module):
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
copied_module = Module.parse(imported_module.operation.get_asm(), imported_module.context)
try:
run_pipeline_with_repro_report(
imported_module,
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
)
result = self.refbackend.compile(imported_module)
return (result, "linalg")
except:
pass

run_pipeline_with_repro_report(
imported_module,
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
copied_module,
f"builtin.module({SHAPE_LEGALIZE_TO_STABLEHLO_PIPELINE})",
"Shape legalize to stablehlo",
)

return self.refbackend.compile(imported_module)
raise_if_not_supported_by_interpreter(copied_module)
return (copied_module, "stablehlo")

def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)
if module[1] == "linalg":
return self.refbackend.load(module[0])
else:
return RefBackendInvoker(module[0])
Loading