Skip to content

Commit

Permalink
Merge pull request #263 from MacPaw/feat/custom-headers
Browse files Browse the repository at this point in the history
Add customHeaders to configuration
  • Loading branch information
nezhyborets authored Feb 17, 2025
2 parents 22e28b0 + 4b6cf46 commit 762d8ea
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 40 deletions.
5 changes: 2 additions & 3 deletions Sources/OpenAI/OpenAI+OpenAIAsync.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,8 @@ extension OpenAI: OpenAIAsync {
}

func performRequestAsync<ResultType: Codable>(request: any URLRequestBuildable) async throws -> ResultType {
let urlRequest = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let urlRequest = try request.build(configuration: configuration)

if #available(iOS 15.0, macOS 12.0, tvOS 15.0, watchOS 8.0, *) {
let (data, _) = try await session.data(for: urlRequest, delegate: nil)
let decoder = JSONDecoder()
Expand Down
5 changes: 2 additions & 3 deletions Sources/OpenAI/OpenAI+OpenAICombine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,8 @@ extension OpenAI: OpenAICombine {

func performRequestCombine<ResultType: Codable>(request: any URLRequestBuildable) -> AnyPublisher<ResultType, Error> {
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let request = try request.build(configuration: configuration)

return session
.dataTaskPublisher(for: request)
.tryMap { (data, response) in
Expand Down
24 changes: 14 additions & 10 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,24 @@ final public class OpenAI {
/// Default request timeout
public let timeoutInterval: TimeInterval

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", port: Int = 443, scheme: String = "https", basePath: String = "", timeoutInterval: TimeInterval = 60.0) {
/// Headers to set on a request.
///
/// Value from this dict would set on any request sent by SDK.
///
/// These values are applied after all the default headers are set, so if names collide, values from this dict would override default values.
///
/// Currently SDK sets such fields: Authorization, Content-Type, OpenAI-Organization.
public let customHeaders: [String: String]

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", port: Int = 443, scheme: String = "https", basePath: String = "", timeoutInterval: TimeInterval = 60.0, customHeaders: [String: String] = [:]) {
self.token = token
self.organizationIdentifier = organizationIdentifier
self.host = host
self.port = port
self.scheme = scheme
self.basePath = basePath
self.timeoutInterval = timeoutInterval
self.customHeaders = customHeaders
}
}

Expand Down Expand Up @@ -212,9 +222,7 @@ extension OpenAI {
func performRequest<ResultType: Codable>(request: any URLRequestBuildable, completion: @escaping (Result<ResultType, Error>) -> Void) -> CancellableRequest {
var cancellable = cancellablesFactory.makeTaskCanceller()
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let request = try request.build(configuration: configuration)
let task = makeDataTask(forRequest: request, completion: completion)
cancellable.task = task
task.resume()
Expand All @@ -227,9 +235,7 @@ extension OpenAI {
func performStreamingRequest<ResultType: Codable>(request: any URLRequestBuildable, onResult: @escaping (Result<ResultType, Error>) -> Void, completion: ((Error?) -> Void)?) -> CancellableRequest {
var cancellable = cancellablesFactory.makeSessionCanceller()
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let request = try request.build(configuration: configuration)
let session = StreamingSession<ResultType>(urlRequest: request)
cancellable.session = session
session.onReceiveContent = {_, object in
Expand All @@ -253,9 +259,7 @@ extension OpenAI {
func performSpeechRequest(request: any URLRequestBuildable, completion: @escaping (Result<AudioSpeechResult, Error>) -> Void) -> CancellableRequest {
var cancellable = cancellablesFactory.makeTaskCanceller()
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let request = try request.build(configuration: configuration)

let task = session.dataTask(with: request) { data, _, error in
if let error = error {
Expand Down
14 changes: 8 additions & 6 deletions Sources/OpenAI/Private/AssistantsRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ struct AssistantsRequest<ResultType>: URLRequestBuildable {
self.method = method
}

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
let customHeaders = ["OpenAI-Beta": "assistants=v2"]
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval, customHeaders: [String: String]) throws -> URLRequest {
let customHeaders = customHeaders
.merging(["OpenAI-Beta": "assistants=v2"], uniquingKeysWith: { first, _ in first })

switch body {
case .json(let codable):
Expand All @@ -49,18 +50,19 @@ struct AssistantsRequest<ResultType>: URLRequestBuildable {
return try jsonRequest.build(
token: token,
organizationIdentifier: organizationIdentifier,
timeoutInterval: timeoutInterval
timeoutInterval: timeoutInterval,
customHeaders: customHeaders
)
case .multipartFormData(let encodable):
let request = MultipartFormDataRequest<ResultType>(
body: encodable,
url: urlBuilder.buildURL(),
customHeaders: customHeaders
url: urlBuilder.buildURL()
)
return try request.build(
token: token,
organizationIdentifier: organizationIdentifier,
timeoutInterval: timeoutInterval
timeoutInterval: timeoutInterval,
customHeaders: customHeaders
)
}
}
Expand Down
17 changes: 7 additions & 10 deletions Sources/OpenAI/Private/JSONRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,37 @@ final class JSONRequest<ResultType> {
let body: Codable?
let url: URL
let method: String
let customHeaders: [String: String]

init(body: Codable? = nil, url: URL, method: String = "POST", customHeaders: [String: String] = [:]) {
self.body = body
self.url = url
self.method = method
self.customHeaders = customHeaders
}
}

extension JSONRequest: URLRequestBuildable {

func build(
token: String,
organizationIdentifier: String?,
timeoutInterval: TimeInterval
timeoutInterval: TimeInterval,
customHeaders: [String: String]
) throws -> URLRequest {
var request = URLRequest(url: url, timeoutInterval: timeoutInterval)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")

for (headerField, value) in customHeaders {
request.setValue(value, forHTTPHeaderField: headerField)
}

if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}

for (headerField, value) in customHeaders {
request.setValue(value, forHTTPHeaderField: headerField)
}

request.httpMethod = method
if let body = body {
request.httpBody = try JSONEncoder().encode(body)
}
return request
}
}


6 changes: 2 additions & 4 deletions Sources/OpenAI/Private/MultipartFormDataRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@ final class MultipartFormDataRequest<ResultType> {
let body: MultipartFormDataBodyEncodable
let url: URL
let method: String
let customHeaders: [String: String]

init(body: MultipartFormDataBodyEncodable, url: URL, method: String = "POST", customHeaders: [String: String] = [:]) {
init(body: MultipartFormDataBodyEncodable, url: URL, method: String = "POST") {
self.body = body
self.url = url
self.method = method
self.customHeaders = customHeaders
}
}

extension MultipartFormDataRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval, customHeaders: [String: String]) throws -> URLRequest {
var request = URLRequest(url: url)
let boundary: String = UUID().uuidString
request.timeoutInterval = timeoutInterval
Expand Down
20 changes: 19 additions & 1 deletion Sources/OpenAI/Private/URLRequestBuildable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,23 @@ import FoundationNetworking
#endif

protocol URLRequestBuildable {
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest
func build(
token: String,
organizationIdentifier: String?,
timeoutInterval: TimeInterval,
customHeaders: [String: String]
) throws -> URLRequest
}

extension URLRequestBuildable {
func build(
configuration: OpenAI.Configuration
) throws -> URLRequest {
try build(
token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval,
customHeaders: configuration.customHeaders
)
}
}
27 changes: 24 additions & 3 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class OpenAITests: XCTestCase {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
let completionQuery = ChatQuery(messages: [.user(.init(content: .string("how are you?")))], model: .gpt3_5Turbo_16k)
let jsonRequest = JSONRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let urlRequest = try jsonRequest.build(configuration: configuration)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Content-Type"), "application/json")
Expand All @@ -404,14 +404,35 @@ class OpenAITests: XCTestCase {
func testMultipartRequestCreation() throws {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
let completionQuery = AudioTranslationQuery(file: Data(), fileType: .mp3, model: .whisper_1)
let jsonRequest = MultipartFormDataRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let multipartFormDataRequest = MultipartFormDataRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try multipartFormDataRequest.build(configuration: configuration)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), configuration.organizationIdentifier)
XCTAssertEqual(urlRequest.timeoutInterval, configuration.timeoutInterval)
}

func testAssistantRequestCreationSetsHeader() throws {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
let jsonRequest = AssistantsRequest<AssistantResult>.jsonRequest(
urlBuilder: DefaultURLBuilder(configuration: configuration, path: .Assistants.assistants.stringValue),
body: assistantsQuery()
)
let urlRequest = try jsonRequest.build(configuration: configuration)
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Beta"), "assistants=v2")
}

func testCustomHeadersOverrideDefault() throws {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14, customHeaders: ["Authorization": "auth", "Content-Type": "ctype", "OpenAI-Organization": "org"])
let completionQuery = ChatQuery(messages: [.user(.init(content: .string("how are you?")))], model: .gpt3_5Turbo_16k)
let jsonRequest = JSONRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(configuration: configuration)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "auth")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Content-Type"), "ctype")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), "org")
}

func testDefaultHostURLBuilt() {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
Expand Down

0 comments on commit 762d8ea

Please sign in to comment.