Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GPT-4 Vision #17

Merged
merged 7 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions Sources/CleverBird/chat/ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public struct ChatMessage: Codable, Identifiable {
public let role: Role

/// The contents of the message. `content` is required for all messages except assistant messages with function calls.
public let content: String?
public let content: Content?

/// The name and arguments of a function that should be called, as generated by the model.
public let functionCall: FunctionCall?
Expand All @@ -36,14 +36,21 @@ public struct ChatMessage: Codable, Identifiable {
content: String? = nil,
id: String? = nil,
functionCall: FunctionCall? = nil) throws {
try self.init(role: role, media: content != nil ? .text(content!) : nil, id: id, functionCall: functionCall)
}

public init(role: Role,
media: ChatMessage.Content?,
id: String? = nil,
functionCall: FunctionCall? = nil) throws {

// Validation: Content is required for all messages except assistant messages with function calls.
if content == nil && !(role == .assistant && functionCall != nil) {
if media == nil && !(role == .assistant && functionCall != nil) {
throw CleverBirdError.invalidMessageContent
}

self.role = role
self.content = content
self.content = media
self.name = functionCall?.name
if role == .function {
// If the role is "function" I need to set functionCall to nil, otherwise this will
Expand All @@ -58,7 +65,9 @@ public struct ChatMessage: Codable, Identifiable {
} else {
var hasher = Hasher()
hasher.combine(self.role)
hasher.combine(self.content ?? "")
if let content {
hasher.combine(content)
}
let hashValue = abs(hasher.finalize())
let timestamp = Int(Date.now.timeIntervalSince1970*10000)

Expand All @@ -69,7 +78,7 @@ public struct ChatMessage: Codable, Identifiable {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.role = try container.decode(Role.self, forKey: .role)
self.content = try container.decodeIfPresent(String.self, forKey: .content)
self.content = try container.decodeIfPresent(Content.self, forKey: .content)
self.functionCall = try container.decodeIfPresent(FunctionCall.self, forKey: .functionCall)
self.name = try container.decodeIfPresent(String.self, forKey: .name)
self.id = "pending"
Expand All @@ -92,3 +101,59 @@ extension ChatMessage: Equatable {
&& lhs.content == rhs.content
}
}

extension ChatMessage {

public enum Content: Codable, Equatable, CustomStringConvertible, Hashable {

case text(String)
case media([MessageContent])

public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
if let textContent = try? container.decode(String.self) {
self = .text(textContent)
} else if let chatContents = try? container.decode([MessageContent].self) {
self = .media(chatContents)
} else {
throw DecodingError.typeMismatch(MessageContent.self, DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Unsupported type for Content"))
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .text(let text):
try container.encode(text)
case .media(let contents):
try container.encode(contents)
}
}

public static func == (lhs: Content, rhs: Content) -> Bool {
switch (lhs, rhs) {
case (.text(let leftText), .text(let rightText)):
return leftText == rightText
case (.media(let leftMedia), .media(let rightMedia)):
return leftMedia == rightMedia
default:
return false
}
}

public var description: String {
switch self {
case .media(let messageContents):
for messageContent in messageContents {
if case .text(let textValue) = messageContent {
return textValue
}
}
return ""
case .text(let text):
return text
}
}
}
}

27 changes: 26 additions & 1 deletion Sources/CleverBird/chat/ChatThread+tokenCount.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,32 @@ extension ChatThread {
let roleTokens = try tokenEncoder.encode(text: message.role.rawValue).count
let contentTokens: Int
if let content = message.content {
contentTokens = try tokenEncoder.encode(text: content).count
switch content {
case .text(let text):
contentTokens = try tokenEncoder.encode(text: text).count
case .media(let media):
var count = 0
for medium in media {
switch medium {
case .text(let text):
count += try tokenEncoder.encode(text: text).count
case .imageUrl(let url):
// See https://platform.openai.com/docs/guides/vision/calculating-costs
switch url.detail {
// TODO: calculate real values for auto and high
case .auto:
count += 1105
case .high:
count += 1105
case .low:
count += 85
case .none:
count += 1105
}
}
}
contentTokens = count
}
} else if let functionCall = message.functionCall {
let jsonEncoder = JSONEncoder()
let jsonData = try jsonEncoder.encode(functionCall)
Expand Down
10 changes: 10 additions & 0 deletions Sources/CleverBird/chat/ChatThread.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ public class ChatThread: Codable {
}
return self
}

@discardableResult
public func addUserMessage(_ media: [MessageContent]) -> Self {
do {
try addMessage(ChatMessage(role: .user, media: .media(media)))
} catch {
print(error.localizedDescription)
}
return self
}

@discardableResult
public func addAssistantMessage(_ content: String) -> Self {
Expand Down
90 changes: 90 additions & 0 deletions Sources/CleverBird/chat/MessageContent.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//
// ChatContent.swift
//
//
// Created by Ronald Mannak on 4/12/24.
//

import Foundation

public enum MessageContent: Hashable {
case text(String)
case imageUrl(URLDetail)
}

extension MessageContent {
public enum ContentType: String, Codable, Hashable {
case text
case imageUrl = "image_url"
}

public struct URLDetail: Codable, Equatable, Hashable {

public enum Detail: String, Codable {
case low, high, auto
}

let url: String
let detail: Detail?

public init(url: String, detail: Detail? = nil) {
self.url = url
self.detail = detail
}

public init(url: URL, detail: Detail? = nil) {
self.init(url: url.absoluteString, detail: detail)
}

public init(imageData: Data, detail: Detail? = nil) {
let base64 = imageData.base64EncodedString()
self.init(url: "data:image/jpeg;base64,\(base64)", detail: detail)
}
}
}

extension MessageContent: Codable {

private enum CodingKeys: String, CodingKey {
case type, text, imageUrl
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let type = try container.decode(ContentType.self, forKey: .type)

switch type {
case .text:
let text = try container.decode(String.self, forKey: .text)
self = .text(text)
case .imageUrl:
let imageUrl = try container.decode(URLDetail.self, forKey: .imageUrl)
self = .imageUrl(imageUrl)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .text(let text):
try container.encode(ContentType.text.rawValue, forKey: .type)
try container.encode(text, forKey: .text)
case .imageUrl(let urlDetail):
try container.encode(ContentType.imageUrl.rawValue, forKey: .type)
try container.encode(urlDetail, forKey: .imageUrl)
}
}
}

extension MessageContent: Equatable {
public static func == (lhs: MessageContent, rhs: MessageContent) -> Bool {
switch (lhs, rhs) {
case (.text(let lhsText), .text(let rhsText)):
return lhsText == rhsText
case (.imageUrl(let lhsUrlDetail), .imageUrl(let rhsUrlDetail)):
return lhsUrlDetail == rhsUrlDetail
default:
return false
}
}
}
24 changes: 24 additions & 0 deletions Tests/CleverBirdTests/MessageContentTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// MessageContentTests.swift
//
//
// Created by Ronald Mannak on 4/12/24.
//

import Foundation
import XCTest
@testable import CleverBird

class MessageContentTests: XCTestCase {

func testURL() {
let content = MessageContent.URLDetail(url: URL(string: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")!)
XCTAssertEqual(content.url, "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")
}

func testBase64() {
let data = "Hello, world".data(using: .utf8)!
let content = MessageContent.URLDetail(imageData: data)
XCTAssertEqual(content.url, "data:image/jpeg;base64,SGVsbG8sIHdvcmxk")
}
}
Loading
Loading