Skip to content

Commit

Permalink
feat: adds answer session to switch
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Sep 26, 2024
1 parent d7c878b commit 62f83e2
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 18 deletions.
24 changes: 12 additions & 12 deletions packages/orama/src/methods/answer-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ import type { AnyDocument, AnyOrama, Nullable, OramaPluginSync, SearchParams, Re
import { createError } from "../errors.js"
import { search } from "./search.js"

type GenericContext =
export type GenericContext =
| string
| object

type MessageRole =
export type MessageRole =
| 'system'
| 'user'
| 'assistant'

type Message = {
export type Message = {
role: MessageRole
content: string
}

type Interaction<SourceT = AnyDocument> = {
export type Interaction<SourceT = AnyDocument> = {
interactionId: string
query: string
response: string
Expand All @@ -29,21 +29,21 @@ type Interaction<SourceT = AnyDocument> = {
errorMessage: Nullable<string>
}

type AnswerSessionEvents<SourceT = AnyDocument> = {
export type AnswerSessionEvents<SourceT = AnyDocument> = {
onStateChange?: (state: Interaction<SourceT>[]) => void
}

type IAnswerSessionConfig<SourceT = AnyDocument> = {
export type IAnswerSessionConfig<SourceT = AnyDocument> = {
conversationID?: string
systemPrompt?: string
userContext?: GenericContext
initialMessages?: Message[]
events?: AnswerSessionEvents<SourceT>
}

type AskParams = SearchParams<AnyDocument>
export type AskParams = SearchParams<AnyDocument>

type RegenerateLastParams = {
export type RegenerateLastParams = {
stream: boolean
}

Expand All @@ -60,8 +60,8 @@ export class AnswerSession<SourceT = AnyDocument> {
private conversationID: string
private messages: Message[] = []
private events: AnswerSessionEvents<SourceT>
private state: Interaction<SourceT>[] = []
private initPromise?: Promise<true | null>
public state: Interaction<SourceT>[] = []

constructor(db: AnyOrama, config: IAnswerSessionConfig<SourceT>) {
this.db = db
Expand Down Expand Up @@ -215,16 +215,16 @@ export class AnswerSession<SourceT = AnyDocument> {
throw createError('PLUGIN_SECURE_PROXY_NOT_FOUND')
}

const pluginExtras = plugin.extra as { proxy: OramaProxy, pluginParams: { models: { chat: ChatModel } } }
const pluginExtras = plugin.extra as { proxy: OramaProxy, pluginParams: { chat: { model: ChatModel } } }

this.proxy = pluginExtras.proxy

if (this.config.systemPrompt) {
this.messages.push({ role: 'system', content: this.config.systemPrompt })
}

if (pluginExtras?.pluginParams?.models?.chat) {
this.chatModel = pluginExtras.pluginParams.models.chat
if (pluginExtras?.pluginParams?.chat?.model) {
this.chatModel = pluginExtras.pluginParams.chat.model
} else {
throw createError('PLUGIN_SECURE_PROXY_MISSING_CHAT_MODEL')
}
Expand Down
12 changes: 12 additions & 0 deletions packages/orama/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ import { Sorter } from './components/sorter.js'
import { Language } from './components/tokenizer/languages.js'
import { Point } from './trees/bkd.js'

export type {
IAnswerSessionConfig,
AnswerSession,
AnswerSessionEvents,
AskParams,
GenericContext,
Interaction,
Message,
MessageRole,
RegenerateLastParams
} from './methods/answer-session.js'

export { MODE_FULLTEXT_SEARCH, MODE_HYBRID_SEARCH, MODE_VECTOR_SEARCH } from './constants.js'

export type { DefaultTokenizer } from './components/tokenizer/index.js'
Expand Down
3 changes: 2 additions & 1 deletion packages/switch/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"files": ["dist"],
"scripts": {
"build": "tsup",
"test": "tsx --test tests/search.test.ts"
"test": "tsx --test tests/answer-session.test.ts && tsx --test tests/search.test.ts"
},
"keywords": ["orama", "orama cloud"],
"type": "module",
Expand All @@ -31,6 +31,7 @@
},
"license": "Apache-2.0",
"devDependencies": {
"@orama/plugin-secure-proxy": "workspace:*",
"tsup": "^7.2.0",
"tsx": "^4.19.0"
}
Expand Down
31 changes: 26 additions & 5 deletions packages/switch/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { AnyOrama, Results, SearchParams, Nullable } from '@orama/orama'
import { search } from '@orama/orama'
import { OramaClient, ClientSearchParams } from '@oramacloud/client'
import type { AnyOrama, Results, SearchParams, Nullable, IAnswerSessionConfig as OSSAnswerSessionConfig } from '@orama/orama'
import { search, AnswerSession as OSSAnswerSession } from '@orama/orama'
import { OramaClient, ClientSearchParams, AnswerSessionParams as CloudAnswerSessionConfig, AnswerSession as CloudAnswerSession } from '@oramacloud/client'

export type OramaSwitchClient = AnyOrama | OramaClient

Expand All @@ -13,6 +13,7 @@ export type SearchConfig = {
}

export class Switch<T = OramaSwitchClient> {
private invalidClientError = 'Invalid client. Expected either an OramaClient or an Orama OSS database.'
client: OramaSwitchClient
clientType: ClientType
isCloud: boolean = false
Expand All @@ -28,7 +29,7 @@ export class Switch<T = OramaSwitchClient> {
this.clientType = 'oss'
this.isOSS = true
} else {
throw new Error('Invalid client. Expected either an OramaClient or an Orama OSS database.')
throw new Error(this.invalidClientError)
}
}

Expand All @@ -42,4 +43,24 @@ export class Switch<T = OramaSwitchClient> {
return search(this.client as AnyOrama, params as SearchParams<AnyOrama>) as Promise<Nullable<Results<R>>>
}
}
}

createAnswerSession(params: T extends OramaClient ? CloudAnswerSessionConfig : OSSAnswerSessionConfig): T extends OramaClient ? CloudAnswerSession : OSSAnswerSession {
if (this.isCloud) {
const p = params as CloudAnswerSessionConfig
return (this.client as OramaClient).createAnswerSession(p) as unknown as T extends OramaClient ? CloudAnswerSession : OSSAnswerSession
}

if (this.isOSS) {
const p = params as OSSAnswerSessionConfig
return new OSSAnswerSession(this.client as AnyOrama, {
conversationID: p.conversationID,
initialMessages: p.initialMessages,
events: p.events,
userContext: p.userContext,
systemPrompt: p.systemPrompt,
}) as unknown as T extends OramaClient ? CloudAnswerSession : OSSAnswerSession
}

throw new Error(this.invalidClientError)
}
}
71 changes: 71 additions & 0 deletions packages/switch/tests/answer-session.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { test } from 'node:test'
import assert from 'node:assert/strict'
import { OramaClient } from '@oramacloud/client'
import { pluginSecureProxy } from '@orama/plugin-secure-proxy'
import { create, insertMultiple } from '@orama/orama'
import { Switch } from '../src/index.js'

const CLOUD_URL = process.env.ORAMA_CLOUD_E2E_URL
const CLOUD_API_KEY = process.env.ORAMA_CLOUD_E2E_API_KEY
const SECURE_PROXY_API_KEY = process.env.ORAMA_SECURE_PROXY_API_KEY

if (!CLOUD_URL || !CLOUD_API_KEY) {
console.log(
'Skipping Orama Switch remote client test since ORAMA_CLOUD_E2E_URL and ORAMA_CLOUD_E2E_API_KEY are not set'
)
process.exit(0)
}

test('local client', async () => {
const db = await create({
schema: {
name: 'string',
age: 'number',
embeddings: 'vector[384]'
} as const,
plugins: [
pluginSecureProxy({
apiKey: SECURE_PROXY_API_KEY!,
chat: {
model: 'openai/gpt-4'
},
embeddings: {
defaultProperty: 'embeddings',
model: 'openai/text-embedding-3-small',
},
})
]
})

await insertMultiple(db, [
{ name: 'Alice', age: 30 },
{ name: 'Bob', age: 40 },
{ name: 'Charlie', age: 50 }
])

const answerSession = new Switch(db).createAnswerSession({
systemPrompt: `You're an AI agent used to greet people. You will receive a name and will have to proceed generating a greeting message for that name. Use your fantasy.`,
})

await answerSession.ask({ term: 'Bob', where: { age: { eq: 40 } } })

const state = answerSession.state

assert(state.length === 1)
assert(state[0].query === 'Bob')
})

test('remote client', async () => {
const cloudClient = new OramaClient({
api_key: CLOUD_API_KEY,
endpoint: CLOUD_URL
})

const answerSession = new Switch(cloudClient).createAnswerSession({})
await answerSession.ask({ term: 'What is Orama?' })

const state = answerSession.state

assert(state.length === 1)
assert(state[0].query === 'What is Orama?')
})
3 changes: 3 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 62f83e2

Please sign in to comment.