Skip to content

Commit

Permalink
Working Quart to Tfhe rs + Change tfhe-rs and cggi shift ops
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 21, 2025
1 parent 4a11ab7 commit d83d267
Show file tree
Hide file tree
Showing 20 changed files with 271 additions and 154 deletions.
3 changes: 1 addition & 2 deletions docs/content/en/docs/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ Example input:
func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 {
%v1 = tfhe_rust.apply_lookup_table %sks, %input, %lut : (!sks, !eui3, !lut) -> !eui3
%v2 = tfhe_rust.add %sks, %input, %v1 : (!sks, !eui3, !eui3) -> !eui3
%c1 = arith.constant 1 : i8
%v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3
%v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3
%v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3
return %v4 : !eui3
}
Expand Down
23 changes: 8 additions & 15 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,10 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
.getSExtValue();

auto inputValue =
mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
Expand All @@ -157,14 +155,12 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto shiftAmount =
cast<IntegerAttr>(cteShiftSizeOp.getValue()).getValue().getSExtValue();

auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);
rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp);

return success();
}
Expand All @@ -184,10 +180,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<cggi::ShiftRightOp>(user);
});
return (isa<IndexType>(op.getValue().getType()) || (usedByShift));
return (isa<IndexType>(op.getValue().getType()));
});

target.addDynamicallyLegalOp<
Expand Down
221 changes: 174 additions & 47 deletions lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h"

#include <mlir/IR/MLIRContext.h>

#include <cstdint>

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
Expand All @@ -15,7 +11,9 @@
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir::heir::arith {

Expand Down Expand Up @@ -94,7 +92,7 @@ class ArithToCGGIQuartTypeConverter : public TypeConverter {
};

static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) {
auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth >> 1);
auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth);
auto intAttr = b.getIntegerAttr(maxWideIntType, value);

auto encoding =
Expand Down Expand Up @@ -153,19 +151,16 @@ static SmallVector<Value> extractLastDimHalves(

static Value createScalarOrSplatConstant(OpBuilder &builder, Location loc,
Type type, int64_t value) {
unsigned elementBitWidth = 0;
if (auto lweTy = dyn_cast<lwe::LWECiphertextType>(type))
elementBitWidth =
cast<lwe::UnspecifiedBitFieldEncodingAttr>(lweTy.getEncoding())
.getCleartextBitwidth();
else
elementBitWidth = maxIntWidth;
// unsigned elementBitWidth = 0;
// if (auto lweTy = dyn_cast<lwe::LWECiphertextType>(type))
// elementBitWidth =
// cast<lwe::UnspecifiedBitFieldEncodingAttr>(lweTy.getEncoding())
// .getCleartextBitwidth();
// else
// elementBitWidth = maxIntWidth;

auto apValue = APInt(elementBitWidth, value);

auto maxWideIntType =
IntegerType::get(builder.getContext(), maxIntWidth >> 1);
auto intAttr = builder.getIntegerAttr(maxWideIntType, value);
auto intAttr = builder.getIntegerAttr(
IntegerType::get(builder.getContext(), maxIntWidth), value);

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}
Expand Down Expand Up @@ -249,6 +244,40 @@ struct ConvertQuartConstantOp
}
};

struct ConvertQuartTruncIOp
: public OpConversionPattern<mlir::arith::TruncIOp> {
ConvertQuartTruncIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto newResultTy = getTypeConverter()->convertType<RankedTensorType>(
op.getResult().getType());
auto newInTy =
getTypeConverter()->convertType<RankedTensorType>(op.getIn().getType());

SmallVector<OpFoldResult> offsets(newResultTy.getShape().size(),
rewriter.getIndexAttr(0));
offsets.back() = rewriter.getIndexAttr(newInTy.getShape().back() -
newResultTy.getShape().back());
SmallVector<OpFoldResult> sizes(newResultTy.getShape().size());
sizes.back() = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> strides(newResultTy.getShape().size(),
rewriter.getIndexAttr(1));

auto resOp = rewriter.create<tensor::ExtractSliceOp>(
op->getLoc(), adaptor.getIn(), offsets, sizes, strides);
rewriter.replaceOp(op, resOp);

return success();
}
};

template <typename ArithExtOp>
struct ConvertQuartExt final : OpConversionPattern<ArithExtOp> {
using OpConversionPattern<ArithExtOp>::OpConversionPattern;
Expand All @@ -274,23 +303,21 @@ struct ConvertQuartExt final : OpConversionPattern<ArithExtOp> {
auto resultChunks = newResultTy.getShape().back();
auto inChunks = newInTy.getShape().back();

if (resultChunks > inChunks) {
auto paddingFactor = resultChunks - inChunks;
// Through definition of ExtOp, paddingFactor is always positive
auto paddingFactor = resultChunks - inChunks;

SmallVector<OpFoldResult, 1> low, high;
low.push_back(rewriter.getIndexAttr(0));
high.push_back(rewriter.getIndexAttr(paddingFactor));
SmallVector<OpFoldResult, 1> low, high;
low.push_back(rewriter.getIndexAttr(0));
high.push_back(rewriter.getIndexAttr(paddingFactor));

auto padValue = createTrivialOpMaxWidth(b, 0);
auto padValue = createTrivialOpMaxWidth(b, 0);

auto resultVec = b.create<tensor::PadOp>(newResultTy, adaptor.getIn(),
low, high, padValue,
/*nofold=*/true);
auto resultVec = b.create<tensor::PadOp>(newResultTy, adaptor.getIn(), low,
high, padValue,
/*nofold=*/true);

rewriter.replaceOp(op, resultVec);
return success();
}
return failure();
rewriter.replaceOp(op, resultVec);
return success();
}
};

