Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: OpenAIGeneration model for embedder #9474

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions packages/embedder/ai_models/AbstractModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export interface GenerationModelConfig extends ModelConfig {}

export abstract class AbstractModel {
public readonly url?: string
public modelInstance: any
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unused. Good cruft to remove


constructor(config: ModelConfig) {
this.url = this.normalizeUrl(config.url)
Expand Down Expand Up @@ -57,7 +56,6 @@ export interface GenerationOptions {
temperature?: number
topK?: number
topP?: number
truncate?: boolean
Copy link
Contributor Author

@jordanh jordanh Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not supported by OpenAI, I believe it truncates by default. I changed the TextGenerationInference implementation to always truncate and removed this option.

}

export abstract class AbstractGenerationModel extends AbstractModel {
Expand Down
9 changes: 6 additions & 3 deletions packages/embedder/ai_models/ModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
GenerationModelConfig,
ModelConfig
} from './AbstractModel'
import OpenAIGeneration from './OpenAIGeneration'
import TextEmbeddingsInference from './TextEmbeddingsInference'
import TextGenerationInference from './TextGenerationInference'

Expand All @@ -16,7 +17,7 @@ interface ModelManagerConfig {
}

export type EmbeddingsModelType = 'text-embeddings-inference'
export type GenerationModelType = 'text-generation-inference'
export type GenerationModelType = 'openai' | 'text-generation-inference'

export class ModelManager {
embeddingModels: AbstractEmbeddingsModel[]
Expand Down Expand Up @@ -80,9 +81,11 @@ export class ModelManager {
const [modelType, _] = modelConfig.model.split(':') as [GenerationModelType, string]

switch (modelType) {
case 'openai': {
return new OpenAIGeneration(modelConfig)
}
case 'text-generation-inference': {
const generator = new TextGenerationInference(modelConfig)
return generator
return new TextGenerationInference(modelConfig)
}
default:
throw new Error(`unsupported summarization model '${modelType}'`)
Expand Down
94 changes: 94 additions & 0 deletions packages/embedder/ai_models/OpenAIGeneration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import OpenAI from 'openai'
import {
AbstractGenerationModel,
GenerationModelConfig,
GenerationModelParams,
GenerationOptions
} from './AbstractModel'

const MAX_REQUEST_TIME_S = 3 * 60

export type ModelId = 'gpt-3.5-turbo-0125' | 'gpt-4-turbo-preview'

type OpenAIGenerationOptions = Omit<GenerationOptions, 'topK'>

const modelIdDefinitions: Record<ModelId, GenerationModelParams> = {
'gpt-3.5-turbo-0125': {
maxInputTokens: 4096
},
'gpt-4-turbo-preview': {
maxInputTokens: 128000
}
}

function isValidModelId(object: any): object is ModelId {
return Object.keys(modelIdDefinitions).includes(object)
}

export class OpenAIGeneration extends AbstractGenerationModel {
private openAIApi: OpenAI | null
private modelId: ModelId

constructor(config: GenerationModelConfig) {
super(config)
if (!process.env.OPEN_AI_API_KEY) {
this.openAIApi = null
return
}
this.openAIApi = new OpenAI({
apiKey: process.env.OPEN_AI_API_KEY,
organization: process.env.OPEN_AI_ORG_ID
})
}

async summarize(content: string, options: OpenAIGenerationOptions) {
if (!this.openAIApi) {
const eMsg = 'OpenAI is not configured'
console.log('OpenAIGenerationSummarizer.summarize(): ', eMsg)
throw new Error(eMsg)
}
const {maxNewTokens: max_tokens = 512, seed, stop, temperature = 0.8, topP: top_p} = options
const prompt = `Create a brief, one-paragraph summary of the following: ${content}`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observing the differences between Zephyr and gpt-3.5-turbo, I have a feeling we may need to play with this prompt a bit to maximize which features are included in the summary. There are choices we may want to consider:

  • Preserving names vs. using pronouns if they are known
  • Attempting to summarize each thread into a chunk of text, rather than summarizing the entire block

I'm not sure yet, though, so I want to forge ahead and try and do some reporting before iterating on this pipeline.


try {
const response = await this.openAIApi.chat.completions.create({
frequency_penalty: 0,
max_tokens,
messages: [
{
role: 'user',
content: prompt
}
],
model: this.modelId,
presence_penalty: 0,
temperature,
seed,
stop,
top_p
})
const maybeSummary = response.choices[0]?.message?.content?.trim()
if (!maybeSummary) throw new Error('OpenAI returned empty summary')
return maybeSummary
} catch (e) {
console.log('OpenAIGenerationSummarizer.summarize(): ', e)
throw e
}
}
protected constructModelParams(config: GenerationModelConfig): GenerationModelParams {
const modelConfigStringSplit = config.model.split(':')
if (modelConfigStringSplit.length != 2) {
throw new Error('OpenAIGeneration model string must be colon-delimited and len 2')
}

const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`OpenAIGeneration model id unknown: ${maybeModelId}`)

this.modelId = maybeModelId

return modelIdDefinitions[maybeModelId]
}
}

export default OpenAIGeneration
2 changes: 1 addition & 1 deletion packages/embedder/ai_models/TextEmbeddingsInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class TextEmbeddingsInference extends AbstractEmbeddingsModel {
if (!this.url) throw new Error('TextGenerationInferenceSummarizer model requires url')
const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`TextGenerationInference model subtype unknown: ${maybeModelId}`)
throw new Error(`TextGenerationInference model id unknown: ${maybeModelId}`)
return modelIdDefinitions[maybeModelId]
}
}
Expand Down
26 changes: 9 additions & 17 deletions packages/embedder/ai_models/TextGenerationInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,16 @@ export class TextGenerationInference extends AbstractGenerationModel {
super(config)
}

public async summarize(content: string, options: GenerationOptions) {
const {
maxNewTokens: max_new_tokens = 512,
seed,
stop,
temperature = 0.8,
topP,
topK,
truncate
} = options
async summarize(content: string, options: GenerationOptions) {
const {maxNewTokens: max_new_tokens = 512, seed, stop, temperature = 0.8, topP, topK} = options
const parameters = {
max_new_tokens,
seed,
stop,
temperature,
topP,
topK,
truncate
truncate: true
}
const prompt = `Create a brief, one-paragraph summary of the following: ${content}`
const fetchOptions = {
Expand All @@ -59,27 +51,27 @@ export class TextGenerationInference extends AbstractGenerationModel {
}

try {
// console.log(`TextGenerationInterface.summarize(): summarizing from ${this.url}/generate`)
// console.log(`TextGenerationInference.summarize(): summarizing from ${this.url}/generate`)
const res = await fetchWithRetry(`${this.url}/generate`, fetchOptions)
const json = await res.json()
if (!json || !json.generated_text)
throw new Error('TextGenerationInterface.summarize(): malformed response')
throw new Error('TextGenerationInference.summarize(): malformed response')
return json.generated_text as string
} catch (e) {
console.log('TextGenerationInterfaceSummarizer.summarize(): timeout')
console.log('TextGenerationInferenceSummarizer.summarize(): timeout')
throw e
}
}
protected constructModelParams(config: GenerationModelConfig): GenerationModelParams {
const modelConfigStringSplit = config.model.split(':')
if (modelConfigStringSplit.length != 2) {
throw new Error('TextGenerationInterface model string must be colon-delimited and len 2')
throw new Error('TextGenerationInference model string must be colon-delimited and len 2')
}

if (!this.url) throw new Error('TextGenerationInterfaceSummarizer model requires url')
if (!this.url) throw new Error('TextGenerationInferenceSummarizer model requires url')
const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`TextGenerationInterface model subtype unknown: ${maybeModelId}`)
throw new Error(`TextGenerationInference model id unknown: ${maybeModelId}`)
return modelIdDefinitions[maybeModelId]
}
}
Expand Down
3 changes: 1 addition & 2 deletions packages/embedder/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ const dequeueAndEmbedUntilEmpty = async (modelManager: ModelManager) => {
try {
const generator = modelManager.generationModels[0] // use 1st generator
if (!generator) throw new Error(`Generator unavailable`)
const summarizeOptions = {maxInputTokens, truncate: true}
console.log(`embedder: ...summarizing ${itemKey} for ${modelTable}`)
embedText = await generator.summarize(fullText, summarizeOptions)
embedText = await generator.summarize(fullText, {maxNewTokens: maxInputTokens})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug here! maxInputTokens should have been maxNewTokens

} catch (e) {
await updateJobState(jobQueueId, 'failed', {
stateMessage: `unable to summarize long embed text: ${e}`
Expand Down
2 changes: 1 addition & 1 deletion packages/embedder/indexing/embeddingsTablesOps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export async function selectMetaToQueue(
.where(({eb, not, or, and, exists, selectFrom}) =>
and([
or([
not(eb('em.models', '<@', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any),
not(eb('em.models', '@>', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug here! We want to check if em.models is missing any configured models. Before this change, we were checking if any configured models were missing any elements in em.models. The only reason this produced queued jobs before was because em.models defaults to NULL.

The <@ is the "is contained by" operator. It checks if the left-hand operand is contained within the right-hand operand. It returns true if the set on the right includes all elements or the specific structure on the left.

ARRAY[2,3] <@ ARRAY[1,2,3] returns true
ARRAY[2] <@ ARRAY[1,2,3] returns true (not what we want!)

The @> operator is the "contains" operator. It checks if the left-hand operand contains the right-hand operand. It returns true if the set on the left includes all elements or the specific structure on the right.

ARRAY[1,2,3] @> ARRAY[2,3] returns true
ARRAY[1,2] @> ARRAY[2,3] returns false (what we want)

The better fix is to just stop using the models array...but we've already got that on our list. This just makes the current state work as designed.

eb('em.models' as any, 'is', null)
]),
not(
Expand Down
Loading