Skip to content

Commit

Permalink
feat: add embedder service
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanh committed Feb 13, 2024
1 parent 012ca77 commit 192a4d0
Show file tree
Hide file tree
Showing 30 changed files with 2,136 additions and 68 deletions.
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ SERVER_ID='1'
# Websocket port for the websocket server, only used in development (yarn dev)
SOCKET_PORT='3001'

# AI MODELS
AI_EMBEDDING_MODELS='[{"model": "text-embeddings-inference:llmrails/ember-v1", "url": "http://localhost:3040/"}]'
AI_GENERATION_MODELS='[{"model": "text-generation-interface:TheBloke/zephyr-7b-beta", "url": "http://localhost:3050/"}]'
AI_EMBEDDER_ENABLED='true'

# APPLICATION
# AMPLITUDE_WRITE_KEY='key_AMPLITUDE_WRITE_KEY'
# Enter a short url redirect service for invitations, it needs to redirecto to /invitation-link
Expand Down
17 changes: 15 additions & 2 deletions docker/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ services:
- /var/run/docker.sock:/var/run/docker.sock
- /proc/:/host/proc/:ro
- /sys/fs/cgroup:/host/sys/fs/cgroup:ro
- "./dd-conf.d:/etc/datadog-agent/conf.d/local.d/"
- "../dev/logs:/var/log/datadog/logs"
db:
image: rethinkdb:2.4.2
restart: unless-stopped
Expand Down Expand Up @@ -72,10 +70,25 @@ services:
- "8082:8081"
networks:
parabol-network:
text-embeddings-inference:
container_name: text-embeddings-inference
image: ghcr.io/huggingface/text-embeddings-inference:cpu-0.6
command:
- "--model-id=llmrails/ember-v1"
platform: linux/x86_64
hostname: text-embeddings-inference
restart: unless-stopped
ports:
- "3040:80"
volumes:
- text-embeddings-inference-data:/data
networks:
parabol-network:
networks:
parabol-network:
volumes:
redis-data: {}
rethink-data: {}
postgres-data: {}
pgadmin-data: {}
text-embeddings-inference-data: {}
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
"html-webpack-plugin": "^5.5.0",
"husky": "^7.0.4",
"jscodeshift": "^0.14.0",
"kysely": "^0.26.3",
"kysely-codegen": "^0.10.0",
"kysely": "^0.27.2",
"kysely-codegen": "^0.11.0",
"lerna": "^6.4.1",
"mini-css-extract-plugin": "^2.7.2",
"minimist": "^1.2.5",
Expand Down
11 changes: 11 additions & 0 deletions packages/embedder/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module.exports = {
extends: [
'../../.eslintrc.js'
],
parserOptions: {
project: './tsconfig.json',
ecmaVersion: 2020,
sourceType: 'module'
},
"ignorePatterns": ["**/lib", "*.js"]
}
71 changes: 71 additions & 0 deletions packages/embedder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# `Embedder`

This service builds embedding vectors for semantic search and for other AI/ML
use cases. It does so by:

