Skip to content

Commit

Permalink
Make encoding more generic (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer authored Sep 3, 2024
1 parent 8119bc3 commit 9a7b256
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 50 deletions.
8 changes: 4 additions & 4 deletions Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ extension Bfv {

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, values: [some ScalarType],
public static func encode(context: Context<Bfv<T>>, values: some Collection<Scalar>,
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(values: values, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, signedValues: [some SignedScalarType],
public static func encode(context: Context<Bfv<T>>, signedValues: some Collection<Scalar.SignedScalar>,
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(signedValues: signedValues, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, values: [some ScalarType], format: EncodeFormat,
public static func encode(context: Context<Bfv<T>>, values: some Collection<Scalar>, format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext
{
let coeffPlaintext = try Self.encode(context: context, values: values, format: format)
Expand All @@ -53,7 +53,7 @@ extension Bfv {
// swiftlint:disable:next missing_docs attributes
public static func encode(
context: Context<Bfv<T>>,
signedValues: [some SignedScalarType],
signedValues: some Collection<Scalar.SignedScalar>,
format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext
{
Expand Down
30 changes: 18 additions & 12 deletions Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ extension Context {
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Coeff> {
public func encode(values: some Collection<Scheme.Scalar>,
format: EncodeFormat) throws -> Plaintext<Scheme, Coeff>
{
try validDataForEncoding(values: values)
switch format {
case .coefficient:
Expand All @@ -44,7 +46,9 @@ extension Context {
/// - Returns: The plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(signedValues: [some SignedScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Coeff> {
public func encode(signedValues: some Collection<Scheme.SignedScalar>,
format: EncodeFormat) throws -> Plaintext<Scheme, Coeff>
{
let signedModulus = Scheme.Scalar.SignedScalar(plaintextModulus)
let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1)
let centeredValues = try signedValues.map { value in
Expand All @@ -60,13 +64,14 @@ extension Context {
/// - Parameters:
/// - values: Values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with
/// all the
/// moduli.
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(
values: [some ScalarType],
values: some Collection<Scheme.Scalar>,
format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
{
Expand All @@ -77,12 +82,13 @@ extension Context {
/// - Parameters:
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with
/// all the
/// moduli.
/// - Returns: The plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(signedValues: [some SignedScalarType], format: EncodeFormat,
public func encode(signedValues: some Collection<Scheme.SignedScalar>, format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
{
try Scheme.encode(context: self, signedValues: signedValues, format: format, moduliCount: moduliCount)
Expand Down Expand Up @@ -135,7 +141,7 @@ extension Context {
}

@inlinable
func validDataForEncoding(values: [some ScalarType]) throws {
func validDataForEncoding(values: some Collection<Scheme.Scalar>) throws {
guard values.count <= encryptionParameters.polyDegree else {
throw HeError.encodingDataCountExceedsLimit(count: values.count, limit: encryptionParameters.polyDegree)
}
Expand All @@ -155,15 +161,15 @@ extension Context {
/// `f(x) = values_0 + values_1 x + ... values_{N_1} x^{N-1}`, padding
/// with 0 coefficients if fewer than `N` values are provided.
@inlinable
func encodeCoefficient(values: [some ScalarType]) throws
func encodeCoefficient(values: some Collection<Scheme.Scalar>) throws
-> Plaintext<Scheme, Coeff>
{
if values.isEmpty {
return Plaintext<Scheme, Coeff>(context: self, poly: PolyRq.zero(context: plaintextContext))
}
let polyDegree = plaintextContext.degree
var array: Array2d<Scheme.Scalar> = Array2d(array: Array2d(
data: values,
data: Array(values),
rowCount: 1,
columnCount: values.count))
array.resizeColumn(newColumnCount: polyDegree, defaultValue: Scheme.Scalar(0))
Expand Down Expand Up @@ -214,12 +220,12 @@ extension Context {
}

@inlinable
func encodeSimd(values: [some ScalarType]) throws -> Plaintext<Scheme, Coeff> {
func encodeSimd(values: some Collection<Scheme.Scalar>) throws -> Plaintext<Scheme, Coeff> {
guard !simdEncodingMatrix.isEmpty else { throw HeError.simdEncodingNotSupported(for: encryptionParameters) }
let polyDegree = encryptionParameters.polyDegree
var array = Array2d<Scheme.Scalar>.zero(rowCount: 1, columnCount: polyDegree)
for index in 0..<values.count {
array[0, simdEncodingMatrix[index]] = Scheme.Scalar(values[index])
for (index, value) in values.enumerated() {
array[0, simdEncodingMatrix[index]] = Scheme.Scalar(value)
}
let poly = PolyRq<_, Eval>(context: plaintextContext, data: array)
let coeffPoly = try poly.inverseNtt()
Expand Down
15 changes: 9 additions & 6 deletions Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ public protocol HeScheme {
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:signedValues:format:)`` to encode signed values.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat) throws -> CoeffPlaintext
static func encode(context: Context<Self>, values: some Collection<Scalar>, format: EncodeFormat) throws
-> CoeffPlaintext

/// Encodes signed values into a plaintext with coefficient format.
///
Expand All @@ -184,7 +185,7 @@ public protocol HeScheme {
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(signedValues:format:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:values:format)`` to encode unsigned values.
static func encode(context: Context<Self>, signedValues: [some SignedScalarType], format: EncodeFormat) throws
static func encode(context: Context<Self>, signedValues: some Collection<SignedScalar>, format: EncodeFormat) throws
-> CoeffPlaintext

/// Encodes values into a plaintext with evaluation format.
Expand All @@ -194,13 +195,14 @@ public protocol HeScheme {
/// - context: Context for HE computation.
/// - values: Values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with
/// all the
/// moduli.
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:signedValues:format:moduliCount:)`` to encode signed values.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat,
static func encode(context: Context<Self>, values: some Collection<Scalar>, format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext

/// Encodes signed values into a plaintext with evaluation format.
Expand All @@ -210,13 +212,14 @@ public protocol HeScheme {
/// - context: Context for HE computation.
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with
/// all the
/// moduli.
/// - Returns: A plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(signedValues:format:moduliCount:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)`` to encode unsigned values.
static func encode(context: Context<Self>, signedValues: [some SignedScalarType], format: EncodeFormat,
static func encode(context: Context<Self>, signedValues: some Collection<Scalar.SignedScalar>, format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext

/// Decodes a plaintext in ``Coeff`` format.
Expand Down
8 changes: 4 additions & 4 deletions Sources/HomomorphicEncryption/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ public enum NoOpScheme: HeScheme {
return SimdEncodingDimensions(rowCount: 2, columnCount: parameters.polyDegree / 2)
}

public static func encode(context: Context<NoOpScheme>, values: [some ScalarType],
public static func encode(context: Context<NoOpScheme>, values: some Collection<Scalar>,
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(values: values, format: format)
}

public static func encode(context: Context<NoOpScheme>, signedValues: [some SignedScalarType],
public static func encode(context: Context<NoOpScheme>, signedValues: some Collection<SignedScalar>,
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(signedValues: signedValues, format: format)
}

public static func encode(context: Context<NoOpScheme>, values: [some ScalarType],
public static func encode(context: Context<NoOpScheme>, values: some Collection<Scalar>,
format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext
{
let coeffPlaintext = try Self.encode(context: context, values: values, format: format)
Expand All @@ -77,7 +77,7 @@ public enum NoOpScheme: HeScheme {

public static func encode(
context: Context<NoOpScheme>,
signedValues: [some SignedScalarType],
signedValues: some Collection<SignedScalar>,
format: EncodeFormat,
moduliCount _: Int?) throws -> EvalPlaintext
{
Expand Down
24 changes: 12 additions & 12 deletions Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
/// - Returns: The plaintexts for `denseColumn` packing.
/// - Throws: Error upon plaintext to compute the plaintexts.
@inlinable
static func denseColumnPlaintexts<V: ScalarType>(context: Context<Scheme>, dimensions: MatrixDimensions,
values: [V]) throws -> [Scheme.CoeffPlaintext]
static func denseColumnPlaintexts(context: Context<Scheme>, dimensions: MatrixDimensions,
values: [Scheme.Scalar]) throws -> [Scheme.CoeffPlaintext]
{
let degree = context.degree
guard let simdColumnCount = context.simdDimensions?.columnCount else {
Expand All @@ -286,7 +286,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
var plaintexts: [Scheme.CoeffPlaintext] = []
plaintexts.reserveCapacity(expectedPlaintextCount)

var packedValues: [V] = []
var packedValues: [Scheme.Scalar] = []
packedValues.reserveCapacity(degree)
for colIndex in 0..<dimensions.columnCount {
for rowIndex in 0..<dimensions.rowCount {
Expand All @@ -303,7 +303,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
if packedValues.count < simdColumnCount, (simdColumnCount + 1...degree).contains(nextColumnCount) {
// Next data column fits in next SIMD row; pad 0s to this SIMD row
let padCount = (context.degree - packedValues.count) % simdColumnCount
packedValues += [V](repeating: 0, count: padCount)
packedValues += [Scheme.Scalar](repeating: 0, count: padCount)
} else if nextColumnCount > degree {
// Next data column requires new plaintext
try plaintexts.append(context.encode(values: packedValues, format: .simd))
Expand All @@ -326,10 +326,10 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
/// - Returns: The plaintexts for `denseRow` packing.
/// - Throws: Error upon failure to compute the plaintexts.
@inlinable
static func denseRowPlaintexts<V: ScalarType>(
static func denseRowPlaintexts(
context: Context<Scheme>,
dimensions: MatrixDimensions,
values: [V]) throws -> [Plaintext<Scheme, Coeff>]
values: [Scheme.Scalar]) throws -> [Plaintext<Scheme, Coeff>]
{
let encryptionParameters = context.encryptionParameters
guard let simdDimensions = context.simdDimensions else {
Expand All @@ -350,9 +350,9 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,

// Pad number of columns to next power of two
let padColCount = dimensions.columnCount.nextPowerOfTwo - dimensions.columnCount
let padValues = [V](repeating: 0, count: padColCount)
let padValues = [Scheme.Scalar](repeating: 0, count: padColCount)

var packedValues: [V] = []
var packedValues: [Scheme.Scalar] = []
packedValues.reserveCapacity(context.degree)
var valuesIdx = 0
for _ in 0..<dimensions.rowCount {
Expand Down Expand Up @@ -400,11 +400,11 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
/// - Returns: The plaintexts for diagonal packing.
/// - Throws: Error upon failure to compute the plaintexts.
@inlinable
static func diagonalPlaintexts<V: ScalarType>(
static func diagonalPlaintexts(
context: Context<Scheme>,
dimensions: MatrixDimensions,
packing: MatrixPacking,
values: [V]) throws -> [Scheme.CoeffPlaintext]
values: [Scheme.Scalar]) throws -> [Scheme.CoeffPlaintext]
{
let encryptionParameters = context.encryptionParameters
guard let simdDimensions = context.simdDimensions else {
Expand All @@ -424,7 +424,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
let data = Array2d(data: values, rowCount: dimensions.rowCount, columnCount: dimensions.columnCount)
// Transposed from original shape, with extra zero columns.
// Encode diagonals
var packedValues = Array2d<V>.zero(
var packedValues = Array2d<Scheme.Scalar>.zero(
rowCount: dimensions.columnCount.nextPowerOfTwo,
columnCount: dimensions.rowCount)
for rowIndex in 0..<packedValues.rowCount {
Expand Down Expand Up @@ -459,7 +459,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
chunk[chunk.startIndex..<middle].rotate(toStartAt: middle - rotationStep)
chunk[middle...].rotate(toStartAt: chunk.endIndex - rotationStep)
}
let plaintext = try context.encode(values: Array(chunk), format: .simd)
let plaintext = try context.encode(values: chunk, format: .simd)
plaintexts.append(plaintext)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@ class ConversionTests: XCTestCase {
let context: Context<Scheme> = try TestUtils.getTestContext()
let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..<context.plaintextModulus)
do { // CoeffPlaintext
let plaintext: Scheme.CoeffPlaintext = try context.encode(values: values,
format: format)
let plaintext: Scheme.CoeffPlaintext = try context.encode(values: values, format: format)
let proto = plaintext.serialize().proto()
let deserialized: Scheme.CoeffPlaintext = try Plaintext(deserialize: proto.native(), context: context)
XCTAssertEqual(deserialized, plaintext)
}
do { // EvalPlaintext
let plaintext: Scheme.EvalPlaintext = try context.encode(values: values,
format: format)
let plaintext: Scheme.EvalPlaintext = try context.encode(values: values, format: format)
let proto = plaintext.serialize().proto()
let deserialized: Scheme.EvalPlaintext = try Plaintext(deserialize: proto.native(), context: context)
XCTAssertEqual(deserialized, plaintext)
Expand Down
3 changes: 1 addition & 2 deletions Tests/HomomorphicEncryptionTests/HeAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ class HeAPITests: XCTestCase {
XCTAssert(evalCiphertext.isTransparent())
XCTAssert(canonicalCiphertext.isTransparent())

let zeroPlaintext: Scheme.CoeffPlaintext = try context.encode(values: zeros,
format: .coefficient)
let zeroPlaintext: Scheme.CoeffPlaintext = try context.encode(values: zeros, format: .coefficient)
let nonTransparentZero = try zeroPlaintext.encrypt(using: testEnv.secretKey)
if Scheme.self != NoOpScheme.self {
XCTAssertFalse(nonTransparentZero.isTransparent())
Expand Down
6 changes: 2 additions & 4 deletions Tests/HomomorphicEncryptionTests/SerializationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,13 @@ class SerializationTests: XCTestCase {
let context: Context<Scheme> = try TestUtils.getTestContext()
let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..<context.plaintextModulus)
do { // CoeffPlaintext
let plaintext: Scheme.CoeffPlaintext = try context.encode(values: values,
format: format)
let plaintext: Scheme.CoeffPlaintext = try context.encode(values: values, format: format)
let serialized = plaintext.serialize()
let deserialized: Scheme.CoeffPlaintext = try Plaintext(deserialize: serialized, context: context)
XCTAssertEqual(deserialized, plaintext)
}
do { // EvalPlaintext
let plaintext: Scheme.EvalPlaintext = try context.encode(values: values,
format: format)
let plaintext: Scheme.EvalPlaintext = try context.encode(values: values, format: format)
let serialized = plaintext.serialize()
let deserialized: Scheme.EvalPlaintext = try Plaintext(deserialize: serialized, context: context)
XCTAssertEqual(deserialized, plaintext)
Expand Down
Loading

0 comments on commit 9a7b256

Please sign in to comment.