Skip to content

Commit

Permalink
feat: add support to windowai.io
Browse files Browse the repository at this point in the history
Closes #1
  • Loading branch information
henrycunh committed Jul 20, 2023
1 parent 161cbf3 commit 3ddde2d
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 108 deletions.
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "cursive-gpt",
"type": "module",
"version": "0.9.0",
"version": "0.10.0",
"packageManager": "pnpm@8.6.0",
"description": "",
"author": "Henrique Cunha <henrycunh@gmail.com>",
Expand Down Expand Up @@ -54,7 +54,8 @@
"typescript": "^5.0.4",
"unbuild": "^1.2.1",
"vite": "^4.3.9",
"vitest": "^0.31.3"
"vitest": "^0.31.3",
"window.ai": "^0.2.4"
},
"simple-git-hooks": {
"pre-commit": "pnpm lint-staged"
Expand Down
7 changes: 7 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

199 changes: 94 additions & 105 deletions src/cursive.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,76 @@ import type { ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateC
import { resguard } from 'resguard'
import type { Hookable } from 'hookable'
import { createDebugger, createHooks } from 'hookable'
import { ofetch } from 'ofetch'
import type { FetchInstance } from 'openai-edge/types/base'
import type { CreateEmbeddingRequest } from 'openai-edge-fns'
import { encode } from 'gpt-tokenizer'
import { type ChatMessage, type CompletionOptions, type MessageOutput } from 'window.ai'
import type { CursiveAnswerResult, CursiveAskCost, CursiveAskOnToken, CursiveAskOptions, CursiveAskOptionsWithPrompt, CursiveAskUsage, CursiveHook, CursiveHooks, CursiveSetupOptions } from './types'
import { CursiveError, CursiveErrorCode } from './types'
import { getStream } from './stream'
import { getTokenCountFromFunctions, getUsage } from './usage'
import type { IfNull } from './util'
import { sleep, toSnake } from './util'
import { randomId, sleep, toSnake } from './util'
import { resolveOpenAIPricing } from './pricing'
import { createOpenAIClient, processOpenAIStream } from './vendor/openai'
import { getUsage } from './usage'
import { resolveVendorFromModel } from './vendor'

export class Cursive {
public _hooks: Hookable<CursiveHooks>
public _vendor: {
openai: ReturnType<typeof createOpenAIClient>
}

private _debugger: { close: () => void }
public options: CursiveSetupOptions

constructor(options: CursiveSetupOptions) {
private _usingWindowAI = false
private _debugger: { close: () => void }
private _ready = false

constructor(options: CursiveSetupOptions = {}) {
this._hooks = createHooks<CursiveHooks>()
this._vendor = {
openai: createOpenAIClient({ apiKey: options.openAI.apiKey }),
openai: createOpenAIClient({ apiKey: options?.openAI?.apiKey }),
}
this.options = options

if (options.debug)
this._debugger = createDebugger(this._hooks, { tag: 'cursive' })

if (options.allowWindowAI === undefined || options.allowWindowAI === true) {
if (typeof window !== 'undefined') {
// Wait for the window.ai to be available, for a maximum of 5 seconds
const start = Date.now()
const interval = setInterval(() => {
if (window.ai) {
clearInterval(interval)
this._usingWindowAI = true
this._ready = true
if (options.debug)
console.log('[cursive] Using WindowAI')
}
else if (Date.now() - start > 200) {
clearInterval(interval)
this._ready = true
}
}, 100)
}
else {
this._ready = true
}
}
else {
this._ready = true
}
}

private _readyCheck() {
return new Promise((resolve) => {
let tries = 0
const interval = setInterval(() => {
if (this._ready || ++tries > 80) {
clearInterval(interval)
resolve(null)
}
}, 10)
})
}

on<H extends CursiveHook>(event: H, callback: CursiveHooks[H]) {
Expand All @@ -41,8 +81,8 @@ export class Cursive {
async ask(
options: CursiveAskOptions,
): Promise<CursiveAnswerResult> {
await this._readyCheck()
const result = await buildAnswer(options, this)

if (result.error) {
return new CursiveAnswer<CursiveError>({
result: null,
Expand All @@ -64,6 +104,7 @@ export class Cursive {
}

async embed(content: string) {
await this._readyCheck()
const options = {
model: 'text-embedding-ada-002',
input: content,
Expand Down Expand Up @@ -185,36 +226,6 @@ export function useCursive(options: CursiveSetupOptions) {
return new Cursive(options)
}

function createOpenAIClient(options: { apiKey: string }) {
const resolvedFetch: FetchInstance = ofetch.native

async function createChatCompletion(payload: CreateChatCompletionRequest, abortSignal?: AbortSignal) {
return resolvedFetch('https://api.openai.com/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${options.apiKey}`,
},
body: JSON.stringify(payload),
signal: abortSignal,
})
}

async function createEmbedding(payload: CreateEmbeddingRequest, abortSignal?: AbortSignal) {
return resolvedFetch('https://api.openai.com/v1/embeddings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${options.apiKey}`,
},
body: JSON.stringify(payload),
signal: abortSignal,
})
}

return { createChatCompletion, createEmbedding }
}

function resolveOptions(options: CursiveAskOptions) {
const {
functions: _ = [],
Expand Down Expand Up @@ -264,79 +275,56 @@ async function createCompletion(context: {
const { payload, abortSignal } = context
await context.cursive._hooks.callHook('completion:before', payload)
const start = Date.now()
const response = await context.cursive._vendor.openai.createChatCompletion({ ...payload }, abortSignal)
let data: any

if (payload.stream) {
const reader = getStream(response).getReader()
data = {
choices: [],
usage: {
completion_tokens: 0,
prompt_tokens: getUsage(payload.messages, payload.model),
},
model: payload.model,
}

if (payload.functions)
data.usage.prompt_tokens += getTokenCountFromFunctions(payload.functions)

while (true) {
const { done, value } = await reader.read()
if (done)
break
let data: CreateChatCompletionResponse & { cost: CursiveAskCost; error: any }

// TODO: Improve the completion creation based on model to vendor matching
// For now this will do
// @ts-expect-error - We're using a private property here
if (context.cursive._usingWindowAI) {
const vendor = resolveVendorFromModel(payload.model)
const resolvedModel = vendor ? `${vendor}/${payload.model}` : payload.model

const options: CompletionOptions<string> = {
maxTokens: payload.max_tokens,
model: resolvedModel,
numOutputs: payload.n,
stopSequences: payload.stop as string[],
temperature: payload.temperature,
}

data = {
...data,
id: value.id,
if (payload.stream && context.onToken) {
options.onStreamResult = (result) => {
const resultResolved = result as MessageOutput
context.onToken({ content: resultResolved.message.content, functionCall: null })
}
value.choices.forEach((choice: any, i: number) => {
const { delta } = choice

if (!data.choices[i]) {
data.choices[i] = {
message: {
function_call: null,
role: 'assistant',
content: '',
},
}
}

if (delta?.function_call?.name)
data.choices[i].message.function_call = delta.function_call

if (delta?.function_call?.arguments)
data.choices[i].message.function_call.arguments += delta.function_call.arguments

if (delta?.content)
data.choices[i].message.content += delta.content

if (context.onToken) {
let chunk: Record<string, any> | null = null
if (delta?.function_call) {
chunk = {
functionCall: delta.function_call,
}
}

if (delta?.content) {
chunk = {
content: delta.content,
}
}

if (chunk)
context.onToken(chunk as any)
}
})
}
const content = data.choices[0].message.content

const response = await window.ai.generateText({
messages: payload.messages as ChatMessage[],
}, options) as MessageOutput[]
data.choices = response.map(choice => ({
message: choice.message,
}))
data.model = payload.model
data.id = randomId()
data.usage.prompt_tokens = getUsage(context.payload.messages, context.payload.model)
const content = data.choices.map(choice => choice.message.content).join('')
data.usage.completion_tokens = encode(content).length
data.usage.total_tokens = data.usage.completion_tokens + data.usage.prompt_tokens
// If we're the model is from OpenAI, we can get the usage and costs
}
else {
data = await response.json()
const response = await context.cursive._vendor.openai.createChatCompletion({ ...payload }, abortSignal)
if (payload.stream) {
data = await processOpenAIStream({ ...context, response })
const content = data.choices.map(choice => choice.message.content).join('')
data.usage.completion_tokens = encode(content).length
data.usage.total_tokens = data.usage.completion_tokens + data.usage.prompt_tokens
}
else {
data = await response.json()
}
}

const end = Date.now()
Expand Down Expand Up @@ -388,6 +376,7 @@ async function askModel(
}), CursiveError)

if (completion.error) {
console.log(completion.error)
if (!completion.error?.details)
throw new CursiveError('Unknown error', completion.error, CursiveErrorCode.UnknownError)

Expand Down
6 changes: 5 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type { JsonSchema7Type } from 'zod-to-json-schema/src/parseDef'
import type { HookResult, ObjectWithNullValues, Override } from './util'
import type { CursiveAnswer } from './cursive'

type CursiveAvailableModels = 'gpt-3.5-turbo' | 'gpt-4' | (string & {})
export type CursiveAvailableModels = 'gpt-3.5-turbo' | 'gpt-4' | (string & {})
// | 'claude-instant' | 'claude-2'

export type InferredFunctionParameters<T extends ZodRawShape> = {
Expand Down Expand Up @@ -46,6 +46,10 @@ export interface CursiveSetupOptions {
modelMapping?: Record<string, string>
}
debug?: boolean
/**
* Allows for the usage of WindowAI
*/
allowWindowAI?: boolean
}

export enum CursiveErrorCode {
Expand Down
4 changes: 4 additions & 0 deletions src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ export type ObjectWithNullValues<T extends Record<string, any>> = {

// Override keys of T with keys of U
export type IfNull<T, U> = T extends null ? U : null

export function randomId() {
return Math.random().toString(36).substring(7)
}
8 changes: 8 additions & 0 deletions src/vendor/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import type { CursiveAvailableModels } from '../types'

export function resolveVendorFromModel(model: CursiveAvailableModels) {
const isFromOpenAI = ['gpt-3.5', 'gpt-4'].find(m => model.startsWith(m))
if (isFromOpenAI)
return 'openai'
return ''
}
Loading

0 comments on commit 3ddde2d

Please sign in to comment.