From 4707d3bdc6d7e1bb12b3e44dcf23455a7d445725 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 4 Jan 2024 15:12:51 +0000 Subject: [PATCH] [MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 116 +++++++++++++++--- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 59 +++++++++ 2 files changed, 160 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 87df83101718..8b6fddecfd56 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -15,6 +15,28 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { + int64_t dtypeIntTorch; + // TODO: Add complete mapping. + switch (dtypeIntOnnx) { + case 1: + dtypeIntTorch = 6; // float + break; + case 10: + dtypeIntTorch = 5; // half + break; + case 11: + dtypeIntTorch = 7; // double + break; + case 16: + dtypeIntTorch = 15; // bfloat16 + break; + default: + dtypeIntTorch = -1; // No dtype + } + return dtypeIntTorch; +} + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -311,6 +333,53 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return failure(); }); + patterns.onOp( + "Bernoulli", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t dtypeIntOnnx, dtypeIntTorch; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || + binder.tensorResultType(resultType)) + return failure(); + + SmallString<64> name("torch.onnx."); + name.append("seed"); + auto attr = binder.op->getAttr(name); + if (attr) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + } + + Value none = rewriter.create(binder.getLoc()); + Value bernoulli = rewriter.create( + binder.getLoc(), input.getType(), input, /*generator=*/none); + + if (dtypeIntOnnx == -1) { + // True, if dtype attribute value is not present. + rewriter.replaceOp(binder.op, bernoulli); + return success(); + } + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch)); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, bernoulli, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); patterns.onOp( "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -386,21 +455,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - // TODO: Add complete mapping. - switch (dtypeIntOnnx) { - case 1: - dtypeIntTorch = 6; // float - break; - case 10: - dtypeIntTorch = 5; // half - break; - case 11: - dtypeIntTorch = 7; // double - break; - case 16: - dtypeIntTorch = 15; // bfloat16 - break; - default: + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); @@ -418,6 +474,36 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( /*memory_format=*/none); return success(); }); + patterns.onOp( + "CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, target; + if (binder.tensorOperands(input, target) || + binder.tensorResultType(resultType)) + return failure(); + + // TODO: Add support to handle the `saturate` attribute. + // Ignoring it right now, since it's only using during the float8 + // conversions which are not supported in Torch-MLIR right now. + + Torch::ValueTensorType targetTy = + target.getType().cast(); + if (!targetTy.hasDtype()) { + return rewriter.notifyMatchFailure(binder.op, + "target tensor must have a dtype"); + } + Type targetDtype = targetTy.getDtype(); + Value constDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), targetDtype); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); patterns.onOp("Ceil", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 8ba88d22c256..08d6e4ea4e91 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -110,6 +110,25 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// CHECK-LABEL: @test_bernoulli +func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) : (!torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + +// CHECK-LABEL: @test_bernoulli_double +func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: %[[DTYPE:.*]] = torch.constant.int 7 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) {torch.onnx.dtype = 11 : si64} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + // CHECK-LABEL: @test_bitshift_left_uint8 func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -323,6 +342,46 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc return %0 : !torch.vtensor<[3,4],f32> } +// CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT +func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],bf16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT +func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f64>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE +func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 7 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT +func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + // CHECK-LABEL: @test_ceil_example func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>