Skip to content

Commit

Permalink
fix: code review comments take 3
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanh committed Feb 22, 2024
1 parent 51a19f3 commit d9dbe82
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
16 changes: 10 additions & 6 deletions packages/embedder/ai_models/AbstractModel.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export interface ModelConfig {
model: string
url?: string
url: string
}

export interface EmbeddingModelConfig extends ModelConfig {
Expand Down Expand Up @@ -32,12 +32,15 @@ export interface EmbeddingModelParams {
}

export abstract class AbstractEmbeddingsModel extends AbstractModel {
readonly modelParams!: EmbeddingModelParams
readonly embeddingDimensions: number
readonly maxInputTokens: number
readonly tableName: string
constructor(config: EmbeddingModelConfig) {
super(config)
this.modelParams = this.constructModelParams(config)
this.tableName = `Embeddings_${this.modelParams.tableSuffix}`
const modelParams = this.constructModelParams(config)
this.embeddingDimensions = modelParams.embeddingDimensions
this.maxInputTokens = modelParams.maxInputTokens
this.tableName = `Embeddings_${modelParams.tableSuffix}`
}
protected abstract constructModelParams(config: EmbeddingModelConfig): EmbeddingModelParams
abstract getEmbedding(content: string): Promise<number[]>
Expand All @@ -58,10 +61,11 @@ export interface GenerationOptions {
}

export abstract class AbstractGenerationModel extends AbstractModel {
readonly modelParams!: GenerationModelParams
readonly maxInputTokens: number
constructor(config: GenerationModelConfig) {
super(config)
this.modelParams = this.constructModelParams(config)
const modelParams = this.constructModelParams(config)
this.maxInputTokens = modelParams.maxInputTokens
}

protected abstract constructModelParams(config: GenerationModelConfig): GenerationModelParams
Expand Down
6 changes: 3 additions & 3 deletions packages/embedder/ai_models/ModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class ModelManager {
// Initialize embeddings models
this.embeddingModelsMapByTable = {}
this.embeddingModels = config.embeddingModels.map((modelConfig) => {
const [modelType, _] = modelConfig.model.split(':') as [EmbeddingsModelType, string]
const [modelType] = modelConfig.model.split(':') as [EmbeddingsModelType, string]

switch (modelType) {
case 'text-embeddings-inference': {
Expand Down Expand Up @@ -100,7 +100,7 @@ export class ModelManager {
)} = ${tableName}`.execute(pg)
).rows.length > 0
if (hasTable) return undefined
const vectorDimensions = embeddingsModel.modelParams.embeddingDimensions
const vectorDimensions = embeddingsModel.embeddingDimensions
console.log(`ModelManager: creating ${tableName} with ${vectorDimensions} dimensions`)
const query = sql`
DO $$
Expand Down Expand Up @@ -130,7 +130,7 @@ let modelManager: ModelManager | undefined
export function getModelManager() {
if (modelManager) return modelManager
const {AI_EMBEDDING_MODELS, AI_GENERATION_MODELS} = process.env
let config: ModelManagerConfig = {
const config: ModelManagerConfig = {
embeddingModels: [],
generationModels: []
}
Expand Down
12 changes: 8 additions & 4 deletions packages/embedder/ai_models/helpers/fetchWithRetry.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import fetch, {RequestInfo, RequestInit, Response} from 'node-fetch'

interface FetchWithRetryOptions extends RequestInit {
deadline: Date // Deadline for the request to complete
debug?: boolean // Enable debug tracing
Expand All @@ -20,8 +18,9 @@ export default async (url: RequestInfo, options: FetchWithRetryOptions): Promise
const timeoutId = setTimeout(() => controller.abort(), timeout)

try {
while (true) {
while (Date.now() < deadline.getTime()) {
attempt++

if (debug) {
console.log(`Attempt ${attempt}: Fetching ${url}`)
}
Expand All @@ -37,16 +36,21 @@ export default async (url: RequestInfo, options: FetchWithRetryOptions): Promise
// if Retry-After specified, use it; else fallback to exponential backoff
let waitTime = retryAfter ? parseInt(retryAfter, 10) * 1000 : Math.pow(2, attempt) * 1000

// cap waitTime to prevent exceeding the deadline
waitTime = Math.min(waitTime, deadline.getTime() - Date.now())

if (debug) {
console.log(
`Waiting ${waitTime / 1000} seconds before retrying due to status ${response.status}...`
)
}
await new Promise((resolve) => setTimeout(resolve, waitTime))
}

throw new Error('Deadline exceeded')
} catch (error) {
clearTimeout(timeoutId)
if (error.name === 'AbortError') {
if (error instanceof Error && error.name === 'AbortError') {
throw new Error('Request aborted due to deadline')
}
if (debug) {
Expand Down
13 changes: 5 additions & 8 deletions packages/embedder/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@ import Redlock, {RedlockAbortSignal} from 'redlock'

import 'parabol-server/initSentry'
import getKysely from 'parabol-server/postgres/getKysely'
import RedisInstance from 'parabol-server/utils/RedisInstance'
import {DB} from 'parabol-server/postgres/pg'
import getDataLoader from 'parabol-server/graphql/getDataLoader'
import {refreshRetroDiscussionTopicsMeta as refreshRetroDiscussionTopicsMeta} from './indexing/retrospectiveDiscussionTopic'
import {orgIdsWithFeatureFlag} from './indexing/orgIdsWithFeatureFlag'
import getModelManager, {ModelManager} from './ai_models/ModelManager'
import {countWords} from './indexing/countWords'
import {createEmbeddingTextFrom} from './indexing/createEmbeddingTextFrom'
import {DataLoaderWorker} from 'parabol-server/graphql/graphql'
import {
selectJobQueueItemById,
selectMetadataByJobQueueId,
Expand All @@ -21,6 +18,8 @@ import {
import {selectMetaToQueue} from './indexing/embeddingsTablesOps'
import {insertNewJobs} from './indexing/embeddingsTablesOps'
import {completeJobTxn} from './indexing/embeddingsTablesOps'
import {getRootDataLoader} from './indexing/getRootDataLoader'
import {getRedisClient} from './indexing/getRedisClient'

/*
* TODO List
Expand All @@ -47,10 +46,8 @@ tracer.init({
})
tracer.use('pg')

const getRedisClient = () => new RedisInstance(`embedder-${SERVER_ID}`)

const refreshMetadata = async () => {
const dataLoader = getDataLoader() as DataLoaderWorker
const dataLoader = getRootDataLoader()
await refreshRetroDiscussionTopicsMeta(dataLoader)
// In the future, other sorts of objects to index could be added here...
}
Expand Down Expand Up @@ -108,7 +105,7 @@ const maybeQueueMetadataItems = async (modelManager: ModelManager) => {
}

const dequeueAndEmbedUntilEmpty = async (modelManager: ModelManager) => {
const dataLoader = getDataLoader() as DataLoaderWorker
const dataLoader = getRootDataLoader()
const redisClient = getRedisClient()
while (true) {
const maybeRedisQItem = await redisClient.zpopmax('embedder:queue', 1)
Expand Down Expand Up @@ -159,7 +156,7 @@ const dequeueAndEmbedUntilEmpty = async (modelManager: ModelManager) => {
const modelTable = embeddingModel.tableName

let embedText = fullText
const {maxInputTokens} = embeddingModel.modelParams
const maxInputTokens = embeddingModel.maxInputTokens
// we're using word count as an appoximation of tokens
if (wordCount * WORD_COUNT_TO_TOKEN_RATIO > maxInputTokens) {
try {
Expand Down
11 changes: 11 additions & 0 deletions packages/embedder/indexing/getRedisClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import RedisInstance from 'parabol-server/utils/RedisInstance'

const {SERVER_ID} = process.env

let redisClient
export const getRedisClient = () => {
if (!redisClient) {
redisClient = new RedisInstance(`embedder-${SERVER_ID}`)
}
return redisClient
}
10 changes: 10 additions & 0 deletions packages/embedder/indexing/getRootDataLoader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import getDataLoader from 'parabol-server/graphql/getDataLoader'
import {DataLoaderWorker} from 'parabol-server/graphql/graphql'

let rootDataLoader
export const getRootDataLoader = () => {
if (!rootDataLoader) {
rootDataLoader = getDataLoader() as DataLoaderWorker
}
return rootDataLoader
}

0 comments on commit d9dbe82

Please sign in to comment.