Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TogetherAI Completion API #43

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"originHash" : "864ef9201dffd6ebf57da3ab4413cc1001edb70cd0c6d264b91e19b3967fb7ba",
"originHash" : "43bc324f7aba4797f38e1f72595973383d186932faefc9a4b3698c64d5b9cb44",
"pins" : [
{
"identity" : "corepersistence",
"kind" : "remoteSourceControl",
"location" : "https://github.com/vmanot/CorePersistence.git",
"state" : {
"branch" : "main",
"revision" : "38fd5271fa906a2d8395e4b42724142886a3c763"
"revision" : "cfbee4e123a18cb893613cdd536391bb7dec2203"
}
},
{
Expand All @@ -16,7 +16,7 @@
"location" : "https://github.com/vmanot/Merge.git",
"state" : {
"branch" : "master",
"revision" : "e8bc37c8dc203cab481efedd71237c151882c007"
"revision" : "925ca4baa33f8462d0ecd0757e7efabdac0a27ea"
}
},
{
Expand All @@ -34,7 +34,7 @@
"location" : "https://github.com/vmanot/Swallow.git",
"state" : {
"branch" : "master",
"revision" : "6227a1114e341daf54e90df61e173599b187a9b1"
"revision" : "85d690b23077728eff23e9fc0297e724995e9d89"
}
},
{
Expand All @@ -51,8 +51,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-syntax.git",
"state" : {
"revision" : "2bc86522d115234d1f588efe2bcb4ce4be8f8b82",
"version" : "510.0.3"
"revision" : "0687f71944021d616d34d922343dcef086855920",
"version" : "600.0.1"
}
},
{
Expand All @@ -70,7 +70,7 @@
"location" : "https://github.com/SwiftUIX/SwiftUIX.git",
"state" : {
"branch" : "master",
"revision" : "836fc284a9bb07fc9ab6d2dce6ebd0e32aabde26"
"revision" : "a68663989c8aaae013c4104c6a4aa2f35afe1000"
}
}
],
Expand Down
11 changes: 11 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ let package = Package(
swiftSettings: [
.enableExperimentalFeature("AccessLevelOnImport")
]
),
.testTarget(
name: "TogetherAITests",
dependencies: [
"AI",
"Swallow"
],
path: "Tests/TogetherAI",
swiftSettings: [
.enableExperimentalFeature("AccessLevelOnImport")
]
)
]
)
94 changes: 92 additions & 2 deletions Sources/TogetherAI/Intramodular/TogetherAI.APISpecification.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ extension TogetherAI {
@POST
@Path("embeddings")
public var createEmbeddings = Endpoint<RequestBodies.CreateEmbedding, TogetherAI.Embeddings, Void>()

@POST
@Path("completions")
public var createCompletion = Endpoint<RequestBodies.CreateCompletion, TogetherAI.Completion, Void>()
}
}