1. Updating a list of all possible items to create embedding vectors for and
storing that list in the `EmbeddingsMetadata` table
2. Adding these items in batches to the `EmbeddingsJobQueue` table and a redis
priority queue called `embedder:queue`
3. Allowing one or more parallel embedding services to calculate embedding
vectors (EmbeddingJobQueue states transistion from `queued` -> `embedding`,
then `embedding` -> [deleting the `EmbeddingJobQueue` row]

In addition to deleteing the `EmbeddingJobQueue` row, when a job completes
successfully:

- A row is added to the model table with the embedding vector; the
`EmbeddingMetadataId` field on this row points the appropriate
metadata row on `EmbeddingsMetadata`
- The `EmbeddingsMetadata.models` array is updated with the name of the
table that the embedding has been generated for

4. This process repeats forever using a silly polling loop

In the future, it would be wonderful to enhance this service such that it were
event driven.

## Prerequisites

The Embedder service depends on pgvector being available in Postgres.

The predeploy script checks for an environment variable
`POSTGRES_USE_PGVECTOR=true` to enable this extension in production.

## Configuration

The Embedder service takes no arguments and is controlled by the following
environment variables, here given with example configuration:

- `AI_EMBEDDER_ENABLE`: enable/disable the embedder service from
performing work, or sleeping indefinitely

`AI_EMBEDDER_ENABLED='true'`

- `AI_EMBEDDING_MODELS`: JSON configuration for which embedding models
are enabled. Each model in the array will be instantiated by
`ai_models/ModelManager`. Each model instance will have its own
database table created for it (if it does not exist already) used
to store calculated vectors. See `ai_models/ModelManager` for
which configurations are supported.

Example:

`AI_EMBEDDING_MODELS='[{"model": "text-embeddings-inference:llmrails/ember-v1", "url": "http://localhost:3040/"}]'`

- `AI_GENERATION_MODELS`: JSON configuration for which AI generation
models (i.e. GPTS are enabled). These models are used for summarization
text to be embedded by an embedding model if the text length would be
greater than the context window of the embedding model. Each model in
the array will be instantiated by `ai_models/ModelManager`.
See `ai_models/ModelManager` for which configurations are supported.

Example:

`AI_GENERATION_MODELS='[{"model": "text-generation-interface:TheBloke/zephyr-7b-beta", "url": "http://localhost:3050/"}]'`

## Usage

The Embedder service is stateless and takes no arguments. Multiple instances
of the service may be started in order to match embedding load, or to
catch up on history more quickly.
179 changes: 179 additions & 0 deletions packages/embedder/ai_models/ModelManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import {Kysely, sql} from 'kysely'

import {
AbstractEmbeddingsModel,
AbstractGenerationModel,
EmbeddingModelConfig,
GenerationModelConfig,
ModelConfig
} from './AbstractModel'
import TextEmbeddingsInterface from './TextEmbeddingsInterface'
import TextGenerationInterface from './TextGenerationInterface'

interface ModelManagerConfig {
embeddingModels: EmbeddingModelConfig[]
generationModels: GenerationModelConfig[]
}

export enum EmbeddingsModelTypes {
TextEmbeddingsInterface = 'text-embeddings-inference'
}
export function isValidEmbeddingsModelType(type: any): type is EmbeddingsModelTypes {
return Object.values(EmbeddingsModelTypes).includes(type)
}

export enum SummarizationModelTypes {
TextGenerationInterface = 'text-generation-interface'
}
export function isValidSummarizationModelType(type: any): type is SummarizationModelTypes {
return Object.values(SummarizationModelTypes).includes(type)
}

export class ModelManager {
private embeddingModels: AbstractEmbeddingsModel[]
private generationModels: AbstractGenerationModel[]

private isValidConfig(
maybeConfig: Partial<ModelManagerConfig>
): maybeConfig is ModelManagerConfig {
if (!maybeConfig.embeddingModels || !Array.isArray(maybeConfig.embeddingModels)) {
throw new Error('Invalid configuration: embedding_models is missing or not an array')
}
if (!maybeConfig.generationModels || !Array.isArray(maybeConfig.generationModels)) {
throw new Error('Invalid configuration: summarization_models is missing or not an array')
}

maybeConfig.embeddingModels.forEach((model: ModelConfig) => {
this.isValidModelConfig(model)
})

maybeConfig.generationModels.forEach((model: ModelConfig) => {
this.isValidModelConfig(model)
})

return true
}

private isValidModelConfig(model: ModelConfig): model is ModelConfig {
if (typeof model.model !== 'string') {
throw new Error('Invalid ModelConfig: model field should be a string')
}
if (model.url !== undefined && typeof model.url !== 'string') {
throw new Error('Invalid ModelConfig: url field should be a string')
}

return true
}

constructor(config: ModelManagerConfig) {
// Validate configuration
this.isValidConfig(config)
// Initialize embeddings models
this.embeddingModels = []
config.embeddingModels.forEach(async (modelConfig) => {
const modelType = modelConfig.model.split(':')[0]

if (!isValidEmbeddingsModelType(modelType))
throw new Error(`unsupported embeddings model '${modelType}'`)

switch (modelType) {
case 'text-embeddings-inference':
const embeddingsModel = new TextEmbeddingsInterface(modelConfig)
this.embeddingModels.push(embeddingsModel)
break
}
})

// Initialize summarization models
this.generationModels = []
config.generationModels.forEach(async (modelConfig) => {
const modelType = modelConfig.model.split(':')[0]

if (!isValidSummarizationModelType(modelType))
throw new Error(`unsupported summarization model '${modelType}'`)

switch (modelType) {
case 'text-generation-interface':
const generator = new TextGenerationInterface(modelConfig)
this.generationModels.push(generator)
break
}
})
}

async maybeCreateTables(pg: Kysely<any>) {
for (const embeddingsModel of this.getEmbeddingModelsIter()) {
const tableName = embeddingsModel.getTableName()
const hasTable = await (async () => {
const query = sql<number[]>`SELECT 1 FROM ${sql.id(
'pg_catalog',
'pg_tables'
)} WHERE ${sql.id('tablename')} = ${tableName}`
const result = await query.execute(pg)
return result.rows.length > 0
})()
if (hasTable) continue
const vectorDimensions = embeddingsModel.getModelParams().embeddingDimensions
console.log(`ModelManager: creating ${tableName} with ${vectorDimensions} dimensions`)
const query = sql`
DO $$
BEGIN
CREATE TABLE IF NOT EXISTS ${sql.id(tableName)} (
"id" SERIAL PRIMARY KEY,
"embedText" TEXT,
"embedding" vector(${sql.raw(vectorDimensions.toString())}),
"embeddingsMetadataId" INTEGER NOT NULL,
FOREIGN KEY ("embeddingsMetadataId")
REFERENCES "EmbeddingsMetadata"("id")
ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS "idx_${sql.raw(tableName)}_embedding_vector_cosign_ops"
ON ${sql.id(tableName)}
USING hnsw ("embedding" vector_cosine_ops);
END $$;
`
await query.execute(pg)
}
}

// returns the highest priority summarizer instance
getFirstGenerator() {
if (!this.generationModels.length) throw new Error('no generator model initialzed')
return this.generationModels[0]
}

getFirstEmbedder() {
if (!this.embeddingModels.length) throw new Error('no embedder model initialzed')
return this.embeddingModels[0]
}

getEmbeddingModelsIter() {
return this.embeddingModels[Symbol.iterator]()
}
}

let modelManager: ModelManager | undefined
export function getModelManager() {
if (modelManager) return modelManager
const {AI_EMBEDDING_MODELS, AI_GENERATION_MODELS} = process.env
let config: ModelManagerConfig = {
embeddingModels: [],
generationModels: []
}
try {
config.embeddingModels = AI_EMBEDDING_MODELS && JSON.parse(AI_EMBEDDING_MODELS)
} catch (e) {
throw new Error(`Invalid AI_EMBEDDING_MODELS .env JSON: ${e}`)
}
try {
config.generationModels = AI_GENERATION_MODELS && JSON.parse(AI_GENERATION_MODELS)
} catch (e) {
throw new Error(`Invalid AI_EMBEDDING_MODELS .env JSON: ${e}`)
}

modelManager = new ModelManager(config)

return modelManager
}

export default getModelManager
Loading

0 comments on commit 192a4d0

Please sign in to comment.