diff --git a/Sources/HomomorphicEncryption/Array2d.swift b/Sources/HomomorphicEncryption/Array2d.swift index 51d479b5..f3f31f7b 100644 --- a/Sources/HomomorphicEncryption/Array2d.swift +++ b/Sources/HomomorphicEncryption/Array2d.swift @@ -119,7 +119,7 @@ extension Array2d { throw HeError.invalidRotationParameter(range: range, columnCount: data.count) } - let effectiveStep = step.toRemainder(range) + let effectiveStep = step.toRemainder(range, variableTime: true) for index in stride(from: 0, to: data.count, by: range) { let replacement = data[index + effectiveStep..: Equatable, Sendable { +public struct Modulus: Equatable, Sendable { /// The maximum valid modulus value. - @usableFromInline static var max: T { + public static var max: T { ReduceModulus.max } @@ -31,15 +30,16 @@ struct Modulus: Equatable, Sendable { /// `ceil(2^k / modulus) - 2^(2 * T.bitWidth)` for /// `k = 2 * T.bitWidth + ceil(log2(modulus)`. @usableFromInline let divisionModulus: DivisionModulus - @usableFromInline let modulus: T + /// The modulus, `p`. + public let modulus: T /// Initializes a ``Modulus``. /// - Parameters: - /// - modulus: Modulus. + /// - modulus: Modulus. Must be less than ``Modulus/max``. /// - variableTime: Must be `true`, indicating `modulus` is leaked through timing. /// - Warning: Leaks `modulus` through timing. @inlinable - init(modulus: T, variableTime: Bool) { + public init(modulus: T, variableTime: Bool) { precondition(variableTime) self.singleWordModulus = ReduceModulus( modulus: modulus, @@ -57,21 +57,35 @@ struct Modulus: Equatable, Sendable { self.modulus = modulus } + /// Performs modular reduction with modulus `p`. + /// - Parameter x: Value to reduce. + /// - Returns: `x mod p` in `[0, p).` @inlinable - func reduce(_ x: T) -> T { + public func reduce(_ x: T) -> T { singleWordModulus.reduce(x) } + /// Performs modular reduction with modulus `p`. + /// - Parameter x: Value to reduce. + /// - Returns: `x mod p` in `[0, p).` @inlinable - func reduce(_ x: T.DoubleWidth) -> T { + public func reduce(_ x: T.SignedScalar) -> T { + singleWordModulus.reduce(x) + } + + /// Performs modular reduction with modulus `p`. + /// - Parameter x: Value to reduce. + /// - Returns: `x mod p` in `[0, p).` + @inlinable + public func reduce(_ x: T.DoubleWidth) -> T { doubleWordModulus.reduce(x) } /// Performs modular reduction with modulus `p`. /// - Parameter x: Must be `< p^2`. - /// - Returns: `x mod p` for `p`. + /// - Returns: `x mod p` in `[0, p).` @inlinable - func reduceProduct(_ x: T.DoubleWidth) -> T { + public func reduceProduct(_ x: T.DoubleWidth) -> T { reduceProductModulus.reduceProduct(x) } @@ -81,7 +95,7 @@ struct Modulus: Equatable, Sendable { /// - y: Must be `< p`. /// - Returns: `x * y mod p`. @inlinable - func multiplyMod(_ x: T, _ y: T) -> T { + public func multiplyMod(_ x: T, _ y: T) -> T { precondition(x < modulus) precondition(y < modulus) let product = x.multipliedFullWidth(by: y) @@ -92,7 +106,7 @@ struct Modulus: Equatable, Sendable { /// - Parameter dividend: Number to divide. /// - Returns: `dividend / modulus`, rounded down to the next integer. @inlinable - func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth { + public func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth { divisionModulus.dividingFloor(by: dividend) } } @@ -153,7 +167,7 @@ struct ReduceModulus: Equatable, Sendable { /// The maximum valid modulus value. @usableFromInline static var max: T { - // Constrained by `reduceProduct` + // Constrained by `reduceProduct` and `reduce(_ x: T.SignedScalar)` (T(1) << (T.bitWidth - 2)) - 1 } @@ -161,7 +175,12 @@ struct ReduceModulus: Equatable, Sendable { @usableFromInline let shift: Int /// Barrett factor. @usableFromInline let factor: T.DoubleWidth + /// The modulus, `p`. @usableFromInline let modulus: T + /// `modulus.previousPowerOfTwo`. + @usableFromInline let modulusPreviousPowerOfTwo: T + /// `round(2^{log2(p) - 1) * 2^{T.bitWidth} / p)`. + @usableFromInline let signedFactor: T.SignedScalar /// Performs pre-computation for fast modular reduction. /// - Parameters: @@ -174,12 +193,25 @@ struct ReduceModulus: Equatable, Sendable { precondition(variableTime) precondition(modulus <= Self.max) self.modulus = modulus + self.modulusPreviousPowerOfTwo = modulus.previousPowerOfTwo switch bound { case .SingleWord: self.shift = T.bitWidth - let numerator = T.DoubleWidth(1) << shift - // 2^T.bitwidth // p - self.factor = numerator / T.DoubleWidth(modulus) + // floor(2^T.bitwidth / p) + self.factor = T.DoubleWidth((high: 1, low: 0)) / T.DoubleWidth(modulus) + if modulus.isPowerOfTwo { + // This should actually be `T.SignedScalar.max + 1`, but this works too. + // See `reduce(_ x: T.SignedScalar)` for more information. + self.signedFactor = T.SignedScalar.max + } else { + // We compute `round(2^{log2(p) - 1} * 2^{T.bitWidth} / p)` by noting + // `2^{log2(p)} = q.previousPowerOfTwo`, and `round(x/p) = floor(x + floor(p/2) / p)`. + let numerator = T.DoubleWidth((high: modulus.previousPowerOfTwo >> 1, low: T.Magnitude(modulus) >> 1)) + // Guaranteed to fit into single word, since `2^{log2(p) - 1) / p < 1/2` for `p` not a power of 2, + // which implies `signedFactor < 2^{T.bitWidth} / 2` + self.signedFactor = T.SignedScalar((numerator / T.DoubleWidth(modulus)).low) + } + case .DoubleWord: self.shift = 2 * T.bitWidth self.factor = if modulus.isPowerOfTwo { @@ -188,11 +220,14 @@ struct ReduceModulus: Equatable, Sendable { // floor(2^{2 * t} / p) == floor((2^{2 * t} - 1) / p) for p not a power of two T.DoubleWidth.max / T.DoubleWidth(modulus) } + self.signedFactor = 0 // Unused + case .ModulusSquared: let reduceModulusAlpha = T.bitWidth - 2 self.shift = modulus.significantBitCount + reduceModulusAlpha let numerator = T.DoubleWidth(1) << shift self.factor = numerator / T.DoubleWidth(modulus) + self.signedFactor = 0 // Unused } } @@ -217,8 +252,43 @@ struct ReduceModulus: Equatable, Sendable { return z.subtractIfExceeds(modulus) } + /// Returns `x mod p` in `[0, p)` for signed integer `x`. + /// + /// Requires the modulus `p` to satisfy `p < 2^{T.bitWidth - 2}`. + /// See Algorithm 5 from . + /// The proof of Lemma 4 still goes through for odd moduli `q < 2^{T.bitWidth - 2}`, by using the bound + /// `floor(2^k \beta / q) >= 2^k \beta / q - 1`, rather than + /// `floor(2^k \beta / q) >= 2^k \beta / q - 1/2`. + /// For a `q` a power of two, the `signedFactor` is off by one (`2^{T.bitWidth} - 1` instead of `2^{T.bitWidth}`), + /// so we provide a quick proof of correctness in this case. + /// Using notation from the proof of Lemma 4 of , and assuming `a >= 0`, + /// we have `2^k = q / 2`, so `v = floor(2^k β / q) = β / 2`. Since we are using `v - 1` instead of `v`, we have + /// `r = a - q * floor(a * (v - 1) / (2^k β))`. Using `floor(x) >= x - 1`, we have + /// `<= a - q * (a * (v - 1) / (2^k β)) + q`. Using `v = β / 2` and `2^k = q / 2`, we have + /// `= a - q * (a β / 2 - a) / (β q / 2) + q` + /// `= a - a + q a / (β q / 2) + q` + /// `= a / (β / 2) + q` + /// `< 1 + q` for `a < β / 2`. + /// Since we use `v - 1` instead of `v`, the result can only be larger than as Algorithm 5 is written. + /// Hence, the lower bound `r > -1` from the proof of Lemma 4 still holds. + /// Since `r < q + 1`, `r > -1`, and `r` is integral, we have `r in [0, q]`. + /// The final `subtractIfExceeds` ensures `r in [0, q - 1]`. + /// + /// The proof follows analagously for `a < 0`. + /// + /// - Parameter x: Value to reduce. + /// - Returns: `x mod p` in `[0, p)`. + @inlinable + func reduce(_ x: T.SignedScalar) -> T { + assert(shift == T.bitWidth) + var t = x.multiplyHigh(signedFactor) >> (modulus.log2 - 1) + t = t &* T.SignedScalar(modulus) + return T(x &- t).subtractIfExceeds(modulus) + } + /// Returns `x mod p`. /// + /// Requires modulus `p < 2^{T.bitWidth - 1}`. /// Useful when `x >= p^2`, otherwise use `` reduceProduct``. /// Proof of correctness: /// Let `t = T.bitWidth` @@ -234,7 +304,7 @@ struct ReduceModulus: Equatable, Sendable { /// Adding (3) and (4) yields /// `0 <= x - q * p < x * p / 2^{2 * t} + p < 2 * p`. /// - /// Note, the bound on `p < 2^63` comes from `2 * p < T.max` + /// Note, the bound on `p < 2^{t - 1}` comes from `2 * p < 2^t` @inlinable func reduce(_ x: T.DoubleWidth) -> T { assert(shift == x.bitWidth) diff --git a/Sources/HomomorphicEncryption/Scalar.swift b/Sources/HomomorphicEncryption/Scalar.swift index bf712fd2..468c7fe0 100644 --- a/Sources/HomomorphicEncryption/Scalar.swift +++ b/Sources/HomomorphicEncryption/Scalar.swift @@ -61,13 +61,22 @@ extension SignedScalarType { return Self(bitPattern: result) } + /// Computes the high `Self.bitWidth` bits of `self * rhs`. + /// - Parameter rhs: Multiplicand. + /// - Returns: the high `Self.bitWidth` bits of `self * rhs`. + @inlinable + public func multiplyHigh(_ rhs: Self) -> Self { + multipliedFullWidth(by: rhs).high + } + /// Constant-time centered-to-remainder conversion. /// - Parameter modulus: Modulus. - /// - Returns: Given `self` in `[-floor(modulus/2), floor(modulus-1)/2]`, returns `self % modulus` in `[0, - /// modulus)`. - /// - Throws: Error upon failure to encode. + /// - Returns: Given `self` in `[-floor(modulus/2), floor((modulus-1)/2)]`, + /// returns `self % modulus` in `[0, modulus)`. @inlinable - public func centeredToRemainder(modulus: some ScalarType) throws -> Self.UnsignedScalar { + public func centeredToRemainder(modulus: some ScalarType) -> Self.UnsignedScalar { + assert(self <= (Self(modulus) - 1) / 2) + assert(self >= -Self(modulus) / 2) let condition = Self.UnsignedScalar(bitPattern: self >> (bitWidth - 1)) let thenValue = Self.UnsignedScalar(bitPattern: self &+ Self(bitPattern: Self.UnsignedScalar(modulus))) let elseValue = Self.UnsignedScalar(bitPattern: self) @@ -198,7 +207,7 @@ extension FixedWidthInteger { } extension ScalarType { - /// Computes the high bits `Self.bitWidth` of `self * rhs`. + /// Computes the high `Self.bitWidth` bits of `self * rhs`. /// - Parameter rhs: Multiplicand. /// - Returns: the high `Self.bitWidth` bits of `self * rhs`. @inlinable @@ -390,6 +399,14 @@ extension FixedWidthInteger { return 1 &<< ((self &- 1).log2 &+ 1) } + /// The next power of two greater than or equal to this value. + /// + /// This value must be positive. + @inlinable public var previousPowerOfTwo: Self { + precondition(self > 0) + return 1 &<< (Self.bitWidth &- 1 - leadingZeroBitCount) + } + /// Computes a modular multiplication. /// /// Is not constant time. Use `ReduceModulus` for a constant-time alternative, which is also faster when the modulus @@ -629,7 +646,7 @@ extension ScalarType { /// - Parameter modulus: Modulus. /// - Returns: Given `self` in `[0,modulus)`, returns `self % modulus` in `[-floor(modulus/2), floor(modulus-1)/2]`. @inlinable - func remainderToCentered(modulus: Self) -> Self.SignedScalar { + public func remainderToCentered(modulus: Self) -> Self.SignedScalar { let condition = constantTimeGreaterThan((modulus - 1) >> 1) let thenValue = Self.SignedScalar(self) - Self.SignedScalar(bitPattern: modulus) let elseValue = Self.SignedScalar(bitPattern: self) diff --git a/Sources/HomomorphicEncryption/Util.swift b/Sources/HomomorphicEncryption/Util.swift index d8f29958..9355ab2a 100644 --- a/Sources/HomomorphicEncryption/Util.swift +++ b/Sources/HomomorphicEncryption/Util.swift @@ -43,7 +43,8 @@ extension Sequence { extension FixedWidthInteger { // not a constant time operation @inlinable - func toRemainder(_ mod: Self) -> Self { + func toRemainder(_ mod: Self, variableTime: Bool) -> Self { + precondition(variableTime) precondition(mod > 0) var result = self % mod if result < 0 { diff --git a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift index 26504e29..441ae3ee 100644 --- a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift +++ b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift @@ -137,19 +137,61 @@ public struct PlaintextMatrix: Equatable, /// - context: Parameter context to encode the data with. /// - dimensions: Plaintext matrix dimensions. /// - packing: The packing with which the data is stored. - /// - values: The data values to store in the plaintext matrix; stored in row-major format. + /// - signedValues: The signed data values to store in the plaintext matrix; stored in row-major format. + /// - reduce: If true, values are reduced into the correct range before encoding. /// - Throws: Error upon failure to create the plaitnext matrix. @inlinable public init( context: Context, dimensions: MatrixDimensions, packing: MatrixPacking, - values: [some ScalarType]) throws + signedValues: [Scheme.SignedScalar], + reduce: Bool = false) throws where Format == Coeff + { + let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true) + let centeredValues = if reduce { + signedValues.map { value in + Scheme.Scalar(modulus.reduce(value)) + } + } else { + signedValues.map { value in + Scheme.Scalar(value.centeredToRemainder(modulus: modulus.modulus)) + } + } + try self.init( + context: context, + dimensions: dimensions, + packing: packing, + values: centeredValues, + reduce: false) + } + + /// Creates a new plaintext matrix. + /// - Parameters: + /// - context: Parameter context to encode the data with. + /// - dimensions: Plaintext matrix dimensions. + /// - packing: The packing with which the data is stored. + /// - values: The data values to store in the plaintext matrix; stored in row-major format. + /// - reduce: If true, values are reduced into the correct range before encoding. + /// - Throws: Error upon failure to create the plaitnext matrix. + @inlinable + init( + context: Context, + dimensions: MatrixDimensions, + packing: MatrixPacking, + values: [Scheme.Scalar], + reduce: Bool = false) throws where Format == Coeff { guard values.count == dimensions.count, !values.isEmpty else { throw PnnsError.wrongEncodingValuesCount(got: values.count, expected: values.count) } + var values = values + if reduce { + let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true) + values = values.map { value in modulus.reduce(value) } + } + switch packing { case .denseColumn: let plaintexts = try PlaintextMatrix.denseColumnPlaintexts( @@ -421,7 +463,7 @@ public struct PlaintextMatrix: Equatable, /// - Returns: The stored data values in row-major format. /// - Throws: Error upon failure to unpack the matrix. @inlinable - func unpack() throws -> [V] where Format == Coeff { + func unpack() throws -> [Scheme.Scalar] where Format == Coeff { switch packing { case .denseColumn: return try unpackDenseColumn() @@ -433,6 +475,17 @@ public struct PlaintextMatrix: Equatable, } } + /// Unpacks the plaintext matrix into signed values. + /// - Returns: The stored data values in row-major format. + /// - Throws: Error upon failure to unpack the matrix. + @inlinable + func unpack() throws -> [Scheme.SignedScalar] where Format == Coeff { + let unsigned: [Scheme.Scalar] = try unpack() + return unsigned.map { unsigned in + unsigned.remainderToCentered(modulus: context.plaintextModulus) + } + } + /// Unpacks a plaintext matrix with `denseColumn` packing. /// - Returns: The stored data values in row-major format. /// - Throws: Error upon failure to unpack the matrix. diff --git a/Tests/HomomorphicEncryptionTests/ScalarTests.swift b/Tests/HomomorphicEncryptionTests/ScalarTests.swift index c535f2fc..ae72a422 100644 --- a/Tests/HomomorphicEncryptionTests/ScalarTests.swift +++ b/Tests/HomomorphicEncryptionTests/ScalarTests.swift @@ -115,6 +115,16 @@ class ScalarTests: XCTestCase { XCTAssertEqual(4.nextPowerOfTwo, 4) } + func testPreviousPowerOfTwo() { + XCTAssertEqual(1.previousPowerOfTwo, 1) + XCTAssertEqual(2.previousPowerOfTwo, 2) + XCTAssertEqual(3.previousPowerOfTwo, 2) + XCTAssertEqual(4.previousPowerOfTwo, 4) + XCTAssertEqual(63.previousPowerOfTwo, 32) + XCTAssertEqual(64.previousPowerOfTwo, 64) + XCTAssertEqual(65.previousPowerOfTwo, 64) + } + func testNextMultiple() { XCTAssertEqual(0.nextMultiple(of: 0, variableTime: true), 0) XCTAssertEqual(0.nextMultiple(of: 7, variableTime: true), 0) @@ -266,6 +276,36 @@ class ScalarTests: XCTestCase { runReduceSingleWordTest(UInt64.self) } + func testReduceSignedSingleWord() { + func runReduceSingleWordTest(_: T.Type) { + func slowSignedReduce(of x: T.SignedScalar, mod modulus: T) -> T { + let remainder = x.quotientAndRemainder(dividingBy: T.SignedScalar(modulus)).remainder + return T(remainder < 0 ? remainder + T.SignedScalar(modulus) : remainder) + } + + for shift in 2..( + modulus: p, + bound: ReduceModulus.InputBound.SingleWord, + variableTime: true) + let pSigned = T.SignedScalar(p) + let x = T.SignedScalar.random(in: -pSigned / 2..(_: T.Type) { for shift in 2..(modulus: T) throws { - var remainders = try (-modulus / 2...((modulus - 1) / 2)).map { v in - let remainder = try v.centeredToRemainder(modulus: T.UnsignedScalar(modulus)) + func runTest(modulus: T) { + var remainders = (-modulus / 2...((modulus - 1) / 2)).map { v in + let remainder = v.centeredToRemainder(modulus: T.UnsignedScalar(modulus)) let centeredRoundtrip = remainder.remainderToCentered(modulus: T.UnsignedScalar(modulus)) XCTAssertEqual(centeredRoundtrip, v) return remainder @@ -458,18 +498,18 @@ class ScalarTests: XCTestCase { let expected: [T.UnsignedScalar] = Array(0..(modulus: T) throws { + func runTest(modulus: T) { let unsignedModulus = T.UnsignedScalar(modulus) let low: T = -modulus / 2 let high: T = (modulus - 1) / 2 let signedValues: [T] = [low, low + 1, low + 2, -1, 0, 1, high - 2, high - 1, high] - let signedRoundTrip = try signedValues.map { value in - try value.centeredToRemainder(modulus: unsignedModulus) + let signedRoundTrip = signedValues.map { value in + value.centeredToRemainder(modulus: unsignedModulus) }.map { value in value.remainderToCentered(modulus: unsignedModulus) } @@ -477,14 +517,14 @@ class ScalarTests: XCTestCase { let mid: T.UnsignedScalar = (unsignedModulus - 1) / 2 let values: [T.UnsignedScalar] = [0, 1, 2, mid - 1, mid, mid + 1, unsignedModulus - 2, unsignedModulus - 1] - let roundTrip = try values.map { value in + let roundTrip = values.map { value in value.remainderToCentered(modulus: unsignedModulus) }.map { value in - try value.centeredToRemainder(modulus: unsignedModulus) + value.centeredToRemainder(modulus: unsignedModulus) } XCTAssertEqual(values, roundTrip) } - try runTest(modulus: Int32(1 << 31 - 63)) - try runTest(modulus: Int64(1 << 62)) + runTest(modulus: Int32(1 << 31 - 63)) + runTest(modulus: Int64(1 << 62)) } } diff --git a/Tests/HomomorphicEncryptionTests/UtilTests.swift b/Tests/HomomorphicEncryptionTests/UtilTests.swift index cba70d44..600fa84e 100644 --- a/Tests/HomomorphicEncryptionTests/UtilTests.swift +++ b/Tests/HomomorphicEncryptionTests/UtilTests.swift @@ -38,12 +38,12 @@ class UtilTests: XCTestCase { } func testToRemainder() { - XCTAssertEqual((-8).toRemainder(7), 6) - XCTAssertEqual((-7).toRemainder(7), 0) - XCTAssertEqual((-6).toRemainder(7), 1) - XCTAssertEqual(6.toRemainder(7), 6) - XCTAssertEqual(7.toRemainder(7), 0) - XCTAssertEqual(8.toRemainder(7), 1) + XCTAssertEqual((-8).toRemainder(7, variableTime: true), 6) + XCTAssertEqual((-7).toRemainder(7, variableTime: true), 0) + XCTAssertEqual((-6).toRemainder(7, variableTime: true), 1) + XCTAssertEqual(6.toRemainder(7, variableTime: true), 6) + XCTAssertEqual(7.toRemainder(7, variableTime: true), 0) + XCTAssertEqual(8.toRemainder(7, variableTime: true), 1) } func testProduct() { diff --git a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift index b6c1b8b7..77acea8e 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift @@ -155,6 +155,47 @@ final class PlaintextMatrixTests: XCTestCase { let decoded: [Scheme.Scalar] = try plaintext.decode(format: .simd) XCTAssertEqual(decoded, expected) } + + // Test signed encoding/decoding + switch packing { + case .diagonal: // TODO: test .diagonal once implemented + break + default: + let signedValues: [Scheme.SignedScalar] = try plaintextMatrix.unpack() + let signedMatrix = try PlaintextMatrix( + context: context, + dimensions: dimensions, + packing: packing, + signedValues: signedValues) + let signedRoundtrip: [Scheme.SignedScalar] = try signedMatrix.unpack() + XCTAssertEqual(signedRoundtrip, signedValues) + + // Test modular reduction + let largerValues = encodeValues.flatMap { $0 }.map { $0 + t } + let largerSignedValues = signedValues.enumerated().map { index, value in + if index.isMultiple(of: 2) { + value + Scheme.SignedScalar(t) + } else { + value - Scheme.SignedScalar(t) + } + } + + let largerPlaintextMatrix = try PlaintextMatrix( + context: context, + dimensions: dimensions, + packing: packing, + values: largerValues, + reduce: true) + XCTAssertEqual(largerPlaintextMatrix, plaintextMatrix) + + let largerSignedMatrix = try PlaintextMatrix( + context: context, + dimensions: dimensions, + packing: packing, + signedValues: largerSignedValues, + reduce: true) + XCTAssertEqual(largerSignedMatrix, signedMatrix) + } } func testPlaintextMatrixDenseColumn() throws {