Skip to content

Commit

Permalink
Add modular reduction and signed encoding to PlaintextMatrix. (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer authored Aug 22, 2024
1 parent 8458194 commit 7affb01
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Sources/HomomorphicEncryption/Array2d.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ extension Array2d {
throw HeError.invalidRotationParameter(range: range, columnCount: data.count)
}

let effectiveStep = step.toRemainder(range)
let effectiveStep = step.toRemainder(range, variableTime: true)
for index in stride(from: 0, to: data.count, by: range) {
let replacement = data[index + effectiveStep..<index + range] + data[index..<index + effectiveStep]
data.replaceSubrange(index..<index + range, with: replacement)
Expand Down
2 changes: 1 addition & 1 deletion Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extension Context {
guard bounds.contains(Scheme.Scalar.SignedScalar(value)) else {
throw HeError.encodingDataOutOfBounds(for: bounds)
}
return try Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
return Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
}
return try encode(values: centeredValues, format: format)
}
Expand Down
104 changes: 87 additions & 17 deletions Sources/HomomorphicEncryption/Modulus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
/// Stores pre-computed data for efficient modular operations.
/// - Warning: The operations may leak the modulus through timing or other side channels. So this struct should only be
/// used for public moduli.
@usableFromInline
struct Modulus<T: ScalarType>: Equatable, Sendable {
public struct Modulus<T: ScalarType>: Equatable, Sendable {
/// The maximum valid modulus value.
@usableFromInline static var max: T {
public static var max: T {
ReduceModulus.max
}

Expand All @@ -31,15 +30,16 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
/// `ceil(2^k / modulus) - 2^(2 * T.bitWidth)` for
/// `k = 2 * T.bitWidth + ceil(log2(modulus)`.
@usableFromInline let divisionModulus: DivisionModulus<T>
@usableFromInline let modulus: T
/// The modulus, `p`.
public let modulus: T

/// Initializes a ``Modulus``.
/// - Parameters:
/// - modulus: Modulus.
/// - modulus: Modulus. Must be less than ``Modulus/max``.
/// - variableTime: Must be `true`, indicating `modulus` is leaked through timing.
/// - Warning: Leaks `modulus` through timing.
@inlinable
init(modulus: T, variableTime: Bool) {
public init(modulus: T, variableTime: Bool) {
precondition(variableTime)
self.singleWordModulus = ReduceModulus(
modulus: modulus,
Expand All @@ -57,21 +57,35 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
self.modulus = modulus
}

/// Performs modular reduction with modulus `p`.
/// - Parameter x: Value to reduce.
/// - Returns: `x mod p` in `[0, p).`
@inlinable
func reduce(_ x: T) -> T {
public func reduce(_ x: T) -> T {
singleWordModulus.reduce(x)
}

/// Performs modular reduction with modulus `p`.
/// - Parameter x: Value to reduce.
/// - Returns: `x mod p` in `[0, p).`
@inlinable
func reduce(_ x: T.DoubleWidth) -> T {
public func reduce(_ x: T.SignedScalar) -> T {
singleWordModulus.reduce(x)
}

/// Performs modular reduction with modulus `p`.
/// - Parameter x: Value to reduce.
/// - Returns: `x mod p` in `[0, p).`
@inlinable
public func reduce(_ x: T.DoubleWidth) -> T {
doubleWordModulus.reduce(x)
}

/// Performs modular reduction with modulus `p`.
/// - Parameter x: Must be `< p^2`.
/// - Returns: `x mod p` for `p`.
/// - Returns: `x mod p` in `[0, p).`
@inlinable
func reduceProduct(_ x: T.DoubleWidth) -> T {
public func reduceProduct(_ x: T.DoubleWidth) -> T {
reduceProductModulus.reduceProduct(x)
}

Expand All @@ -81,7 +95,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
/// - y: Must be `< p`.
/// - Returns: `x * y mod p`.
@inlinable
func multiplyMod(_ x: T, _ y: T) -> T {
public func multiplyMod(_ x: T, _ y: T) -> T {
precondition(x < modulus)
precondition(y < modulus)
let product = x.multipliedFullWidth(by: y)
Expand All @@ -92,7 +106,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
/// - Parameter dividend: Number to divide.
/// - Returns: `dividend / modulus`, rounded down to the next integer.
@inlinable
func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth {
public func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth {
divisionModulus.dividingFloor(by: dividend)
}
}
Expand Down Expand Up @@ -153,15 +167,20 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {

/// The maximum valid modulus value.
@usableFromInline static var max: T {
// Constrained by `reduceProduct`
// Constrained by `reduceProduct` and `reduce(_ x: T.SignedScalar)`
(T(1) << (T.bitWidth - 2)) - 1
}

/// Power used in computed Barrett factor.
@usableFromInline let shift: Int
/// Barrett factor.
@usableFromInline let factor: T.DoubleWidth
/// The modulus, `p`.
@usableFromInline let modulus: T
/// `modulus.previousPowerOfTwo`.
@usableFromInline let modulusPreviousPowerOfTwo: T
/// `round(2^{log2(p) - 1) * 2^{T.bitWidth} / p)`.
@usableFromInline let signedFactor: T.SignedScalar

/// Performs pre-computation for fast modular reduction.
/// - Parameters:
Expand All @@ -174,12 +193,25 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
precondition(variableTime)
precondition(modulus <= Self.max)
self.modulus = modulus
self.modulusPreviousPowerOfTwo = modulus.previousPowerOfTwo
switch bound {
case .SingleWord:
self.shift = T.bitWidth
let numerator = T.DoubleWidth(1) << shift
// 2^T.bitwidth // p
self.factor = numerator / T.DoubleWidth(modulus)
// floor(2^T.bitwidth / p)
self.factor = T.DoubleWidth((high: 1, low: 0)) / T.DoubleWidth(modulus)
if modulus.isPowerOfTwo {
// This should actually be `T.SignedScalar.max + 1`, but this works too.
// See `reduce(_ x: T.SignedScalar)` for more information.
self.signedFactor = T.SignedScalar.max
} else {
// We compute `round(2^{log2(p) - 1} * 2^{T.bitWidth} / p)` by noting
// `2^{log2(p)} = q.previousPowerOfTwo`, and `round(x/p) = floor(x + floor(p/2) / p)`.
let numerator = T.DoubleWidth((high: modulus.previousPowerOfTwo >> 1, low: T.Magnitude(modulus) >> 1))
// Guaranteed to fit into single word, since `2^{log2(p) - 1) / p < 1/2` for `p` not a power of 2,
// which implies `signedFactor < 2^{T.bitWidth} / 2`
self.signedFactor = T.SignedScalar((numerator / T.DoubleWidth(modulus)).low)
}

case .DoubleWord:
self.shift = 2 * T.bitWidth
self.factor = if modulus.isPowerOfTwo {
Expand All @@ -188,11 +220,14 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
// floor(2^{2 * t} / p) == floor((2^{2 * t} - 1) / p) for p not a power of two
T.DoubleWidth.max / T.DoubleWidth(modulus)
}
self.signedFactor = 0 // Unused

case .ModulusSquared:
let reduceModulusAlpha = T.bitWidth - 2
self.shift = modulus.significantBitCount + reduceModulusAlpha
let numerator = T.DoubleWidth(1) << shift
self.factor = numerator / T.DoubleWidth(modulus)
self.signedFactor = 0 // Unused
}
}

Expand All @@ -217,8 +252,43 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
return z.subtractIfExceeds(modulus)
}

/// Returns `x mod p` in `[0, p)` for signed integer `x`.
///
/// Requires the modulus `p` to satisfy `p < 2^{T.bitWidth - 2}`.
/// See Algorithm 5 from <https://eprint.iacr.org/2018/039.pdf>.
/// The proof of Lemma 4 still goes through for odd moduli `q < 2^{T.bitWidth - 2}`, by using the bound
/// `floor(2^k \beta / q) >= 2^k \beta / q - 1`, rather than
/// `floor(2^k \beta / q) >= 2^k \beta / q - 1/2`.
/// For a `q` a power of two, the `signedFactor` is off by one (`2^{T.bitWidth} - 1` instead of `2^{T.bitWidth}`),
/// so we provide a quick proof of correctness in this case.
/// Using notation from the proof of Lemma 4 of <https://eprint.iacr.org/2018/039.pdf>, and assuming `a >= 0`,
/// we have `2^k = q / 2`, so `v = floor(2^k β / q) = β / 2`. Since we are using `v - 1` instead of `v`, we have
/// `r = a - q * floor(a * (v - 1) / (2^k β))`. Using `floor(x) >= x - 1`, we have
/// `<= a - q * (a * (v - 1) / (2^k β)) + q`. Using `v = β / 2` and `2^k = q / 2`, we have
/// `= a - q * (a β / 2 - a) / (β q / 2) + q`
/// `= a - a + q a / (β q / 2) + q`
/// `= a / (β / 2) + q`
/// `< 1 + q` for `a < β / 2`.
/// Since we use `v - 1` instead of `v`, the result can only be larger than as Algorithm 5 is written.
/// Hence, the lower bound `r > -1` from the proof of Lemma 4 still holds.
/// Since `r < q + 1`, `r > -1`, and `r` is integral, we have `r in [0, q]`.
/// The final `subtractIfExceeds` ensures `r in [0, q - 1]`.
///
/// The proof follows analagously for `a < 0`.
///
/// - Parameter x: Value to reduce.
/// - Returns: `x mod p` in `[0, p)`.
@inlinable
func reduce(_ x: T.SignedScalar) -> T {
assert(shift == T.bitWidth)
var t = x.multiplyHigh(signedFactor) >> (modulus.log2 - 1)
t = t &* T.SignedScalar(modulus)
return T(x &- t).subtractIfExceeds(modulus)
}

/// Returns `x mod p`.
///
/// Requires modulus `p < 2^{T.bitWidth - 1}`.
/// Useful when `x >= p^2`, otherwise use `` reduceProduct``.
/// Proof of correctness:
/// Let `t = T.bitWidth`
Expand All @@ -234,7 +304,7 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
/// Adding (3) and (4) yields
/// `0 <= x - q * p < x * p / 2^{2 * t} + p < 2 * p`.
///
/// Note, the bound on `p < 2^63` comes from `2 * p < T.max`
/// Note, the bound on `p < 2^{t - 1}` comes from `2 * p < 2^t`
@inlinable
func reduce(_ x: T.DoubleWidth) -> T {
assert(shift == x.bitWidth)
Expand Down
29 changes: 23 additions & 6 deletions Sources/HomomorphicEncryption/Scalar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,22 @@ extension SignedScalarType {
return Self(bitPattern: result)
}

/// Computes the high `Self.bitWidth` bits of `self * rhs`.
/// - Parameter rhs: Multiplicand.
/// - Returns: the high `Self.bitWidth` bits of `self * rhs`.
@inlinable
public func multiplyHigh(_ rhs: Self) -> Self {
multipliedFullWidth(by: rhs).high
}

/// Constant-time centered-to-remainder conversion.
/// - Parameter modulus: Modulus.
/// - Returns: Given `self` in `[-floor(modulus/2), floor(modulus-1)/2]`, returns `self % modulus` in `[0,
/// modulus)`.
/// - Throws: Error upon failure to encode.
/// - Returns: Given `self` in `[-floor(modulus/2), floor((modulus-1)/2)]`,
/// returns `self % modulus` in `[0, modulus)`.
@inlinable
public func centeredToRemainder(modulus: some ScalarType) throws -> Self.UnsignedScalar {
public func centeredToRemainder(modulus: some ScalarType) -> Self.UnsignedScalar {
assert(self <= (Self(modulus) - 1) / 2)
assert(self >= -Self(modulus) / 2)
let condition = Self.UnsignedScalar(bitPattern: self >> (bitWidth - 1))
let thenValue = Self.UnsignedScalar(bitPattern: self &+ Self(bitPattern: Self.UnsignedScalar(modulus)))
let elseValue = Self.UnsignedScalar(bitPattern: self)
Expand Down Expand Up @@ -198,7 +207,7 @@ extension FixedWidthInteger {
}

extension ScalarType {
/// Computes the high bits `Self.bitWidth` of `self * rhs`.
/// Computes the high `Self.bitWidth` bits of `self * rhs`.
/// - Parameter rhs: Multiplicand.
/// - Returns: the high `Self.bitWidth` bits of `self * rhs`.
@inlinable
Expand Down Expand Up @@ -390,6 +399,14 @@ extension FixedWidthInteger {
return 1 &<< ((self &- 1).log2 &+ 1)
}

/// The next power of two greater than or equal to this value.
///
/// This value must be positive.
@inlinable public var previousPowerOfTwo: Self {
precondition(self > 0)
return 1 &<< (Self.bitWidth &- 1 - leadingZeroBitCount)
}

/// Computes a modular multiplication.
///
/// Is not constant time. Use `ReduceModulus` for a constant-time alternative, which is also faster when the modulus
Expand Down Expand Up @@ -629,7 +646,7 @@ extension ScalarType {
/// - Parameter modulus: Modulus.
/// - Returns: Given `self` in `[0,modulus)`, returns `self % modulus` in `[-floor(modulus/2), floor(modulus-1)/2]`.
@inlinable
func remainderToCentered(modulus: Self) -> Self.SignedScalar {
public func remainderToCentered(modulus: Self) -> Self.SignedScalar {
let condition = constantTimeGreaterThan((modulus - 1) >> 1)
let thenValue = Self.SignedScalar(self) - Self.SignedScalar(bitPattern: modulus)
let elseValue = Self.SignedScalar(bitPattern: self)
Expand Down
3 changes: 2 additions & 1 deletion Sources/HomomorphicEncryption/Util.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ extension Sequence {
extension FixedWidthInteger {
// not a constant time operation
@inlinable
func toRemainder(_ mod: Self) -> Self {
func toRemainder(_ mod: Self, variableTime: Bool) -> Self {
precondition(variableTime)
precondition(mod > 0)
var result = self % mod
if result < 0 {
Expand Down
59 changes: 56 additions & 3 deletions Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,61 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
/// - context: Parameter context to encode the data with.
/// - dimensions: Plaintext matrix dimensions.
/// - packing: The packing with which the data is stored.
/// - values: The data values to store in the plaintext matrix; stored in row-major format.
/// - signedValues: The signed data values to store in the plaintext matrix; stored in row-major format.
/// - reduce: If true, values are reduced into the correct range before encoding.
/// - Throws: Error upon failure to create the plaitnext matrix.
@inlinable
public init(
context: Context<Scheme>,
dimensions: MatrixDimensions,
packing: MatrixPacking,
values: [some ScalarType]) throws
signedValues: [Scheme.SignedScalar],
reduce: Bool = false) throws where Format == Coeff
{
let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true)
let centeredValues = if reduce {
signedValues.map { value in
Scheme.Scalar(modulus.reduce(value))
}
} else {
signedValues.map { value in
Scheme.Scalar(value.centeredToRemainder(modulus: modulus.modulus))
}
}
try self.init(
context: context,
dimensions: dimensions,
packing: packing,
values: centeredValues,
reduce: false)
}

/// Creates a new plaintext matrix.
/// - Parameters:
/// - context: Parameter context to encode the data with.
/// - dimensions: Plaintext matrix dimensions.
/// - packing: The packing with which the data is stored.
/// - values: The data values to store in the plaintext matrix; stored in row-major format.
/// - reduce: If true, values are reduced into the correct range before encoding.
/// - Throws: Error upon failure to create the plaitnext matrix.
@inlinable
init(
context: Context<Scheme>,
dimensions: MatrixDimensions,
packing: MatrixPacking,
values: [Scheme.Scalar],
reduce: Bool = false) throws
where Format == Coeff
{
guard values.count == dimensions.count, !values.isEmpty else {
throw PnnsError.wrongEncodingValuesCount(got: values.count, expected: values.count)
}
var values = values
if reduce {
let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true)
values = values.map { value in modulus.reduce(value) }
}

switch packing {
case .denseColumn:
let plaintexts = try PlaintextMatrix.denseColumnPlaintexts(
Expand Down Expand Up @@ -421,7 +463,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
/// - Returns: The stored data values in row-major format.
/// - Throws: Error upon failure to unpack the matrix.
@inlinable
func unpack<V: ScalarType>() throws -> [V] where Format == Coeff {
func unpack() throws -> [Scheme.Scalar] where Format == Coeff {
switch packing {
case .denseColumn:
return try unpackDenseColumn()
Expand All @@ -433,6 +475,17 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
}
}

/// Unpacks the plaintext matrix into signed values.
/// - Returns: The stored data values in row-major format.
/// - Throws: Error upon failure to unpack the matrix.
@inlinable
func unpack() throws -> [Scheme.SignedScalar] where Format == Coeff {
let unsigned: [Scheme.Scalar] = try unpack()
return unsigned.map { unsigned in
unsigned.remainderToCentered(modulus: context.plaintextModulus)
}
}

/// Unpacks a plaintext matrix with `denseColumn` packing.
/// - Returns: The stored data values in row-major format.
/// - Throws: Error upon failure to unpack the matrix.
Expand Down
Loading

0 comments on commit 7affb01

Please sign in to comment.