Skip to content

Commit 15757ce

Browse files
Enable poly-approx based lowering for arith.cmpi
This commit enables the existing polynomial approximation for the default arithmetization pipeline, adds a lowering from arith.cmpi/cmpf to an arithmetic expression using sign(x), which is then lowered using the polynomial approximation framework. Note that this currently only supports a < b and a > b comparions, but not ==,<= or >= (since equality requires a different strategy rather than "just" sign(x)). Also, the lowering currently only makes sense for CKKS, as the sign(x)-based approach requires a (scalar/plaintext) division by half / multiplication with 0.5. Towards the goal of enabling arith.cmpi/cmpf, specifically those arising from non-data-oblivious high-level code, this commit includes a variety of "enablement" changes/fixes: * Improves `--convert-secret-for-to-static-for` pass This now handles various edge cases correctly, such as dynamic (but not secret) lower/upper bounds, scf.for bounds that are signless integers rather than index type, and correctly refuses to translate an scf.for with dynamic step value (which scf.for allows but affine.for does not). The pass now also by default converts all scf.for (even non-secret ones) to affine.for, as the rest of the piepline cannot handle nested scf.for even if the inner loop is not secret-dependent. This can be toggled off using the `convert-all-scf-for` flag on the pass (default = true). * `--add-client-interface` now also works for multiple functions Specifically, it adjusts the insertion point logic so that __encrypt/etc helpers are emitted directly after the function in question. Apparently, this is already enough to avoid the pass trying to process one of the helper functions it added itself, though it is unclear whether this is a stable/guaranteed behavior. * Adds support for `IndexType` in `ConvertToCiphertextSemantics` * In Layout assignment, assigns layout for index type operands of arith.cmpi * Creates a new `math_ext` dialect and adds `math_ext.sign` op * Adds the lowerings for arith.cmpi/cmpf -> math.sign * Adds polynomial-approximation passes to the arithmetization Pipelines
1 parent 61fbc2d commit 15757ce

File tree

33 files changed

+766
-79
lines changed

33 files changed

+766
-79
lines changed

