Skip to content

Commit

Permalink
[Vertex AI] Replace ModelContent.Part enum with protocol/structs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Oct 7, 2024
1 parent a3e7a20 commit f27e34d
Show file tree
Hide file tree
Showing 22 changed files with 697 additions and 495 deletions.
1 change: 1 addition & 0 deletions FirebaseVertexAI.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Firebase SDK.
]
unit_tests.resources = [
unit_tests_dir + 'vertexai-sdk-test-data/mock-responses/**/*.{txt,json}',
unit_tests_dir + 'Resources/**/*',
]
end
end
6 changes: 6 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
- [changed] **Breaking Change**: The `CountTokensError` enum has been removed;
errors occurring in `GenerativeModel.countTokens(...)` are now thrown directly
instead of being wrapped in a `CountTokensError.internalError`. (#13736)
- [changed] **Breaking Change**: The enum `ModelContent.Part` has been replaced
with a protocol named `Part` to avoid future breaking changes with new part
types. The new types `TextPart` and `FunctionCallPart` may be received when
generating content the types `TextPart`; additionally the types
`InlineDataPart`, `FileDataPart` and `FunctionResponsePart` may be provided
as input. (#13767)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FunctionCallingViewModel: ObservableObject {
}

/// Function calls pending processing
private var functionCalls = [FunctionCall]()
private var functionCalls = [FunctionCallPart]()

private var model: GenerativeModel
private var chat: Chat
Expand Down Expand Up @@ -144,26 +144,26 @@ class FunctionCallingViewModel: ObservableObject {

for part in candidate.content.parts {
switch part {
case let .text(text):
case let textPart as TextPart:
// replace pending message with backend response
messages[messages.count - 1].message += text
messages[messages.count - 1].message += textPart.text
messages[messages.count - 1].pending = false
case let .functionCall(functionCall):
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCall)
case .inlineData, .fileData, .functionResponse:
fatalError("Unsupported response content.")
case let functionCallPart as FunctionCallPart:
messages.insert(functionCallPart.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCallPart)
default:
fatalError("Unsupported response part: \(part)")
}
}
}

