Skip to content

Commit

Permalink
Merge pull request #266 from MacPaw/bug/261-streaming-does-not-honor-…
Browse files Browse the repository at this point in the history
…http-status-codes

Bug: Fix error parsing in StreamInterpreter
  • Loading branch information
nezhyborets authored Feb 17, 2025
2 parents 762d8ea + 3e64d91 commit 190c355
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 52 deletions.
134 changes: 82 additions & 52 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Combine
import OpenAI
import SwiftUI

@MainActor
public final class ChatStore: ObservableObject {
public var openAIClient: OpenAIProtocol
let idProvider: () -> String
Expand Down Expand Up @@ -61,8 +62,7 @@ public final class ChatStore: ObservableObject {
func deleteConversation(_ conversationId: Conversation.ID) {
conversations.removeAll(where: { $0.id == conversationId })
}

@MainActor

func sendMessage(
_ message: Message,
conversationId: Conversation.ID,
Expand All @@ -78,7 +78,8 @@ public final class ChatStore: ObservableObject {

await completeChat(
conversationId: conversationId,
model: model
model: model,
stream: true
)
// For assistant case we send chats to thread and then poll, polling will receive sent chat + new assistant messages.
case .assistant:
Expand Down Expand Up @@ -139,11 +140,11 @@ public final class ChatStore: ObservableObject {
}
}
}

@MainActor

func completeChat(
conversationId: Conversation.ID,
model: Model
model: Model,
stream: Bool
) async {
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
return
Expand All @@ -169,59 +170,88 @@ public final class ChatStore: ObservableObject {
))

let functions = [weatherFunction]

let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
query: ChatQuery(
messages: conversation.messages.map { message in
ChatQuery.ChatCompletionMessageParam(role: message.role, content: message.content)!
}, model: model,
tools: functions
)

let chatQuery = ChatQuery(
messages: conversation.messages.map { message in
ChatQuery.ChatCompletionMessageParam(role: message.role, content: message.content)!
}, model: model,
tools: functions
)

var functionCalls = [(name: String, argument: String?)]()
for try await partialChatResult in chatsStream {
for choice in partialChatResult.choices {
let existingMessages = conversations[conversationIndex].messages
// Function calls are also streamed, so we need to accumulate.
choice.delta.toolCalls?.forEach { toolCallDelta in
if let functionCallDelta = toolCallDelta.function {
if let nameDelta = functionCallDelta.name {
functionCalls.append((nameDelta, functionCallDelta.arguments))
}

if stream {
try await completeConversationStreaming(
conversationIndex: conversationIndex,
model: model,
query: chatQuery
)
} else {
try await completeConversation(conversationIndex: conversationIndex, model: model, query: chatQuery)
}
} catch {
conversationErrors[conversationId] = error
}
}

private func completeConversation(conversationIndex: Int, model: Model, query: ChatQuery) async throws {
let chatResult: ChatResult = try await openAIClient.chats(query: query)
chatResult.choices
.map {
Message(
id: chatResult.id,
role: $0.message.role,
content: $0.message.content?.string ?? "",
createdAt: Date(timeIntervalSince1970: TimeInterval(chatResult.created))
)
}.forEach { message in
conversations[conversationIndex].messages.append(message)
}
}

private func completeConversationStreaming(conversationIndex: Int, model: Model, query: ChatQuery) async throws {
let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
query: query
)

var functionCalls = [(name: String, argument: String?)]()
for try await partialChatResult in chatsStream {
for choice in partialChatResult.choices {
let existingMessages = conversations[conversationIndex].messages
// Function calls are also streamed, so we need to accumulate.
choice.delta.toolCalls?.forEach { toolCallDelta in
if let functionCallDelta = toolCallDelta.function {
if let nameDelta = functionCallDelta.name {
functionCalls.append((nameDelta, functionCallDelta.arguments))
}
}
var messageText = choice.delta.content ?? ""
if let finishReason = choice.finishReason,
finishReason == .toolCalls
{
functionCalls.forEach { (name: String, argument: String?) in
messageText += "Function call: name=\(name) arguments=\(argument ?? "")\n"
}
}
var messageText = choice.delta.content ?? ""
if let finishReason = choice.finishReason,
finishReason == .toolCalls
{
functionCalls.forEach { (name: String, argument: String?) in
messageText += "Function call: name=\(name) arguments=\(argument ?? "")\n"
}
let message = Message(
id: partialChatResult.id,
role: choice.delta.role ?? .assistant,
content: messageText,
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
}
let message = Message(
id: partialChatResult.id,
role: choice.delta.role ?? .assistant,
content: messageText,
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
)
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
// Meld into previous message
let previousMessage = existingMessages[existingMessageIndex]
let combinedMessage = Message(
id: message.id, // id stays the same for different deltas
role: message.role,
content: previousMessage.content + message.content,
createdAt: message.createdAt
)
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
// Meld into previous message
let previousMessage = existingMessages[existingMessageIndex]
let combinedMessage = Message(
id: message.id, // id stays the same for different deltas
role: message.role,
content: previousMessage.content + message.content,
createdAt: message.createdAt
)
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
} else {
conversations[conversationIndex].messages.append(message)
}
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
} else {
conversations[conversationIndex].messages.append(message)
}
}
} catch {
conversationErrors[conversationId] = error
}
}

Expand Down
5 changes: 5 additions & 0 deletions Sources/OpenAI/Private/StreamInterpreter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class StreamInterpreter<ResultType: Codable> {
var onEventDispatched: ((ResultType) -> Void)?

func processData(_ data: Data) throws {
let decoder = JSONDecoder()
if let decoded = try? decoder.decode(APIErrorResponse.self, from: data) {
throw decoded
}

guard let stringContent = String(data: data, encoding: .utf8) else {
throw StreamingError.unknownContent
}
Expand Down
13 changes: 13 additions & 0 deletions Tests/OpenAITests/StreamInterpreterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ struct StreamInterpreterTests {
#expect(chatStreamResults.count == 1)
}

@Test func parseApiError() throws {
do {
try interpreter.processData(chatCompletionError())
} catch {
#expect(error is APIErrorResponse)
}
}

// Chunk with 3 objects. I captured it from a real response. It's a very short response that contains just "Hi"
private func chatCompletionChunk() -> Data {
"data: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}]}\n\n".data(using: .utf8)!
Expand All @@ -44,4 +52,9 @@ struct StreamInterpreterTests {
private func chatCompletionChunkTermination() -> Data {
"data: [DONE]\n\n".data(using: .utf8)!
}

// Copied from an actual reponse that was an input to inreptreter
private func chatCompletionError() -> Data {
"{\n \"error\": {\n \"message\": \"The model `o3-mini` does not exist or you do not have access to it.\",\n \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": \"model_not_found\"\n }\n}\n".data(using: .utf8)!
}
}

0 comments on commit 190c355

Please sign in to comment.