lib/Dialect/MathExt/IR/BUILD

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# MathExt dialect implementation
2+
3+
load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library")
4+
load("@llvm-project//mlir:tblgen.bzl", "td_library")
5+
6+
package(
7+
default_applicable_licenses = ["@heir//:license"],
8+
default_visibility = ["//visibility:public"],
9+
)
10+
11+
cc_library(
12+
name = "Dialect",
13+
srcs = [
14+
"MathExtDialect.cpp",
15+
],
16+
hdrs = [
17+
"MathExtDialect.h",
18+
"MathExtOps.h",
19+
],
20+
deps = [
21+
"dialect_inc_gen",
22+
"ops_inc_gen",
23+
":MathExtOps",
24+
"@llvm-project//llvm:Support",
25+
"@llvm-project//mlir:IR",
26+
],
27+
)
28+
29+
cc_library(
30+
name = "MathExtOps",
31+
srcs = [
32+
"MathExtOps.cpp",
33+
],
34+
hdrs = [
35+
"MathExtDialect.h",
36+
"MathExtOps.h",
37+
],
38+
deps = [
39+
":dialect_inc_gen",
40+
":ops_inc_gen",
41+
"@llvm-project//mlir:IR",
42+
"@llvm-project//mlir:InferTypeOpInterface",
43+
],
44+
)
45+
46+
td_library(
47+
name = "td_files",
48+
srcs = [
49+
"MathExtDialect.td",
50+
"MathExtOps.td",
51+
],
52+
# include from the heir-root to enable fully-qualified include-paths
53+
includes = ["../../../.."],
54+
deps = [
55+
"@llvm-project//mlir:BuiltinDialectTdFiles",
56+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
57+
"@llvm-project//mlir:OpBaseTdFiles",
58+
],
59+
)
60+
61+
add_heir_dialect_library(
62+
name = "dialect_inc_gen",
63+
dialect = "MathExt",
64+
kind = "dialect",
65+
td_file = "MathExtDialect.td",
66+
deps = [
67+
":td_files",
68+
],
69+
)
70+
71+
add_heir_dialect_library(
72+
name = "ops_inc_gen",
73+
dialect = "MathExt",
74+
kind = "op",
75+
td_file = "MathExtOps.td",
76+
deps = [
77+
":td_files",
78+
],
79+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "lib/Dialect/MathExt/IR/MathExtDialect.h"
2+
3+
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
4+
5+
// NOLINTNEXTLINE(misc-include-cleaner): Required to define MathExtOps
6+
7+
#include "lib/Dialect/MathExt/IR/MathExtOps.h"
8+
9+
// Generated definitions
10+
#include "lib/Dialect/MathExt/IR/MathExtDialect.cpp.inc"
11+
12+
#define GET_OP_CLASSES
13+
#include "lib/Dialect/MathExt/IR/MathExtOps.cpp.inc"
14+
15+
namespace mlir {
16+
namespace heir {
17+
namespace math_ext {
18+
19+
void MathExtDialect::initialize() {
20+
addOperations<
21+
#define GET_OP_LIST
22+
#include "lib/Dialect/MathExt/IR/MathExtOps.cpp.inc"
23+
>();
24+
}
25+
26+
} // namespace math_ext
27+
} // namespace heir
28+
} // namespace mlir
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_H_
2+
#define LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_H_
3+
4+
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
5+
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
6+
7+
// Generated headers (block clang-format from messing up order)
8+
#include "lib/Dialect/MathExt/IR/MathExtDialect.h.inc"
9+
10+
#endif // LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_H_
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_TD_
2+
#define LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_TD_
3+
4+
include "mlir/IR/DialectBase.td"
5+
include "mlir/IR/OpBase.td"
6+
7+
def MathExt_Dialect : Dialect {
8+
let name = "math_ext";
9+
let description = [{
10+
Math-related operations we require for HEIR
11+
which do not (yet) exist in upstream `math`.
12+
}];
13+
14+
let cppNamespace = "::mlir::heir::math_ext";
15+
}
16+
17+
#endif // LIB_DIALECT_MATHEXT_IR_MATHEXTDIALECT_TD_
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "lib/Dialect/MathExt/IR/MathExtOps.h"
2+
3+
namespace mlir {
4+
namespace heir {
5+
namespace math_ext {} // namespace math_ext
6+
} // namespace heir
7+
} // namespace mlir
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_H_
2+
#define LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_H_
3+
4+
#include "lib/Dialect/MathExt/IR/MathExtDialect.h"
5+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
6+
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project-project
7+
8+
#define GET_OP_CLASSES
9+
#include "lib/Dialect/MathExt/IR/MathExtOps.h.inc"
10+
11+
#endif // LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_H_
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_TD_
2+
#define LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_TD_
3+
4+
include "lib/Dialect/MathExt/IR/MathExtDialect.td"
5+
include "mlir/Interfaces/SideEffectInterfaces.td"
6+
include "mlir/Interfaces/InferTypeOpInterface.td"
7+
include "mlir/IR/OpBase.td"
8+
9+
class MathExt_Op<string mnemonic, list<Trait> traits = []> :
10+
Op<MathExt_Dialect, mnemonic, traits> {
11+
let cppNamespace = "::mlir::heir::math_ext";
12+
}
13+
14+
def MathExt_SignOp : MathExt_Op<"sign", [Pure, ElementwiseMappable,SameOperandsAndResultType]> {
15+
let summary = "Returns the sign of the input value";
16+
let description = [{
17+
Returns -1 if the input is negative, 0 if it is zero, and 1 if it is positive.
18+
The behavior is undefined for NaN inputs.
19+
}];
20+
let arguments = (ins SignlessIntegerOrFloatLike:$value);
21+
let results = (outs SignlessIntegerOrFloatLike:$result);
22+
let assemblyFormat = "$value attr-dict `:` type($result)";
23+
}
24+
#endif // LIB_DIALECT_MATHEXT_IR_MATHEXTOPS_TD_

lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,14 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
374374
lwe::ReinterpretApplicationDataOp>,
375375
SecretGenericOpConversion<arith::ExtUIOp,
376376
lwe::ReinterpretApplicationDataOp>,
377+
SecretGenericOpConversion<arith::FPToSIOp,
378+
lwe::ReinterpretApplicationDataOp>,
379+
SecretGenericOpConversion<arith::FPToUIOp,
380+
lwe::ReinterpretApplicationDataOp>,
381+
SecretGenericOpConversion<arith::SIToFPOp,
382+
lwe::ReinterpretApplicationDataOp>,
383+
SecretGenericOpConversion<arith::UIToFPOp,
384+
lwe::ReinterpretApplicationDataOp>,
377385
SecretGenericOpConversion<arith::MulFOp, ckks::MulOp>,
378386
SecretGenericOpConversion<arith::MulIOp, ckks::MulOp>,
379387
SecretGenericOpConversion<arith::SubFOp, ckks::SubOp>,

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,18 @@
2727
#include "lib/Pipelines/PipelineRegistration.h"
2828
#include "lib/Transforms/AddClientInterface/AddClientInterface.h"
2929
#include "lib/Transforms/ApplyFolders/ApplyFolders.h"
30+
#include "lib/Transforms/CompareToSignRewrite/CompareToSignRewrite.h"
3031
#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h"
3132
#include "lib/Transforms/DropUnitDims/DropUnitDims.h"
3233
#include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h"
3334
#include "lib/Transforms/GenerateParam/GenerateParam.h"
3435
#include "lib/Transforms/LayoutOptimization/LayoutOptimization.h"
3536
#include "lib/Transforms/LayoutPropagation/LayoutPropagation.h"
3637
#include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h"
38+
#include "lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.h"
3739
#include "lib/Transforms/OperationBalancer/OperationBalancer.h"
3840
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
41+
#include "lib/Transforms/PolynomialApproximation/PolynomialApproximation.h"
3942
#include "lib/Transforms/PopulateScale/PopulateScale.h"
4043
#include "lib/Transforms/PropagateAnnotation/PropagateAnnotation.h"
4144
#include "lib/Transforms/SecretInsertMgmt/Passes.h"
@@ -123,6 +126,9 @@ void mlirToSecretArithmeticPipelineBuilder(
123126
pm.addPass(createWrapGeneric());
124127
convertToDataObliviousPipelineBuilder(pm);
125128
pm.addPass(createSelectRewrite());
129+
pm.addPass(createCompareToSignRewrite());
130+
pm.addPass(createPolynomialApproximation());
131+
pm.addPass(createLowerPolynomialEval());
126132
pm.addPass(createCanonicalizerPass());
127133
pm.addPass(createCSEPass());
128134

lib/Pipelines/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@ cc_library(
1616
"@heir//lib/Dialect/Polynomial/Conversions/PolynomialToModArith",
1717
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
1818
"@heir//lib/Dialect/TOSA/Conversions/TosaToSecretArith",
19+
"@heir//lib/Transforms/CompareToSignRewrite",
1920
"@heir//lib/Transforms/ConvertIfToSelect",
2021
"@heir//lib/Transforms/ConvertSecretExtractToStaticExtract",
2122
"@heir//lib/Transforms/ConvertSecretForToStaticFor",
2223
"@heir//lib/Transforms/ConvertSecretInsertToStaticInsert",
2324
"@heir//lib/Transforms/ConvertSecretWhileToStaticFor",
2425
"@heir//lib/Transforms/ElementwiseToAffine",
26+
"@heir//lib/Transforms/LowerPolynomialEval",
2527
"@heir//lib/Transforms/MemrefToArith:ExpandCopy",
2628
"@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration",
29+
"@heir//lib/Transforms/PolynomialApproximation",
2730
"@llvm-project//mlir:AffineToStandard",
2831
"@llvm-project//mlir:AffineTransforms",
2932
"@llvm-project//mlir:ArithTransforms",

0 commit comments

Comments
 (0)