-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
2,136 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
module.exports = { | ||
extends: [ | ||
'../../.eslintrc.js' | ||
], | ||
parserOptions: { | ||
project: './tsconfig.json', | ||
ecmaVersion: 2020, | ||
sourceType: 'module' | ||
}, | ||
"ignorePatterns": ["**/lib", "*.js"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# `Embedder` | ||
|
||
This service builds embedding vectors for semantic search and for other AI/ML | ||
use cases. It does so by: | ||
|
||
1. Updating a list of all possible items to create embedding vectors for and | ||
storing that list in the `EmbeddingsMetadata` table | ||
2. Adding these items in batches to the `EmbeddingsJobQueue` table and a redis | ||
priority queue called `embedder:queue` | ||
3. Allowing one or more parallel embedding services to calculate embedding | ||
vectors (EmbeddingJobQueue states transistion from `queued` -> `embedding`, | ||
then `embedding` -> [deleting the `EmbeddingJobQueue` row] | ||
|
||
In addition to deleteing the `EmbeddingJobQueue` row, when a job completes | ||
successfully: | ||
|
||
- A row is added to the model table with the embedding vector; the | ||
`EmbeddingMetadataId` field on this row points the appropriate | ||
metadata row on `EmbeddingsMetadata` | ||
- The `EmbeddingsMetadata.models` array is updated with the name of the | ||
table that the embedding has been generated for | ||
|
||
4. This process repeats forever using a silly polling loop | ||
|
||
In the future, it would be wonderful to enhance this service such that it were | ||
event driven. | ||
|
||
## Prerequisites | ||
|
||
The Embedder service depends on pgvector being available in Postgres. | ||
|
||
The predeploy script checks for an environment variable | ||
`POSTGRES_USE_PGVECTOR=true` to enable this extension in production. | ||
|
||
## Configuration | ||
|
||
The Embedder service takes no arguments and is controlled by the following | ||
environment variables, here given with example configuration: | ||
|
||
- `AI_EMBEDDER_ENABLE`: enable/disable the embedder service from | ||
performing work, or sleeping indefinitely | ||
|
||
`AI_EMBEDDER_ENABLED='true'` | ||
|
||
- `AI_EMBEDDING_MODELS`: JSON configuration for which embedding models | ||
are enabled. Each model in the array will be instantiated by | ||
`ai_models/ModelManager`. Each model instance will have its own | ||
database table created for it (if it does not exist already) used | ||
to store calculated vectors. See `ai_models/ModelManager` for | ||
which configurations are supported. | ||
|
||
Example: | ||
|
||
`AI_EMBEDDING_MODELS='[{"model": "text-embeddings-inference:llmrails/ember-v1", "url": "http://localhost:3040/"}]'` | ||
|
||
- `AI_GENERATION_MODELS`: JSON configuration for which AI generation | ||
models (i.e. GPTS are enabled). These models are used for summarization | ||
text to be embedded by an embedding model if the text length would be | ||
greater than the context window of the embedding model. Each model in | ||
the array will be instantiated by `ai_models/ModelManager`. | ||
See `ai_models/ModelManager` for which configurations are supported. | ||
|
||
Example: | ||
|
||
`AI_GENERATION_MODELS='[{"model": "text-generation-interface:TheBloke/zephyr-7b-beta", "url": "http://localhost:3050/"}]'` | ||
|
||
## Usage | ||
|
||
The Embedder service is stateless and takes no arguments. Multiple instances | ||
of the service may be started in order to match embedding load, or to | ||
catch up on history more quickly. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import {Kysely, sql} from 'kysely' | ||
|
||
import { | ||
AbstractEmbeddingsModel, | ||
AbstractGenerationModel, | ||
EmbeddingModelConfig, | ||
GenerationModelConfig, | ||
ModelConfig | ||
} from './AbstractModel' | ||
import TextEmbeddingsInterface from './TextEmbeddingsInterface' | ||
import TextGenerationInterface from './TextGenerationInterface' | ||
|
||
interface ModelManagerConfig { | ||
embeddingModels: EmbeddingModelConfig[] | ||
generationModels: GenerationModelConfig[] | ||
} | ||
|
||
export enum EmbeddingsModelTypes { | ||
TextEmbeddingsInterface = 'text-embeddings-inference' | ||
} | ||
export function isValidEmbeddingsModelType(type: any): type is EmbeddingsModelTypes { | ||
return Object.values(EmbeddingsModelTypes).includes(type) | ||
} | ||
|
||
export enum SummarizationModelTypes { | ||
TextGenerationInterface = 'text-generation-interface' | ||
} | ||
export function isValidSummarizationModelType(type: any): type is SummarizationModelTypes { | ||
return Object.values(SummarizationModelTypes).includes(type) | ||
} | ||
|
||
export class ModelManager { | ||
private embeddingModels: AbstractEmbeddingsModel[] | ||
private generationModels: AbstractGenerationModel[] | ||
|
||
private isValidConfig( | ||
maybeConfig: Partial<ModelManagerConfig> | ||
): maybeConfig is ModelManagerConfig { | ||
if (!maybeConfig.embeddingModels || !Array.isArray(maybeConfig.embeddingModels)) { | ||
throw new Error('Invalid configuration: embedding_models is missing or not an array') | ||
} | ||
if (!maybeConfig.generationModels || !Array.isArray(maybeConfig.generationModels)) { | ||
throw new Error('Invalid configuration: summarization_models is missing or not an array') | ||
} | ||
|
||
maybeConfig.embeddingModels.forEach((model: ModelConfig) => { | ||
this.isValidModelConfig(model) | ||
}) | ||
|
||
maybeConfig.generationModels.forEach((model: ModelConfig) => { | ||
this.isValidModelConfig(model) | ||
}) | ||
|
||
return true | ||
} | ||
|
||
private isValidModelConfig(model: ModelConfig): model is ModelConfig { | ||
if (typeof model.model !== 'string') { | ||
throw new Error('Invalid ModelConfig: model field should be a string') | ||
} | ||
if (model.url !== undefined && typeof model.url !== 'string') { | ||
throw new Error('Invalid ModelConfig: url field should be a string') | ||
} | ||
|
||
return true | ||
} | ||
|
||
constructor(config: ModelManagerConfig) { | ||
// Validate configuration | ||
this.isValidConfig(config) | ||
// Initialize embeddings models | ||
this.embeddingModels = [] | ||
config.embeddingModels.forEach(async (modelConfig) => { | ||
const modelType = modelConfig.model.split(':')[0] | ||
|
||
if (!isValidEmbeddingsModelType(modelType)) | ||
throw new Error(`unsupported embeddings model '${modelType}'`) | ||
|
||
switch (modelType) { | ||
case 'text-embeddings-inference': | ||
const embeddingsModel = new TextEmbeddingsInterface(modelConfig) | ||
this.embeddingModels.push(embeddingsModel) | ||
break | ||
} | ||
}) | ||
|
||
// Initialize summarization models | ||
this.generationModels = [] | ||
config.generationModels.forEach(async (modelConfig) => { | ||
const modelType = modelConfig.model.split(':')[0] | ||
|
||
if (!isValidSummarizationModelType(modelType)) | ||
throw new Error(`unsupported summarization model '${modelType}'`) | ||
|
||
switch (modelType) { | ||
case 'text-generation-interface': | ||
const generator = new TextGenerationInterface(modelConfig) | ||
this.generationModels.push(generator) | ||
break | ||
} | ||
}) | ||
} | ||
|
||
async maybeCreateTables(pg: Kysely<any>) { | ||
for (const embeddingsModel of this.getEmbeddingModelsIter()) { | ||
const tableName = embeddingsModel.getTableName() | ||
const hasTable = await (async () => { | ||
const query = sql<number[]>`SELECT 1 FROM ${sql.id( | ||
'pg_catalog', | ||
'pg_tables' | ||
)} WHERE ${sql.id('tablename')} = ${tableName}` | ||
const result = await query.execute(pg) | ||
return result.rows.length > 0 | ||
})() | ||
if (hasTable) continue | ||
const vectorDimensions = embeddingsModel.getModelParams().embeddingDimensions | ||
console.log(`ModelManager: creating ${tableName} with ${vectorDimensions} dimensions`) | ||
const query = sql` | ||
DO $$ | ||
BEGIN | ||
CREATE TABLE IF NOT EXISTS ${sql.id(tableName)} ( | ||
"id" SERIAL PRIMARY KEY, | ||
"embedText" TEXT, | ||
"embedding" vector(${sql.raw(vectorDimensions.toString())}), | ||
"embeddingsMetadataId" INTEGER NOT NULL, | ||
FOREIGN KEY ("embeddingsMetadataId") | ||
REFERENCES "EmbeddingsMetadata"("id") | ||
ON DELETE CASCADE | ||
); | ||
CREATE INDEX IF NOT EXISTS "idx_${sql.raw(tableName)}_embedding_vector_cosign_ops" | ||
ON ${sql.id(tableName)} | ||
USING hnsw ("embedding" vector_cosine_ops); | ||
END $$; | ||
` | ||
await query.execute(pg) | ||
} | ||
} | ||
|
||
// returns the highest priority summarizer instance | ||
getFirstGenerator() { | ||
if (!this.generationModels.length) throw new Error('no generator model initialzed') | ||
return this.generationModels[0] | ||
} | ||
|
||
getFirstEmbedder() { | ||
if (!this.embeddingModels.length) throw new Error('no embedder model initialzed') | ||
return this.embeddingModels[0] | ||
} | ||
|
||
getEmbeddingModelsIter() { | ||
return this.embeddingModels[Symbol.iterator]() | ||
} | ||
} | ||
|
||
let modelManager: ModelManager | undefined | ||
export function getModelManager() { | ||
if (modelManager) return modelManager | ||
const {AI_EMBEDDING_MODELS, AI_GENERATION_MODELS} = process.env | ||
let config: ModelManagerConfig = { | ||
embeddingModels: [], | ||
generationModels: [] | ||
} | ||
try { | ||
config.embeddingModels = AI_EMBEDDING_MODELS && JSON.parse(AI_EMBEDDING_MODELS) | ||
} catch (e) { | ||
throw new Error(`Invalid AI_EMBEDDING_MODELS .env JSON: ${e}`) | ||
} | ||
try { | ||
config.generationModels = AI_GENERATION_MODELS && JSON.parse(AI_GENERATION_MODELS) | ||
} catch (e) { | ||
throw new Error(`Invalid AI_EMBEDDING_MODELS .env JSON: ${e}`) | ||
} | ||
|
||
modelManager = new ModelManager(config) | ||
|
||
return modelManager | ||
} | ||
|
||
export default getModelManager |
Oops, something went wrong.