-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: OpenAIGeneration model for embedder #9474
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ export interface GenerationModelConfig extends ModelConfig {} | |
|
||
export abstract class AbstractModel { | ||
public readonly url?: string | ||
public modelInstance: any | ||
|
||
constructor(config: ModelConfig) { | ||
this.url = this.normalizeUrl(config.url) | ||
|
@@ -57,7 +56,6 @@ export interface GenerationOptions { | |
temperature?: number | ||
topK?: number | ||
topP?: number | ||
truncate?: boolean | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not supported by OpenAI, I believe it truncates by default. I changed the |
||
} | ||
|
||
export abstract class AbstractGenerationModel extends AbstractModel { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import OpenAI from 'openai' | ||
import { | ||
AbstractGenerationModel, | ||
GenerationModelConfig, | ||
GenerationModelParams, | ||
GenerationOptions | ||
} from './AbstractModel' | ||
|
||
const MAX_REQUEST_TIME_S = 3 * 60 | ||
|
||
export type ModelId = 'gpt-3.5-turbo-0125' | 'gpt-4-turbo-preview' | ||
|
||
type OpenAIGenerationOptions = Omit<GenerationOptions, 'topK'> | ||
|
||
const modelIdDefinitions: Record<ModelId, GenerationModelParams> = { | ||
'gpt-3.5-turbo-0125': { | ||
maxInputTokens: 4096 | ||
}, | ||
'gpt-4-turbo-preview': { | ||
maxInputTokens: 128000 | ||
} | ||
} | ||
|
||
function isValidModelId(object: any): object is ModelId { | ||
return Object.keys(modelIdDefinitions).includes(object) | ||
} | ||
|
||
export class OpenAIGeneration extends AbstractGenerationModel { | ||
private openAIApi: OpenAI | null | ||
private modelId: ModelId | ||
|
||
constructor(config: GenerationModelConfig) { | ||
super(config) | ||
if (!process.env.OPEN_AI_API_KEY) { | ||
this.openAIApi = null | ||
return | ||
} | ||
this.openAIApi = new OpenAI({ | ||
apiKey: process.env.OPEN_AI_API_KEY, | ||
organization: process.env.OPEN_AI_ORG_ID | ||
}) | ||
} | ||
|
||
async summarize(content: string, options: OpenAIGenerationOptions) { | ||
if (!this.openAIApi) { | ||
const eMsg = 'OpenAI is not configured' | ||
console.log('OpenAIGenerationSummarizer.summarize(): ', eMsg) | ||
throw new Error(eMsg) | ||
} | ||
const {maxNewTokens: max_tokens = 512, seed, stop, temperature = 0.8, topP: top_p} = options | ||
const prompt = `Create a brief, one-paragraph summary of the following: ${content}` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Observing the differences between Zephyr and gpt-3.5-turbo, I have a feeling we may need to play with this prompt a bit to maximize which features are included in the summary. There are choices we may want to consider:
I'm not sure yet, though, so I want to forge ahead and try and do some reporting before iterating on this pipeline. |
||
|
||
try { | ||
const response = await this.openAIApi.chat.completions.create({ | ||
frequency_penalty: 0, | ||
max_tokens, | ||
messages: [ | ||
{ | ||
role: 'user', | ||
content: prompt | ||
} | ||
], | ||
model: this.modelId, | ||
presence_penalty: 0, | ||
temperature, | ||
seed, | ||
stop, | ||
top_p | ||
}) | ||
const maybeSummary = response.choices[0]?.message?.content?.trim() | ||
if (!maybeSummary) throw new Error('OpenAI returned empty summary') | ||
return maybeSummary | ||
} catch (e) { | ||
console.log('OpenAIGenerationSummarizer.summarize(): ', e) | ||
throw e | ||
} | ||
} | ||
protected constructModelParams(config: GenerationModelConfig): GenerationModelParams { | ||
const modelConfigStringSplit = config.model.split(':') | ||
if (modelConfigStringSplit.length != 2) { | ||
throw new Error('OpenAIGeneration model string must be colon-delimited and len 2') | ||
} | ||
|
||
const maybeModelId = modelConfigStringSplit[1] | ||
if (!isValidModelId(maybeModelId)) | ||
throw new Error(`OpenAIGeneration model id unknown: ${maybeModelId}`) | ||
|
||
this.modelId = maybeModelId | ||
|
||
return modelIdDefinitions[maybeModelId] | ||
} | ||
} | ||
|
||
export default OpenAIGeneration |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,9 +162,8 @@ const dequeueAndEmbedUntilEmpty = async (modelManager: ModelManager) => { | |
try { | ||
const generator = modelManager.generationModels[0] // use 1st generator | ||
if (!generator) throw new Error(`Generator unavailable`) | ||
const summarizeOptions = {maxInputTokens, truncate: true} | ||
console.log(`embedder: ...summarizing ${itemKey} for ${modelTable}`) | ||
embedText = await generator.summarize(fullText, summarizeOptions) | ||
embedText = await generator.summarize(fullText, {maxNewTokens: maxInputTokens}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a bug here! |
||
} catch (e) { | ||
await updateJobState(jobQueueId, 'failed', { | ||
stateMessage: `unable to summarize long embed text: ${e}` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,7 @@ export async function selectMetaToQueue( | |
.where(({eb, not, or, and, exists, selectFrom}) => | ||
and([ | ||
or([ | ||
not(eb('em.models', '<@', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any), | ||
not(eb('em.models', '@>', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a bug here! We want to check if The
The
The better fix is to just stop using the models array...but we've already got that on our list. This just makes the current state work as designed. |
||
eb('em.models' as any, 'is', null) | ||
]), | ||
not( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was unused. Good cruft to remove