Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add playground, support images in input #43

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
node_modules
package-lock.json
playground
dist
test/*.html
src/tools/repos
__*
Expand Down
46 changes: 43 additions & 3 deletions bin/cli.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
#!/usr/bin/env node
const fs = require('fs')
const gpt4 = require('gpt-tokenizer/cjs/model/gpt-4')
const { CompletionService, tools } = require('langxlang')

function countTokens (text) {
return gpt4.encode(text).length
return gpt4.encode(text.replaceAll('<|endoftext|>', '<|EndOfText|>')).length
}

function raise (msg) {
if (msg) console.error(msg)
console.error('Usage: langxlang <command> ...args')
console.error('Usage: langxlang count <tokenizer> <file>')
console.error('Usage: langxlang githubRepoToMarkdown <repo> <branch or ref> [output file] [comma separated extensions]')
console.error('Usage: langxlang folderToMarkdown <path> [output file] [comma separated extensions]')
console.error('Usage (alias): langxlang repo2md <repo> <branch or ref> [output file] [comma seperated extensions]')
console.error('Usage (alias): langxlang folder2md <path> [output file] [comma separated extensions]')
console.error('Example: langxlang count gpt4 myfile.js')
console.error('Example: langxlang count gemini1.5pro myfile.txt')
console.error('Example: langxlang githubRepoToMarkdown PrismarineJS/vec3 master vec3.md')
console.error('Example: langxlang folderToMarkdown ./src output.md .js,.ts')
}

if (process.argv.length < 3) {
Expand All @@ -24,16 +33,47 @@ const commands = {
process.exit(1)
}
console.log('Counting tokens in', file, 'using', tokenizer)
const text = require('fs').readFileSync(file, 'utf8')
if (tokenizer === 'gpt4') {
const text = require('fs').readFileSync(file, 'utf8')
console.log('Tokens:', countTokens(text).toLocaleString())
} else if (tokenizer === 'gemini1.5pro' || tokenizer === 'g15pro') {
const service = new CompletionService()
service.countTokens('gemini-1.5-pro-latest', text).then((tokens) => {
console.log('Tokens:', tokens.toLocaleString())
})
} else {
console.error('Unknown tokenizer', tokenizer)
process.exit(1)
}
},
githubRepoToMarkdown (repo, branch, outFile = 'repo.md', extensions) {
const files = tools.collectGithubRepoFiles(repo, {
branch,
extension: extensions ? extensions.split(',') : undefined,
truncateLargeFiles: 16_000 // 16k
})
const md = tools.concatFilesToMarkdown(files)
fs.writeFileSync(outFile, md)
},
folderToMarkdown (path, outFile = 'folder.md', extensions) {
const files = tools.collectFolderFiles(path, {
extension: extensions ? extensions.split(',') : undefined,
truncateLargeFiles: 16_000 // 16k
})
const md = tools.concatFilesToMarkdown(files)
fs.writeFileSync(outFile, md)
}
}

commands.repo2md = commands.githubRepoToMarkdown
commands.folder2md = commands.folderToMarkdown

const [, , command, ...args] = process.argv
console.error(`command: ${command}`, args)
commands[command](...args)
const handler = commands[command]
if (handler) {
handler(...args)
} else {
raise('Unknown command')
process.exit(1)
}
4 changes: 4 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"pretest": "npm run lint",
"mocha": "mocha --bail test/*.test.js",
"lint": "standard",
"buildWeb": "browserify src/browser.js -o dist/langxlang.js",
"fix": "standard --fix"
},
"repository": {
Expand All @@ -34,6 +35,8 @@
},
"homepage": "https://github.com/extremeheat/LXL#readme",
"devDependencies": {
"basic-ipc": "^0.1.1",
"browserify": "^17.0.0",
"langxlang": "file:.",
"mocha": "^10.0.0",
"standard": "^17.0.0"
Expand All @@ -45,6 +48,7 @@
"debug": "^4.3.4",
"fast-xml-parser": "^4.3.6",
"gpt-tokenizer": "^2.1.2",
"js-yaml": "^4.1.0",
"openai": "^4.28.0",
"ws": "^8.16.0"
}
Expand Down
137 changes: 129 additions & 8 deletions src/CompletionService.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class CompletionService {
}
if (this.geminiApiKey) {
const geminiList = await gemini.listModels(this.geminiApiKey)
Object.assign(geminiModels, Object.fromEntries(geminiList.map((e) => ([e.name, e]))))
Object.assign(geminiModels, Object.fromEntries(geminiList
.filter((e) => e.name.startsWith('models/'))
.map((e) => ([e.name.replace('models/', ''), e]))))
}
return { openai: openaiModels, google: geminiModels }
}
Expand Down Expand Up @@ -107,6 +109,31 @@ class CompletionService {
const msg = structuredClone(entry)
if (msg.role === 'model') msg.role = 'assistant'
if (msg.role === 'guidance') msg.role = 'assistant'
if (msg.text != null) {
delete msg.text
msg.content = entry.text
}
if (typeof msg.content === 'object') {
const updated = []
for (const key in msg.content) {
const value = msg.content[key]
if (value.text) {
updated.push({ type: 'text', text: value.text })
} else if (value.imageURL) {
updated.push({ type: 'image_url', image_url: { url: value.imageURL, detail: value.imageDetail } })
} else if (value.imageB64) {
let dataURL = value.imageB64
if (!dataURL.startsWith('data:')) {
if (!value.mimeType) throw new Error('Missing accompanying `mimeType` for imageB64 that is not a data URL')
dataURL = `data:${value.mimeType};base64,${dataURL}`
}
updated.push({ type: 'image_url', image_url: { url: dataURL, detail: value.imageDetail } })
} else if (value.image_url) {
updated.push({ type: 'image_url', image_url: value.image_url })
}
}
msg.content = updated
}
return msg
}).filter((msg) => msg.content),
{
Expand All @@ -130,26 +157,86 @@ class CompletionService {
tool_calls: 'function'
}[choice.finishReason] ?? 'unknown'
const content = guidance ? guidance + choice.content : choice.content
return { type: choiceType, isTruncated: choice.finishReason === 'length', ...choice, content, text: content }
return {
type: choiceType,
isTruncated: choice.finishReason === 'length',
// ...choice,
content,
text: content
}
})
}

async _requestChatCompleteGemini (model, messages, { maxTokens, stopSequences, temperature, topP, topK }, functions, chunkCb) {
if (!this.geminiApiKey) throw new Error('Gemini API key not set')
async _processGeminiMessages (model, messages) {
// Google Gemini doesn't support data URLs, or even remote ones, so we need to fetch them, extract data URLs then split
async function resolveImage (url) {
// fetch the URL contents to a data URL (node.js)
const req = await fetch(url)
const buffer = await req.arrayBuffer()
const dataURL = `data:${req.headers.get('content-type')};base64,${Buffer.from(buffer).toString('base64')}`
return dataURL
}

function splitDataURL (entry) {
// gemini doesn't support data URLs
const mimeType = entry.slice(5, entry.indexOf(';'))
const data = entry.slice(entry.indexOf(',') + 1)
return { inlineData: { mimeType, data } }
}

// April 2024 - Only Gemini 1.5 supports instructions
const supportsSystemInstruction = checkDoesGoogleModelSupportInstructions(model)
const guidance = checkGuidance(messages, chunkCb)
const imagesForResolve = []
const geminiMessages = messages.map((msg) => {
const m = structuredClone(msg)
if (msg.role === 'assistant') m.role = 'model'
if (msg.role === 'system') m.role = supportsSystemInstruction ? 'system' : 'user'
if (msg.role === 'guidance') m.role = 'model'
if (msg.content != null) {
if (typeof msg.content === 'object') {
const updated = []
for (const entry of msg.content) {
if (entry.text) {
updated.push({ text: entry.text })
} else if (entry.imageURL) {
const val = { imageURL: entry.imageURL }
imagesForResolve.push(val)
updated.push(val)
} else if (entry.imageB64) {
if (entry.imageB64.startsWith('data:')) {
updated.push(splitDataURL(entry.imageB64))
} else if (entry.mimeType) {
updated.push({
inlineData: {
mimeType: entry.mimeType,
data: entry.imageB64
}
})
}
}
}
delete m.content
m.parts = updated
} else if (msg.content != null) {
delete m.content
m.parts = [{ text: msg.content }]
}
return m
}).filter((msg) => msg.parts && (msg.parts.length > 0))

for (const entry of imagesForResolve) {
const dataURL = await resolveImage(entry.imageURL)
Object.assign(entry, splitDataURL(dataURL))
delete entry.imageURL
}

return geminiMessages
}

async _requestChatCompleteGemini (model, messages, { maxTokens, stopSequences, temperature, topP, topK }, functions, chunkCb) {
if (!this.geminiApiKey) throw new Error('Gemini API key not set')
const guidance = checkGuidance(messages, chunkCb)
const geminiMessages = await this._processGeminiMessages(model, messages)

const response = await gemini.generateChatCompletionEx(model, geminiMessages, {
apiKey: this.geminiApiKey,
functions,
Expand All @@ -165,7 +252,13 @@ class CompletionService {
const answer = response.text()
chunkCb?.({ done: true, delta: '' })
const content = guidance ? guidance + answer : answer
const result = { type: 'text', content, text: content }
const result = {
type: 'text',
isTruncated: response.finishReason === 'MAX_TOKENS',
content,
safetyRatings: response.safetyRatings,
text: content
}
return [result]
} else if (response.functionCalls()) {
const calls = response.functionCalls()
Expand All @@ -178,7 +271,7 @@ class CompletionService {
args: call.args
}
}
const result = { type: 'function', fnCalls }
const result = { type: 'function', fnCalls, safetyRatings: response.safetyRatings }
return [result]
} else {
throw new Error('Unknown response from Gemini')
Expand Down Expand Up @@ -214,6 +307,34 @@ class CompletionService {
}
}

async countTokens (model, content) {
const { family } = getModelInfo(model)
switch (family) {
case 'openai':
return require('./tools/tokens').countTokens('gpt-4', content)
case 'gemini':
return gemini.countTokens(this.geminiApiKey, model, Array.isArray(content)
? (await this._processGeminiMessages(model, [{ role: 'user', content }]))[0].parts
: content)
default:
throw new Error(`Model '${model}' not supported for token counting, available models: ${knownModels.join(', ')}`)
}
}

async countTokensInMessages (model, messages) {
const { family } = getModelInfo(model)
switch (family) {
case 'openai':
return messages.reduce((cumLen, entry) => {
return cumLen + this.countTokens(model, entry.content)
}, 0)
case 'gemini':
return gemini.countTokens(this.geminiApiKey, model, this._processGeminiMessages(model, messages))
default:
throw new Error(`Model '${model}' not supported for token counting, available models: ${knownModels.join(', ')}`)
}
}

stop () {}
close () {}
}
Expand Down
11 changes: 10 additions & 1 deletion src/backends/gemini.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async function generateChatCompletionIn (model, messages, options, chunkCb) {
// Function response
resultCandidates.push({
type: 'function',
finishReason: candidate.finishReason,
fnCalls: candidate.content.functionCalls,
raw: data,
safetyRatings: candidate.safetyRatings
Expand All @@ -97,6 +98,7 @@ async function generateChatCompletionIn (model, messages, options, chunkCb) {
// Text response
resultCandidates.push({
type: 'text',
finishReason: candidate.finishReason,
text: () => candidate.content.parts.reduce((acc, part) => acc + part.text, ''),
raw: data,
safetyRatings: candidate.safetyRatings
Expand Down Expand Up @@ -131,6 +133,13 @@ async function listModels (apiKey) {
return response.models
}

async function countTokens (apiKey, model, content) {
const google = new GoogleGenerativeAI(apiKey)
const generator = google.getGenerativeModel({ model }, { apiVersion: 'v1beta' })
const results = await generator.countTokens(content)
return results.totalTokens
}

function mergeDuplicatedRoleMessages (messages) {
// if there are 2 messages with the same role, merge them with a newline.
// Not doing this can return `GoogleGenerativeAIError: [400 Bad Request] Please ensure that multiturn requests ends with a user role or a function response.`
Expand All @@ -146,7 +155,7 @@ function mergeDuplicatedRoleMessages (messages) {
return mergedMessages
}

module.exports = { generateChatCompletionEx, generateChatCompletionIn, generateCompletion, listModels }
module.exports = { generateChatCompletionEx, generateChatCompletionIn, generateCompletion, listModels, countTokens }

/*
{
Expand Down
7 changes: 6 additions & 1 deletion src/backends/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ function createChunkProcessor (chunkCb, resultChoices) {
}
for (const choiceId in chunk.choices) {
const choice = chunk.choices[choiceId]
const resultChoice = resultChoices[choiceId] ??= { content: '', fnCalls: [], finishReason: '', safetyRatings: {} }
const resultChoice = resultChoices[choiceId] ??= {
content: '',
fnCalls: [],
finishReason: '',
safetyRatings: {}
}
if (choice.finish_reason) {
resultChoice.finishReason = choice.finish_reason
}
Expand Down
15 changes: 15 additions & 0 deletions src/browser.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Exports for browser bundles
const mdp = require('./tools/mdp')
const stripping = require('./tools/stripping')
const tokenizer = require('./tools/tokens')
const misc = require('./tools/misc')

window.lxl = {
tools: {
stripping,
tokenizer,
loadPrompt: mdp.loadPrompt,
_segmentPromptByRoles: mdp.segmentByRoles,
...misc
}
}
Loading
Loading