Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat completion support for Mistral.ai #73

Merged
merged 3 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1953,6 +1953,79 @@ See `FalFluxLoRAInputSchema.swift` for the full range of inference controls

***

## Mistral

### How to create a chat completion with Mistral

Use `api.mistral.ai` as the proxy domain when creating your AIProxy service in the developer dashboard.

import AIProxy

let mistralService = AIProxy.mistralService(
partialKey: "partial-key-from-your-developer-dashboard",
serviceURL: "service-url-from-your-developer-dashboard"
)

do {
let response = try await mistralService.chatCompletionRequest(body: .init(
messages: [.user(content: "Hello world")],
model: "mistral-small-latest"
))
print(response.choices.first?.message.content ?? "")
if let usage = response.usage {
print(
"""
Used:
\(usage.promptTokens ?? 0) prompt tokens
\(usage.completionTokens ?? 0) completion tokens
\(usage.totalTokens ?? 0) total tokens
"""
)
}
} catch AIProxyError.unsuccessfulRequest(let statusCode, let responseBody) {
print("Received non-200 status code: \(statusCode) with response body: \(responseBody)")
} catch {
print("Could not create mistral chat completion: \(error.localizedDescription)")
}


### How to create a streaming chat completion with Perplexity

Use `api.mistral.ai` as the proxy domain when creating your AIProxy service in the developer dashboard.

import AIProxy

let mistralService = AIProxy.mistralService(
partialKey: "partial-key-from-your-developer-dashboard",
serviceURL: "service-url-from-your-developer-dashboard"
)

do {
let stream = try await mistralService.streamingChatCompletionRequest(body: .init(
messages: [.user(content: "Hello world")],
model: "mistral-small-latest"
))
for try await chunk in stream {
print(chunk.choices.first?.delta.content ?? "")
if let usage = chunk.usage {
print(
"""
Used:
\(usage.promptTokens ?? 0) prompt tokens
\(usage.completionTokens ?? 0) completion tokens
\(usage.totalTokens ?? 0) total tokens
"""
)
}
}
} catch AIProxyError.unsuccessfulRequest(let statusCode, let responseBody) {
print("Received non-200 status code: \(statusCode) with response body: \(responseBody)")
} catch {
print("Could not create mistral streaming chat completion: \(error.localizedDescription)")
}

***


## OpenMeteo

Expand Down
30 changes: 30 additions & 0 deletions Sources/AIProxy/AIProxy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,36 @@ public struct AIProxy {
)
}

/// AIProxy's Mistral service
///
/// - Parameters:
/// - partialKey: Your partial key is displayed in the AIProxy dashboard when you submit your Mistral key.
/// AIProxy takes your Mistral key, encrypts it, and stores part of the result on our servers. The part that you include
/// here is the other part. Both pieces are needed to decrypt your key and fulfill the request to Mistral.
///
/// - serviceURL: The service URL is displayed in the AIProxy dashboard when you submit your Mistral key.
///
/// - clientID: An optional clientID to attribute requests to specific users or devices. It is OK to leave this blank for
/// most applications. You would set this if you already have an analytics system, and you'd like to annotate AIProxy
/// requests with IDs that are known to other parts of your system.
///
/// If you do not supply your own clientID, the internals of this lib will generate UUIDs for you. The default UUIDs are
/// persistent on macOS and can be accurately used to attribute all requests to the same device. The default UUIDs
/// on iOS are pesistent until the end user chooses to rotate their vendor identification number.
///
/// - Returns: An instance of MistralService configured and ready to make requests
public static func mistralService(
partialKey: String,
serviceURL: String,
clientID: String? = nil
) -> MistralService {
return MistralService(
partialKey: partialKey,
serviceURL: serviceURL,
clientID: clientID
)
}


