Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Replace ModelContent.Part enum with protocol/structs #13767

Merged
merged 16 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions FirebaseVertexAI.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Firebase SDK.
]
unit_tests.resources = [
unit_tests_dir + 'vertexai-sdk-test-data/mock-responses/**/*.{txt,json}',
unit_tests_dir + 'Resources/**/*',
]
end
end
6 changes: 6 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
- [changed] **Breaking Change**: The `CountTokensError` enum has been removed;
errors occurring in `GenerativeModel.countTokens(...)` are now thrown directly
instead of being wrapped in a `CountTokensError.internalError`. (#13736)
- [changed] **Breaking Change**: The enum `ModelContent.Part` has been replaced
with a protocol named `Part` to avoid future breaking changes with new part
types. The new types `TextPart` and `FunctionCallPart` may be received when
generating content the types `TextPart`; additionally the types
`InlineDataPart`, `FileDataPart` and `FunctionResponsePart` may be provided
as input. (#13767)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FunctionCallingViewModel: ObservableObject {
}

/// Function calls pending processing
private var functionCalls = [FunctionCall]()
private var functionCalls = [FunctionCallPart]()

private var model: GenerativeModel
private var chat: Chat
Expand Down Expand Up @@ -144,26 +144,26 @@ class FunctionCallingViewModel: ObservableObject {

for part in candidate.content.parts {
switch part {
case let .text(text):
case let textPart as TextPart:
// replace pending message with backend response
messages[messages.count - 1].message += text
messages[messages.count - 1].message += textPart.text
messages[messages.count - 1].pending = false
case let .functionCall(functionCall):
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCall)
case .inlineData, .fileData, .functionResponse:
fatalError("Unsupported response content.")
case let functionCallPart as FunctionCallPart:
messages.insert(functionCallPart.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCallPart)
default:
fatalError("Unsupported response part: \(part)")
}
}
}