func processFunctionCalls() async throws -> [FunctionResponse] {
var functionResponses = [FunctionResponse]()
func processFunctionCalls() async throws -> [FunctionResponsePart] {
var functionResponses = [FunctionResponsePart]()
for functionCall in functionCalls {
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
functionResponses.append(FunctionResponse(
functionResponses.append(FunctionResponsePart(
name: "get_exchange_rate",
response: exchangeRates
))
Expand Down Expand Up @@ -208,7 +208,7 @@ class FunctionCallingViewModel: ObservableObject {
}
}

private extension FunctionCall {
private extension FunctionCallPart {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted
Expand All @@ -228,7 +228,7 @@ private extension FunctionCall {
}
}

private extension FunctionResponse {
private extension FunctionResponsePart {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted
Expand All @@ -248,12 +248,8 @@ private extension FunctionResponse {
}
}

private extension [FunctionResponse] {
private extension [FunctionResponsePart] {
func modelContent() -> [ModelContent] {
return self.map { ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse($0)]
)
}
return self.map { ModelContent(role: "function", parts: [$0]) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {

let prompt = "Look at the image(s), and then answer the following question: \(userInput)"

var images = [any ThrowingPartsRepresentable]()
var images = [any PartsRepresentable]()
for item in selectedItems {
if let data = try? await item.loadTransferable(type: Data.self) {
guard let image = UIImage(data: data) else {
Expand Down
43 changes: 12 additions & 31 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
public func sendMessage(_ parts: any PartsRepresentable...) async throws
-> GenerateContentResponse {
return try await sendMessage([ModelContent(parts: parts)])
}
Expand All @@ -45,19 +45,10 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
public func sendMessage(_ content: [ModelContent]) async throws
-> GenerateContentResponse {
// Ensure that the new content has the role set.
let newContent: [ModelContent]
do {
newContent = try content().map(populateContentRole(_:))
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
} else {
throw GenerateContentError.internalError(underlying: underlying)
}
}
let newContent = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down Expand Up @@ -85,7 +76,7 @@ public class Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws
public func sendMessageStream(_ parts: any PartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream([ModelContent(parts: parts)])
}
Expand All @@ -95,24 +86,14 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws
public func sendMessageStream(_ content: [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []

// Ensure that the new content has the role set.
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
let newContent: [ModelContent] = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down Expand Up @@ -146,20 +127,20 @@ public class Chat {
}

private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent {
var parts: [ModelContent.Part] = []
var parts: [any Part] = []
var combinedText = ""
for aggregate in chunks {
// Loop through all the parts, aggregating the text and adding the images.
for part in aggregate.parts {
switch part {
case let .text(str):
combinedText += str
case let textPart as TextPart:
combinedText += textPart.text

case .inlineData, .fileData, .functionCall, .functionResponse:
default:
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
parts.append(.text(combinedText))
parts.append(TextPart(combinedText))
combinedText = ""
}

Expand All @@ -169,7 +150,7 @@ public class Chat {
}

if !combinedText.isEmpty {
parts.append(.text(combinedText))
parts.append(TextPart(combinedText))
}

return ModelContent(role: "model", parts: parts)
Expand Down
65 changes: 0 additions & 65 deletions FirebaseVertexAI/Sources/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@

import Foundation

/// A predicted function call returned from the model.
public struct FunctionCall: Equatable, Sendable {
/// The name of the function to call.
public let name: String

/// The function parameters and values.
public let args: JSONObject

/// Constructs a new function call.
///
/// > Note: A `FunctionCall` is typically received from the model, rather than created manually.
///
/// - Parameters:
/// - name: The name of the function to call.
/// - args: The function parameters and values.
public init(name: String, args: JSONObject) {
self.name = name
self.args = args
}
}

/// Structured representation of a function declaration.
///
/// This `FunctionDeclaration` is a representation of a block of code that can be used as a ``Tool``
Expand Down Expand Up @@ -136,50 +115,8 @@ public struct ToolConfig {
}
}

/// Result output from a ``FunctionCall``.
///
/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object
/// containing any output from the function is used as context to the model. This should contain the
/// result of a ``FunctionCall`` made based on model prediction.
public struct FunctionResponse: Equatable, Sendable {
/// The name of the function that was called.
let name: String

/// The function's response.
let response: JSONObject

/// Constructs a new `FunctionResponse`.
///
/// - Parameters:
/// - name: The name of the function that was called.
/// - response: The function's response.
public init(name: String, response: JSONObject) {
self.name = name
self.response = response
}
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
case args
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
name = try container.decode(String.self, forKey: .name)
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
self.args = args
} else {
args = JSONObject()
}
}
}

extension FunctionCall: Encodable {}

extension FunctionDeclaration: Encodable {
enum CodingKeys: String, CodingKey {
case name
Expand All @@ -202,5 +139,3 @@ extension FunctionCallingConfig: Encodable {}
extension FunctionCallingConfig.Mode: Encodable {}

extension ToolConfig: Encodable {}

extension FunctionResponse: Codable {}
14 changes: 9 additions & 5 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ public struct GenerateContentResponse: Sendable {
return nil
}
let textValues: [String] = candidate.content.parts.compactMap { part in
guard case let .text(text) = part else {
switch part {
case let textPart as TextPart:
return textPart.text
default:
return nil
}
return text
}
guard textValues.count > 0 else {
VertexLog.error(
Expand All @@ -65,15 +67,17 @@ public struct GenerateContentResponse: Sendable {
}

/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
public var functionCalls: [FunctionCall] {
public var functionCalls: [FunctionCallPart] {
guard let candidate = candidates.first else {
return []
}
return candidate.content.parts.compactMap { part in
guard case let .functionCall(functionCall) = part else {
switch part {
case let functionCallPart as FunctionCallPart:
return functionCallPart
default:
return nil
}
return functionCall
}
}

Expand Down
Loading

0 comments on commit f27e34d

Please sign in to comment.