diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index 53afd32e13..ca661ef0b2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -31,9 +31,10 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate the multiplication of a noisy plaintext with an integer /// @@ -41,9 +42,10 @@ uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate a keyswitch on a noisy plaintext /// diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index ae88c00c32..34a21ae383 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -57,6 +57,15 @@ inline void forwardOptimizerID(mlir::Operation *source, destination->setAttr("TFHE.OId", optimizerIdAttr); } +// Set the `signed` attribute to true if the type is signed +inline void markOpIfSigned(mlir::Operation *op, + FHE::FheIntegerInterface resultType) { + auto isSigned = resultType.isSigned(); + if (isSigned) { + op->setAttr("signed", mlir::BoolAttr::get(op->getContext(), true)); + } +} + inline void forwardLinearlyOptimizerIDS(mlir::Operation &source, std::vector &destinations) { @@ -185,6 +194,29 @@ struct AddEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), encodedInt); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); + + return mlir::success(); + } +}; + +/// Rewriter for the `FHE::add_eint` operation. +struct AddEintOpPattern : public mlir::OpConversionPattern { + AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpConversionPattern(converter, context, benefit) { + } + + mlir::LogicalResult + matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + // Write the new op + auto newOp = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + adaptor.getOperands()); + forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); } @@ -225,6 +257,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), encodedInt); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -252,6 +285,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), encodedInt, adaptor.getB()); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -281,6 +315,7 @@ struct SubEintOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), lhsOperand, negative.getResult()); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -310,6 +345,7 @@ struct MulEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), eintOperand, castedCleartext); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); } @@ -804,12 +840,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { FHE::NegEintOp, TFHE::NegGLWEOp, true>, // |_ `FHE::not` mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::BoolNotOp, TFHE::NegGLWEOp, true>, - // |_ `FHE::add_eint` - mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter); + FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter); // |_ `FHE::add_eint_int` patterns.add { } }; +int locationStringCtr = 0; mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc) { std::string locString; auto ros = llvm::raw_string_ostream(locString); loc.print(ros); - - std::string msgName; - std::stringstream stream; - stream << "loc_" << rand(); - stream >> msgName; - return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString, - mlir::LLVM::linkage::Linkage::Linkonce, - false); + locString.append("\0"); + auto locStrWithNullByte = + llvm::StringRef(locString.c_str(), locString.size() + 1); + + std::stringstream msgName; + msgName << "str_loc_" << locationStringCtr++; + return mlir::LLVM::createGlobalString( + loc, rewriter, msgName.str(), locStrWithNullByte, + mlir::LLVM::linkage::Linkage::Linkonce, false); } template @@ -122,12 +124,21 @@ struct AddOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_add_lwe_u64"; auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + addOp.getLoc(), isSigned, 1); if (insertForwardDeclaration( addOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -135,7 +146,8 @@ struct AddOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -156,11 +168,21 @@ struct MulOpPattern : public mlir::OpConversionPattern { auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + mulOp.getLoc(), isSigned, 1); + if (insertForwardDeclaration( mulOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -168,7 +190,8 @@ struct MulOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -186,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { mlir::Value negated = rewriter.create( subOp.getLoc(), subOp.getB().getType(), subOp.getB()); - rewriter.replaceOpWithNewOp(subOp, subOp.getType(), - negated, subOp.getA()); + rewriter.replaceOpWithNewOp( + subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}), + // to forward the signed attr if set + subOp.getOperation()->getAttrs()); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index f537483141..6dc43c7334 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -92,14 +92,30 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, else out = -tlu[mod_switched % poly_size]; + // get encoded info from lsb + bool is_signed = (out >> 1) & 1; + bool is_overflow = out & 1; + // discard info bits (2 lsb) + out = out & 18446744073709551612U; + + if (!is_signed && out > UINT63_MAX) { + printf("WARNING at %s: overflow (padding bit) happened during LUT in " + "simulation\n", + loc); + } + if (is_overflow) { + printf("WARNING at %s: overflow (original value didn't fit, so a modulus " + "was applied) happened " + "during LUT in " + "simulation\n", + loc); + } + double variance_bsk = security_curve()->getVariance(glwe_dim, poly_size, 64); double variance = concrete_cpu_variance_blind_rotate( input_lwe_dim, glwe_dim, poly_size, base_log, level, 64, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); out = out + gaussian_noise(0, variance); - if (out > UINT63_MAX) { - printf("WARNING at %s: overflow happened during LUT in simulation\n", loc); - } return out; } @@ -189,33 +205,145 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (lhs > UINT63_MAX - rhs) { - printf("WARNING at %s: overflow happened during addition in simulation\n", - loc); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + const char msg_f[] = + "WARNING at %s: overflow happened during addition in simulation\n"; + + uint64_t result = lhs + rhs; + + if (is_signed) { + // We shift left to discard the padding bit and only consider the message + // for easier overflow checking + int64_t lhs_signed = (int64_t)lhs << 1; + int64_t rhs_signed = (int64_t)rhs << 1; + if (lhs_signed > 0 && rhs_signed > INT64_MAX - lhs_signed) + printf(msg_f, loc); + else if (lhs_signed < 0 && rhs_signed < INT64_MIN - lhs_signed) + printf(msg_f, loc); + } else if (lhs > UINT63_MAX - rhs || result > UINT63_MAX) { + printf(msg_f, loc); } - return lhs + rhs; + return result; } -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (rhs != 0 && lhs > UINT63_MAX / rhs) { - printf("WARNING at %s: overflow happened during multiplication in " - "simulation\n", - loc); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + const char msg_f[] = + "WARNING at %s: overflow happened during multiplication in simulation\n"; + + uint64_t result = lhs * rhs; + + if (is_signed) { + // We shift left to discard the padding bit and only consider the message + // for easier overflow checking + int64_t lhs_signed = (int64_t)lhs << 1; + int64_t rhs_signed = (int64_t)rhs << 1; + if (lhs_signed != 0 && rhs_signed > INT64_MAX / lhs_signed) + printf(msg_f, loc); + else if (lhs_signed != 0 && rhs_signed < INT64_MIN / lhs_signed) + printf(msg_f, loc); + } else if (rhs != 0 && lhs > UINT63_MAX / rhs) { + printf(msg_f, loc); } - return lhs * rhs; + return result; } +// a copy of memref_encode_expand_lut_for_bootstrap but which encodes overflow +// and sign info into the LUT. Those information should later be discarder by +// the LUT function void sim_encode_expand_lut_for_boostrap( - uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, - uint64_t out_size, uint64_t out_stride, uint64_t *in_allocated, - uint64_t *in_aligned, uint64_t in_offset, uint64_t in_size, - uint64_t in_stride, uint32_t poly_size, uint32_t output_bits, - bool is_signed) { - return memref_encode_expand_lut_for_bootstrap( - out_allocated, out_aligned, out_offset, out_size, out_stride, - in_allocated, in_aligned, in_offset, in_size, in_stride, poly_size, - output_bits, is_signed); + uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, + uint64_t output_lut_offset, uint64_t output_lut_size, + uint64_t output_lut_stride, uint64_t *input_lut_allocated, + uint64_t *input_lut_aligned, uint64_t input_lut_offset, + uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size, + uint32_t out_MESSAGE_BITS, bool is_signed) { + + assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_encode_expand_lut_bootstrap"); + + assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_encode_expand_lut_bootstrap"); + + size_t mega_case_size = output_lut_size / input_lut_size; + + assert((mega_case_size % 2) == 0); + + // compute overflow bit + std::vector overflow_info(output_lut_size, false); + uint64_t upper_bound = uint64_t(1) + << (out_MESSAGE_BITS + (is_signed ? 1 : 0)); + for (size_t i = 0; i < input_lut_size; i++) { + if (input_lut_aligned[input_lut_offset + i] >= upper_bound) { + overflow_info[i] = true; + } else { + overflow_info[i] = false; + } + } + // used to set the sign bit or not + uint64_t sign_bit_setter = 0; + if (is_signed) { + sign_bit_setter = 2; + } + + // When the bootstrap is executed on encrypted signed integers, the lut must + // be half-rotated. This map takes care about properly indexing into the input + // lut depending on what bootstrap gets executed. + std::function indexMap; + if (is_signed) { + size_t halfInputSize = input_lut_size / 2; + indexMap = [=](size_t idx) { + if (idx < halfInputSize) { + return idx + halfInputSize; + } else { + return idx - halfInputSize; + } + }; + } else { + indexMap = [=](size_t idx) { return idx; }; + } + + // The first lut value should be centered over zero. This means that half of + // it should appear at the beginning of the output lut, and half of it at the + // end (but negated). + for (size_t idx = 0; idx < mega_case_size / 2; ++idx) { + output_lut_aligned[output_lut_offset + idx] = + input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1); + // set the sign bit + output_lut_aligned[output_lut_offset + idx] |= sign_bit_setter; + // set the overflow bit + output_lut_aligned[output_lut_offset + idx] |= (uint64_t)overflow_info[0]; + } + for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2; + idx < output_lut_size; ++idx) { + output_lut_aligned[output_lut_offset + idx] = + -(input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1)); + // set the sign bit + output_lut_aligned[output_lut_offset + idx] |= sign_bit_setter; + // set the overflow bit + output_lut_aligned[output_lut_offset + idx] |= + (uint64_t)overflow_info[indexMap(0)]; + } + + // Treats the other ut values. + for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) { + uint64_t lut_value = input_lut_aligned[input_lut_offset + indexMap(lut_idx)] + << (64 - out_MESSAGE_BITS - 1); + // set the sign bit + lut_value |= sign_bit_setter; + // set the overflow bit + lut_value |= (uint64_t)overflow_info[indexMap(lut_idx)]; + size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2; + for (size_t output_idx = start; output_idx < start + mega_case_size; + ++output_idx) { + output_lut_aligned[output_lut_offset + output_idx] = lut_value; + } + } + + return; } void sim_encode_plaintext_with_crt(uint64_t *output_allocated, diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 79306b76f4..980d4db62e 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -436,8 +436,10 @@ mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, if (fheContext) { auto solution = fheContext.value().solution; auto optCrt = getCrtDecompositionFromSolution(solution); - if (optCrt) + if (optCrt) { enableOverflowDetection = false; + log_verbose() << "WARNING: overflow detection disabled since using CRT"; + } } pipelinePrinting("TFHESimulation", pm, context); diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index c083d94d2d..4755b7a4f9 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -21,7 +21,7 @@ def assert_result(result, expected_result): """ assert type(expected_result) == type(result) if isinstance(expected_result, int): - assert result == expected_result + assert result == expected_result, f"{result} != {expected_result}" else: assert np.all(result == expected_result) @@ -258,6 +258,186 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-1, -2), + -3, + b"", + id="add_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-60, -20), + -80, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (60, 20), + -48, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_int_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (81, 73), + 154, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-81, 73), + -8, + b"", + id="add_eint_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-60, -20), + -80, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (81, 73), + -102, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (4, 7), + 256 - 3, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (4, 7), + -3, + b"", + id="sub_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-37, 40), + -77, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (33, -40), + -55, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (11, 18), + 256 - 7, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (11, 18), + -7, + b"", + id="sub_eint_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-44, 32), + -76, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (61, -25), + -42, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_signed_overflow", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { @@ -270,18 +450,93 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + %2 = "FHE.mul_eint_int"(%1, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %2: !FHE.eint<7> + } + """, + (5, 10), + 256 - 50, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\nWARNING at loc("-":4:22): overflow happened during multiplication in simulation\n', + id="sub_mul_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (5, -2), + -10, + b"", + id="mul_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-33, 5), + -37, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', + id="mul_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-33, -5), + -91, + b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', + id="mul_eint_int_signed_overflow", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - %tlu = arith.constant dense<[0, 140, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %tlu = arith.constant dense<[0, 1420, -2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, (1,), 140, - b'WARNING at loc("-":4:22): overflow happened during LUT in simulation\n', - id="apply_lookup_table", + b'WARNING at loc("-":4:22): overflow (padding bit) happened during LUT in simulation\nWARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', + id="apply_lookup_table_big_value", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> { + %tlu = arith.constant dense<[0, 1400, 254, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.esint<7>, tensor<128xi64>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (2,), + -2, + b"", + id="apply_lookup_table_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> { + %tlu = arith.constant dense<[0, 1400, -2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.esint<7>, tensor<128xi64>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (1,), + -8, + b'WARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', + id="apply_lookup_table_signed_big_value", ), ]