From eb4e59e1899d4f3ed61e7ed3956e4fd9e1cc9aae Mon Sep 17 00:00:00 2001 From: yyp0 Date: Sun, 29 Sep 2024 17:41:20 +0800 Subject: [PATCH] [Torch] support binary_cross_entropy_with_logits decomposition (#3741) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++++++ .../Transforms/AbstractInterpLibrary.cpp | 16 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 73 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 12 +++ .../build_tools/torch_ods_gen.py | 3 + .../test_suite/reduction.py | 23 ++++++ 6 files changed, 154 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c9329ccb895d..6f02a94768d0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9224,6 +9224,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy }]; } +def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$pos_weight, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 995a7df283fd..445d4e459013 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10289,6 +10289,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list, list>\n" " return %0 : !torch.tuple, list, list>\n" @@ -14634,6 +14646,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1ee57b60f248..29c176f96afd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8799,6 +8799,77 @@ class DecomposeAtenCrossEntropyLossOp }; } // namespace +namespace { +class DecomposeAtenBinaryCrossEntropyWithLogitsOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto self = op.getSelf(); + auto target = op.getTarget(); + auto posWeight = op.getPosWeight(); + auto weight = op.getWeight(); + auto reduction = op.getReduction(); + + Value loss; + auto one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto _one = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + + auto _target = + rewriter.create(loc, target.getType(), target, _one); + auto _target_1 = rewriter.create(loc, _target.getType(), + _target, one, one); + Value mm = + rewriter.create(loc, self.getType(), _target_1, self); + Value logSigm = + rewriter.create(loc, self.getType(), self); + + if (!isa(posWeight.getType())) { + auto logWeight = rewriter.create( + loc, posWeight.getType(), + rewriter.create(loc, posWeight.getType(), posWeight, + one, one), + one, one); + loss = rewriter.create( + loc, mm.getType(), mm, + rewriter.create(loc, logWeight.getType(), logWeight, + logSigm), + one); + } else { + loss = + rewriter.create(loc, mm.getType(), mm, logSigm, one); + } + + if (!isa(weight.getType())) { + loss = + rewriter.create(loc, loss.getType(), loss, weight); + } + + // apply loss reduction. + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, "no reduction type is appointed!"); + } + + auto none = rewriter.create(loc); + Value res; + if (reductionInt == 1) { + res = rewriter.create(loc, op.getType(), loss, none); + } else if (reductionInt == 2) { + res = rewriter.create(loc, op.getType(), loss, none); + } else { + res = loss; + } + + rewriter.replaceOp(op, res); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -9936,6 +10007,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 22fe8e299f07..d3ec25bcea70 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1993,6 +1993,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) +def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]: + scalar_shape: List[int] = [] + if reduction == 0: + result_shape = upstream_shape_functions._copy(self) + else: + result_shape = scalar_shape + return result_shape + @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) @@ -4958,6 +4966,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return dtype return aten〇std〡dtype(self_rank_dtype) +def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( tensor_shapes=[(3,3)], diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f3227f29b5ce..ea5c504284eb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -743,6 +743,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)" ) + emit( + "aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)" + ) emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 9a683e3c6219..e9b84ea0652c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -2294,6 +2294,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) +class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 2], torch.float32, True), + ([8, 2], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.binary_cross_entropy_with_logits( + input, target, reduction=0 + ) + + +@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule()) +def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.rand(8, 2)) + + # ==============================================================================