Skip to content

Commit

Permalink
Refactor FileDataPart
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Oct 4, 2024
1 parent d5fe328 commit 97d9d46
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 64 deletions.
44 changes: 13 additions & 31 deletions FirebaseVertexAI/Sources/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ public struct ModelContent: Equatable, Sendable {
case let .inlineData(mimetype, data):
convertedParts.append(InlineDataPart(data: data, mimeType: mimetype))
case let .fileData(mimetype, uri):
convertedParts.append(FileDataPart(fileData: FileData(mimeType: mimetype, uri: uri)))
convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype))
case let .functionCall(functionCall):
convertedParts.append(FunctionCallPart(functionCall: functionCall))
convertedParts.append(FunctionCallPart(functionCall))
case let .functionResponse(functionResponse):
convertedParts.append(FunctionResponsePart(functionResponse: functionResponse))
}
Expand All @@ -120,7 +120,7 @@ public struct ModelContent: Equatable, Sendable {
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
case let functionResponsePart as FunctionResponsePart:
Expand All @@ -145,7 +145,7 @@ public struct ModelContent: Equatable, Sendable {
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
case let functionResponsePart as FunctionResponsePart:
Expand Down Expand Up @@ -192,31 +192,15 @@ extension ModelContent.InternalPart: Codable {
case functionResponse
}

enum InlineDataKeys: String, CodingKey {
case mimeType = "mime_type"
case bytes = "data"
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case let .text(a0):
try container.encode(a0, forKey: .text)
case let .text(text):
try container.encode(text, forKey: .text)
case let .inlineData(mimetype, bytes):
var inlineDataContainer = container.nestedContainer(
keyedBy: InlineDataKeys.self,
forKey: .inlineData
)
try inlineDataContainer.encode(mimetype, forKey: .mimeType)
try inlineDataContainer.encode(bytes, forKey: .bytes)
try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData)
case let .fileData(mimetype: mimetype, url):
// var fileDataContainer = container.nestedContainer(
// keyedBy: FileDataKeys.self,
// forKey: .fileData
// )
try container.encode(FileData(mimeType: mimetype, uri: url), forKey: .fileData)
// try fileDataContainer.encode(mimetype, forKey: .mimeType)
// try fileDataContainer.encode(url, forKey: .uri)
try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData)
case let .functionCall(functionCall):
try container.encode(functionCall, forKey: .functionCall)
case let .functionResponse(functionResponse):
Expand All @@ -229,13 +213,11 @@ extension ModelContent.InternalPart: Codable {
if values.contains(.text) {
self = try .text(values.decode(String.self, forKey: .text))
} else if values.contains(.inlineData) {
let dataContainer = try values.nestedContainer(
keyedBy: InlineDataKeys.self,
forKey: .inlineData
)
let mimetype = try dataContainer.decode(String.self, forKey: .mimeType)
let bytes = try dataContainer.decode(Data.self, forKey: .bytes)
self = .inlineData(mimetype: mimetype, bytes)
let inlineData = try values.decode(InlineData.self, forKey: .inlineData)
self = .inlineData(mimetype: inlineData.mimeType, inlineData.data)
} else if values.contains(.fileData) {
let fileData = try values.decode(FileData.self, forKey: .fileData)
self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)
} else if values.contains(.functionCall) {
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
} else if values.contains(.functionResponse) {
Expand Down
11 changes: 11 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ struct InlineData: Codable, Equatable, Sendable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct FileData: Codable, Equatable, Sendable {
let fileURI: String
let mimeType: String

init(fileURI: String, mimeType: String) {
self.fileURI = fileURI
self.mimeType = mimeType
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ErrorPart: Part, Error {
let error: Error
Expand Down
27 changes: 10 additions & 17 deletions FirebaseVertexAI/Sources/Types/Public/Part.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,41 +43,34 @@ public struct InlineDataPart: Part {
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FileData: Codable, Equatable, Sendable {
enum CodingKeys: String, CodingKey {
case mimeType = "mime_type"
case uri = "file_uri"
}
public struct FileDataPart: Part {
let fileData: FileData

public let mimeType: String
public let uri: String
public var uri: String { fileData.fileURI }
public var mimeType: String { fileData.mimeType }

public init(mimeType: String, uri: String) {
self.mimeType = mimeType
self.uri = uri
public init(uri: String, mimeType: String) {
self.init(FileData(fileURI: uri, mimeType: mimeType))
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FileDataPart: Part {
public let fileData: FileData

public init(fileData: FileData) {
init(_ fileData: FileData) {
self.fileData = fileData
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionCallPart: Part {
// TODO: Consider making FunctionCall internal and exposing params on FunctionCallPart instead.
public let functionCall: FunctionCall

public init(functionCall: FunctionCall) {
public init(_ functionCall: FunctionCall) {
self.functionCall = functionCall
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionResponsePart: Part {
// TODO: Consider making FunctionResponsePart internal and exposing params here instead.
public let functionResponse: FunctionResponse

public init(functionResponse: FunctionResponse) {
Expand Down
65 changes: 50 additions & 15 deletions FirebaseVertexAI/Tests/Unit/PartTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,28 @@ final class PartTests: XCTestCase {

// MARK: - Part Decoding

func testDecodeTextPart() throws {
let expectedText = "Hello, world!"
let json = """
{
"text" : "\(expectedText)"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let part = try decoder.decode(TextPart.self, from: jsonData)

XCTAssertEqual(part.text, expectedText)
}

func testDecodeInlineDataPart() throws {
let imageBase64 = try blueSquareImage()
let imageBase64 = try PartTests.blueSquareImage()
let mimeType = "image/png"

let json = """
{
"inlineData": {
"data": "\(imageBase64)",
"mimeType": "\(mimeType)"
"inlineData" : {
"data" : "\(imageBase64)",
"mimeType" : "\(mimeType)"
}
}
"""
Expand Down Expand Up @@ -75,9 +88,23 @@ final class PartTests: XCTestCase {

// MARK: - Part Encoding

func testEncodeTextPart() throws {
let expectedText = "Hello, world!"
let textPart = TextPart(expectedText)

let jsonData = try encoder.encode(textPart)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"text" : "\(expectedText)"
}
""")
}

func testEncodeInlineDataPart() throws {
let mimeType = "image/png"
let imageBase64 = try blueSquareImage()
let imageBase64 = try PartTests.blueSquareImage()
let imageBase64Data = Data(base64Encoded: imageBase64)
let inlineDataPart = InlineDataPart(data: imageBase64Data!, mimeType: mimeType)

Expand All @@ -97,26 +124,34 @@ final class PartTests: XCTestCase {
func testEncodeFileDataPart() throws {
let mimeType = "image/jpeg"
let fileURI = "gs://test-bucket/image.jpg"
let fileDataPart = FileDataPart(fileData: FileData(mimeType: mimeType, uri: fileURI))
let fileDataPart = FileDataPart(uri: fileURI, mimeType: mimeType)

let jsonData = try encoder.encode(fileDataPart)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"fileData" : {
"file_uri" : "\(fileURI)",
"mime_type" : "\(mimeType)"
"fileURI" : "\(fileURI)",
"mimeType" : "\(mimeType)"
}
}
""")
}
}

// MARK: - Helpers
// MARK: - Helpers

func blueSquareImage() throws -> String {
let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")!
let imageData = try Data(contentsOf: imageURL)
return imageData.base64EncodedString()
private static func bundle() -> Bundle {
#if SWIFT_PACKAGE
return Bundle.module
#else // SWIFT_PACKAGE
return Bundle(for: Self.self)
#endif // SWIFT_PACKAGE
}

private static func blueSquareImage() throws -> String {
let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")!
let imageData = try Data(contentsOf: imageURL)
return imageData.base64EncodedString()
}
}
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ final class VertexAIAPITests: XCTestCase {
let _ = try await genAI.generateContent(str, "abc", "def")
let _ = try await genAI.generateContent(
str,
FileDataPart(fileData: FileData(mimeType: "image/jpeg", uri: "gs://test-bucket/image.jpg"))
FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg")
)
#if canImport(UIKit)
_ = try await genAI.generateContent(UIImage())
Expand Down

0 comments on commit 97d9d46

Please sign in to comment.