#if canImport(AppKit)
public static func encodeImageAsJpeg(
Expand Down
207 changes: 207 additions & 0 deletions Sources/AIProxy/Mistral/MistralChatCompletionRequestBody.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
//
// MistralChatCompletionRequestBody.swift
//
//
// Created by Lou Zell on 11/24/24.
//

import Foundation

/// Docstrings from: https://docs.mistral.ai/api/#tag/chat
public struct MistralChatCompletionRequestBody: Encodable {

// Required

/// The prompt(s) to generate completions for
public let messages: [Message]

/// ID of the model to use. You can use the List Available Models API to see all of your
/// available models, or see our Model overview for model descriptions. Temperature (number)
/// or Temperature (null) (Temperature)
///
/// List Available Models API: https://docs.mistral.ai/api/#tag/models/operation/list_models_v1_models_get
/// Model overview: https://docs.mistral.ai/models
public let model: String

// Optional

/// frequency_penalty penalizes the repetition of words based on their frequency in the
/// generated text. A higher frequency penalty discourages the model from repeating words that
/// have already appeared frequently in the output, promoting diversity and reducing
/// repetition.
///
/// Acceptable range: [-2..2]
/// Default: 0
public let frequencyPenalty: Double?

/// The maximum number of tokens to generate in the completion. The token count of your prompt
/// plus max_tokens cannot exceed the model's context length.
public let maxTokens: Int?

/// Number of completions to return for each request, input tokens are only billed once.
public let n: Int?

/// presence_penalty determines how much the model penalizes the repetition of words or
/// phrases. A higher presence penalty encourages the model to use a wider variety of words and
/// phrases, making the output more diverse and creative.
///
/// Acceptable range: [-2..2]
/// Default: 0
public let presencePenalty: Double?

/// An object specifying the format that the model must output. Setting to `.jsonObject` enables
/// JSON mode, which guarantees the message the model generates is in JSON. When using JSON
/// mode you MUST also instruct the model to produce JSON yourself with a system or a user
/// message.
public let responseFormat: ResponseFormat?

/// Whether to inject a safety prompt before all conversations.
/// Default: false
public let safePrompt: Bool?

/// The seed to use for random sampling. If set, different calls will generate deterministic results.
public let seed: Int?

/// Stop generation if one of these tokens is detected
public let stop: [String]?

/// Whether to stream back partial progress.
/// Default: false
public var stream: Bool?

/// What sampling temperature to use, we recommend between 0.0 and 0.7. Higher values like 0.7
/// will make the output more random, while lower values like 0.2 will make it more focused and
/// deterministic. We generally recommend altering this or `topP` but not both. The default
/// value varies depending on the model you are targeting. Call the /models endpoint to
/// retrieve the appropriate value.
///
/// Acceptable range: [0..1]
/// Default: 1
public let temperature: Double?

/// Tool calls are not implemented
/// public let tools: [Tool]?

/// Tool calls are not implemented
/// public let toolChoice: ToolChoice?

/// Nucleus sampling, where the model considers the results of the tokens with `topP`
/// probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are
/// considered. We generally recommend altering this or `temperature` but not both.
///
/// Acceptable range: [0..1]
/// Default: 1
public let topP: Double?

private enum CodingKeys: String, CodingKey {
case messages
case model

case frequencyPenalty = "frequency_penalty"
case maxTokens = "max_tokens"
case n
case presencePenalty = "presence_penalty"
case responseFormat = "response_format"
case safePrompt = "safe_prompt"
case seed
case stop
case stream
case temperature
// case tools
// case toolChoice = "tool_choice"
case topP = "top_p"
}

// This memberwise initializer is autogenerated.
// To regenerate, use `cmd-shift-a` > Generate Memberwise Initializer
// To format, place the cursor in the initializer's parameter list and use `ctrl-m`
public init(
messages: [MistralChatCompletionRequestBody.Message],
model: String,
frequencyPenalty: Double? = nil,
maxTokens: Int? = nil,
n: Int? = nil,
presencePenalty: Double? = nil,
responseFormat: MistralChatCompletionRequestBody.ResponseFormat? = nil,
safePrompt: Bool? = nil,
seed: Int? = nil,
stop: [String]? = nil,
stream: Bool? = nil,
temperature: Double? = nil,
topP: Double? = nil
) {
self.messages = messages
self.model = model
self.frequencyPenalty = frequencyPenalty
self.maxTokens = maxTokens
self.n = n
self.presencePenalty = presencePenalty
self.responseFormat = responseFormat
self.safePrompt = safePrompt
self.seed = seed
self.stop = stop
self.stream = stream
self.temperature = temperature
self.topP = topP
}
}


// MARK: - RequestBody.Message
extension MistralChatCompletionRequestBody {
public enum Message: Encodable {
case assistant(content: String)
case system(content: String)
case user(content: String)

private enum RootKey: String, CodingKey {
case content
case role
}

public func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: RootKey.self)
switch self {
case .assistant(let content):
try container.encode(content, forKey: .content)
try container.encode("assistant", forKey: .role)
case .system(let content):
try container.encode(content, forKey: .content)
try container.encode("system", forKey: .role)
case .user(let content):
try container.encode(content, forKey: .content)
try container.encode("user", forKey: .role)
}
}
}
}

extension MistralChatCompletionRequestBody {
/// An object specifying the format that the model must output.
/// Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.
/// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.
public enum ResponseFormat: Encodable {

/// Enables JSON mode, which ensures the message the model generates is valid JSON.
/// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a
/// system or user message.
case jsonObject

/// Instructs the model to produce text only.
case text

private enum RootKey: String, CodingKey {
case type
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: RootKey.self)
switch self {
case .jsonObject:
try container.encode("json_object", forKey: .type)
case .text:
try container.encode("text", forKey: .type)
}
}
}
}
64 changes: 64 additions & 0 deletions Sources/AIProxy/Mistral/MistralChatCompletionResponseBody.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//
// GrogChatCompletionResponseBody.swift
//
//
// Created by Lou Zell on 9/30/24.
//

import Foundation

/// Docstrings from: https://platform.openai.com/docs/api-reference/chat/object
public struct MistralChatCompletionResponseBody: Decodable {
/// A list of chat completion choices.
/// Can be more than one if `n` on `MistralChatCompletionRequestBody` is greater than 1.
public let choices: [Choice]

/// The Unix timestamp (in seconds) of when the chat completion was created.
public let created: Int

/// The model used for the chat completion.
public let model: String

/// Usage statistics for the completion request.
public let usage: MistralChatUsage?
}

// MARK: - ResponseBody.Choice
extension MistralChatCompletionResponseBody {
public struct Choice: Decodable {
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a
/// natural stop point or a provided stop sequence, `length` if the maximum number of
/// tokens specified in the request was reached, `content_filter` if content was omitted
/// due to a flag from our content filters, `tool_calls` if the model called a tool, or
/// `function_call` (deprecated) if the model called a function.
public let finishReason: String?

/// A chat completion message generated by the model.
public let message: Message

private enum CodingKeys: String, CodingKey {
case finishReason = "finish_reason"
case message
}
}
}

// MARK: - ResponseBody.Choice.Message
extension MistralChatCompletionResponseBody.Choice {
public struct Message: Decodable {
/// The contents of the message.
public let content: String

/// The role of the author of this message.
public let role: String

/// The tool calls generated by the model, such as function calls.
// public let toolCalls: [ToolCall]?

private enum CodingKeys: String, CodingKey {
case content
case role
// case toolCalls = "tool_calls"
}
}
}
Loading