From 97d9d46eca48844591ae239c47ddcf68e0ceb2ee Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Fri, 4 Oct 2024 15:48:49 -0400 Subject: [PATCH] Refactor `FileDataPart` --- FirebaseVertexAI/Sources/ModelContent.swift | 44 ++++--------- .../Sources/Types/Internal/InternalPart.swift | 11 ++++ .../Sources/Types/Public/Part.swift | 27 +++----- FirebaseVertexAI/Tests/Unit/PartTests.swift | 65 ++++++++++++++----- .../Tests/Unit/VertexAIAPITests.swift | 2 +- 5 files changed, 85 insertions(+), 64 deletions(-) diff --git a/FirebaseVertexAI/Sources/ModelContent.swift b/FirebaseVertexAI/Sources/ModelContent.swift index daed456f6cf3..766712af15f6 100644 --- a/FirebaseVertexAI/Sources/ModelContent.swift +++ b/FirebaseVertexAI/Sources/ModelContent.swift @@ -93,9 +93,9 @@ public struct ModelContent: Equatable, Sendable { case let .inlineData(mimetype, data): convertedParts.append(InlineDataPart(data: data, mimeType: mimetype)) case let .fileData(mimetype, uri): - convertedParts.append(FileDataPart(fileData: FileData(mimeType: mimetype, uri: uri))) + convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype)) case let .functionCall(functionCall): - convertedParts.append(FunctionCallPart(functionCall: functionCall)) + convertedParts.append(FunctionCallPart(functionCall)) case let .functionResponse(functionResponse): convertedParts.append(FunctionResponsePart(functionResponse: functionResponse)) } @@ -120,7 +120,7 @@ public struct ModelContent: Equatable, Sendable { convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) case let fileDataPart as FileDataPart: let fileData = fileDataPart.fileData - convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri)) + convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) case let functionCallPart as FunctionCallPart: convertedParts.append(.functionCall(functionCallPart.functionCall)) case let functionResponsePart as FunctionResponsePart: @@ -145,7 +145,7 @@ public struct ModelContent: Equatable, Sendable { convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) case let fileDataPart as FileDataPart: let fileData = fileDataPart.fileData - convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri)) + convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) case let functionCallPart as FunctionCallPart: convertedParts.append(.functionCall(functionCallPart.functionCall)) case let functionResponsePart as FunctionResponsePart: @@ -192,31 +192,15 @@ extension ModelContent.InternalPart: Codable { case functionResponse } - enum InlineDataKeys: String, CodingKey { - case mimeType = "mime_type" - case bytes = "data" - } - public func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) switch self { - case let .text(a0): - try container.encode(a0, forKey: .text) + case let .text(text): + try container.encode(text, forKey: .text) case let .inlineData(mimetype, bytes): - var inlineDataContainer = container.nestedContainer( - keyedBy: InlineDataKeys.self, - forKey: .inlineData - ) - try inlineDataContainer.encode(mimetype, forKey: .mimeType) - try inlineDataContainer.encode(bytes, forKey: .bytes) + try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData) case let .fileData(mimetype: mimetype, url): -// var fileDataContainer = container.nestedContainer( -// keyedBy: FileDataKeys.self, -// forKey: .fileData -// ) - try container.encode(FileData(mimeType: mimetype, uri: url), forKey: .fileData) -// try fileDataContainer.encode(mimetype, forKey: .mimeType) -// try fileDataContainer.encode(url, forKey: .uri) + try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData) case let .functionCall(functionCall): try container.encode(functionCall, forKey: .functionCall) case let .functionResponse(functionResponse): @@ -229,13 +213,11 @@ extension ModelContent.InternalPart: Codable { if values.contains(.text) { self = try .text(values.decode(String.self, forKey: .text)) } else if values.contains(.inlineData) { - let dataContainer = try values.nestedContainer( - keyedBy: InlineDataKeys.self, - forKey: .inlineData - ) - let mimetype = try dataContainer.decode(String.self, forKey: .mimeType) - let bytes = try dataContainer.decode(Data.self, forKey: .bytes) - self = .inlineData(mimetype: mimetype, bytes) + let inlineData = try values.decode(InlineData.self, forKey: .inlineData) + self = .inlineData(mimetype: inlineData.mimeType, inlineData.data) + } else if values.contains(.fileData) { + let fileData = try values.decode(FileData.self, forKey: .fileData) + self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI) } else if values.contains(.functionCall) { self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) } else if values.contains(.functionResponse) { diff --git a/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift index 57c33d7e27cb..34c21eb67c06 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift @@ -25,6 +25,17 @@ struct InlineData: Codable, Equatable, Sendable { } } +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct FileData: Codable, Equatable, Sendable { + let fileURI: String + let mimeType: String + + init(fileURI: String, mimeType: String) { + self.fileURI = fileURI + self.mimeType = mimeType + } +} + @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct ErrorPart: Part, Error { let error: Error diff --git a/FirebaseVertexAI/Sources/Types/Public/Part.swift b/FirebaseVertexAI/Sources/Types/Public/Part.swift index 8601116a88fb..69d083c4339c 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Part.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Part.swift @@ -43,41 +43,34 @@ public struct InlineDataPart: Part { } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public struct FileData: Codable, Equatable, Sendable { - enum CodingKeys: String, CodingKey { - case mimeType = "mime_type" - case uri = "file_uri" - } +public struct FileDataPart: Part { + let fileData: FileData - public let mimeType: String - public let uri: String + public var uri: String { fileData.fileURI } + public var mimeType: String { fileData.mimeType } - public init(mimeType: String, uri: String) { - self.mimeType = mimeType - self.uri = uri + public init(uri: String, mimeType: String) { + self.init(FileData(fileURI: uri, mimeType: mimeType)) } -} - -@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public struct FileDataPart: Part { - public let fileData: FileData - public init(fileData: FileData) { + init(_ fileData: FileData) { self.fileData = fileData } } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionCallPart: Part { + // TODO: Consider making FunctionCall internal and exposing params on FunctionCallPart instead. public let functionCall: FunctionCall - public init(functionCall: FunctionCall) { + public init(_ functionCall: FunctionCall) { self.functionCall = functionCall } } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionResponsePart: Part { + // TODO: Consider making FunctionResponsePart internal and exposing params here instead. public let functionResponse: FunctionResponse public init(functionResponse: FunctionResponse) { diff --git a/FirebaseVertexAI/Tests/Unit/PartTests.swift b/FirebaseVertexAI/Tests/Unit/PartTests.swift index 6cc75dcfb8cb..97b6164fde9b 100644 --- a/FirebaseVertexAI/Tests/Unit/PartTests.swift +++ b/FirebaseVertexAI/Tests/Unit/PartTests.swift @@ -30,15 +30,28 @@ final class PartTests: XCTestCase { // MARK: - Part Decoding + func testDecodeTextPart() throws { + let expectedText = "Hello, world!" + let json = """ + { + "text" : "\(expectedText)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(TextPart.self, from: jsonData) + + XCTAssertEqual(part.text, expectedText) + } + func testDecodeInlineDataPart() throws { - let imageBase64 = try blueSquareImage() + let imageBase64 = try PartTests.blueSquareImage() let mimeType = "image/png" - let json = """ { - "inlineData": { - "data": "\(imageBase64)", - "mimeType": "\(mimeType)" + "inlineData" : { + "data" : "\(imageBase64)", + "mimeType" : "\(mimeType)" } } """ @@ -75,9 +88,23 @@ final class PartTests: XCTestCase { // MARK: - Part Encoding + func testEncodeTextPart() throws { + let expectedText = "Hello, world!" + let textPart = TextPart(expectedText) + + let jsonData = try encoder.encode(textPart) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "text" : "\(expectedText)" + } + """) + } + func testEncodeInlineDataPart() throws { let mimeType = "image/png" - let imageBase64 = try blueSquareImage() + let imageBase64 = try PartTests.blueSquareImage() let imageBase64Data = Data(base64Encoded: imageBase64) let inlineDataPart = InlineDataPart(data: imageBase64Data!, mimeType: mimeType) @@ -97,7 +124,7 @@ final class PartTests: XCTestCase { func testEncodeFileDataPart() throws { let mimeType = "image/jpeg" let fileURI = "gs://test-bucket/image.jpg" - let fileDataPart = FileDataPart(fileData: FileData(mimeType: mimeType, uri: fileURI)) + let fileDataPart = FileDataPart(uri: fileURI, mimeType: mimeType) let jsonData = try encoder.encode(fileDataPart) @@ -105,18 +132,26 @@ final class PartTests: XCTestCase { XCTAssertEqual(json, """ { "fileData" : { - "file_uri" : "\(fileURI)", - "mime_type" : "\(mimeType)" + "fileURI" : "\(fileURI)", + "mimeType" : "\(mimeType)" } } """) } -} -// MARK: - Helpers + // MARK: - Helpers -func blueSquareImage() throws -> String { - let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")! - let imageData = try Data(contentsOf: imageURL) - return imageData.base64EncodedString() + private static func bundle() -> Bundle { + #if SWIFT_PACKAGE + return Bundle.module + #else // SWIFT_PACKAGE + return Bundle(for: Self.self) + #endif // SWIFT_PACKAGE + } + + private static func blueSquareImage() throws -> String { + let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")! + let imageData = try Data(contentsOf: imageURL) + return imageData.base64EncodedString() + } } diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index 9e0976ab6048..d978df1b6809 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -98,7 +98,7 @@ final class VertexAIAPITests: XCTestCase { let _ = try await genAI.generateContent(str, "abc", "def") let _ = try await genAI.generateContent( str, - FileDataPart(fileData: FileData(mimeType: "image/jpeg", uri: "gs://test-bucket/image.jpg")) + FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg") ) #if canImport(UIKit) _ = try await genAI.generateContent(UIImage())