Expand Down Expand Up @@ -318,14 +345,15 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {

// Actual type of the underlying elements; we use half the width.
// Create Constant
auto intAttr = IntegerAttr::get(rewriter.getI8Type(), maxIntWidth >> 1);
auto shiftAttr =
IntegerAttr::get(rewriter.getIndexType(), maxIntWidth >> 1);

auto elemType = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth), op->getContext());
auto realTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext());

auto constantOp = b.create<mlir::arith::ConstantOp>(intAttr);
// auto constantOp = b.create<mlir::arith::ConstantOp>(intAttr);

SmallVector<Value> carries;
SmallVector<Value> outputs;
Expand All @@ -338,7 +366,8 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {

// Now all the outputs are 16b elements, wants presentation of 4x8b
if (i != splitLhs.size() - 1) {
auto carry = b.create<cggi::ShiftRightOp>(elemType, lowSum, constantOp);
auto carry =
b.create<cggi::ScalarShiftRightOp>(elemType, lowSum, shiftAttr);
carries.push_back(carry);
}

Expand All @@ -356,6 +385,103 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
}
};

// Implemented using the Karatsuba algorithm
// https://en.wikipedia.org/wiki/Karatsuba_algorithm#Algorithm
struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::MulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(loc, rewriter);

auto newTy =
getTypeConverter()->convertType<RankedTensorType>(op.getType());
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", op.getType()));
if (newTy.getShape().back() != 4)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("Mul only support 4 split elements. Shape: {0}",
newTy));

auto elemTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth), op->getContext());
auto realTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext());

// Create Constant
auto shiftAttr =
rewriter.getIntegerAttr(b.getIndexType(), maxIntWidth >> 1);

SmallVector<Value> splitLhs =
extractLastDimHalves(rewriter, loc, adaptor.getLhs());
SmallVector<Value> splitRhs =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());

// TODO: Implement the real Karatsuba algorithm for 4x4 multiplication.
// First part of Karatsuba algorithm
auto z00 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(z1a1, z1b1);

// Output are now all 16b elements, wants presentation of 4x8b
auto output0Lsb = b.create<cggi::CastOp>(realTy, z00);
auto output0LsbHigh = b.create<cggi::CastOp>(elemTy, output0Lsb);
auto output0Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, z00, shiftAttr);

auto output1Lsb = b.create<cggi::CastOp>(realTy, z01);
auto output1LsbHigh = b.create<cggi::CastOp>(elemTy, output1Lsb);
auto output1Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, z01, shiftAttr);

auto output2Lsb = b.create<cggi::CastOp>(realTy, out2Carry);
auto output2LsbHigh = b.create<cggi::CastOp>(elemTy, output2Lsb);
auto output2Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, out2Carry, shiftAttr);

auto output3Lsb = b.create<cggi::CastOp>(realTy, out3Carry);
auto output3LsbHigh = b.create<cggi::CastOp>(elemTy, output3Lsb);

auto output1 = b.create<cggi::AddOp>(output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(output3LsbHigh, output2Msb);

Value resultVec = constructResultTensor(
rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3});
rewriter.replaceOp(op, resultVec);
return success();
}
};

struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -386,28 +512,29 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {

target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<cggi::ShiftRightOp>(user);
});
return (isa<IndexType>(op.getValue().getType()) || (usedByShift));
return isa<IndexType>(op.getValue().getType());
});

patterns.add<
ConvertQuartConstantOp, ConvertQuartExt<mlir::arith::ExtUIOp>,
ConvertQuartExt<mlir::arith::ExtSIOp>, ConvertQuartAddI,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp>>(
typeConverter, context);
patterns
.add<ConvertQuartConstantOp, ConvertQuartExt<mlir::arith::ExtUIOp>,
ConvertQuartExt<mlir::arith::ExtSIOp>, ConvertQuartAddI,
ConvertQuartMulI, ConvertAny<memref::LoadOp>,
ConvertAny<memref::AllocOp>, ConvertAny<memref::DeallocOp>,
ConvertAny<memref::StoreOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<tensor::FromElementsOp>,
ConvertAny<tensor::ExtractOp>, ConvertAny<affine::AffineStoreOp>,
ConvertAny<affine::AffineLoadOp>>(typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}

// Remove the uncessary tensor ops between each converted arith operation.
OpPassManager pipeline("builtin.module");
pipeline.addPass(createCSEPass());
(void)runPipeline(pipeline, getOperation());
}
};

Expand Down
Loading

0 comments on commit d83d267

Please sign in to comment.