diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift index 9340f07b..d49701fe 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift @@ -26,7 +26,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, values: [some ScalarType], + public static func encode(context: Context>, values: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(values: values, format: format) @@ -34,7 +34,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, signedValues: [some SignedScalarType], + public static func encode(context: Context>, signedValues: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(signedValues: signedValues, format: format) @@ -42,7 +42,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, values: [some ScalarType], format: EncodeFormat, + public static func encode(context: Context>, values: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext { let coeffPlaintext = try Self.encode(context: context, values: values, format: format) @@ -53,7 +53,7 @@ extension Bfv { // swiftlint:disable:next missing_docs attributes public static func encode( context: Context>, - signedValues: [some SignedScalarType], + signedValues: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext { diff --git a/Sources/HomomorphicEncryption/Encoding.swift b/Sources/HomomorphicEncryption/Encoding.swift index 738e8a6f..0e5055cc 100644 --- a/Sources/HomomorphicEncryption/Encoding.swift +++ b/Sources/HomomorphicEncryption/Encoding.swift @@ -25,7 +25,9 @@ extension Context { /// - Returns: The plaintext encoding `values`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext { + public func encode(values: some Collection, + format: EncodeFormat) throws -> Plaintext + { try validDataForEncoding(values: values) switch format { case .coefficient: @@ -44,7 +46,9 @@ extension Context { /// - Returns: The plaintext encoding `signedValues`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(signedValues: [some SignedScalarType], format: EncodeFormat) throws -> Plaintext { + public func encode(signedValues: some Collection, + format: EncodeFormat) throws -> Plaintext + { let signedModulus = Scheme.Scalar.SignedScalar(plaintextModulus) let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1) let centeredValues = try signedValues.map { value in @@ -60,13 +64,14 @@ extension Context { /// - Parameters: /// - values: Values to encode. /// - format: Encoding format. - /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the + /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with + /// all the /// moduli. /// - Returns: The plaintext encoding `values`. /// - Throws: Error upon failure to encode. @inlinable public func encode( - values: [some ScalarType], + values: some Collection, format: EncodeFormat, moduliCount: Int? = nil) throws -> Plaintext { @@ -77,12 +82,13 @@ extension Context { /// - Parameters: /// - signedValues: Signed values to encode. /// - format: Encoding format. - /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the + /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with + /// all the /// moduli. /// - Returns: The plaintext encoding `signedValues`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(signedValues: [some SignedScalarType], format: EncodeFormat, + public func encode(signedValues: some Collection, format: EncodeFormat, moduliCount: Int? = nil) throws -> Plaintext { try Scheme.encode(context: self, signedValues: signedValues, format: format, moduliCount: moduliCount) @@ -135,7 +141,7 @@ extension Context { } @inlinable - func validDataForEncoding(values: [some ScalarType]) throws { + func validDataForEncoding(values: some Collection) throws { guard values.count <= encryptionParameters.polyDegree else { throw HeError.encodingDataCountExceedsLimit(count: values.count, limit: encryptionParameters.polyDegree) } @@ -155,7 +161,7 @@ extension Context { /// `f(x) = values_0 + values_1 x + ... values_{N_1} x^{N-1}`, padding /// with 0 coefficients if fewer than `N` values are provided. @inlinable - func encodeCoefficient(values: [some ScalarType]) throws + func encodeCoefficient(values: some Collection) throws -> Plaintext { if values.isEmpty { @@ -163,7 +169,7 @@ extension Context { } let polyDegree = plaintextContext.degree var array: Array2d = Array2d(array: Array2d( - data: values, + data: Array(values), rowCount: 1, columnCount: values.count)) array.resizeColumn(newColumnCount: polyDegree, defaultValue: Scheme.Scalar(0)) @@ -214,12 +220,12 @@ extension Context { } @inlinable - func encodeSimd(values: [some ScalarType]) throws -> Plaintext { + func encodeSimd(values: some Collection) throws -> Plaintext { guard !simdEncodingMatrix.isEmpty else { throw HeError.simdEncodingNotSupported(for: encryptionParameters) } let polyDegree = encryptionParameters.polyDegree var array = Array2d.zero(rowCount: 1, columnCount: polyDegree) - for index in 0..(context: plaintextContext, data: array) let coeffPoly = try poly.inverseNtt() diff --git a/Sources/HomomorphicEncryption/HeScheme.swift b/Sources/HomomorphicEncryption/HeScheme.swift index 84c11b40..175598a4 100644 --- a/Sources/HomomorphicEncryption/HeScheme.swift +++ b/Sources/HomomorphicEncryption/HeScheme.swift @@ -172,7 +172,8 @@ public protocol HeScheme { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(values:format:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:signedValues:format:)`` to encode signed values. - static func encode(context: Context, values: [some ScalarType], format: EncodeFormat) throws -> CoeffPlaintext + static func encode(context: Context, values: some Collection, format: EncodeFormat) throws + -> CoeffPlaintext /// Encodes signed values into a plaintext with coefficient format. /// @@ -184,7 +185,7 @@ public protocol HeScheme { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(signedValues:format:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:values:format)`` to encode unsigned values. - static func encode(context: Context, signedValues: [some SignedScalarType], format: EncodeFormat) throws + static func encode(context: Context, signedValues: some Collection, format: EncodeFormat) throws -> CoeffPlaintext /// Encodes values into a plaintext with evaluation format. @@ -194,13 +195,14 @@ public protocol HeScheme { /// - context: Context for HE computation. /// - values: Values to encode. /// - format: Encoding format. - /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the + /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with + /// all the /// moduli. /// - Returns: A plaintext encoding `values`. /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:signedValues:format:moduliCount:)`` to encode signed values. - static func encode(context: Context, values: [some ScalarType], format: EncodeFormat, + static func encode(context: Context, values: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext /// Encodes signed values into a plaintext with evaluation format. @@ -210,13 +212,14 @@ public protocol HeScheme { /// - context: Context for HE computation. /// - signedValues: Signed values to encode. /// - format: Encoding format. - /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the + /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext context with + /// all the /// moduli. /// - Returns: A plaintext encoding `signedValues`. /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(signedValues:format:moduliCount:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)`` to encode unsigned values. - static func encode(context: Context, signedValues: [some SignedScalarType], format: EncodeFormat, + static func encode(context: Context, signedValues: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext /// Decodes a plaintext in ``Coeff`` format. diff --git a/Sources/HomomorphicEncryption/NoOpScheme.swift b/Sources/HomomorphicEncryption/NoOpScheme.swift index bc4ee300..15ecf5cb 100644 --- a/Sources/HomomorphicEncryption/NoOpScheme.swift +++ b/Sources/HomomorphicEncryption/NoOpScheme.swift @@ -56,19 +56,19 @@ public enum NoOpScheme: HeScheme { return SimdEncodingDimensions(rowCount: 2, columnCount: parameters.polyDegree / 2) } - public static func encode(context: Context, values: [some ScalarType], + public static func encode(context: Context, values: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(values: values, format: format) } - public static func encode(context: Context, signedValues: [some SignedScalarType], + public static func encode(context: Context, signedValues: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(signedValues: signedValues, format: format) } - public static func encode(context: Context, values: [some ScalarType], + public static func encode(context: Context, values: some Collection, format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext { let coeffPlaintext = try Self.encode(context: context, values: values, format: format) @@ -77,7 +77,7 @@ public enum NoOpScheme: HeScheme { public static func encode( context: Context, - signedValues: [some SignedScalarType], + signedValues: some Collection, format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext { diff --git a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift index 9212993b..92d06fe3 100644 --- a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift +++ b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift @@ -270,8 +270,8 @@ public struct PlaintextMatrix: Equatable, /// - Returns: The plaintexts for `denseColumn` packing. /// - Throws: Error upon plaintext to compute the plaintexts. @inlinable - static func denseColumnPlaintexts(context: Context, dimensions: MatrixDimensions, - values: [V]) throws -> [Scheme.CoeffPlaintext] + static func denseColumnPlaintexts(context: Context, dimensions: MatrixDimensions, + values: [Scheme.Scalar]) throws -> [Scheme.CoeffPlaintext] { let degree = context.degree guard let simdColumnCount = context.simdDimensions?.columnCount else { @@ -286,7 +286,7 @@ public struct PlaintextMatrix: Equatable, var plaintexts: [Scheme.CoeffPlaintext] = [] plaintexts.reserveCapacity(expectedPlaintextCount) - var packedValues: [V] = [] + var packedValues: [Scheme.Scalar] = [] packedValues.reserveCapacity(degree) for colIndex in 0..: Equatable, if packedValues.count < simdColumnCount, (simdColumnCount + 1...degree).contains(nextColumnCount) { // Next data column fits in next SIMD row; pad 0s to this SIMD row let padCount = (context.degree - packedValues.count) % simdColumnCount - packedValues += [V](repeating: 0, count: padCount) + packedValues += [Scheme.Scalar](repeating: 0, count: padCount) } else if nextColumnCount > degree { // Next data column requires new plaintext try plaintexts.append(context.encode(values: packedValues, format: .simd)) @@ -326,10 +326,10 @@ public struct PlaintextMatrix: Equatable, /// - Returns: The plaintexts for `denseRow` packing. /// - Throws: Error upon failure to compute the plaintexts. @inlinable - static func denseRowPlaintexts( + static func denseRowPlaintexts( context: Context, dimensions: MatrixDimensions, - values: [V]) throws -> [Plaintext] + values: [Scheme.Scalar]) throws -> [Plaintext] { let encryptionParameters = context.encryptionParameters guard let simdDimensions = context.simdDimensions else { @@ -350,9 +350,9 @@ public struct PlaintextMatrix: Equatable, // Pad number of columns to next power of two let padColCount = dimensions.columnCount.nextPowerOfTwo - dimensions.columnCount - let padValues = [V](repeating: 0, count: padColCount) + let padValues = [Scheme.Scalar](repeating: 0, count: padColCount) - var packedValues: [V] = [] + var packedValues: [Scheme.Scalar] = [] packedValues.reserveCapacity(context.degree) var valuesIdx = 0 for _ in 0..: Equatable, /// - Returns: The plaintexts for diagonal packing. /// - Throws: Error upon failure to compute the plaintexts. @inlinable - static func diagonalPlaintexts( + static func diagonalPlaintexts( context: Context, dimensions: MatrixDimensions, packing: MatrixPacking, - values: [V]) throws -> [Scheme.CoeffPlaintext] + values: [Scheme.Scalar]) throws -> [Scheme.CoeffPlaintext] { let encryptionParameters = context.encryptionParameters guard let simdDimensions = context.simdDimensions else { @@ -424,7 +424,7 @@ public struct PlaintextMatrix: Equatable, let data = Array2d(data: values, rowCount: dimensions.rowCount, columnCount: dimensions.columnCount) // Transposed from original shape, with extra zero columns. // Encode diagonals - var packedValues = Array2d.zero( + var packedValues = Array2d.zero( rowCount: dimensions.columnCount.nextPowerOfTwo, columnCount: dimensions.rowCount) for rowIndex in 0..: Equatable, chunk[chunk.startIndex.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try context.encode(values: data, - format: .coefficient) + let plaintext: Plaintext = try context.encode(values: data, format: .coefficient) let secretKey = try context.generateSecretKey() let expandedQueryCount = degree