From fd862796c3f7070d54180eef33c74e5516372274 Mon Sep 17 00:00:00 2001 From: WoutLegiest Date: Wed, 25 Dec 2024 01:30:54 +0000 Subject: [PATCH] Working Quart to Tfhe rs + Change tfhe-rs and cggi shift ops Co-authored-by: asraa --- docs/content/en/docs/pipelines.md | 3 +- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 23 +- .../ArithToCGGIQuart/ArithToCGGIQuart.cpp | 213 ++++++++++++++---- .../CGGIToTfheRust/CGGIToTfheRust.cpp | 42 ++-- lib/Dialect/CGGI/IR/CGGIOps.td | 8 +- lib/Dialect/TfheRust/IR/TfheRustOps.td | 12 +- lib/Dialect/TfheRust/IR/TfheRustTypes.td | 1 + lib/Target/TfheRust/TfheRustEmitter.cpp | 20 +- lib/Target/TfheRustHL/TfheRustHLEmitter.cpp | 8 +- .../ArithToCGGI/arith-to-cggi.mlir | 2 + .../ArithToCGGIQuart/quarter_wide.mlir | 12 +- .../Conversions/cggi_to_tfhe_rust/arith.mlir | 3 +- .../cggi_to_tfhe_rust/binary_gates.mlir | 6 +- .../TfheRust/Emitters/emit_levelled_ops.mlir | 5 +- .../TfheRust/Emitters/emit_tfhe_rust.mlir | 6 +- tests/Dialect/TfheRust/IR/ops.mlir | 3 +- .../TfheRust/Transforms/canonicalize.mlir | 10 +- tests/Examples/tfhe_rust/test_simple_lut.mlir | 3 +- .../forward_add_one.mlir | 32 +-- .../loop_unroll/full_loop_unroll.mlir | 3 +- 20 files changed, 259 insertions(+), 156 deletions(-) diff --git a/docs/content/en/docs/pipelines.md b/docs/content/en/docs/pipelines.md index af256c738b..ab7850f1c1 100644 --- a/docs/content/en/docs/pipelines.md +++ b/docs/content/en/docs/pipelines.md @@ -97,8 +97,7 @@ Example input: func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 { %v1 = tfhe_rust.apply_lookup_table %sks, %input, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %input, %v1 : (!sks, !eui3, !eui3) -> !eui3 - %c1 = arith.constant 1 : i8 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index 8f728ee0e8..d1509799ca 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -139,12 +139,10 @@ struct ConvertShRUIOp : public OpConversionPattern { .getSExtValue(); auto inputValue = - mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount); - auto cteOp = rewriter.create( - op.getLoc(), rewriter.getI8Type(), inputValue); + mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount); - auto shiftOp = - b.create(outputType, adaptor.getLhs(), cteOp); + auto shiftOp = b.create( + outputType, adaptor.getLhs(), inputValue); rewriter.replaceOp(op, shiftOp); return success(); @@ -157,14 +155,12 @@ struct ConvertShRUIOp : public OpConversionPattern { auto shiftAmount = cast(cteShiftSizeOp.getValue()).getValue().getSExtValue(); - auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount); - auto cteOp = rewriter.create( - op.getLoc(), rewriter.getI8Type(), inputValue); + auto inputValue = + mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount); - auto shiftOp = - b.create(outputType, adaptor.getLhs(), cteOp); + auto shiftOp = b.create( + outputType, adaptor.getLhs(), inputValue); rewriter.replaceOp(op, shiftOp); - rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp); return success(); } @@ -184,10 +180,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { target.addDynamicallyLegalOp( [](mlir::arith::ConstantOp op) { // Allow use of constant if it is used to denote the size of a shift - bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { - return isa(user); - }); - return (isa(op.getValue().getType()) || (usedByShift)); + return (isa(op.getValue().getType())); }); target.addDynamicallyLegalOp< diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp index 343bde13b4..fe8b4035a5 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp @@ -1,9 +1,5 @@ #include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h" -#include - -#include - #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CGGI/IR/CGGIOps.h" #include "lib/Dialect/LWE/IR/LWEOps.h" @@ -15,7 +11,9 @@ #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project namespace mlir::heir::arith { @@ -94,7 +92,7 @@ class ArithToCGGIQuartTypeConverter : public TypeConverter { }; static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) { - auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth >> 1); + auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth); auto intAttr = b.getIntegerAttr(maxWideIntType, value); auto encoding = @@ -153,19 +151,8 @@ static SmallVector extractLastDimHalves( static Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, int64_t value) { - unsigned elementBitWidth = 0; - if (auto lweTy = dyn_cast(type)) - elementBitWidth = - cast(lweTy.getEncoding()) - .getCleartextBitwidth(); - else - elementBitWidth = maxIntWidth; - - auto apValue = APInt(elementBitWidth, value); - - auto maxWideIntType = - IntegerType::get(builder.getContext(), maxIntWidth >> 1); - auto intAttr = builder.getIntegerAttr(maxWideIntType, value); + auto intAttr = builder.getIntegerAttr( + IntegerType::get(builder.getContext(), maxIntWidth), value); return builder.create(loc, type, intAttr); } @@ -249,6 +236,38 @@ struct ConvertQuartConstantOp } }; +struct ConvertQuartTruncIOp + : public OpConversionPattern { + ConvertQuartTruncIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto newResultTy = getTypeConverter()->convertType( + op.getResult().getType()); + auto newInTy = cast(adaptor.getIn().getType()); + + SmallVector offsets(newResultTy.getShape().size(), + rewriter.getIndexAttr(0)); + offsets.back() = rewriter.getIndexAttr(newInTy.getShape().back() - + newResultTy.getShape().back()); + SmallVector sizes(newResultTy.getShape().size()); + sizes.back() = rewriter.getIndexAttr(1); + SmallVector strides(newResultTy.getShape().size(), + rewriter.getIndexAttr(1)); + + auto resOp = rewriter.replaceOpWithNewOp( + adaptor.getIn(), offsets, sizes, strides); + + return success(); + } +}; + template struct ConvertQuartExt final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -274,23 +293,21 @@ struct ConvertQuartExt final : OpConversionPattern { auto resultChunks = newResultTy.getShape().back(); auto inChunks = newInTy.getShape().back(); - if (resultChunks > inChunks) { - auto paddingFactor = resultChunks - inChunks; + // Through definition of ExtOp, paddingFactor is always positive + auto paddingFactor = resultChunks - inChunks; - SmallVector low, high; - low.push_back(rewriter.getIndexAttr(0)); - high.push_back(rewriter.getIndexAttr(paddingFactor)); + SmallVector low, high; + low.push_back(rewriter.getIndexAttr(0)); + high.push_back(rewriter.getIndexAttr(paddingFactor)); - auto padValue = createTrivialOpMaxWidth(b, 0); + auto padValue = createTrivialOpMaxWidth(b, 0); - auto resultVec = b.create(newResultTy, adaptor.getIn(), - low, high, padValue, - /*nofold=*/true); + auto resultVec = b.create(newResultTy, adaptor.getIn(), low, + high, padValue, + /*nofold=*/true); - rewriter.replaceOp(op, resultVec); - return success(); - } - return failure(); + rewriter.replaceOp(op, resultVec); + return success(); } }; @@ -318,15 +335,14 @@ struct ConvertQuartAddI final : OpConversionPattern { // Actual type of the underlying elements; we use half the width. // Create Constant - auto intAttr = IntegerAttr::get(rewriter.getI8Type(), maxIntWidth >> 1); + auto shiftAttr = + IntegerAttr::get(rewriter.getIndexType(), maxIntWidth >> 1); auto elemType = convertArithToCGGIType( IntegerType::get(op->getContext(), maxIntWidth), op->getContext()); auto realTy = convertArithToCGGIType( IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext()); - auto constantOp = b.create(intAttr); - SmallVector carries; SmallVector outputs; @@ -338,7 +354,8 @@ struct ConvertQuartAddI final : OpConversionPattern { // Now all the outputs are 16b elements, wants presentation of 4x8b if (i != splitLhs.size() - 1) { - auto carry = b.create(elemType, lowSum, constantOp); + auto carry = + b.create(elemType, lowSum, shiftAttr); carries.push_back(carry); } @@ -356,6 +373,103 @@ struct ConvertQuartAddI final : OpConversionPattern { } }; +// Implemented using the Karatsuba algorithm +// https://en.wikipedia.org/wiki/Karatsuba_algorithm#Algorithm +struct ConvertQuartMulI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto newTy = + getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + if (newTy.getShape().back() != 4) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("Mul only support 4 split elements. Shape: {0}", + newTy)); + + auto elemTy = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth), op->getContext()); + auto realTy = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext()); + + // Create Constant + auto shiftAttr = + rewriter.getIntegerAttr(b.getIndexType(), maxIntWidth >> 1); + + SmallVector splitLhs = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + SmallVector splitRhs = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + // TODO: Implement the real Karatsuba algorithm for 4x4 multiplication. + // First part of Karatsuba algorithm + auto z00 = b.create(splitLhs[0], splitRhs[0]); + auto z02 = b.create(splitLhs[1], splitRhs[1]); + auto z01_p1 = b.create(splitLhs[0], splitLhs[1]); + auto z01_p2 = b.create(splitRhs[0], splitRhs[1]); + auto z01_m = b.create(z01_p1, z01_p2); + auto z01_s = b.create(z01_m, z00); + auto z01 = b.create(z01_s, z02); + + // Second part I of Karatsuba algorithm + auto z1a0 = b.create(splitLhs[0], splitRhs[2]); + auto z1a2 = b.create(splitLhs[1], splitRhs[3]); + auto z1a1_p1 = b.create(splitLhs[0], splitLhs[1]); + auto z1a1_p2 = b.create(splitRhs[2], splitRhs[3]); + auto z1a1_m = b.create(z1a1_p1, z1a1_p2); + auto z1a1_s = b.create(z1a1_m, z1a0); + auto z1a1 = b.create(z1a1_s, z1a2); + + // Second part II of Karatsuba algorithm + auto z1b0 = b.create(splitLhs[2], splitRhs[0]); + auto z1b2 = b.create(splitLhs[3], splitRhs[1]); + auto z1b1_p1 = b.create(splitLhs[2], splitLhs[3]); + auto z1b1_p2 = b.create(splitRhs[0], splitRhs[1]); + auto z1b1_m = b.create(z1b1_p1, z1b1_p2); + auto z1b1_s = b.create(z1b1_m, z1b0); + auto z1b1 = b.create(z1b1_s, z1b2); + + auto out2Kara = b.create(z1a0, z1b0); + auto out2Carry = b.create(out2Kara, z02); + auto out3Carry = b.create(z1a1, z1b1); + + // Output are now all 16b elements, wants presentation of 4x8b + auto output0Lsb = b.create(realTy, z00); + auto output0LsbHigh = b.create(elemTy, output0Lsb); + auto output0Msb = + b.create(elemTy, z00, shiftAttr); + + auto output1Lsb = b.create(realTy, z01); + auto output1LsbHigh = b.create(elemTy, output1Lsb); + auto output1Msb = + b.create(elemTy, z01, shiftAttr); + + auto output2Lsb = b.create(realTy, out2Carry); + auto output2LsbHigh = b.create(elemTy, output2Lsb); + auto output2Msb = + b.create(elemTy, out2Carry, shiftAttr); + + auto output3Lsb = b.create(realTy, out3Carry); + auto output3LsbHigh = b.create(elemTy, output3Lsb); + + auto output1 = b.create(output1LsbHigh, output0Msb); + auto output2 = b.create(output2LsbHigh, output1Msb); + auto output3 = b.create(output3LsbHigh, output2Msb); + + Value resultVec = constructResultTensor( + rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -386,28 +500,29 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase { target.addDynamicallyLegalOp( [](mlir::arith::ConstantOp op) { - // Allow use of constant if it is used to denote the size of a shift - bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { - return isa(user); - }); - return (isa(op.getValue().getType()) || (usedByShift)); + return isa(op.getValue().getType()); }); - patterns.add< - ConvertQuartConstantOp, ConvertQuartExt, - ConvertQuartExt, ConvertQuartAddI, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny>( - typeConverter, context); + patterns + .add, + ConvertQuartExt, ConvertQuartAddI, + ConvertQuartMulI, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny>(typeConverter, context); addStructuralConversionPatterns(typeConverter, patterns, target); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); } + + // Remove the unnecessary tensor ops between each converted arith operation. + OpPassManager pipeline("builtin.module"); + pipeline.addPass(createCSEPass()); + (void)runPipeline(pipeline, getOperation()); } }; diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index a43d3a2376..523e9d09eb 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -215,13 +215,9 @@ struct ConvertLut3Op : public OpConversionPattern { serverKey, adaptor.getLookupTable()); // Construct input = c << 2 + b << 1 + a auto shiftedC = b.create( - serverKey, adaptor.getC(), - b.create(b.getI8Type(), b.getI8IntegerAttr(2)) - .getResult()); + serverKey, adaptor.getC(), b.getIndexAttr(2)); auto shiftedB = b.create( - serverKey, adaptor.getB(), - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + serverKey, adaptor.getB(), b.getIndexAttr(1)); auto summedBC = b.create(serverKey, shiftedC, shiftedB); auto summedABC = b.create(serverKey, summedBC, adaptor.getA()); @@ -251,9 +247,7 @@ struct ConvertLut2Op : public OpConversionPattern { serverKey, adaptor.getLookupTable()); // Construct input = b << 1 + a auto shiftedB = b.create( - serverKey, adaptor.getB(), - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + serverKey, adaptor.getB(), b.getIndexAttr(1)); auto summedBA = b.create(serverKey, shiftedB, adaptor.getA()); @@ -277,10 +271,8 @@ static LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs, auto lutOp = b.create(serverKey, lookupTable); // Construct input = rhs << 1 + lhs - auto shiftedRhs = b.create( - serverKey, rhs, - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + auto shiftedRhs = + b.create(serverKey, rhs, b.getIndexAttr(1)); auto input = b.create(serverKey, shiftedRhs, lhs); rewriter.replaceOp( op, b.create(serverKey, input, lutOp)); @@ -348,14 +340,14 @@ struct ConvertXorOp : public OpConversionPattern { } }; -struct ConvertShROp : public OpConversionPattern { +struct ConvertShROp : public OpConversionPattern { ConvertShROp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) {} using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - cggi::ShiftRightOp op, OpAdaptor adaptor, + cggi::ScalarShiftRightOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); FailureOr result = getContextualServerKey(op); @@ -479,10 +471,15 @@ struct ConvertTrivialOp : public OpConversionPattern { auto constantWidth = op.getValue().getValue().getBitWidth(); auto cteOp = rewriter.create( - op.getLoc(), rewriter.getIntegerType(constantWidth), inputValue); + op.getLoc(), op.getValue().getType(), inputValue); auto outputType = encrytpedUIntTypeFromWidth(getContext(), constantWidth); + if (auto rankedTensorTy = dyn_cast(op.getResult().getType())) { + auto shape = rankedTensorTy.getShape(); + outputType = RankedTensorType::get(shape, outputType); + } + auto createTrivialOp = rewriter.create( op.getLoc(), outputType, serverKey, cteOp); rewriter.replaceOp(op, createTrivialOp); @@ -538,11 +535,11 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { target.addDynamicallyLegalOp< memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::LoadOp, memref::SubViewOp, memref::CopyOp, affine::AffineLoadOp, - affine::AffineStoreOp, tensor::FromElementsOp, tensor::ExtractOp>( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + tensor::InsertOp, tensor::InsertSliceOp, affine::AffineStoreOp, + tensor::FromElementsOp, tensor::ExtractOp>([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); // FIXME: still need to update callers to insert the new server key arg, if // needed and possible. @@ -556,6 +553,7 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny>( typeConverter, context); diff --git a/lib/Dialect/CGGI/IR/CGGIOps.td b/lib/Dialect/CGGI/IR/CGGIOps.td index c9c7e38d84..f1ccc04b9c 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.td +++ b/lib/Dialect/CGGI/IR/CGGIOps.td @@ -295,18 +295,18 @@ def CGGI_SubOp : CGGI_Op<"sub", [ } -def CGGI_ShiftRightOp : CGGI_Op<"shr", [ +def CGGI_ScalarShiftRightOp : CGGI_Op<"sshr", [ Pure, ]> { - let arguments = (ins LWECiphertextLike:$lhs, AnyI8:$shiftAmount); + let arguments = (ins LWECiphertextLike:$lhs, IndexAttr:$shiftAmount); let results = (outs LWECiphertextLike:$output); let summary = "Arithmetic shift to the right of a ciphertext by an integer. Note this operations to mirror the TFHE-rs implmementation."; } -def CGGI_ShiftLeftOp : CGGI_Op<"shl", [ +def CGGI_ScalarShiftLeftOp : CGGI_Op<"sshl", [ Pure ]> { - let arguments = (ins LWECiphertextLike:$lhs, AnyI8:$shiftAmount); + let arguments = (ins LWECiphertextLike:$lhs, IndexAttr:$shiftAmount); let results = (outs LWECiphertextLike:$output); let summary = "Arithmetic shift to left of a ciphertext by an integer. Note this operations to mirror the TFHE-rs implmementation."; } diff --git a/lib/Dialect/TfheRust/IR/TfheRustOps.td b/lib/Dialect/TfheRust/IR/TfheRustOps.td index 43c2714626..9c28875bf3 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustOps.td +++ b/lib/Dialect/TfheRust/IR/TfheRustOps.td @@ -24,15 +24,15 @@ class TfheRust_BinaryOp ]> { let arguments = (ins TfheRust_ServerKey:$serverKey, - TfheRust_CiphertextType:$lhs, - TfheRust_CiphertextType:$rhs + TfheRust_CiphertextLikeType:$lhs, + TfheRust_CiphertextLikeType:$rhs ); - let results = (outs TfheRust_CiphertextType:$output); + let results = (outs TfheRust_CiphertextLikeType:$output); } def TfheRust_CreateTrivialOp : TfheRust_Op<"create_trivial", [Pure]> { let arguments = (ins TfheRust_ServerKey:$serverKey, AnyInteger:$value); - let results = (outs TfheRust_CiphertextType:$output); + let results = (outs TfheRust_CiphertextLikeType:$output); let hasCanonicalizer = 1; } @@ -49,7 +49,7 @@ def TfheRust_ScalarLeftShiftOp : TfheRust_Op<"scalar_left_shift", [ let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, - AnyI8:$shiftAmount + IndexAttr:$shiftAmount ); let results = (outs TfheRust_CiphertextType:$output); } @@ -61,7 +61,7 @@ def TfheRust_ScalarRightShiftOp : TfheRust_Op<"scalar_right_shift", [ let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, - AnyI8:$shiftAmount + IndexAttr:$shiftAmount ); let results = (outs TfheRust_CiphertextType:$output); } diff --git a/lib/Dialect/TfheRust/IR/TfheRustTypes.td b/lib/Dialect/TfheRust/IR/TfheRustTypes.td index 6ab2f6a6b6..53bec97dd0 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/lib/Dialect/TfheRust/IR/TfheRustTypes.td @@ -65,6 +65,7 @@ def TfheRust_CiphertextType : TfheRust_EncryptedInt256, ]>; +def TfheRust_CiphertextLikeType : TypeOrContainer; def TfheRust_ServerKey : TfheRust_Type<"ServerKey", "server_key", [PassByReference]> { let summary = "The short int server key required to perform homomorphic operations."; diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index 6fd1495f06..bb346bb1e8 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -492,11 +492,7 @@ std::string TfheRustEmitter::operationType(Operation *op) { "\")"; }) .Case([&](ScalarLeftShiftOp op) { - auto constantShift = - cast(op.getShiftAmount().getDefiningOp()); - return "LSH(" + - std::to_string( - cast(constantShift.getValue()).getInt()) + + return "LSH(" + std::to_string(op.getShiftAmount().getSExtValue()) + ")"; }) .Case([&](Operation *) { return "ADD"; }); @@ -518,9 +514,17 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) { } LogicalResult TfheRustEmitter::printOperation(ScalarLeftShiftOp op) { - return printSksMethod(op.getResult(), op.getServerKey(), - {op.getCiphertext(), op.getShiftAmount()}, - "scalar_left_shift", {"", "u8"}); + emitAssignPrefix(op.getResult()); + os << variableNames->getNameForValue(op.getServerKey()) + << ".scalar_left_shift("; + + auto valueStr = variableNames->getNameForValue(op.getCiphertext()); + std::string prefix = + op.getCiphertext().getType().hasTrait() ? "&" : ""; + auto cipherString = prefix + valueStr; + + os << cipherString << ", " << op.getShiftAmount() << " as u8);\n"; + return success(); } LogicalResult TfheRustEmitter::printOperation(CreateTrivialOp op) { diff --git a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp index 9c9ea73539..de8e3d67b7 100644 --- a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp +++ b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp @@ -561,8 +561,12 @@ LogicalResult TfheRustHLEmitter::printOperation(SubOp op) { } LogicalResult TfheRustHLEmitter::printOperation(ScalarRightShiftOp op) { - return printBinaryOp(op.getResult(), op.getCiphertext(), op.getShiftAmount(), - ">>"); + emitAssignPrefix(op.getResult()); + + os << checkOrigin(op.getCiphertext()) + << variableNames->getNameForValue(op.getCiphertext()) << " >> " + << op.getShiftAmount() << "u8;\n"; + return success(); } LogicalResult TfheRustHLEmitter::printOperation(CastOp op) { diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir index dfc88b26ae..01fb3fe105 100644 --- a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir @@ -66,6 +66,8 @@ func.func @test_affine(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32> %25 = arith.muli %0, %c33_i8 : i32 %26 = arith.addi %c429_i32, %25 : i32 + %c2 = arith.constant 2 : i32 + %27 = arith.shrui %26, %c2 : i32 affine.store %26, %alloc[0, 0] : memref<1x1xi32> return %alloc : memref<1x1xi32> } diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir index 363f50da95..f72e9243af 100644 --- a/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir @@ -1,10 +1,10 @@ // RUN: heir-opt --arith-to-cggi-quart %s | FileCheck %s // CHECK: return %[[RET:.*]] tensor<4x!lwe.lwe_ciphertext> -func.func @test_simple_split2(%arg0: i32, %arg1: i16) -> i32 { - %2 = arith.constant 31 : i16 - %5 = arith.addi %arg1, %2 : i16 - %6 = arith.extui %5 : i16 to i32 - %7 = arith.addi %arg0, %6 : i32 - return %6 : i32 +func.func @test_simple_split2(%arg0: i32, %arg1: i32) -> i32 { + %2 = arith.constant 31 : i8 + %1 = arith.extui %2 : i8 to i32 + %5 = arith.addi %arg1, %1 : i32 + %7 = arith.muli %arg0, %5 : i32 + return %7 : i32 } diff --git a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir index 226ce55e5c..13c8cbe781 100644 --- a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir +++ b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir @@ -14,6 +14,7 @@ func.func @test_affine(%arg0: memref<1x1x!ct_ty>) -> memref<1x1x!ct_ty> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x!ct_ty> %3 = cggi.mul %2, %1 : !ct_ty %4 = cggi.add %3, %0 : !ct_ty - affine.store %4, %alloc[0, 0] : memref<1x1x!ct_ty> + %5 = cggi.sshr %4 {shiftAmount = 2 : index} : (!ct_ty) -> !ct_ty + affine.store %5, %alloc[0, 0] : memref<1x1x!ct_ty> return %alloc : memref<1x1x!ct_ty> } diff --git a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir index 89c5e274ff..b16c931124 100644 --- a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir +++ b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir @@ -7,9 +7,7 @@ // CHECK-SAME: %[[sks:.*]]: [[sks_ty:!tfhe_rust.server_key]], %[[arg1:.*]]: [[ct_ty:!tfhe_rust.eui3]], %[[arg2:.*]]: [[ct_ty]] func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) { // CHECK: %[[v0:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 8 : ui4} - // CHECK: %[[shiftAmount:.*]] = arith.constant 1 : i8 - - // CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]], %[[shiftAmount]] + // CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]] {shiftAmount = 1 : index} // CHECK: %[[v2:.*]] = tfhe_rust.add %[[sks]], %[[v1]], %[[arg1]] // CHECK: %[[v3:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v0]] %0 = cggi.and %arg1, %arg2 : !ct_ty @@ -25,7 +23,7 @@ func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) { %2 = cggi.not %1 : !ct_ty // CHECK: %[[v8:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 6 : ui4} - // CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]], %[[shiftAmount]] + // CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]] {shiftAmount = 1 : index} // CHECK: %[[v10:.*]] = tfhe_rust.add %[[sks]], %[[v9]], %[[v7]] // CHECK: %[[v11:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v10]], %[[v8]] %3 = cggi.xor %2, %0 : !ct_ty diff --git a/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir b/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir index 62093f8609..ca3895a3b4 100644 --- a/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir +++ b/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir @@ -14,11 +14,10 @@ // CHECK: temp_nodes[ // CHECK-NEXT: } func.func @test_levelled_op(%sks : !sks, %lut: !lut, %input1 : !eui3, %input2 : !eui3) -> !eui3 { - %c1 = arith.constant 1 : i8 %v0 = tfhe_rust.apply_lookup_table %sks, %input1, %lut : (!sks, !eui3, !lut) -> !eui3 %v1 = tfhe_rust.apply_lookup_table %sks, %input2, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %v0, %v1 : (!sks, !eui3, !eui3) -> !eui3 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } @@ -44,7 +43,7 @@ func.func @test_levelled_op_break(%sks : !sks, %lut: !lut, %input1 : !eui3, %inp %v1 = tfhe_rust.apply_lookup_table %sks, %input2, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %v0, %v1 : (!sks, !eui3, !eui3) -> !eui3 %c1 = arith.constant 1 : i8 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } diff --git a/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir b/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir index 52b786c039..a9793af95a 100644 --- a/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir +++ b/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir @@ -38,16 +38,14 @@ func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> ! // CHECK-NEXT: ) -> Ciphertext { // CHECK: let [[v1:.*]] = [[sks]].apply_lookup_table(&[[input]], &[[lut]]); // CHECK: let [[v2:.*]] = [[sks]].unchecked_add(&[[input]], &[[v1]]); -// CHECK: let [[c1:.*]] = 1; -// CHECK: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1]] as u8); +// CHECK: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1:.*]] as u8); // CHECK: let [[v4:.*]] = [[sks]].apply_lookup_table(&[[v3]], &[[lut]]); // CHECK-NEXT: [[v4]] // CHECK-NEXT: } func.func @test_apply_lookup_table2(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 { %v1 = tfhe_rust.apply_lookup_table %sks, %input, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %input, %v1 : (!sks, !eui3, !eui3) -> !eui3 - %c1 = arith.constant 1 : i8 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } diff --git a/tests/Dialect/TfheRust/IR/ops.mlir b/tests/Dialect/TfheRust/IR/ops.mlir index a7284621bf..9283e6ca21 100644 --- a/tests/Dialect/TfheRust/IR/ops.mlir +++ b/tests/Dialect/TfheRust/IR/ops.mlir @@ -38,8 +38,7 @@ module { %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Dialect/TfheRust/Transforms/canonicalize.mlir b/tests/Dialect/TfheRust/Transforms/canonicalize.mlir index 27941ea526..d9434b8de1 100644 --- a/tests/Dialect/TfheRust/Transforms/canonicalize.mlir +++ b/tests/Dialect/TfheRust/Transforms/canonicalize.mlir @@ -6,14 +6,11 @@ module { // CHECK-LABEL: func @test_move_create_trivial func.func @test_move_create_trivial(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 { // CHECK: arith.constant - // CHECK-NEXT: arith.constant // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: tfhe_rust.create_trivial %0 = arith.constant 1 : i3 - %1 = arith.constant 2 : i3 %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -23,19 +20,16 @@ module { // CHECK-LABEL: func @test_move_out_of_loop func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> { // CHECK: arith.constant - // CHECK-NEXT: arith.constant // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: memref.alloc // CHECK-NEXT: affine.for %0 = arith.constant 1 : i3 - %1 = arith.constant 2 : i3 %memref = memref.alloca() : memref<10x!tfhe_rust.eui3> affine.for %i = 0 to 10 { %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Examples/tfhe_rust/test_simple_lut.mlir b/tests/Examples/tfhe_rust/test_simple_lut.mlir index e8fc8fe3bf..22aa7a113c 100644 --- a/tests/Examples/tfhe_rust/test_simple_lut.mlir +++ b/tests/Examples/tfhe_rust/test_simple_lut.mlir @@ -11,8 +11,7 @@ // CHECK: 1 func.func @fn_under_test(%sks : !sks, %a: !eui3, %b: !eui3) -> !eui3 { %lut = tfhe_rust.generate_lookup_table %sks {truthTable = 7 : ui8} : (!sks) -> !lut - %c1 = arith.constant 1 : i8 - %0 = tfhe_rust.scalar_left_shift %sks, %a, %c1 : (!sks, !eui3, i8) -> !eui3 + %0 = tfhe_rust.scalar_left_shift %sks, %a {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %1 = tfhe_rust.add %sks, %0, %b : (!sks, !eui3, !eui3) -> !eui3 %2 = tfhe_rust.apply_lookup_table %sks, %1, %lut : (!sks, !eui3, !lut) -> !eui3 return %2 : !eui3 diff --git a/tests/Transforms/forward_store_to_load/forward_add_one.mlir b/tests/Transforms/forward_store_to_load/forward_add_one.mlir index 0b0912f5be..3bfaca09ff 100644 --- a/tests/Transforms/forward_store_to_load/forward_add_one.mlir +++ b/tests/Transforms/forward_store_to_load/forward_add_one.mlir @@ -32,8 +32,8 @@ module { %2 = tfhe_rust.create_trivial %arg0, %false : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %3 = tfhe_rust.create_trivial %arg0, %0 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %4 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 8 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %5 = tfhe_rust.scalar_left_shift %arg0, %2, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %6 = tfhe_rust.scalar_left_shift %arg0, %3, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %5 = tfhe_rust.scalar_left_shift %arg0, %2 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %6 = tfhe_rust.scalar_left_shift %arg0, %3 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %7 = tfhe_rust.add %arg0, %5, %6 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %8 = tfhe_rust.add %arg0, %7, %1 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %9 = tfhe_rust.apply_lookup_table %arg0, %8, %4 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -41,8 +41,8 @@ module { %11 = memref.load %arg1[%c1] : memref<8x!tfhe_rust.eui3> %12 = tfhe_rust.create_trivial %arg0, %10 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %13 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 150 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %14 = tfhe_rust.scalar_left_shift %arg0, %12, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %15 = tfhe_rust.scalar_left_shift %arg0, %11, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %14 = tfhe_rust.scalar_left_shift %arg0, %12 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %15 = tfhe_rust.scalar_left_shift %arg0, %11 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %16 = tfhe_rust.add %arg0, %14, %15 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %17 = tfhe_rust.add %arg0, %16, %9 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %18 = tfhe_rust.apply_lookup_table %arg0, %17, %13 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -52,32 +52,32 @@ module { %22 = memref.load %arg1[%c2] : memref<8x!tfhe_rust.eui3> %23 = tfhe_rust.create_trivial %arg0, %21 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %24 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 43 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %25 = tfhe_rust.scalar_left_shift %arg0, %23, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %26 = tfhe_rust.scalar_left_shift %arg0, %22, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %25 = tfhe_rust.scalar_left_shift %arg0, %23 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %26 = tfhe_rust.scalar_left_shift %arg0, %22 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %27 = tfhe_rust.add %arg0, %25, %26 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %28 = tfhe_rust.add %arg0, %27, %20 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %29 = tfhe_rust.apply_lookup_table %arg0, %28, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %30 = memref.load %alloc[%c3] : memref<8xi1> %31 = memref.load %arg1[%c3] : memref<8x!tfhe_rust.eui3> %32 = tfhe_rust.create_trivial %arg0, %30 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %33 = tfhe_rust.scalar_left_shift %arg0, %32, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %34 = tfhe_rust.scalar_left_shift %arg0, %31, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %33 = tfhe_rust.scalar_left_shift %arg0, %32 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %34 = tfhe_rust.scalar_left_shift %arg0, %31 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %35 = tfhe_rust.add %arg0, %33, %34 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %36 = tfhe_rust.add %arg0, %35, %29 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %37 = tfhe_rust.apply_lookup_table %arg0, %36, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %38 = memref.load %alloc[%c4] : memref<8xi1> %39 = memref.load %arg1[%c4] : memref<8x!tfhe_rust.eui3> %40 = tfhe_rust.create_trivial %arg0, %38 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %41 = tfhe_rust.scalar_left_shift %arg0, %40, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %42 = tfhe_rust.scalar_left_shift %arg0, %39, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %41 = tfhe_rust.scalar_left_shift %arg0, %40 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %42 = tfhe_rust.scalar_left_shift %arg0, %39 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %43 = tfhe_rust.add %arg0, %41, %42 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %44 = tfhe_rust.add %arg0, %43, %37 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %45 = tfhe_rust.apply_lookup_table %arg0, %44, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %46 = memref.load %alloc[%c5] : memref<8xi1> %47 = memref.load %arg1[%c5] : memref<8x!tfhe_rust.eui3> %48 = tfhe_rust.create_trivial %arg0, %46 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %49 = tfhe_rust.scalar_left_shift %arg0, %48, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %50 = tfhe_rust.scalar_left_shift %arg0, %47, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %49 = tfhe_rust.scalar_left_shift %arg0, %48 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %50 = tfhe_rust.scalar_left_shift %arg0, %47 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %51 = tfhe_rust.add %arg0, %49, %50 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %52 = tfhe_rust.add %arg0, %51, %45 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %53 = tfhe_rust.apply_lookup_table %arg0, %52, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -85,8 +85,8 @@ module { %55 = memref.load %arg1[%c6] : memref<8x!tfhe_rust.eui3> %56 = tfhe_rust.create_trivial %arg0, %54 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %57 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 105 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %58 = tfhe_rust.scalar_left_shift %arg0, %56, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %59 = tfhe_rust.scalar_left_shift %arg0, %55, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %58 = tfhe_rust.scalar_left_shift %arg0, %56 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %59 = tfhe_rust.scalar_left_shift %arg0, %55 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %60 = tfhe_rust.add %arg0, %58, %59 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %61 = tfhe_rust.add %arg0, %60, %53 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %62 = tfhe_rust.apply_lookup_table %arg0, %61, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -94,8 +94,8 @@ module { %64 = memref.load %alloc[%c7] : memref<8xi1> %65 = memref.load %arg1[%c7] : memref<8x!tfhe_rust.eui3> %66 = tfhe_rust.create_trivial %arg0, %64 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %67 = tfhe_rust.scalar_left_shift %arg0, %66, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %68 = tfhe_rust.scalar_left_shift %arg0, %65, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %67 = tfhe_rust.scalar_left_shift %arg0, %66 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %68 = tfhe_rust.scalar_left_shift %arg0, %65 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %69 = tfhe_rust.add %arg0, %67, %68 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %70 = tfhe_rust.add %arg0, %69, %63 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %71 = tfhe_rust.apply_lookup_table %arg0, %70, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Transforms/loop_unroll/full_loop_unroll.mlir b/tests/Transforms/loop_unroll/full_loop_unroll.mlir index 4170b2322e..ad4eaa5f30 100644 --- a/tests/Transforms/loop_unroll/full_loop_unroll.mlir +++ b/tests/Transforms/loop_unroll/full_loop_unroll.mlir @@ -11,8 +11,7 @@ func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> affine.for %i = 0 to 10 { %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3