Skip to content

Commit

Permalink
adds events
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Jul 19, 2024
1 parent 7a61094 commit 9ae9e87
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 7 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ dependencies: [

## Usage

Performing full-text, vector, or hybrid search:

```swift
import OramaClient

Expand All @@ -51,7 +53,7 @@ struct MyDoc: Encodable & Decodable {
let clientParams = OramaClientParams(endpoint: "<ORAMA CLOUD URL>", apiKey: "<ORAMA CLOUD API KEY>")
let orama = OramaClient(params: clientParams)

let searchParams = ClientSearchParams.builder(term: "What is Orama?", mode: .fulltext)
let searchParams = ClientSearchParams.builder(term: "What is Orama?", mode: .fulltext) // Mode can be .vector or .hybrid too
.limit(10) // optional
.offset(0) // optional
.returning(["title", "description"]) // optional
Expand All @@ -62,6 +64,38 @@ let searchResults: SearchResults<MyDoc> = try await orama.search(query: searchPa
print("\(searchResults.count) total results.")
```

Performing an answer session:

```swift
import OramaClient

struct MyDoc: Encodable & Decodable {
let title: String
let description: String
}

let clientParams = OramaClientParams(endpoint: "<ORAMA CLOUD URL>", apiKey: "<ORAMA CLOUD API KEY>")
let orama = OramaClient(params: clientParams)
let answerSessionParams = AnswerParams<E2EDoc>(
initialMessages: [],
inferenceType: .documentation,
oramaClient: orama,
userContext: nil,
events: nil
)

let answerSession = AnswerSession(params: answerSessionParams)
.on(event: .stateChange, callback: { state in print(state) })
.on(event: .relatedQueries, callback: { related in print(related) })

let askParams = AnswerParams<E2EDoc>.AskParams(
query: "What's the best movie to watch with the family?",
userData: nil,
related: nil
)
let answer = try await answerSession.ask(params: askParams)
```

## License

[Apache 2.0](/LICENSE.md)
63 changes: 57 additions & 6 deletions Sources/oramacloud-client/answer-session.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct AnswerParams<Doc: Encodable & Decodable> {
let initialMessages: [Message]
let inferenceType: InferenceType
let oramaClient: OramaClient
let userContext: UserSpecs
let userContext: UserSpecs?
let events: Events?

enum InferenceType: Encodable, Decodable {
Expand Down Expand Up @@ -55,6 +55,8 @@ struct AnswerParams<Doc: Encodable & Decodable> {
}

struct Events {
typealias Callback = (Any) -> Void

var onMessageChange: (([Message]) -> Void)?
var onMessageLoading: ((Bool) -> Void)?
var onAnswerAborted: ((Bool) -> Void)?
Expand All @@ -64,6 +66,17 @@ struct AnswerParams<Doc: Encodable & Decodable> {
var onNewInteractionStarted: ((String) -> Void)?
var onStateChange: (([Interaction<Doc>]) -> Void)?
}

enum Event {
case messageChange
case messageLoading
case answerAborted
case sourceChange
case queryTranslated
case relatedQueries
case newInteractionStarted
case stateChange
}
}

@available(macOS 13.0, *)
Expand All @@ -78,7 +91,7 @@ class AnswerSession<Doc: Encodable & Decodable> {
private var abortController: Task<Void, Error>?
private var endpoint: String
private let userContext: AnswerParams<Doc>.UserSpecs
private let events: AnswerParams<Doc>.Events?
private var events: AnswerParams<Doc>.Events?
private let searchEndpoint: String
private let conversationID = Cuid.generateId()
private let userID = User.init().getUserID()
Expand All @@ -87,21 +100,55 @@ class AnswerSession<Doc: Encodable & Decodable> {
private var state: [AnswerParams<Doc>.Interaction<Doc>] = []

init(params: AnswerParams<Doc>) {
self.userContext = params.userContext
self.userContext = params.userContext ?? .string
self.events = params.events
self.messages = params.initialMessages
self.inferenceType = params.inferenceType
self.endpoint = "\(self.endpointBaseURL)/v1/answer?api-key=\(params.oramaClient.apiKey)"
self.searchEndpoint = params.oramaClient.endpoint
}

func fetchAnswer(params: AnswerParams<Doc>.AskParams) async throws -> AsyncThrowingStream<String, Error> {
public func on(event: AnswerParams<Doc>.Event, callback: @escaping AnswerParams<Doc>.Events.Callback) -> AnswerSession<Doc> {
switch event {
case .messageChange:
self.events?.onMessageChange = { message in callback(message) }
case .messageLoading:
self.events?.onMessageLoading = { loading in callback(loading) }
case .answerAborted:
self.events?.onAnswerAborted = { answerAborted in callback(answerAborted) }
case .sourceChange:
self.events?.onSourceChange = { sources in callback(sources) }
case .queryTranslated:
self.events?.onQueryTranslated = { query in callback(query) }
case .relatedQueries:
self.events?.onRelatedQueries = { relatedQueries in callback(relatedQueries) }
case .newInteractionStarted:
self.events?.onNewInteractionStarted = { interactionID in callback(interactionID) }
case .stateChange:
self.events?.onStateChange = { state in callback(state) }
}

return self
}

public func askStream(params: AnswerParams<Doc>.AskParams) async throws -> String {
return try await ask(params: params)
}

public func ask(params: AnswerParams<Doc>.AskParams) async throws -> String {
let stream = try await fetchAnswer(params: params)
var response = ""
for try await message in stream {
response += message
}
return response
}

private func fetchAnswer(params: AnswerParams<Doc>.AskParams) async throws -> AsyncThrowingStream<String, Error> {
AsyncThrowingStream { continuation in
let interactionId = Cuid.generateId()
self.abortController = Task {
do {


self.state.append(AnswerParams<Doc>.Interaction(
interactionId: interactionId,
query: params.query,
Expand Down Expand Up @@ -197,6 +244,7 @@ class AnswerSession<Doc: Encodable & Decodable> {
self.state[self.state.firstIndex(where: { $0.interactionId == interactionId })!].aborted = true
self.events?.onAnswerAborted?(true)
self.events?.onStateChange?(self.state)
continuation.finish()
} else {
continuation.finish(throwing: error)
}
Expand All @@ -205,7 +253,10 @@ class AnswerSession<Doc: Encodable & Decodable> {
self.state[self.state.firstIndex(where: { $0.interactionId == interactionId })!].loading = false
self.events?.onStateChange?(self.state)
self.events?.onMessageLoading?(false)
continuation.finish()
}

continuation.finish()
}
}

Expand Down
27 changes: 27 additions & 0 deletions Tests/oramacloud-clientTests/oramacloud_clientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,40 @@ final class oramacloud_clientTests: XCTestCase {
expectation.fulfill()
} catch {
print("Search failed with error: \(error)")
fflush(stdout)
XCTFail("Search failed with error: \(error)")
expectation.fulfill()
}
}

wait(for: [expectation], timeout: 10.0)
}

func testE2EAnswerSession() async throws {
struct E2EDoc: Encodable & Decodable {
let breed: String
}

let expectation = XCTestExpectation(description: "Async answer session completes")

let clientParams = OramaClientParams(endpoint: e2eEndpoint, apiKey: e2eApiKey)
let orama = OramaClient(params: clientParams)
let answerSessionParams = AnswerParams<E2EDoc>(
initialMessages: [],
inferenceType: .documentation,
oramaClient: orama,
userContext: nil,
events: nil
)

let answerSession = AnswerSession(params: answerSessionParams)

let askParams = AnswerParams<E2EDoc>.AskParams(query: "german", userData: nil, related: nil)
let response = try await answerSession.ask(params: askParams)

XCTAssertNotNil(response)
wait(for: [expectation], timeout: 60.0)
}
}

@available(macOS 12.0, *)
Expand Down

0 comments on commit 9ae9e87

Please sign in to comment.