Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ptxt-ctxt support to arith-to-cggi and cggi ops #1285

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 103 additions & 14 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type,
lwe::UnspecifiedBitFieldEncodingAttr::get(
ctx, type.getIntOrFloatBitWidth()),
lwe::LWEParamsAttr());
;
}

static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) {
Expand Down Expand Up @@ -186,6 +185,43 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ConvertArithBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(lhsDefOp)) {
WoutLegiest marked this conversation as resolved.
Show resolved Hide resolved
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -196,29 +232,82 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
ConversionTarget target(*context);
target.addLegalDialect<cggi::CGGIDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return isa<mlir::arith::ConstantOp>(defOp);
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
return (isa<IndexType>(op.getValue().getType()));
target.addDynamicallyLegalOp<memref::SubViewOp, memref::CopyOp,
tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>(
[&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

target.addDynamicallyLegalOp<
memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp,
memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) {
target.addDynamicallyLegalOp<memref::AllocOp>([&](Operation *op) {
// Check if all Store ops are constants, if not store op, accepts
// Check if there is at least one Store op that is a constants
return (llvm::all_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return true;
}) &&
llvm::any_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return false;
})) ||
// The other case: Memref need to be in LWE format
(typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes()));
});

target.addDynamicallyLegalOp<memref::StoreOp>([&](Operation *op) {
if (auto *defOp = cast<memref::StoreOp>(op).getValue().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(defOp)) {
return true;
}
}

return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

// Convert LoadOp if memref comes from an argument
target.addDynamicallyLegalOp<memref::LoadOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}
auto loadOp = dyn_cast<memref::LoadOp>(op);

return loadOp.getMemRef().getDefiningOp() != nullptr;
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
ConvertShRUIOp, ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertShRUIOp, ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp> >(
typeConverter, context);
Expand Down
49 changes: 25 additions & 24 deletions lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
SmallVector<Value> outputs;

for (int i = 0; i < splitLhs.size(); ++i) {
auto lowSum = b.create<cggi::AddOp>(splitLhs[i], splitRhs[i]);
auto lowSum = b.create<cggi::AddOp>(elemType, splitLhs[i], splitRhs[i]);
auto outputLsb = b.create<cggi::CastOp>(op.getLoc(), realTy, lowSum);
auto outputLsbHigh =
b.create<cggi::CastOp>(op.getLoc(), elemType, outputLsb);
Expand All @@ -365,7 +365,8 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
if (i == 0) {
outputs.push_back(outputLsbHigh);
} else {
auto high = b.create<cggi::AddOp>(outputLsbHigh, carries[i - 1]);
auto high =
b.create<cggi::AddOp>(elemType, outputLsbHigh, carries[i - 1]);
outputs.push_back(high);
}
}
Expand Down Expand Up @@ -413,35 +414,35 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {

// TODO: Implement the real Karatsuba algorithm for 4x4 multiplication.
// First part of Karatsuba algorithm
auto z00 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(z01_p1, z01_p2);
auto z00 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(z1a1_p1, z1a1_p2);
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(z1b1_m, z1b0);
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(elemTy, splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(z1a1, z1b1);
auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(elemTy, z1a1, z1b1);

// Output are now all 16b elements, wants presentation of 4x8b
auto output0Lsb = b.create<cggi::CastOp>(realTy, z00);
Expand All @@ -462,9 +463,9 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto output3Lsb = b.create<cggi::CastOp>(realTy, out3Carry);
auto output3LsbHigh = b.create<cggi::CastOp>(elemTy, output3Lsb);

auto output1 = b.create<cggi::AddOp>(output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(output3LsbHigh, output2Msb);
auto output1 = b.create<cggi::AddOp>(elemTy, output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(elemTy, output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(elemTy, output3LsbHigh, output2Msb);

Value resultVec = constructResultTensor(
rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3});
Expand Down
40 changes: 26 additions & 14 deletions lib/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include "lib/Dialect/LWE/IR/LWETypes.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -40,15 +41,25 @@ class CGGI_BinaryOp<string mnemonic>
let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

class CGGI_ScalarBinaryOp<string mnemonic>
: CGGI_Op<mnemonic, [
Pure,
Commutative
]> {
let arguments = (ins LWECiphertext:$lhs, AnyTypeOf<[Builtin_Integer, LWECiphertext]>:$rhs);
let results = (outs LWECiphertext:$output);
}

def CGGI_AndOp : CGGI_BinaryOp<"and"> { let summary = "Logical AND of two ciphertexts."; }
def CGGI_NandOp : CGGI_BinaryOp<"nand"> { let summary = "Logical NAND of two ciphertexts."; }
def CGGI_NorOp : CGGI_BinaryOp<"nor"> { let summary = "Logical NOR of two ciphertexts."; }
def CGGI_OrOp : CGGI_BinaryOp<"or"> { let summary = "Logical OR of two ciphertexts."; }
def CGGI_XorOp : CGGI_BinaryOp<"xor"> { let summary = "Logical XOR of two ciphertexts."; }
def CGGI_XNorOp : CGGI_BinaryOp<"xnor"> { let summary = "Logical XNOR of two ciphertexts."; }
def CGGI_AddOp : CGGI_BinaryOp<"add"> { let summary = "Arithmetic addition of two ciphertexts."; }
def CGGI_MulOp : CGGI_BinaryOp<"mul"> {
let summary = "Arithmetic multiplication of two ciphertexts.";

def CGGI_AddOp : CGGI_ScalarBinaryOp<"add"> { let summary = "Arithmetic addition of two ciphertexts. One of the two ciphertext is allowed to be a scalar, this will result in the scalar addition to a ciphertext."; }
def CGGI_MulOp : CGGI_ScalarBinaryOp<"mul"> {
let summary = "Arithmetic multiplication of two ciphertexts. One of the two ciphertext is allowed to be a scalar, this will result in the scalar multiplication to a ciphertext.";
let description = [{
While CGGI does not have a native multiplication operation,
some backend targets provide a multiplication
Expand All @@ -59,6 +70,18 @@ def CGGI_MulOp : CGGI_BinaryOp<"mul"> {
}];
}

def CGGI_SubOp : CGGI_Op<"sub", [
Pure,
SameOperandsAndResultType,
ElementwiseMappable,
Scalarizable
]> {
let arguments = (ins LWECiphertext:$lhs, LWECiphertext:$rhs);
let results = (outs LWECiphertext:$output);
let summary = "Subtraction of two ciphertexts.";
}


def CGGI_NotOp : CGGI_Op<"not", [
Pure,
Involution,
Expand Down Expand Up @@ -282,17 +305,6 @@ def CGGI_MultiLutLinCombOp : CGGI_Op<"multi_lut_lincomb", [
let hasVerifier = 1;
}

def CGGI_SubOp : CGGI_Op<"sub", [
Pure,
SameOperandsAndResultType,
ElementwiseMappable,
Scalarizable
]> {
let arguments = (ins LWECiphertextLike:$lhs, LWECiphertextLike:$rhs);
let results = (outs LWECiphertextLike:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($output))";
let summary = "Subtraction of two ciphertexts.";
}


def CGGI_ScalarShiftRightOp : CGGI_Op<"sshr", [
Expand Down
5 changes: 3 additions & 2 deletions lib/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "lib/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand Down Expand Up @@ -93,8 +94,8 @@ struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
Expand Down
Loading
Loading