diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index b43733418..63c8aa4f7 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -22,7 +22,6 @@ static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type, lwe::UnspecifiedBitFieldEncodingAttr::get( ctx, type.getIntOrFloatBitWidth()), lwe::LWEParamsAttr()); - ; } static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) { @@ -186,6 +185,43 @@ struct ConvertShRUIOp : public OpConversionPattern { } }; +template +struct ConvertArithBinOp : public OpConversionPattern { + ConvertArithBinOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SourceArithOp op, typename SourceArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + if (auto lhsDefOp = op.getLhs().getDefiningOp()) { + if (isa(lhsDefOp)) { + auto result = b.create(adaptor.getRhs().getType(), + adaptor.getRhs(), op.getLhs()); + rewriter.replaceOp(op, result); + return success(); + } + } + + if (auto rhsDefOp = op.getRhs().getDefiningOp()) { + if (isa(rhsDefOp)) { + auto result = b.create(adaptor.getLhs().getType(), + adaptor.getLhs(), op.getRhs()); + rewriter.replaceOp(op, result); + return success(); + } + } + + auto result = b.create( + adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithToCGGI : public impl::ArithToCGGIBase { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -196,29 +232,82 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalDialect(); + target.addLegalOp(); + + target.addDynamicallyLegalOp([&](Operation *op) { + if (auto *defOp = + cast(op).getOperand().getDefiningOp()) { + return isa(defOp); + } + return false; + }); - target.addDynamicallyLegalOp( - [](mlir::arith::ConstantOp op) { - // Allow use of constant if it is used to denote the size of a shift - return (isa(op.getValue().getType())); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); }); - target.addDynamicallyLegalOp< - memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp, - memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp, - affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) { + target.addDynamicallyLegalOp([&](Operation *op) { + // Check if all Store ops are constants, if not store op, accepts + // Check if there is at least one Store op that is a constants + return (llvm::all_of(op->getUses(), + [&](OpOperand &op) { + auto defOp = + dyn_cast(op.getOwner()); + if (defOp) { + return isa( + defOp.getValue().getDefiningOp()); + } + return true; + }) && + llvm::any_of(op->getUses(), + [&](OpOperand &op) { + auto defOp = + dyn_cast(op.getOwner()); + if (defOp) { + return isa( + defOp.getValue().getDefiningOp()); + } + return false; + })) || + // The other case: Memref need to be in LWE format + (typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes())); + }); + + target.addDynamicallyLegalOp([&](Operation *op) { + if (auto *defOp = cast(op).getValue().getDefiningOp()) { + if (isa(defOp)) { + return true; + } + } + return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); + // Convert LoadOp if memref comes from an argument + target.addDynamicallyLegalOp([&](Operation *op) { + if (typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes())) { + return true; + } + auto loadOp = dyn_cast(op); + + return loadOp.getMemRef().getDefiningOp() != nullptr; + }); + patterns.add< ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, - ConvertShRUIOp, ConvertBinOp, - ConvertBinOp, - ConvertBinOp, + ConvertShRUIOp, ConvertArithBinOp, + ConvertArithBinOp, + ConvertArithBinOp, ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny >( typeConverter, context); diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp index 20b0e14d4..3efd0fae2 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp @@ -350,7 +350,7 @@ struct ConvertQuartAddI final : OpConversionPattern { SmallVector outputs; for (int i = 0; i < splitLhs.size(); ++i) { - auto lowSum = b.create(splitLhs[i], splitRhs[i]); + auto lowSum = b.create(elemType, splitLhs[i], splitRhs[i]); auto outputLsb = b.create(op.getLoc(), realTy, lowSum); auto outputLsbHigh = b.create(op.getLoc(), elemType, outputLsb); @@ -365,7 +365,8 @@ struct ConvertQuartAddI final : OpConversionPattern { if (i == 0) { outputs.push_back(outputLsbHigh); } else { - auto high = b.create(outputLsbHigh, carries[i - 1]); + auto high = + b.create(elemType, outputLsbHigh, carries[i - 1]); outputs.push_back(high); } } @@ -413,35 +414,35 @@ struct ConvertQuartMulI final : OpConversionPattern { // TODO: Implement the real Karatsuba algorithm for 4x4 multiplication. // First part of Karatsuba algorithm - auto z00 = b.create(splitLhs[0], splitRhs[0]); - auto z02 = b.create(splitLhs[1], splitRhs[1]); - auto z01_p1 = b.create(splitLhs[0], splitLhs[1]); - auto z01_p2 = b.create(splitRhs[0], splitRhs[1]); - auto z01_m = b.create(z01_p1, z01_p2); + auto z00 = b.create(elemTy, splitLhs[0], splitRhs[0]); + auto z02 = b.create(elemTy, splitLhs[1], splitRhs[1]); + auto z01_p1 = b.create(elemTy, splitLhs[0], splitLhs[1]); + auto z01_p2 = b.create(elemTy, splitRhs[0], splitRhs[1]); + auto z01_m = b.create(elemTy, z01_p1, z01_p2); auto z01_s = b.create(z01_m, z00); auto z01 = b.create(z01_s, z02); // Second part I of Karatsuba algorithm - auto z1a0 = b.create(splitLhs[0], splitRhs[2]); - auto z1a2 = b.create(splitLhs[1], splitRhs[3]); - auto z1a1_p1 = b.create(splitLhs[0], splitLhs[1]); - auto z1a1_p2 = b.create(splitRhs[2], splitRhs[3]); - auto z1a1_m = b.create(z1a1_p1, z1a1_p2); + auto z1a0 = b.create(elemTy, splitLhs[0], splitRhs[2]); + auto z1a2 = b.create(elemTy, splitLhs[1], splitRhs[3]); + auto z1a1_p1 = b.create(elemTy, splitLhs[0], splitLhs[1]); + auto z1a1_p2 = b.create(elemTy, splitRhs[2], splitRhs[3]); + auto z1a1_m = b.create(elemTy, z1a1_p1, z1a1_p2); auto z1a1_s = b.create(z1a1_m, z1a0); auto z1a1 = b.create(z1a1_s, z1a2); // Second part II of Karatsuba algorithm - auto z1b0 = b.create(splitLhs[2], splitRhs[0]); - auto z1b2 = b.create(splitLhs[3], splitRhs[1]); - auto z1b1_p1 = b.create(splitLhs[2], splitLhs[3]); - auto z1b1_p2 = b.create(splitRhs[0], splitRhs[1]); - auto z1b1_m = b.create(z1b1_p1, z1b1_p2); - auto z1b1_s = b.create(z1b1_m, z1b0); + auto z1b0 = b.create(elemTy, splitLhs[2], splitRhs[0]); + auto z1b2 = b.create(elemTy, splitLhs[3], splitRhs[1]); + auto z1b1_p1 = b.create(elemTy, splitLhs[2], splitLhs[3]); + auto z1b1_p2 = b.create(elemTy, splitRhs[0], splitRhs[1]); + auto z1b1_m = b.create(elemTy, z1b1_p1, z1b1_p2); + auto z1b1_s = b.create(elemTy, z1b1_m, z1b0); auto z1b1 = b.create(z1b1_s, z1b2); - auto out2Kara = b.create(z1a0, z1b0); - auto out2Carry = b.create(out2Kara, z02); - auto out3Carry = b.create(z1a1, z1b1); + auto out2Kara = b.create(elemTy, z1a0, z1b0); + auto out2Carry = b.create(elemTy, out2Kara, z02); + auto out3Carry = b.create(elemTy, z1a1, z1b1); // Output are now all 16b elements, wants presentation of 4x8b auto output0Lsb = b.create(realTy, z00); @@ -462,9 +463,9 @@ struct ConvertQuartMulI final : OpConversionPattern { auto output3Lsb = b.create(realTy, out3Carry); auto output3LsbHigh = b.create(elemTy, output3Lsb); - auto output1 = b.create(output1LsbHigh, output0Msb); - auto output2 = b.create(output2LsbHigh, output1Msb); - auto output3 = b.create(output3LsbHigh, output2Msb); + auto output1 = b.create(elemTy, output1LsbHigh, output0Msb); + auto output2 = b.create(elemTy, output2LsbHigh, output1Msb); + auto output3 = b.create(elemTy, output3LsbHigh, output2Msb); Value resultVec = constructResultTensor( rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3}); diff --git a/lib/Dialect/CGGI/IR/CGGIOps.td b/lib/Dialect/CGGI/IR/CGGIOps.td index f1ccc04b9..b30859684 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.td +++ b/lib/Dialect/CGGI/IR/CGGIOps.td @@ -10,6 +10,7 @@ include "lib/Dialect/LWE/IR/LWETypes.td" include "mlir/IR/OpBase.td" include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -40,15 +41,25 @@ class CGGI_BinaryOp let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ; } +class CGGI_ScalarBinaryOp + : CGGI_Op { + let arguments = (ins LWECiphertext:$lhs, AnyTypeOf<[Builtin_Integer, LWECiphertext]>:$rhs); + let results = (outs LWECiphertext:$output); +} + def CGGI_AndOp : CGGI_BinaryOp<"and"> { let summary = "Logical AND of two ciphertexts."; } def CGGI_NandOp : CGGI_BinaryOp<"nand"> { let summary = "Logical NAND of two ciphertexts."; } def CGGI_NorOp : CGGI_BinaryOp<"nor"> { let summary = "Logical NOR of two ciphertexts."; } def CGGI_OrOp : CGGI_BinaryOp<"or"> { let summary = "Logical OR of two ciphertexts."; } def CGGI_XorOp : CGGI_BinaryOp<"xor"> { let summary = "Logical XOR of two ciphertexts."; } def CGGI_XNorOp : CGGI_BinaryOp<"xnor"> { let summary = "Logical XNOR of two ciphertexts."; } -def CGGI_AddOp : CGGI_BinaryOp<"add"> { let summary = "Arithmetic addition of two ciphertexts."; } -def CGGI_MulOp : CGGI_BinaryOp<"mul"> { - let summary = "Arithmetic multiplication of two ciphertexts."; + +def CGGI_AddOp : CGGI_ScalarBinaryOp<"add"> { let summary = "Arithmetic addition of two ciphertexts. One of the two ciphertext is allowed to be a scalar, this will result in the scalar addition to a ciphertext."; } +def CGGI_MulOp : CGGI_ScalarBinaryOp<"mul"> { + let summary = "Arithmetic multiplication of two ciphertexts. One of the two ciphertext is allowed to be a scalar, this will result in the scalar multiplication to a ciphertext."; let description = [{ While CGGI does not have a native multiplication operation, some backend targets provide a multiplication @@ -59,6 +70,18 @@ def CGGI_MulOp : CGGI_BinaryOp<"mul"> { }]; } +def CGGI_SubOp : CGGI_Op<"sub", [ + Pure, + SameOperandsAndResultType, + ElementwiseMappable, + Scalarizable +]> { + let arguments = (ins LWECiphertext:$lhs, LWECiphertext:$rhs); + let results = (outs LWECiphertext:$output); + let summary = "Subtraction of two ciphertexts."; +} + + def CGGI_NotOp : CGGI_Op<"not", [ Pure, Involution, @@ -282,17 +305,6 @@ def CGGI_MultiLutLinCombOp : CGGI_Op<"multi_lut_lincomb", [ let hasVerifier = 1; } -def CGGI_SubOp : CGGI_Op<"sub", [ - Pure, - SameOperandsAndResultType, - ElementwiseMappable, - Scalarizable -]> { - let arguments = (ins LWECiphertextLike:$lhs, LWECiphertextLike:$rhs); - let results = (outs LWECiphertextLike:$output); - let assemblyFormat = "operands attr-dict `:` qualified(type($output))"; - let summary = "Subtraction of two ciphertexts."; -} def CGGI_ScalarShiftRightOp : CGGI_Op<"sshr", [ diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index 90c9bcec0..dd5890d7e 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -16,6 +16,7 @@ #include "lib/Dialect/TfheRust/IR/TfheRustTypes.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -93,8 +94,8 @@ struct ConvertBinOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto result = - b.create(adaptor.getLhs(), adaptor.getRhs()); + auto result = b.create( + adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs()); rewriter.replaceOp(op, result); return success(); } diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir index 01fb3fe10..8a6df25c9 100644 --- a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir @@ -3,34 +3,28 @@ // CHECK-LABEL: @test_lower_add // CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { func.func @test_lower_add(%lhs : i32, %rhs : i32) -> i32 { - // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[RHS]] : ([[T]], [[T]]) -> [[T]] // CHECK: return %[[ADD:.*]] : [[T]] %res = arith.addi %lhs, %rhs : i32 return %res : i32 } -// CHECK-LABEL: @test_lower_add_vec -// CHECK-SAME: (%[[LHS:.*]]: tensor<4x!lwe.lwe_ciphertext>, %[[RHS:.*]]: tensor<4x!lwe.lwe_ciphertext>) -> [[T:.*]] { -func.func @test_lower_add_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { - // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[RHS]] : [[T]] +// CHECK-LABEL: @test_lower_cte_add +// CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { +func.func @test_lower_cte_add(%in : i32) -> i32 { + // CHECK: %[[CTE:.*]] = arith.constant 7 : i32 + // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[CTE]] : ([[T]], i32) -> [[T]] // CHECK: return %[[ADD:.*]] : [[T]] - %res = arith.addi %lhs, %rhs : tensor<4xi32> - return %res : tensor<4xi32> + %c7 = arith.constant 7 : i32 + %res = arith.addi %in, %c7 : i32 + return %res : i32 } -// CHECK-LABEL: @test_lower_sub_vec -// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { -func.func @test_lower_sub_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { - // CHECK: %[[ADD:.*]] = cggi.sub %[[LHS]], %[[RHS]] : [[T]] - // CHECK: return %[[ADD:.*]] : [[T]] - %res = arith.subi %lhs, %rhs : tensor<4xi32> - return %res : tensor<4xi32> -} // CHECK-LABEL: @test_lower_sub // CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { func.func @test_lower_sub(%lhs : i16, %rhs : i16) -> i16 { - // CHECK: %[[ADD:.*]] = cggi.sub %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[ADD:.*]] = cggi.sub %[[LHS]], %[[RHS]] : ([[T]], [[T]]) -> [[T]] // CHECK: return %[[ADD:.*]] : [[T]] %res = arith.subi %lhs, %rhs : i16 return %res : i16 @@ -39,22 +33,12 @@ func.func @test_lower_sub(%lhs : i16, %rhs : i16) -> i16 { // CHECK-LABEL: @test_lower_mul // CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { func.func @test_lower_mul(%lhs : i8, %rhs : i8) -> i8 { - // CHECK: %[[ADD:.*]] = cggi.mul %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[ADD:.*]] = cggi.mul %[[LHS]], %[[RHS]] : ([[T]], [[T]]) -> [[T]] // CHECK: return %[[ADD:.*]] : [[T]] %res = arith.muli %lhs, %rhs : i8 return %res : i8 } -// CHECK-LABEL: @test_lower_mul_vec -// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { -func.func @test_lower_mul_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { - // CHECK: %[[ADD:.*]] = cggi.mul %[[LHS]], %[[RHS]] : [[T]] - // CHECK: return %[[ADD:.*]] : [[T]] - %res = arith.muli %lhs, %rhs : tensor<4xi8> - return %res : tensor<4xi8> -} - - // CHECK-LABEL: @test_affine // CHECK-SAME: (%[[ARG:.*]]: memref<1x1x!lwe.lwe_ciphertext>) -> [[T:.*]] { func.func @test_affine(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { @@ -62,7 +46,6 @@ func.func @test_affine(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { %c429_i32 = arith.constant 429 : i32 %c33_i8 = arith.constant 33 : i32 %0 = affine.load %arg0[0, 0] : memref<1x1xi32> - %c0 = arith.constant 0 : index %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32> %25 = arith.muli %0, %c33_i8 : i32 %26 = arith.addi %c429_i32, %25 : i32 diff --git a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir index 13c8cbe78..5c78448ff 100644 --- a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir +++ b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir @@ -12,8 +12,8 @@ func.func @test_affine(%arg0: memref<1x1x!ct_ty>) -> memref<1x1x!ct_ty> { %2 = affine.load %arg0[0, 0] : memref<1x1x!ct_ty> %c0 = arith.constant 0 : index %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x!ct_ty> - %3 = cggi.mul %2, %1 : !ct_ty - %4 = cggi.add %3, %0 : !ct_ty + %3 = cggi.mul %2, %1 : (!ct_ty, !ct_ty) -> !ct_ty + %4 = cggi.add %3, %0 : (!ct_ty, !ct_ty) -> !ct_ty %5 = cggi.sshr %4 {shiftAmount = 2 : index} : (!ct_ty) -> !ct_ty affine.store %5, %alloc[0, 0] : memref<1x1x!ct_ty> return %alloc : memref<1x1x!ct_ty>