Skip to content

Commit

Permalink
Merge branch 'release/0.2.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
intitni committed Mar 4, 2024
2 parents 0fa0b1e + 95bd1e4 commit 8034ce4
Show file tree
Hide file tree
Showing 32 changed files with 1,244 additions and 204 deletions.
2 changes: 1 addition & 1 deletion Core/Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ let package = Package(
dependencies: [
.package(
url: "https://github.com/intitni/CopilotForXcodeKit",
from: "0.4.0"
from: "0.5.0"
),
.package(
url: "https://github.com/pointfreeco/swift-dependencies",
Expand Down
26 changes: 23 additions & 3 deletions Core/Sources/CodeCompletionService/AzureOpenAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,38 @@ public actor AzureOpenAIService {
}

extension AzureOpenAIService: CodeCompletionServiceType {
func getCompletion(_ request: PromptStrategy) async throws -> String {
func getCompletion(_ request: PromptStrategy) async throws -> AsyncStream<String> {
switch endpoint {
case .chatCompletion:
let messages = createMessages(from: request)
CodeCompletionLogger.logger.logPrompt(messages.map {
($0.content, $0.role.rawValue)
})
return try await sendMessages(messages)
return AsyncStream<String> { continuation in
let task = Task {
let result = try await sendMessages(messages)
try Task.checkCancellation()
continuation.yield(result)
continuation.finish()
}
continuation.onTermination = { _ in
task.cancel()
}
}
case .completion:
let prompt = createPrompt(from: request)
CodeCompletionLogger.logger.logPrompt([(prompt, "user")])
return try await sendPrompt(prompt)
return AsyncStream<String> { continuation in
let task = Task {
let result = try await sendPrompt(prompt)
try Task.checkCancellation()
continuation.yield(result)
continuation.finish()
}
continuation.onTermination = { _ in
task.cancel()
}
}
}
}
}
Expand Down
37 changes: 33 additions & 4 deletions Core/Sources/CodeCompletionService/CodeCompletionService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import Fundamental
import Storage

protocol CodeCompletionServiceType {
func getCompletion(
_ request: PromptStrategy
) async throws -> String
associatedtype CompletionSequence: AsyncSequence where CompletionSequence.Element == String

func getCompletion(_ request: PromptStrategy) async throws -> CompletionSequence
}

extension CodeCompletionServiceType {
Expand All @@ -16,7 +16,12 @@ extension CodeCompletionServiceType {
try await withThrowingTaskGroup(of: String.self) { group in
for _ in 0..<max(1, count) {
_ = group.addTaskUnlessCancelled {
try await getCompletion(request)
var result = ""
let stream = try await getCompletion(request)
for try await response in stream {
result.append(response)
}
return result
}
}

Expand Down Expand Up @@ -110,6 +115,18 @@ public struct CodeCompletionService {
let result = try await service.getCompletions(prompt, count: count)
try Task.checkCancellation()
return result
case .ollama:
let service = OllamaService(
url: model.endpoint,
endpoint: .chatCompletion,
modelName: model.info.modelName,
stopWords: prompt.stopWords,
keepAlive: model.info.ollamaInfo.keepAlive,
format: .none
)
let result = try await service.getCompletions(prompt, count: count)
try Task.checkCancellation()
return result
case .unknown:
throw Error.unknownFormat
}
Expand Down Expand Up @@ -145,6 +162,18 @@ public struct CodeCompletionService {
let result = try await service.getCompletions(prompt, count: count)
try Task.checkCancellation()
return result
case .ollama:
let service = OllamaService(
url: model.endpoint,
endpoint: .completion,
modelName: model.info.modelName,
stopWords: prompt.stopWords,
keepAlive: model.info.ollamaInfo.keepAlive,
format: .none
)
let result = try await service.getCompletions(prompt, count: count)
try Task.checkCancellation()
return result
case .unknown:
throw Error.unknownFormat
}
Expand Down
14 changes: 12 additions & 2 deletions Core/Sources/CodeCompletionService/GoogleGeminiService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,22 @@ public struct GoogleGeminiService {
}

extension GoogleGeminiService: CodeCompletionServiceType {
func getCompletion(_ request: PromptStrategy) async throws -> String {
func getCompletion(_ request: PromptStrategy) async throws -> AsyncStream<String> {
let messages = createMessages(from: request)
CodeCompletionLogger.logger.logPrompt(messages.map {
($0.parts.first?.text ?? "N/A", $0.role ?? "N/A")
})
return try await sendMessages(messages)
return AsyncStream<String> { continuation in
let task = Task {
let result = try await sendMessages(messages)
try Task.checkCancellation()
continuation.yield(result)
continuation.finish()
}
continuation.onTermination = { _ in
task.cancel()
}
}
}
}

Expand Down
Loading

0 comments on commit 8034ce4

Please sign in to comment.