func processFunctionCalls() async throws -> [FunctionResponse] {
var functionResponses = [FunctionResponse]()
func processFunctionCalls() async throws -> [FunctionResponsePart] {
var functionResponses = [FunctionResponsePart]()
for functionCall in functionCalls {
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
functionResponses.append(FunctionResponse(
functionResponses.append(FunctionResponsePart(
name: "get_exchange_rate",
response: exchangeRates
))
Expand Down Expand Up @@ -208,7 +208,7 @@ class FunctionCallingViewModel: ObservableObject {
}
}

private extension FunctionCall {
private extension FunctionCallPart {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted
Expand All @@ -228,7 +228,7 @@ private extension FunctionCall {
}
}

private extension FunctionResponse {
private extension FunctionResponsePart {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted
Expand All @@ -248,12 +248,8 @@ private extension FunctionResponse {
}
}

private extension [FunctionResponse] {
private extension [FunctionResponsePart] {
func modelContent() -> [ModelContent] {
return self.map { ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse($0)]
)
}
return self.map { ModelContent(role: "function", parts: [$0]) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {

let prompt = "Look at the image(s), and then answer the following question: \(userInput)"

var images = [any ThrowingPartsRepresentable]()
var images = [any PartsRepresentable]()
for item in selectedItems {
if let data = try? await item.loadTransferable(type: Data.self) {
guard let image = UIImage(data: data) else {
Expand Down
43 changes: 12 additions & 31 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
public func sendMessage(_ parts: any PartsRepresentable...) async throws
-> GenerateContentResponse {
return try await sendMessage([ModelContent(parts: parts)])
}
Expand All @@ -45,19 +45,10 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
public func sendMessage(_ content: [ModelContent]) async throws
-> GenerateContentResponse {
// Ensure that the new content has the role set.
let newContent: [ModelContent]
do {
newContent = try content().map(populateContentRole(_:))
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
} else {
throw GenerateContentError.internalError(underlying: underlying)
}
}
let newContent = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down Expand Up @@ -85,7 +76,7 @@ public class Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws
public func sendMessageStream(_ parts: any PartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream([ModelContent(parts: parts)])
}
Expand All @@ -95,24 +86,14 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws
public func sendMessageStream(_ content: [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []

// Ensure that the new content has the role set.
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
let newContent: [ModelContent] = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down Expand Up @@ -146,20 +127,20 @@ public class Chat {
}

private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent {
var parts: [ModelContent.Part] = []
var parts: [any Part] = []
var combinedText = ""
for aggregate in chunks {
// Loop through all the parts, aggregating the text and adding the images.
for part in aggregate.parts {
switch part {
case let .text(str):
combinedText += str
case let textPart as TextPart:
combinedText += textPart.text

case .inlineData, .fileData, .functionCall, .functionResponse:
default:
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
parts.append(.text(combinedText))
parts.append(TextPart(combinedText))
combinedText = ""
}

Expand All @@ -169,7 +150,7 @@ public class Chat {
}

if !combinedText.isEmpty {
parts.append(.text(combinedText))
parts.append(TextPart(combinedText))
}

return ModelContent(role: "model", parts: parts)
Expand Down
65 changes: 0 additions & 65 deletions FirebaseVertexAI/Sources/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@

import Foundation

/// A predicted function call returned from the model.
public struct FunctionCall: Equatable, Sendable {
/// The name of the function to call.
public let name: String

/// The function parameters and values.
public let args: JSONObject

/// Constructs a new function call.
///
/// > Note: A `FunctionCall` is typically received from the model, rather than created manually.
///
/// - Parameters:
/// - name: The name of the function to call.
/// - args: The function parameters and values.
public init(name: String, args: JSONObject) {
self.name = name
self.args = args
}
}

/// Structured representation of a function declaration.
///
/// This `FunctionDeclaration` is a representation of a block of code that can be used as a ``Tool``
Expand Down Expand Up @@ -136,50 +115,8 @@ public struct ToolConfig {
}
}

/// Result output from a ``FunctionCall``.
///
/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object
/// containing any output from the function is used as context to the model. This should contain the
/// result of a ``FunctionCall`` made based on model prediction.
public struct FunctionResponse: Equatable, Sendable {
/// The name of the function that was called.
let name: String

/// The function's response.
let response: JSONObject

/// Constructs a new `FunctionResponse`.
///
/// - Parameters:
/// - name: The name of the function that was called.
/// - response: The function's response.
public init(name: String, response: JSONObject) {
self.name = name
self.response = response
}
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
case args
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
name = try container.decode(String.self, forKey: .name)
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
self.args = args
} else {
args = JSONObject()
}
}
}

extension FunctionCall: Encodable {}

extension FunctionDeclaration: Encodable {
enum CodingKeys: String, CodingKey {
case name
Expand All @@ -202,5 +139,3 @@ extension FunctionCallingConfig: Encodable {}
extension FunctionCallingConfig.Mode: Encodable {}

extension ToolConfig: Encodable {}

extension FunctionResponse: Codable {}
14 changes: 9 additions & 5 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ public struct GenerateContentResponse: Sendable {
return nil
}
let textValues: [String] = candidate.content.parts.compactMap { part in
guard case let .text(text) = part else {
switch part {
case let textPart as TextPart:
return textPart.text
default:
return nil
}
return text
}
guard textValues.count > 0 else {
VertexLog.error(
Expand All @@ -65,15 +67,17 @@ public struct GenerateContentResponse: Sendable {
}

/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
public var functionCalls: [FunctionCall] {
public var functionCalls: [FunctionCallPart] {
guard let candidate = candidates.first else {
return []
}
return candidate.content.parts.compactMap { part in
guard case let .functionCall(functionCall) = part else {
switch part {
case let functionCallPart as FunctionCallPart:
return functionCallPart
default:
return nil
}
return functionCall
}
}

Expand Down
Loading
Loading