Expand Down Expand Up @@ -122,7 +126,93 @@ extension TogetherAI.APISpecification {

extension TogetherAI.APISpecification.RequestBodies {
public struct CreateEmbedding: Codable, Hashable {
public let model: TogetherAI.Model
public let input: String
public let model: TogetherAI.Model.Embedding
public let input: [String]
}

public struct CreateCompletion: Codable, Hashable {

private enum CodingKeys: String, CodingKey {
case model
case prompt
case maxTokens = "max_tokens"
case stream
case stop
case temperature
case topP = "top_p"
case topK = "top_k"
case repetitionPenalty = "repetition_penalty"
case logprobs
case echo
case choices = "n"
case safetyModel = "safety_model"
}

public let model: TogetherAI.Model.Completion
public let prompt: String

// The maximum number of tokens to generate.
// Defaults to 200
public let maxTokens: Int?

// If true, stream tokens as Server-Sent Events as the model generates them instead of waiting for the full model response. If false, return a single JSON object containing the results.
public let stream: Bool?

// A list of string sequences that will truncate (stop) inference text output. For example, "" will stop generation as soon as the model generates the given token.
public let stop: [String]?

// A decimal number that determines the degree of randomness in the response. A value of 1 will always yield the same output. A temperature less than 1 favors more correctness and is appropriate for question answering or summarization. A value greater than 1 introduces more randomness in the output.
public let temperature: Double?

// The top_p (nucleus) parameter is used to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out. This technique helps to maintain diversity and generate more fluent and natural-sounding text.
public let topP: Double?

// The top_k parameter is used to limit the number of choices for the next predicted word or token. It specifies the maximum number of tokens to consider at each step, based on their probability of occurrence. This technique helps to speed up the generation process and can improve the quality of the generated text by focusing on the most likely options.
public let topK: Double?

// A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition.
public let repetitionPenalty: Double?

// Number of top-k logprobs to return
public let logprobs: Int?

// Echo prompt in output. Can be used with logprobs to return prompt logprobs.
public let echo: Bool?

// How many completions to generate for each prompt
public let choices: Int?

// A moderation model to validate tokens. Choice between available moderation models found here: https://docs.together.ai/docs/inference-models#moderation-models
public let safetyModel: String?

public init(
model: TogetherAI.Model.Completion,
prompt: String,
maxTokens: Int?,
stream: Bool? = nil,
stop: [String]? = nil,
temperature: Double? = nil,
topP: Double? = nil,
topK: Double? = nil,
repetitionPenalty: Double? = nil,
logprobs: Int? = nil,
echo: Bool? = nil,
choices: Int? = nil,
safetyModel: String? = nil
) {
self.model = model
self.prompt = prompt
self.maxTokens = maxTokens ?? 200
self.stream = stream
self.stop = stop
self.temperature = temperature
self.topP = topP
self.topK = topK
self.repetitionPenalty = repetitionPenalty
self.logprobs = logprobs
self.echo = echo
self.choices = choices
self.safetyModel = safetyModel
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
//
// Copyright (c) Vatsal Manot
//

import CoreMI
import CorePersistence
import Diagnostics
@_spi(Internal) import LargeLanguageModels
import Merge
import Swallow

extension TogetherAI.Client: _TaskDependenciesExporting {
public var _exportedTaskDependencies: TaskDependencies {
var result = TaskDependencies()

result[\.llm] = self
result[\.embedding] = self

return result
}
}

extension TogetherAI.Client: LLMRequestHandling {
private var _debugPrintCompletions: Bool {
false
}

public var _availableModels: [ModelIdentifier]? {
return TogetherAI.Model.allCases.map({ $0.__conversion() })
}

public func complete<Prompt: AbstractLLM.Prompt>(
prompt: Prompt,
parameters: Prompt.CompletionParameters
) async throws -> Prompt.Completion {
let _completion: Any

switch prompt {
case let prompt as AbstractLLM.TextPrompt:
_completion = try await _complete(
prompt: prompt,
parameters: try cast(parameters)
)
default:
throw LLMRequestHandlingError.unsupportedPromptType(Prompt.self)
}

return try cast(_completion)
}

private func _complete(
prompt: AbstractLLM.TextPrompt,
parameters: AbstractLLM.TextCompletionParameters
) async throws -> AbstractLLM.TextCompletion {
let parameters = try cast(parameters, to: AbstractLLM.TextCompletionParameters.self)

let model = TogetherAI.Model.Completion.mixtral8x7b

let promptText = try prompt.prefix.promptLiteral
let completion = try await
self.createCompletion(
for: model,
prompt: promptText._stripToText(),
maxTokens: parameters.tokenLimit.fixedValue,
stop: parameters.stops,
temperature: parameters.temperatureOrTopP?.temperature,
topP: parameters.temperatureOrTopP?.topProbabilityMass
)

let text = try completion.choices.toCollectionOfOne().first.text

_debugPrint(
prompt: prompt.debugDescription
.delimited(by: .quotationMark)
.delimited(by: "\n")
,
completion: text
.delimited(by: .quotationMark)
.delimited(by: "\n")
)


return .init(prefix: promptText, text: text)
}
}

extension TogetherAI.Client {
private func _debugPrint(prompt: String, completion: String) {
guard _debugPrintCompletions else {
return
}

guard _isDebugAssertConfiguration else {
return
}

let description = String.concatenate(separator: "\n") {
"=== [PROMPT START] ==="
prompt.debugDescription
.delimited(by: .quotationMark)
.delimited(by: "\n")
"==== [COMPLETION] ===="
completion
.delimited(by: .quotationMark)
.delimited(by: "\n")
"==== [PROMPT END] ===="
}

print(description)
}
}

// MARK: - Auxiliary

extension ModelIdentifier {

public init(
from model: TogetherAI.Model.Completion
) {
self.init(provider: .togetherAI, name: model.rawValue, revision: nil)
}

public init(
from model: TogetherAI.Model.Embedding
) {
self.init(provider: .togetherAI, name: model.rawValue, revision: nil)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//
// Copyright (c) Vatsal Manot
//

import CoreMI
import CorePersistence

extension TogetherAI.Client: TextEmbeddingsRequestHandling {
public func fulfill(
_ request: TextEmbeddingsRequest
) async throws -> TextEmbeddings {
guard !request.input.isEmpty else {
return TextEmbeddings(
model: .init(from: TogetherAI.Model.Embedding.togetherM2Bert80M2KRetrieval),
data: []
)
}

let model: ModelIdentifier = request.model ?? ModelIdentifier(from: TogetherAI.Model.Embedding.togetherM2Bert80M2KRetrieval)
let embeddingModel = try TogetherAI.Model.Embedding(rawValue: model.name).unwrap()

let embeddings = try await createEmbeddings(
for: embeddingModel,
input: request.input
).data

try _tryAssert(request.input.count == embeddings.count)

return TextEmbeddings(
model: .init(from: TogetherAI.Model.Embedding.togetherM2Bert80M2KRetrieval),
data: request.input.zip(embeddings).map {
TextEmbeddings.Element(
text: $0,
embedding: $1.embedding,
model: model
)
}
)
}
}
Loading
Loading