diff --git a/.sqlx/query-ad216288cbbe83aba35b5d04705ee5964f1da4f3839c4725a6784c13f2245379.json b/.sqlx/query-ad216288cbbe83aba35b5d04705ee5964f1da4f3839c4725a6784c13f2245379.json deleted file mode 100644 index 9bc39f3e9..000000000 --- a/.sqlx/query-ad216288cbbe83aba35b5d04705ee5964f1da4f3839c4725a6784c13f2245379.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n select c.workspace_id, c.oid, c.partition_key\n from af_collab c\n join af_workspace w on c.workspace_id = w.workspace_id\n where not coalesce(w.settings['disable_search_indexding']::boolean, false)\n and not exists (\n select 1 from af_collab_embeddings em\n where em.oid = c.oid and em.partition_key = 0\n )\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "workspace_id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "oid", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "partition_key", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "ad216288cbbe83aba35b5d04705ee5964f1da4f3839c4725a6784c13f2245379" -} diff --git a/.sqlx/query-d0e5f5097b35a15f19e9e7faf2c62336d5f130e939331e84c7d834f6028ea673.json b/.sqlx/query-d0e5f5097b35a15f19e9e7faf2c62336d5f130e939331e84c7d834f6028ea673.json new file mode 100644 index 000000000..088448334 --- /dev/null +++ b/.sqlx/query-d0e5f5097b35a15f19e9e7faf2c62336d5f130e939331e84c7d834f6028ea673.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE af_collab\n SET indexed_at = $1\n WHERE oid = $2 AND partition_key = $3\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Text", + "Int4" + ] + }, + "nullable": [] + }, + "hash": "d0e5f5097b35a15f19e9e7faf2c62336d5f130e939331e84c7d834f6028ea673" +} diff --git a/.sqlx/query-f68cc2042d6aa78feeb33640e9ef13f46c5e10ee269ea0bd965b0e57dee6cf94.json b/.sqlx/query-f68cc2042d6aa78feeb33640e9ef13f46c5e10ee269ea0bd965b0e57dee6cf94.json new file mode 100644 index 000000000..ce90f37ba --- /dev/null +++ b/.sqlx/query-f68cc2042d6aa78feeb33640e9ef13f46c5e10ee269ea0bd965b0e57dee6cf94.json @@ -0,0 +1,35 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT c.workspace_id, c.oid, c.partition_key\n FROM af_collab c\n JOIN af_workspace w ON c.workspace_id = w.workspace_id\n WHERE c.workspace_id = $1\n AND NOT COALESCE(w.settings['disable_search_indexing']::boolean, false)\n AND c.indexed_at IS NULL\n ORDER BY c.updated_at DESC\n LIMIT $2\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "workspace_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "oid", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "partition_key", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Int8" + ] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "f68cc2042d6aa78feeb33640e9ef13f46c5e10ee269ea0bd965b0e57dee6cf94" +} diff --git a/.sqlx/query-f8c909517885cb30e3f7d573edf47138f90ea9c5fa73eb927cc5487c3d9ad0be.json b/.sqlx/query-f8c909517885cb30e3f7d573edf47138f90ea9c5fa73eb927cc5487c3d9ad0be.json new file mode 100644 index 000000000..35e948e0e --- /dev/null +++ b/.sqlx/query-f8c909517885cb30e3f7d573edf47138f90ea9c5fa73eb927cc5487c3d9ad0be.json @@ -0,0 +1,29 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT oid, indexed_at\n FROM af_collab\n WHERE (oid, partition_key) = ANY (\n SELECT UNNEST($1::text[]), UNNEST($2::int[])\n )\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "oid", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "indexed_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int4Array" + ] + }, + "nullable": [ + false, + true + ] + }, + "hash": "f8c909517885cb30e3f7d573edf47138f90ea9c5fa73eb927cc5487c3d9ad0be" +} diff --git a/Cargo.lock b/Cargo.lock index cee803e31..2a6ad2168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -649,6 +649,7 @@ dependencies = [ "hex", "http 0.2.12", "image", + "indexer", "infra", "itertools 0.11.0", "lazy_static", @@ -738,6 +739,7 @@ dependencies = [ "futures", "futures-util", "governor", + "indexer", "indexmap 2.3.0", "itertools 0.12.1", "lazy_static", @@ -777,6 +779,8 @@ name = "appflowy-worker" version = "0.1.0" dependencies = [ "anyhow", + "app-error", + "appflowy-collaborate", "async_zip", "aws-config", "aws-sdk-s3", @@ -792,11 +796,13 @@ dependencies = [ "database-entity", "dotenvy", "futures", + "indexer", "infra", "mailer", "md5", "mime_guess", "prometheus-client", + "rayon", "redis 0.25.4", "reqwest", "secrecy", @@ -2856,9 +2862,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "6.0.1" +version = "6.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", @@ -4202,6 +4208,42 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d" +[[package]] +name = "indexer" +version = "0.1.0" +dependencies = [ + "anyhow", + "app-error", + "appflowy-ai-client", + "async-trait", + "bytes", + "chrono", + "collab", + "collab-document", + "collab-entity", + "collab-folder", + "collab-stream", + "dashmap 6.1.0", + "database", + "database-entity", + "futures-util", + "infra", + "prometheus-client", + "rayon", + "redis 0.25.4", + "serde", + "serde_json", + "sqlx", + "thiserror 1.0.63", + "tiktoken-rs", + "tokio", + "tokio-util", + "tracing", + "unicode-segmentation", + "ureq", + "uuid", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -7310,11 +7352,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.8" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" +checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" dependencies = [ - "thiserror-impl 2.0.8", + "thiserror-impl 2.0.9", ] [[package]] @@ -7330,9 +7372,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.8" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" +checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" dependencies = [ "proc-macro2", "quote", @@ -7899,7 +7941,7 @@ dependencies = [ "native-tls", "rand 0.8.5", "sha1", - "thiserror 2.0.8", + "thiserror 2.0.9", "utf-8", ] @@ -7959,9 +8001,9 @@ checksum = "e4259d9d4425d9f0661581b804cb85fe66a4c631cadd8f490d1c13a35d5d9291" [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" @@ -8696,7 +8738,7 @@ dependencies = [ "arc-swap", "async-lock", "async-trait", - "dashmap 6.0.1", + "dashmap 6.1.0", "fastrand", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 2f47d8131..69941a014 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -156,6 +156,7 @@ base64.workspace = true md5.workspace = true nanoid = "0.4.0" http.workspace = true +indexer.workspace = true [dev-dependencies] flate2 = "1.0" @@ -177,7 +178,6 @@ collab-rt-entity = { path = "libs/collab-rt-entity" } hex = "0.4.3" unicode-normalization = "0.1.24" - [[bin]] name = "appflowy_cloud" path = "src/main.rs" @@ -221,9 +221,11 @@ members = [ "xtask", "libs/tonic-proto", "libs/mailer", + "libs/indexer", ] [workspace.dependencies] +indexer = { path = "libs/indexer" } collab-rt-entity = { path = "libs/collab-rt-entity" } collab-rt-protocol = { path = "libs/collab-rt-protocol" } database = { path = "libs/database" } diff --git a/deploy.env b/deploy.env index e273beca3..fe4667740 100644 --- a/deploy.env +++ b/deploy.env @@ -162,6 +162,7 @@ APPFLOWY_LOCAL_AI_TEST_ENABLED=false APPFLOWY_INDEXER_ENABLED=true APPFLOWY_INDEXER_DATABASE_URL=postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} APPFLOWY_INDEXER_REDIS_URL=redis://${REDIS_HOST}:${REDIS_PORT} +APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE=5000 # AppFlowy Collaborate APPFLOWY_COLLABORATE_MULTI_THREAD=false diff --git a/dev.env b/dev.env index 4c060b9a2..34737c958 100644 --- a/dev.env +++ b/dev.env @@ -124,6 +124,7 @@ APPFLOWY_LOCAL_AI_TEST_ENABLED=false APPFLOWY_INDEXER_ENABLED=true APPFLOWY_INDEXER_DATABASE_URL=postgres://postgres:password@postgres:5432/postgres APPFLOWY_INDEXER_REDIS_URL=redis://redis:6379 +APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE=5000 # AppFlowy Collaborate APPFLOWY_COLLABORATE_MULTI_THREAD=false diff --git a/libs/app-error/src/lib.rs b/libs/app-error/src/lib.rs index 0ba39eefc..d0c2613e9 100644 --- a/libs/app-error/src/lib.rs +++ b/libs/app-error/src/lib.rs @@ -182,6 +182,9 @@ pub enum AppError { #[error("Decode update error: {0}")] DecodeUpdateError(String), + #[error("{0}")] + ActionTimeout(String), + #[error("Apply update error:{0}")] ApplyUpdateError(String), } @@ -263,6 +266,7 @@ impl AppError { AppError::ServiceTemporaryUnavailable(_) => ErrorCode::ServiceTemporaryUnavailable, AppError::DecodeUpdateError(_) => ErrorCode::DecodeUpdateError, AppError::ApplyUpdateError(_) => ErrorCode::ApplyUpdateError, + AppError::ActionTimeout(_) => ErrorCode::ActionTimeout, } } } @@ -316,6 +320,7 @@ impl From for AppError { sqlx::Error::RowNotFound => { AppError::RecordNotFound(format!("Record not exist in db. {})", msg)) }, + sqlx::Error::PoolTimedOut => AppError::ActionTimeout(value.to_string()), _ => AppError::SqlxError(msg), } } @@ -424,6 +429,7 @@ pub enum ErrorCode { ServiceTemporaryUnavailable = 1054, DecodeUpdateError = 1055, ApplyUpdateError = 1056, + ActionTimeout = 1057, } impl ErrorCode { diff --git a/libs/collab-rt-protocol/src/data_validation.rs b/libs/collab-rt-protocol/src/data_validation.rs index b84f6fd60..365fc724d 100644 --- a/libs/collab-rt-protocol/src/data_validation.rs +++ b/libs/collab-rt-protocol/src/data_validation.rs @@ -6,14 +6,8 @@ use collab::preclude::Collab; use collab_entity::CollabType; use tracing::instrument; -#[instrument(level = "trace", skip(data), fields(len = %data.len()))] #[inline] -pub async fn spawn_blocking_validate_encode_collab( - object_id: &str, - data: &[u8], - collab_type: &CollabType, -) -> Result<(), Error> { - let collab_type = collab_type.clone(); +pub async fn collab_from_encode_collab(object_id: &str, data: &[u8]) -> Result { let object_id = object_id.to_string(); let data = data.to_vec(); @@ -27,28 +21,19 @@ pub async fn spawn_blocking_validate_encode_collab( false, )?; - collab_type.validate_require_data(&collab)?; - Ok::<(), Error>(()) + Ok::<_, Error>(collab) }) .await? } #[instrument(level = "trace", skip(data), fields(len = %data.len()))] #[inline] -pub fn validate_encode_collab( +pub async fn validate_encode_collab( object_id: &str, data: &[u8], collab_type: &CollabType, ) -> Result<(), Error> { - let encoded_collab = EncodedCollab::decode_from_bytes(data)?; - let collab = Collab::new_with_source( - CollabOrigin::Empty, - object_id, - DataSource::DocStateV1(encoded_collab.doc_state.to_vec()), - vec![], - false, - )?; - + let collab = collab_from_encode_collab(object_id, data).await?; collab_type.validate_require_data(&collab)?; Ok::<(), Error>(()) } diff --git a/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs index 1d2b0d23d..18030ee07 100644 --- a/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs @@ -76,7 +76,7 @@ async fn single_group_async_read_message_test() { } #[tokio::test] -async fn different_group_read_message_test() { +async fn different_group_read_undelivered_message_test() { let oid = format!("o{}", random_i64()); let client = stream_client().await; let mut group_1 = client @@ -110,6 +110,40 @@ async fn different_group_read_message_test() { assert_eq!(group_2_messages[0].data, vec![1, 2, 3, 4, 5]); } +#[tokio::test] +async fn different_group_read_message_test() { + let oid = format!("o{}", random_i64()); + let client = stream_client().await; + let mut group_1 = client.collab_update_stream("w1", &oid, "g1").await.unwrap(); + let mut group_2 = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + + let msg = StreamBinary(vec![1, 2, 3, 4, 5]); + + { + let client = stream_client().await; + let mut group = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + group.insert_binary(msg).await.unwrap(); + } + let msg = group_1 + .consumer_messages("consumer1", ReadOption::Count(1)) + .await + .unwrap(); + group_1.ack_messages(&msg).await.unwrap(); + + let (result1, result2) = join( + group_1.consumer_messages("consumer1", ReadOption::Count(1)), + group_2.consumer_messages("consumer1", ReadOption::Count(1)), + ) + .await; + + let group_1_messages = result1.unwrap(); + let group_2_messages = result2.unwrap(); + + // consumer1 already acked the message before. so it should not be available + assert!(group_1_messages.is_empty()); + assert_eq!(group_2_messages[0].data, vec![1, 2, 3, 4, 5]); +} + #[tokio::test] async fn read_specific_num_of_message_test() { let object_id = format!("o{}", random_i64()); diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 7065d24ca..bb5e2db1b 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -742,7 +742,6 @@ impl From for AFWorkspaceInvitationStatus { pub struct AFCollabEmbeddedChunk { pub fragment_id: String, pub object_id: String, - pub collab_type: CollabType, pub content_type: EmbeddingContentType, pub content: String, pub embedding: Option>, diff --git a/libs/database/src/index/collab_embeddings_ops.rs b/libs/database/src/index/collab_embeddings_ops.rs index c1db97390..e97b6d6ba 100644 --- a/libs/database/src/index/collab_embeddings_ops.rs +++ b/libs/database/src/index/collab_embeddings_ops.rs @@ -1,14 +1,17 @@ +use crate::collab::partition_key_from_collab_type; +use chrono::{DateTime, Utc}; use collab_entity::CollabType; +use database_entity::dto::{AFCollabEmbeddedChunk, IndexingStatus, QueryCollab, QueryCollabParams}; use futures_util::stream::BoxStream; use futures_util::StreamExt; use pgvector::Vector; +use sqlx::pool::PoolConnection; use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; -use sqlx::{Error, Executor, PgPool, Postgres, Transaction}; +use sqlx::{Error, Executor, Postgres, Transaction}; +use std::collections::HashMap; use std::ops::DerefMut; use uuid::Uuid; -use database_entity::dto::{AFCollabEmbeddedChunk, IndexingStatus, QueryCollab, QueryCollabParams}; - pub async fn get_index_status<'a, E>( tx: E, workspace_id: &Uuid, @@ -89,17 +92,17 @@ impl PgHasArrayType for Fragment { pub async fn upsert_collab_embeddings( transaction: &mut Transaction<'_, Postgres>, workspace_id: &Uuid, + object_id: &str, + collab_type: CollabType, tokens_used: u32, records: Vec, ) -> Result<(), sqlx::Error> { - if records.is_empty() { - return Ok(()); - } - - let object_id = records[0].object_id.clone(); - let collab_type = records[0].collab_type.clone(); let fragments = records.into_iter().map(Fragment::from).collect::>(); - + tracing::trace!( + "[Embedding] upsert {} {} fragments", + object_id, + fragments.len() + ); sqlx::query(r#"CALL af_collab_embeddings_upsert($1, $2, $3, $4, $5::af_fragment_v2[])"#) .bind(*workspace_id) .bind(object_id) @@ -111,21 +114,26 @@ pub async fn upsert_collab_embeddings( Ok(()) } -pub fn get_collabs_without_embeddings(pg_pool: &PgPool) -> BoxStream> { - // atm. get only documents +pub async fn stream_collabs_without_embeddings( + conn: &mut PoolConnection, + workspace_id: Uuid, + limit: i64, +) -> BoxStream> { sqlx::query!( r#" - select c.workspace_id, c.oid, c.partition_key - from af_collab c - join af_workspace w on c.workspace_id = w.workspace_id - where not coalesce(w.settings['disable_search_indexding']::boolean, false) - and not exists ( - select 1 from af_collab_embeddings em - where em.oid = c.oid and em.partition_key = 0 - ) - "# + SELECT c.workspace_id, c.oid, c.partition_key + FROM af_collab c + JOIN af_workspace w ON c.workspace_id = w.workspace_id + WHERE c.workspace_id = $1 + AND NOT COALESCE(w.settings['disable_search_indexing']::boolean, false) + AND c.indexed_at IS NULL + ORDER BY c.updated_at DESC + LIMIT $2 + "#, + workspace_id, + limit ) - .fetch(pg_pool) + .fetch(conn.deref_mut()) .map(|row| { row.map(|r| CollabId { collab_type: CollabType::from(r.partition_key), @@ -136,6 +144,71 @@ pub fn get_collabs_without_embeddings(pg_pool: &PgPool) -> BoxStream( + tx: E, + object_id: &str, + collab_type: &CollabType, + indexed_at: DateTime, +) -> Result<(), Error> +where + E: Executor<'a, Database = Postgres>, +{ + let partition_key = partition_key_from_collab_type(collab_type); + sqlx::query!( + r#" + UPDATE af_collab + SET indexed_at = $1 + WHERE oid = $2 AND partition_key = $3 + "#, + indexed_at, + object_id, + partition_key + ) + .execute(tx) + .await?; + + Ok(()) +} + +pub async fn get_collabs_indexed_at<'a, E>( + executor: E, + collab_ids: Vec<(String, CollabType)>, +) -> Result>, Error> +where + E: Executor<'a, Database = Postgres>, +{ + let (oids, partition_keys): (Vec, Vec) = collab_ids + .into_iter() + .map(|(object_id, collab_type)| (object_id, partition_key_from_collab_type(&collab_type))) + .unzip(); + + let result = sqlx::query!( + r#" + SELECT oid, indexed_at + FROM af_collab + WHERE (oid, partition_key) = ANY ( + SELECT UNNEST($1::text[]), UNNEST($2::int[]) + ) + "#, + &oids, + &partition_keys + ) + .fetch_all(executor) + .await?; + + let map = result + .into_iter() + .filter_map(|r| { + if let Some(indexed_at) = r.indexed_at { + Some((r.oid, indexed_at)) + } else { + None + } + }) + .collect::>>(); + Ok(map) +} + #[derive(Debug, Clone)] pub struct CollabId { pub collab_type: CollabType, diff --git a/libs/indexer/Cargo.toml b/libs/indexer/Cargo.toml new file mode 100644 index 000000000..2602a9f44 --- /dev/null +++ b/libs/indexer/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "indexer" +version = "0.1.0" +edition = "2021" + +[dependencies] +rayon.workspace = true +tiktoken-rs = "0.6.0" +app-error = { workspace = true } +appflowy-ai-client = { workspace = true, features = ["client-api"] } +unicode-segmentation = "1.12.0" +collab = { workspace = true } +collab-entity = { workspace = true } +collab-folder = { workspace = true } +collab-document = { workspace = true } +collab-stream = { workspace = true } +database-entity.workspace = true +database.workspace = true +futures-util.workspace = true +sqlx.workspace = true +tokio.workspace = true +tracing.workspace = true +thiserror = "1.0.56" +uuid.workspace = true +async-trait.workspace = true +serde_json.workspace = true +anyhow.workspace = true +infra.workspace = true +prometheus-client = "0.22.3" +bytes.workspace = true +dashmap = "6.1.0" +chrono = "0.4.39" +ureq = "2.12.1" +serde.workspace = true +redis = { workspace = true, features = [ + "aio", + "tokio-comp", + "connection-manager", +] } +tokio-util = "0.7.12" \ No newline at end of file diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/libs/indexer/src/collab_indexer/document_indexer.rs similarity index 85% rename from services/appflowy-collaborate/src/indexer/document_indexer.rs rename to libs/indexer/src/collab_indexer/document_indexer.rs index fec2d845f..b0409864b 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/libs/indexer/src/collab_indexer/document_indexer.rs @@ -1,6 +1,6 @@ -use crate::indexer::open_ai::split_text_by_max_content_len; -use crate::indexer::vector::embedder::Embedder; -use crate::indexer::Indexer; +use crate::collab_indexer::Indexer; +use crate::vector::embedder::Embedder; +use crate::vector::open_ai::split_text_by_max_content_len; use anyhow::anyhow; use app_error::AppError; use appflowy_ai_client::dto::{ @@ -20,7 +20,7 @@ pub struct DocumentIndexer; #[async_trait] impl Indexer for DocumentIndexer { - fn create_embedded_chunks( + fn create_embedded_chunks_from_collab( &self, collab: &Collab, embedding_model: EmbeddingModel, @@ -35,9 +35,7 @@ impl Indexer for DocumentIndexer { let result = document.to_plain_text(collab.transact(), false, true); match result { - Ok(content) => { - split_text_into_chunks(object_id, content, CollabType::Document, &embedding_model) - }, + Ok(content) => self.create_embedded_chunks_from_text(object_id, content, embedding_model), Err(err) => { if matches!(err, DocumentError::NoRequiredData) { Ok(vec![]) @@ -48,6 +46,15 @@ impl Indexer for DocumentIndexer { } } + fn create_embedded_chunks_from_text( + &self, + object_id: String, + text: String, + model: EmbeddingModel, + ) -> Result, AppError> { + split_text_into_chunks(object_id, text, CollabType::Document, &model) + } + fn embed( &self, embedder: &Embedder, @@ -104,6 +111,10 @@ fn split_text_into_chunks( embedding_model, EmbeddingModel::TextEmbedding3Small )); + + if content.is_empty() { + return Ok(vec![]); + } // We assume that every token is ~4 bytes. We're going to split document content into fragments // of ~2000 tokens each. let split_contents = split_text_by_max_content_len(content, 8000)?; @@ -115,7 +126,6 @@ fn split_text_into_chunks( .map(|content| AFCollabEmbeddedChunk { fragment_id: Uuid::new_v4().to_string(), object_id: object_id.clone(), - collab_type: collab_type.clone(), content_type: EmbeddingContentType::PlainText, content, embedding: None, diff --git a/libs/indexer/src/collab_indexer/mod.rs b/libs/indexer/src/collab_indexer/mod.rs new file mode 100644 index 000000000..4810b352d --- /dev/null +++ b/libs/indexer/src/collab_indexer/mod.rs @@ -0,0 +1,5 @@ +mod document_indexer; +mod provider; + +pub use document_indexer::*; +pub use provider::*; diff --git a/services/appflowy-collaborate/src/indexer/provider.rs b/libs/indexer/src/collab_indexer/provider.rs similarity index 83% rename from services/appflowy-collaborate/src/indexer/provider.rs rename to libs/indexer/src/collab_indexer/provider.rs index ed70a710d..6968a234a 100644 --- a/services/appflowy-collaborate/src/indexer/provider.rs +++ b/libs/indexer/src/collab_indexer/provider.rs @@ -1,22 +1,29 @@ -use crate::config::get_env_var; -use crate::indexer::vector::embedder::Embedder; -use crate::indexer::DocumentIndexer; +use crate::collab_indexer::DocumentIndexer; +use crate::vector::embedder::Embedder; use app_error::AppError; use appflowy_ai_client::dto::EmbeddingModel; use collab::preclude::Collab; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddedChunk, AFCollabEmbeddings}; +use infra::env_util::get_env_var; use std::collections::HashMap; use std::sync::Arc; use tracing::info; pub trait Indexer: Send + Sync { - fn create_embedded_chunks( + fn create_embedded_chunks_from_collab( &self, collab: &Collab, model: EmbeddingModel, ) -> Result, AppError>; + fn create_embedded_chunks_from_text( + &self, + object_id: String, + text: String, + model: EmbeddingModel, + ) -> Result, AppError>; + fn embed( &self, embedder: &Embedder, diff --git a/libs/indexer/src/entity.rs b/libs/indexer/src/entity.rs new file mode 100644 index 000000000..f20ce2ce9 --- /dev/null +++ b/libs/indexer/src/entity.rs @@ -0,0 +1,31 @@ +use collab::entity::EncodedCollab; +use collab_entity::CollabType; +use database_entity::dto::AFCollabEmbeddedChunk; +use uuid::Uuid; + +pub struct UnindexedCollab { + pub workspace_id: Uuid, + pub object_id: String, + pub collab_type: CollabType, + pub collab: EncodedCollab, +} + +pub struct EmbeddingRecord { + pub workspace_id: Uuid, + pub object_id: String, + pub collab_type: CollabType, + pub tokens_used: u32, + pub contents: Vec, +} + +impl EmbeddingRecord { + pub fn empty(workspace_id: Uuid, object_id: String, collab_type: CollabType) -> Self { + Self { + workspace_id, + object_id, + collab_type, + tokens_used: 0, + contents: vec![], + } + } +} diff --git a/libs/indexer/src/error.rs b/libs/indexer/src/error.rs new file mode 100644 index 000000000..ea5a77be8 --- /dev/null +++ b/libs/indexer/src/error.rs @@ -0,0 +1,8 @@ +#[derive(thiserror::Error, Debug)] +pub enum IndexerError { + #[error("Redis stream group not exist: {0}")] + StreamGroupNotExist(String), + + #[error(transparent)] + Internal(#[from] anyhow::Error), +} diff --git a/libs/indexer/src/lib.rs b/libs/indexer/src/lib.rs new file mode 100644 index 000000000..72c4147e7 --- /dev/null +++ b/libs/indexer/src/lib.rs @@ -0,0 +1,9 @@ +pub mod collab_indexer; +pub mod entity; +pub mod error; +pub mod metrics; +pub mod queue; +pub mod scheduler; +pub mod thread_pool; +mod unindexed_workspace; +pub mod vector; diff --git a/services/appflowy-collaborate/src/indexer/metrics.rs b/libs/indexer/src/metrics.rs similarity index 64% rename from services/appflowy-collaborate/src/indexer/metrics.rs rename to libs/indexer/src/metrics.rs index 6825bdb8d..d75b76025 100644 --- a/services/appflowy-collaborate/src/indexer/metrics.rs +++ b/libs/indexer/src/metrics.rs @@ -4,8 +4,9 @@ use prometheus_client::registry::Registry; pub struct EmbeddingMetrics { total_embed_count: Counter, failed_embed_count: Counter, - processing_time_histogram: Histogram, write_embedding_time_histogram: Histogram, + gen_embeddings_time_histogram: Histogram, + fallback_background_tasks: Counter, } impl EmbeddingMetrics { @@ -13,8 +14,9 @@ impl EmbeddingMetrics { Self { total_embed_count: Counter::default(), failed_embed_count: Counter::default(), - processing_time_histogram: Histogram::new([500.0, 1000.0, 5000.0, 8000.0].into_iter()), write_embedding_time_histogram: Histogram::new([500.0, 1000.0, 5000.0, 8000.0].into_iter()), + gen_embeddings_time_histogram: Histogram::new([1000.0, 3000.0, 5000.0, 8000.0].into_iter()), + fallback_background_tasks: Counter::default(), } } @@ -33,17 +35,24 @@ impl EmbeddingMetrics { "Total count of failed embeddings", metrics.failed_embed_count.clone(), ); - realtime_registry.register( - "processing_time_seconds", - "Histogram of embedding processing times", - metrics.processing_time_histogram.clone(), - ); realtime_registry.register( "write_embedding_time_seconds", "Histogram of embedding write times", metrics.write_embedding_time_histogram.clone(), ); + realtime_registry.register( + "gen_embeddings_time_histogram", + "Histogram of embedding generation times", + metrics.gen_embeddings_time_histogram.clone(), + ); + + realtime_registry.register( + "fallback_background_tasks", + "Total count of fallback background tasks", + metrics.fallback_background_tasks.clone(), + ); + metrics } @@ -55,13 +64,16 @@ impl EmbeddingMetrics { self.failed_embed_count.inc_by(count); } - pub fn record_generate_embedding_time(&self, millis: u128) { - tracing::trace!("[Embedding]: generate embeddings cost: {}ms", millis); - self.processing_time_histogram.observe(millis as f64); + pub fn record_fallback_background_tasks(&self, count: u64) { + self.fallback_background_tasks.inc_by(count); } pub fn record_write_embedding_time(&self, millis: u128) { - tracing::trace!("[Embedding]: write embedding time cost: {}ms", millis); self.write_embedding_time_histogram.observe(millis as f64); } + + pub fn record_gen_embedding_time(&self, num: u32, millis: u128) { + tracing::info!("[Embedding]: index {} collabs cost: {}ms", num, millis); + self.gen_embeddings_time_histogram.observe(millis as f64); + } } diff --git a/libs/indexer/src/queue.rs b/libs/indexer/src/queue.rs new file mode 100644 index 000000000..5595a03a0 --- /dev/null +++ b/libs/indexer/src/queue.rs @@ -0,0 +1,172 @@ +use crate::error::IndexerError; +use crate::scheduler::UnindexedCollabTask; +use anyhow::anyhow; +use app_error::AppError; +use redis::aio::ConnectionManager; +use redis::streams::{StreamId, StreamReadOptions, StreamReadReply}; +use redis::{AsyncCommands, RedisResult, Value}; +use serde_json::from_str; +use tracing::error; + +pub const INDEX_TASK_STREAM_NAME: &str = "index_collab_task_stream"; +const INDEXER_WORKER_GROUP_NAME: &str = "indexer_worker_group"; +const INDEXER_CONSUMER_NAME: &str = "appflowy_worker"; + +impl TryFrom<&StreamId> for UnindexedCollabTask { + type Error = IndexerError; + + fn try_from(stream_id: &StreamId) -> Result { + let task_str = match stream_id.map.get("task") { + Some(value) => match value { + Value::Data(data) => String::from_utf8_lossy(data).to_string(), + _ => { + error!("Unexpected value type for task field: {:?}", value); + return Err(IndexerError::Internal(anyhow!( + "Unexpected value type for task field: {:?}", + value + ))); + }, + }, + None => { + error!("Task field not found in Redis stream entry"); + return Err(IndexerError::Internal(anyhow!( + "Task field not found in Redis stream entry" + ))); + }, + }; + + from_str::(&task_str).map_err(|err| IndexerError::Internal(err.into())) + } +} + +/// Adds a list of tasks to the Redis stream. +/// +/// This function pushes a batch of `EmbedderTask` items into the Redis stream for processing. +/// The tasks are serialized into JSON format before being added to the stream. +/// +pub async fn add_background_embed_task( + redis_client: ConnectionManager, + tasks: Vec, +) -> Result<(), AppError> { + let items = tasks + .into_iter() + .flat_map(|task| { + let task = serde_json::to_string(&task).ok()?; + Some(("task", task)) + }) + .collect::>(); + + let _: () = redis_client + .clone() + .xadd(INDEX_TASK_STREAM_NAME, "*", &items) + .await + .map_err(|err| { + AppError::Internal(anyhow!( + "Failed to push embedder task to Redis stream: {}", + err + )) + })?; + Ok(()) +} + +/// Reads tasks from the Redis stream for processing by a consumer group. +pub async fn read_background_embed_tasks( + redis_client: &mut ConnectionManager, + options: &StreamReadOptions, +) -> Result { + let tasks: StreamReadReply = match redis_client + .xread_options(&[INDEX_TASK_STREAM_NAME], &[">"], options) + .await + { + Ok(tasks) => tasks, + Err(err) => { + error!("Failed to read tasks from Redis stream: {:?}", err); + if let Some(code) = err.code() { + if code == "NOGROUP" { + return Err(IndexerError::StreamGroupNotExist( + INDEXER_WORKER_GROUP_NAME.to_string(), + )); + } + } + return Err(IndexerError::Internal(err.into())); + }, + }; + Ok(tasks) +} + +/// Acknowledges a task in a Redis stream and optionally removes it from the stream. +/// +/// It is used to acknowledge the processing of a task in a Redis stream +/// within a specific consumer group. Once a task is acknowledged, it is removed from +/// the **Pending Entries List (PEL)** for the consumer group. If the `delete_task` +/// flag is set to `true`, the task will also be removed from the Redis stream entirely. +/// +/// # Parameters: +/// - `redis_client`: A mutable reference to the Redis `ConnectionManager`, used to +/// interact with the Redis server. +/// - `stream_entity_id`: The unique identifier (ID) of the task in the stream. +/// - `delete_task`: A boolean flag that indicates whether the task should be removed +/// from the stream after it is acknowledged. If `true`, the task is deleted from the stream. +/// If `false`, the task remains in the stream after acknowledgment. +pub async fn ack_task( + redis_client: &mut ConnectionManager, + stream_entity_ids: Vec, + delete_task: bool, +) -> Result<(), IndexerError> { + let _: () = redis_client + .xack( + INDEX_TASK_STREAM_NAME, + INDEXER_WORKER_GROUP_NAME, + &stream_entity_ids, + ) + .await + .map_err(|err| { + error!("Failed to ack task: {:?}", err); + IndexerError::Internal(err.into()) + })?; + + if delete_task { + let _: () = redis_client + .xdel(INDEX_TASK_STREAM_NAME, &stream_entity_ids) + .await + .map_err(|err| { + error!("Failed to delete task: {:?}", err); + IndexerError::Internal(err.into()) + })?; + } + + Ok(()) +} + +pub fn default_indexer_group_option(limit: usize) -> StreamReadOptions { + StreamReadOptions::default() + .group(INDEXER_WORKER_GROUP_NAME, INDEXER_CONSUMER_NAME) + .count(limit) +} + +/// Ensure the consumer group exists, if not, create it. +pub async fn ensure_indexer_consumer_group( + redis_client: &mut ConnectionManager, +) -> Result<(), IndexerError> { + let result: RedisResult<()> = redis_client + .xgroup_create_mkstream(INDEX_TASK_STREAM_NAME, INDEXER_WORKER_GROUP_NAME, "0") + .await; + + if let Err(redis_error) = result { + if let Some(code) = redis_error.code() { + if code == "BUSYGROUP" { + return Ok(()); + } + + if code == "NOGROUP" { + return Err(IndexerError::StreamGroupNotExist( + INDEXER_WORKER_GROUP_NAME.to_string(), + )); + } + } + error!("Error when creating consumer group: {:?}", redis_error); + return Err(IndexerError::Internal(redis_error.into())); + } + + Ok(()) +} diff --git a/libs/indexer/src/scheduler.rs b/libs/indexer/src/scheduler.rs new file mode 100644 index 000000000..384f3fd41 --- /dev/null +++ b/libs/indexer/src/scheduler.rs @@ -0,0 +1,549 @@ +use crate::collab_indexer::{Indexer, IndexerProvider}; +use crate::entity::EmbeddingRecord; +use crate::error::IndexerError; +use crate::metrics::EmbeddingMetrics; +use crate::queue::add_background_embed_task; +use crate::thread_pool::{ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; +use crate::vector::embedder::Embedder; +use crate::vector::open_ai; +use app_error::AppError; +use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; +use collab::preclude::Collab; +use collab_document::document::DocumentBody; +use collab_entity::CollabType; +use database::collab::CollabStorage; +use database::index::{update_collab_indexed_at, upsert_collab_embeddings}; +use database::workspace::select_workspace_settings; +use database_entity::dto::AFCollabEmbeddedChunk; +use infra::env_util::get_env_var; +use rayon::prelude::*; +use redis::aio::ConnectionManager; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; +use std::cmp::max; +use std::collections::HashSet; +use std::ops::DerefMut; +use std::sync::{Arc, Weak}; +use std::time::{Duration, Instant}; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::RwLock as TokioRwLock; +use tokio::time::timeout; +use tracing::{debug, error, info, instrument, trace, warn}; +use uuid::Uuid; + +pub struct IndexerScheduler { + pub(crate) indexer_provider: Arc, + pub(crate) pg_pool: PgPool, + pub(crate) storage: Arc, + pub(crate) threads: Arc, + #[allow(dead_code)] + pub(crate) metrics: Arc, + write_embedding_tx: UnboundedSender, + gen_embedding_tx: mpsc::Sender, + config: IndexerConfiguration, + redis_client: ConnectionManager, +} + +#[derive(Debug)] +pub struct IndexerConfiguration { + pub enable: bool, + pub openai_api_key: String, + /// High watermark for the number of embeddings that can be buffered before being written to the database. + pub embedding_buffer_size: usize, +} + +impl IndexerScheduler { + pub fn new( + indexer_provider: Arc, + pg_pool: PgPool, + storage: Arc, + metrics: Arc, + config: IndexerConfiguration, + redis_client: ConnectionManager, + ) -> Arc { + // Since threads often block while waiting for I/O, you can use more threads than CPU cores to improve concurrency. + // A good rule of thumb is 2x to 10x the number of CPU cores + let num_thread = max( + get_env_var("APPFLOWY_INDEXER_SCHEDULER_NUM_THREAD", "50") + .parse::() + .unwrap_or(50), + 5, + ); + + info!("Indexer scheduler config: {:?}", config); + let (write_embedding_tx, write_embedding_rx) = unbounded_channel::(); + let (gen_embedding_tx, gen_embedding_rx) = + mpsc::channel::(config.embedding_buffer_size); + let threads = Arc::new( + ThreadPoolNoAbortBuilder::new() + .num_threads(num_thread) + .thread_name(|index| format!("create-embedding-thread-{index}")) + .build() + .unwrap(), + ); + + let this = Arc::new(Self { + indexer_provider, + pg_pool, + storage, + threads, + metrics, + write_embedding_tx, + gen_embedding_tx, + config, + redis_client, + }); + + info!( + "Indexer scheduler is enabled: {}, num threads: {}", + this.index_enabled(), + num_thread + ); + + let latest_write_embedding_err = Arc::new(TokioRwLock::new(None)); + if this.index_enabled() { + tokio::spawn(spawn_rayon_generate_embeddings( + gen_embedding_rx, + Arc::downgrade(&this), + num_thread, + latest_write_embedding_err.clone(), + )); + + tokio::spawn(spawn_pg_write_embeddings( + write_embedding_rx, + this.pg_pool.clone(), + this.metrics.clone(), + latest_write_embedding_err, + )); + } + + this + } + + fn index_enabled(&self) -> bool { + // if indexing is disabled, return false + if !self.config.enable { + return false; + } + + // if openai api key is empty, return false + if self.config.openai_api_key.is_empty() { + return false; + } + + true + } + + pub fn is_indexing_enabled(&self, collab_type: &CollabType) -> bool { + self.indexer_provider.is_indexing_enabled(collab_type) + } + + pub(crate) fn create_embedder(&self) -> Result { + if self.config.openai_api_key.is_empty() { + return Err(AppError::AIServiceUnavailable( + "OpenAI API key is empty".to_string(), + )); + } + + Ok(Embedder::OpenAI(open_ai::Embedder::new( + self.config.openai_api_key.clone(), + ))) + } + + pub fn create_search_embeddings( + &self, + request: EmbeddingRequest, + ) -> Result { + let embedder = self.create_embedder()?; + let embeddings = embedder.embed(request)?; + Ok(embeddings) + } + + pub fn embed_in_background( + &self, + pending_collabs: Vec, + ) -> Result<(), AppError> { + if !self.index_enabled() { + return Ok(()); + } + + let redis_client = self.redis_client.clone(); + tokio::spawn(add_background_embed_task(redis_client, pending_collabs)); + Ok(()) + } + + pub fn embed_immediately(&self, pending_collab: UnindexedCollabTask) -> Result<(), AppError> { + if !self.index_enabled() { + return Ok(()); + } + if let Err(err) = self.gen_embedding_tx.try_send(pending_collab) { + match err { + TrySendError::Full(pending) => { + warn!("[Embedding] Embedding queue is full, embedding in background"); + self.embed_in_background(vec![pending])?; + self.metrics.record_failed_embed_count(1); + }, + TrySendError::Closed(_) => { + error!("Failed to send embedding record: channel closed"); + }, + } + } + + Ok(()) + } + + pub fn index_pending_collab_one( + &self, + pending_collab: UnindexedCollabTask, + background: bool, + ) -> Result<(), AppError> { + if !self.index_enabled() { + return Ok(()); + } + + let indexer = self + .indexer_provider + .indexer_for(&pending_collab.collab_type); + if indexer.is_none() { + return Ok(()); + } + + if background { + let _ = self.embed_in_background(vec![pending_collab]); + } else { + let _ = self.embed_immediately(pending_collab); + } + Ok(()) + } + + /// Index all pending collabs in the background + pub fn index_pending_collabs( + &self, + mut pending_collabs: Vec, + ) -> Result<(), AppError> { + if !self.index_enabled() { + return Ok(()); + } + + pending_collabs.retain(|collab| self.is_indexing_enabled(&collab.collab_type)); + if pending_collabs.is_empty() { + return Ok(()); + } + + info!("indexing {} collabs in background", pending_collabs.len()); + let _ = self.embed_in_background(pending_collabs); + + Ok(()) + } + + pub async fn index_collab_immediately( + &self, + workspace_id: &str, + object_id: &str, + collab: &Collab, + collab_type: &CollabType, + ) -> Result<(), AppError> { + if !self.index_enabled() { + return Ok(()); + } + + if !self.is_indexing_enabled(collab_type) { + return Ok(()); + } + + match collab_type { + CollabType::Document => { + let txn = collab.transact(); + let text = DocumentBody::from_collab(&collab) + .and_then(|body| body.to_plain_text(txn, false, true).ok()); + + if let Some(text) = text { + if !text.is_empty() { + let pending = UnindexedCollabTask::new( + Uuid::parse_str(workspace_id)?, + object_id.to_string(), + collab_type.clone(), + UnindexedData::Text(text), + ); + self.embed_immediately(pending)?; + } + } + }, + _ => { + // TODO(nathan): support other collab types + }, + } + + Ok(()) + } + + pub async fn can_index_workspace(&self, workspace_id: &str) -> Result { + if !self.index_enabled() { + return Ok(false); + } + + let uuid = Uuid::parse_str(workspace_id)?; + let settings = select_workspace_settings(&self.pg_pool, &uuid).await?; + match settings { + None => Ok(true), + Some(settings) => Ok(!settings.disable_search_indexing), + } + } +} + +async fn spawn_rayon_generate_embeddings( + mut rx: mpsc::Receiver, + scheduler: Weak, + buffer_size: usize, + latest_write_embedding_err: Arc>>, +) { + let mut buf = Vec::with_capacity(buffer_size); + loop { + let latest_error = latest_write_embedding_err.write().await.take(); + if let Some(err) = latest_error { + if matches!(err, AppError::ActionTimeout(_)) { + info!( + "[Embedding] last write embedding task failed with timeout, waiting for 30s before retrying..." + ); + tokio::time::sleep(Duration::from_secs(30)).await; + } + } + + let n = rx.recv_many(&mut buf, buffer_size).await; + let scheduler = match scheduler.upgrade() { + Some(scheduler) => scheduler, + None => { + error!("[Embedding] Failed to upgrade scheduler"); + break; + }, + }; + + if n == 0 { + info!("[Embedding] Stop generating embeddings"); + break; + } + + let start = Instant::now(); + let records = buf.drain(..n).collect::>(); + trace!( + "[Embedding] received {} embeddings to generate", + records.len() + ); + let metrics = scheduler.metrics.clone(); + let threads = scheduler.threads.clone(); + let indexer_provider = scheduler.indexer_provider.clone(); + let write_embedding_tx = scheduler.write_embedding_tx.clone(); + let embedder = scheduler.create_embedder(); + let result = tokio::task::spawn_blocking(move || { + match embedder { + Ok(embedder) => { + records.into_par_iter().for_each(|record| { + let result = threads.install(|| { + let indexer = indexer_provider.indexer_for(&record.collab_type); + match process_collab(&embedder, indexer, &record.object_id, record.data, &metrics) { + Ok(Some((tokens_used, contents))) => { + if let Err(err) = write_embedding_tx.send(EmbeddingRecord { + workspace_id: record.workspace_id, + object_id: record.object_id, + collab_type: record.collab_type, + tokens_used, + contents, + }) { + error!("Failed to send embedding record: {}", err); + } + }, + Ok(None) => { + debug!("No embedding for collab:{}", record.object_id); + }, + Err(err) => { + warn!( + "Failed to create embeddings content for collab:{}, error:{}", + record.object_id, err + ); + }, + } + }); + + if let Err(err) = result { + error!("Failed to install a task to rayon thread pool: {}", err); + } + }); + }, + Err(err) => error!("[Embedding] Failed to create embedder: {}", err), + } + Ok::<_, IndexerError>(()) + }) + .await; + + match result { + Ok(Ok(_)) => { + scheduler + .metrics + .record_gen_embedding_time(n as u32, start.elapsed().as_millis()); + trace!("Successfully generated embeddings"); + }, + Ok(Err(err)) => error!("Failed to generate embeddings: {}", err), + Err(err) => error!("Failed to spawn a task to generate embeddings: {}", err), + } + } +} + +const EMBEDDING_RECORD_BUFFER_SIZE: usize = 10; +pub async fn spawn_pg_write_embeddings( + mut rx: UnboundedReceiver, + pg_pool: PgPool, + metrics: Arc, + latest_write_embedding_error: Arc>>, +) { + let mut buf = Vec::with_capacity(EMBEDDING_RECORD_BUFFER_SIZE); + loop { + let n = rx.recv_many(&mut buf, EMBEDDING_RECORD_BUFFER_SIZE).await; + if n == 0 { + info!("Stop writing embeddings"); + break; + } + + trace!("[Embedding] received {} embeddings to write", n); + let start = Instant::now(); + let records = buf.drain(..n).collect::>(); + for record in records.iter() { + info!( + "[Embedding] generate collab:{} embeddings, tokens used: {}", + record.object_id, record.tokens_used + ); + } + + let result = timeout( + Duration::from_secs(20), + batch_insert_records(&pg_pool, records), + ) + .await + .unwrap_or_else(|_| { + Err(AppError::ActionTimeout( + "timeout when writing embeddings".to_string(), + )) + }); + + match result { + Ok(_) => { + trace!("[Embedding] save {} embeddings to disk", n); + metrics.record_write_embedding_time(start.elapsed().as_millis()); + }, + Err(err) => { + error!("Failed to write collab embedding to disk:{}", err); + latest_write_embedding_error.write().await.replace(err); + }, + } + } +} + +#[instrument(level = "trace", skip_all)] +pub(crate) async fn batch_insert_records( + pg_pool: &PgPool, + records: Vec, +) -> Result<(), AppError> { + let mut seen = HashSet::new(); + let records = records + .into_iter() + .filter(|record| seen.insert(record.object_id.clone())) + .collect::>(); + + let mut txn = pg_pool.begin().await?; + for record in records { + update_collab_indexed_at( + txn.deref_mut(), + &record.object_id, + &record.collab_type, + chrono::Utc::now(), + ) + .await?; + + upsert_collab_embeddings( + &mut txn, + &record.workspace_id, + &record.object_id, + record.collab_type, + record.tokens_used, + record.contents, + ) + .await?; + } + txn.commit().await.map_err(|e| { + error!("[Embedding] Failed to commit transaction: {:?}", e); + e + })?; + + Ok(()) +} + +/// This function must be called within the rayon thread pool. +fn process_collab( + embedder: &Embedder, + indexer: Option>, + object_id: &str, + data: UnindexedData, + metrics: &EmbeddingMetrics, +) -> Result)>, AppError> { + if let Some(indexer) = indexer { + metrics.record_embed_count(1); + + let chunks = match data { + UnindexedData::Text(text) => { + indexer.create_embedded_chunks_from_text(object_id.to_string(), text, embedder.model())? + }, + }; + + let result = indexer.embed(embedder, chunks); + match result { + Ok(Some(embeddings)) => Ok(Some((embeddings.tokens_consumed, embeddings.params))), + Ok(None) => Ok(None), + Err(err) => { + metrics.record_failed_embed_count(1); + Err(err) + }, + } + } else { + Ok(None) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UnindexedCollabTask { + pub workspace_id: Uuid, + pub object_id: String, + pub collab_type: CollabType, + pub data: UnindexedData, + pub created_at: i64, +} + +impl UnindexedCollabTask { + pub fn new( + workspace_id: Uuid, + object_id: String, + collab_type: CollabType, + data: UnindexedData, + ) -> Self { + Self { + workspace_id, + object_id, + collab_type, + data, + created_at: chrono::Utc::now().timestamp(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum UnindexedData { + Text(String), +} + +impl UnindexedData { + pub fn is_empty(&self) -> bool { + match self { + UnindexedData::Text(text) => text.is_empty(), + } + } +} diff --git a/libs/indexer/src/thread_pool.rs b/libs/indexer/src/thread_pool.rs new file mode 100644 index 000000000..401fdc173 --- /dev/null +++ b/libs/indexer/src/thread_pool.rs @@ -0,0 +1,167 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use rayon::{ThreadPool, ThreadPoolBuilder}; +use thiserror::Error; + +/// A thread pool that does not abort on panics. +/// +/// This custom thread pool wraps Rayon’s `ThreadPool` and ensures that the thread pool +/// can recover from panics gracefully. It detects any panics in worker threads and +/// prevents the entire application from aborting. +#[derive(Debug)] +pub struct ThreadPoolNoAbort { + /// Internal Rayon thread pool. + thread_pool: ThreadPool, + /// Atomic flag to detect if a panic occurred in the thread pool. + catched_panic: Arc, +} + +impl ThreadPoolNoAbort { + /// Executes a closure within the thread pool. + /// + /// This method runs the provided closure (`op`) inside the thread pool. If a panic + /// occurs during the execution, it is detected and returned as an error. + /// + /// # Arguments + /// * `op` - A closure that will be executed within the thread pool. + /// + /// # Returns + /// * `Ok(R)` - The result of the closure if execution was successful. + /// * `Err(PanicCatched)` - An error indicating that a panic occurred during execution. + /// + pub fn install(&self, op: OP) -> Result + where + OP: FnOnce() -> R + Send, + R: Send, + { + let output = self.thread_pool.install(op); + // Reset the panic flag and return an error if a panic was detected. + if self.catched_panic.swap(false, Ordering::SeqCst) { + Err(CatchedPanic) + } else { + Ok(output) + } + } + + /// Returns the current number of threads in the thread pool. + /// + /// # Returns + /// The number of threads being used by the thread pool. + pub fn current_num_threads(&self) -> usize { + self.thread_pool.current_num_threads() + } +} + +/// Error indicating that a panic occurred during thread pool execution. +/// +/// This error is returned when a closure executed in the thread pool panics. +#[derive(Error, Debug)] +#[error("A panic occurred happened in the thread pool. Check the logs for more information")] +pub struct CatchedPanic; + +/// A builder for creating a `ThreadPoolNoAbort` instance. +/// +/// This builder wraps Rayon’s `ThreadPoolBuilder` and customizes the panic handling behavior. +#[derive(Default)] +pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder); + +impl ThreadPoolNoAbortBuilder { + pub fn new() -> ThreadPoolNoAbortBuilder { + ThreadPoolNoAbortBuilder::default() + } + + /// Sets a custom naming function for threads in the pool. + /// + /// # Arguments + /// * `closure` - A function that takes a thread index and returns a thread name. + /// + pub fn thread_name(mut self, closure: F) -> Self + where + F: FnMut(usize) -> String + 'static, + { + self.0 = self.0.thread_name(closure); + self + } + + /// Sets the number of threads for the thread pool. + /// + /// # Arguments + /// * `num_threads` - The number of threads to create in the thread pool. + pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder { + self.0 = self.0.num_threads(num_threads); + self + } + + /// Builds the `ThreadPoolNoAbort` instance. + /// + /// This method creates a `ThreadPoolNoAbort` with the specified configurations, + /// including custom panic handling behavior. + /// + /// # Returns + /// * `Ok(ThreadPoolNoAbort)` - The constructed thread pool. + /// * `Err(ThreadPoolBuildError)` - If the thread pool failed to build. + /// + pub fn build(mut self) -> Result { + let catched_panic = Arc::new(AtomicBool::new(false)); + self.0 = self.0.panic_handler({ + let catched_panic = catched_panic.clone(); + move |_result| catched_panic.store(true, Ordering::SeqCst) + }); + Ok(ThreadPoolNoAbort { + thread_pool: self.0.build()?, + catched_panic, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[test] + fn test_install_closure_success() { + // Create a thread pool with 4 threads. + let pool = ThreadPoolNoAbortBuilder::new() + .num_threads(4) + .build() + .expect("Failed to build thread pool"); + + // Run a closure that executes successfully. + let result = pool.install(|| 42); + + // Ensure the result is correct. + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_multiple_threads_execution() { + // Create a thread pool with multiple threads. + let pool = ThreadPoolNoAbortBuilder::new() + .num_threads(8) + .build() + .expect("Failed to build thread pool"); + + // Shared atomic counter to verify parallel execution. + let counter = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..100) + .map(|_| { + let counter_clone = counter.clone(); + pool.install(move || { + counter_clone.fetch_add(1, Ordering::SeqCst); + }) + }) + .collect(); + + // Ensure all tasks completed successfully. + for handle in handles { + assert!(handle.is_ok()); + } + + // Verify that the counter equals the number of tasks executed. + assert_eq!(counter.load(Ordering::SeqCst), 100); + } +} diff --git a/libs/indexer/src/unindexed_workspace.rs b/libs/indexer/src/unindexed_workspace.rs new file mode 100644 index 000000000..c2d0752a1 --- /dev/null +++ b/libs/indexer/src/unindexed_workspace.rs @@ -0,0 +1,223 @@ +use crate::collab_indexer::IndexerProvider; +use crate::entity::{EmbeddingRecord, UnindexedCollab}; +use crate::scheduler::{batch_insert_records, IndexerScheduler}; +use crate::thread_pool::ThreadPoolNoAbort; +use crate::vector::embedder::Embedder; +use collab::core::collab::DataSource; +use collab::core::origin::CollabOrigin; +use collab::preclude::Collab; +use collab_entity::CollabType; +use database::collab::{CollabStorage, GetCollabOrigin}; +use database::index::stream_collabs_without_embeddings; +use futures_util::stream::BoxStream; +use futures_util::StreamExt; +use rayon::iter::ParallelIterator; +use rayon::prelude::IntoParallelIterator; +use sqlx::pool::PoolConnection; +use sqlx::Postgres; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::{error, info, trace}; +use uuid::Uuid; + +#[allow(dead_code)] +pub(crate) async fn index_workspace(scheduler: Arc, workspace_id: Uuid) { + let weak_threads = Arc::downgrade(&scheduler.threads); + let mut retry_delay = Duration::from_secs(2); + loop { + let threads = match weak_threads.upgrade() { + Some(threads) => threads, + None => { + info!("[Embedding] thread pool is dropped, stop indexing"); + break; + }, + }; + + let conn = scheduler.pg_pool.try_acquire(); + if conn.is_none() { + tokio::time::sleep(retry_delay).await; + // 4s, 8s, 16s, 32s, 60s + retry_delay = retry_delay.saturating_mul(2); + if retry_delay > Duration::from_secs(60) { + error!("[Embedding] failed to acquire db connection for 1 minute, stop indexing"); + break; + } + continue; + } + + retry_delay = Duration::from_secs(2); + let mut conn = conn.unwrap(); + let mut stream = + stream_unindexed_collabs(&mut conn, workspace_id, scheduler.storage.clone(), 50).await; + + let batch_size = 5; + let mut unindexed_collabs = Vec::with_capacity(batch_size); + while let Some(Ok(collab)) = stream.next().await { + if unindexed_collabs.len() < batch_size { + unindexed_collabs.push(collab); + continue; + } + + index_then_write_embedding_to_disk( + &scheduler, + threads.clone(), + std::mem::take(&mut unindexed_collabs), + ) + .await; + } + + if !unindexed_collabs.is_empty() { + index_then_write_embedding_to_disk(&scheduler, threads.clone(), unindexed_collabs).await; + } + } +} + +async fn index_then_write_embedding_to_disk( + scheduler: &Arc, + threads: Arc, + unindexed_collabs: Vec, +) { + info!( + "[Embedding] process batch {:?} embeddings", + unindexed_collabs + .iter() + .map(|v| v.object_id.clone()) + .collect::>() + ); + + if let Ok(embedder) = scheduler.create_embedder() { + let start = Instant::now(); + let embeddings = create_embeddings( + embedder, + &scheduler.indexer_provider, + threads.clone(), + unindexed_collabs, + ) + .await; + scheduler + .metrics + .record_gen_embedding_time(embeddings.len() as u32, start.elapsed().as_millis()); + + let write_start = Instant::now(); + let n = embeddings.len(); + match batch_insert_records(&scheduler.pg_pool, embeddings).await { + Ok(_) => trace!( + "[Embedding] upsert {} embeddings success, cost:{}ms", + n, + write_start.elapsed().as_millis() + ), + Err(err) => error!("{}", err), + } + + scheduler + .metrics + .record_write_embedding_time(write_start.elapsed().as_millis()); + tokio::time::sleep(Duration::from_secs(5)).await; + } else { + trace!("[Embedding] no embeddings to process in this batch"); + } +} + +async fn stream_unindexed_collabs( + conn: &mut PoolConnection, + workspace_id: Uuid, + storage: Arc, + limit: i64, +) -> BoxStream> { + let cloned_storage = storage.clone(); + stream_collabs_without_embeddings(conn, workspace_id, limit) + .await + .map(move |result| { + let storage = cloned_storage.clone(); + async move { + match result { + Ok(cid) => match cid.collab_type { + CollabType::Document => { + let collab = storage + .get_encode_collab(GetCollabOrigin::Server, cid.clone().into(), false) + .await?; + + Ok(Some(UnindexedCollab { + workspace_id: cid.workspace_id, + object_id: cid.object_id, + collab_type: cid.collab_type, + collab, + })) + }, + // TODO(nathan): support other collab types + _ => Ok::<_, anyhow::Error>(None), + }, + Err(e) => Err(e.into()), + } + } + }) + .filter_map(|future| async { + match future.await { + Ok(Some(unindexed_collab)) => Some(Ok(unindexed_collab)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }) + .boxed() +} + +async fn create_embeddings( + embedder: Embedder, + indexer_provider: &Arc, + threads: Arc, + unindexed_records: Vec, +) -> Vec { + unindexed_records + .into_par_iter() + .flat_map(|unindexed| { + let indexer = indexer_provider.indexer_for(&unindexed.collab_type)?; + let collab = Collab::new_with_source( + CollabOrigin::Empty, + &unindexed.object_id, + DataSource::DocStateV1(unindexed.collab.doc_state.into()), + vec![], + false, + ) + .ok()?; + + let chunks = indexer + .create_embedded_chunks_from_collab(&collab, embedder.model()) + .ok()?; + if chunks.is_empty() { + trace!("[Embedding] {} has no embeddings", unindexed.object_id,); + return Some(EmbeddingRecord::empty( + unindexed.workspace_id, + unindexed.object_id, + unindexed.collab_type, + )); + } + + let result = threads.install(|| match indexer.embed(&embedder, chunks) { + Ok(embeddings) => embeddings.map(|embeddings| EmbeddingRecord { + workspace_id: unindexed.workspace_id, + object_id: unindexed.object_id, + collab_type: unindexed.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }), + Err(err) => { + error!("Failed to embed collab: {}", err); + None + }, + }); + + if let Ok(Some(record)) = &result { + trace!( + "[Embedding] generate collab:{} embeddings, tokens used: {}", + record.object_id, + record.tokens_used + ); + } + + result.unwrap_or_else(|err| { + error!("Failed to spawn a task to index collab: {}", err); + None + }) + }) + .collect::>() +} diff --git a/services/appflowy-collaborate/src/indexer/vector/embedder.rs b/libs/indexer/src/vector/embedder.rs similarity index 92% rename from services/appflowy-collaborate/src/indexer/vector/embedder.rs rename to libs/indexer/src/vector/embedder.rs index 96550e1eb..64e42f385 100644 --- a/services/appflowy-collaborate/src/indexer/vector/embedder.rs +++ b/libs/indexer/src/vector/embedder.rs @@ -1,4 +1,4 @@ -use crate::indexer::vector::open_ai; +use crate::vector::open_ai; use app_error::AppError; use appflowy_ai_client::dto::{EmbeddingModel, EmbeddingRequest, OpenAIEmbeddingResponse}; diff --git a/services/appflowy-collaborate/src/indexer/vector/mod.rs b/libs/indexer/src/vector/mod.rs similarity index 100% rename from services/appflowy-collaborate/src/indexer/vector/mod.rs rename to libs/indexer/src/vector/mod.rs diff --git a/services/appflowy-collaborate/src/indexer/open_ai.rs b/libs/indexer/src/vector/open_ai.rs similarity index 86% rename from services/appflowy-collaborate/src/indexer/open_ai.rs rename to libs/indexer/src/vector/open_ai.rs index ca1c6289e..c04ab8ba1 100644 --- a/services/appflowy-collaborate/src/indexer/open_ai.rs +++ b/libs/indexer/src/vector/open_ai.rs @@ -1,7 +1,74 @@ +use crate::vector::rest::check_response; +use anyhow::anyhow; use app_error::AppError; +use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; +use serde::de::DeserializeOwned; +use std::time::Duration; use tiktoken_rs::CoreBPE; use unicode_segmentation::UnicodeSegmentation; +pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; + +pub const REQUEST_PARALLELISM: usize = 40; + +#[derive(Debug, Clone)] +pub struct Embedder { + bearer: String, + client: ureq::Agent, +} + +impl Embedder { + pub fn new(api_key: String) -> Self { + let bearer = format!("Bearer {api_key}"); + let client = ureq::AgentBuilder::new() + .max_idle_connections(REQUEST_PARALLELISM * 2) + .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) + .build(); + + Self { bearer, client } + } + + pub fn embed(&self, params: EmbeddingRequest) -> Result { + for attempt in 0..3 { + let request = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .set("Authorization", &self.bearer) + .set("Content-Type", "application/json"); + + let result = check_response(request.send_json(¶ms)); + let retry_duration = match result { + Ok(response) => { + let data = from_response::(response)?; + return Ok(data); + }, + Err(retry) => retry.into_duration(attempt), + } + .map_err(|err| AppError::Internal(err.into()))?; + let retry_duration = retry_duration.min(Duration::from_secs(10)); + std::thread::sleep(retry_duration); + } + + Err(AppError::Internal(anyhow!( + "Failed to generate embeddings after 3 attempts" + ))) + } +} + +pub fn from_response(resp: ureq::Response) -> Result +where + T: DeserializeOwned, +{ + let status_code = resp.status(); + if status_code != 200 { + let body = resp.into_string()?; + anyhow::bail!("error code: {}, {}", status_code, body) + } + + let resp = resp.into_json()?; + Ok(resp) +} + /// ## Execution Time Comparison Results /// /// The following results were observed when running `execution_time_comparison_tests`: @@ -128,7 +195,7 @@ pub fn split_text_by_max_content_len( #[cfg(test)] mod tests { - use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens}; + use crate::vector::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens}; use tiktoken_rs::cl100k_base; #[test] diff --git a/services/appflowy-collaborate/src/indexer/vector/rest.rs b/libs/indexer/src/vector/rest.rs similarity index 99% rename from services/appflowy-collaborate/src/indexer/vector/rest.rs rename to libs/indexer/src/vector/rest.rs index d4182ec55..310f4f84d 100644 --- a/services/appflowy-collaborate/src/indexer/vector/rest.rs +++ b/libs/indexer/src/vector/rest.rs @@ -1,4 +1,4 @@ -use crate::thread_pool_no_abort::CatchedPanic; +use crate::thread_pool::CatchedPanic; #[derive(Debug, thiserror::Error)] #[error("{fault}: {kind}")] diff --git a/migrations/20241222152427_collab_add_indexed_at.sql b/migrations/20241222152427_collab_add_indexed_at.sql new file mode 100644 index 000000000..6fc634e8d --- /dev/null +++ b/migrations/20241222152427_collab_add_indexed_at.sql @@ -0,0 +1,3 @@ +-- Add migration script here +ALTER TABLE af_collab +ADD COLUMN indexed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL; \ No newline at end of file diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index cfcec86cf..11661d828 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -96,6 +96,7 @@ aws-sdk-s3 = { version = "1.36.0", features = [ "rt-tokio", ] } zstd.workspace = true +indexer.workspace = true [dev-dependencies] rand = "0.8.5" diff --git a/services/appflowy-collaborate/src/application.rs b/services/appflowy-collaborate/src/application.rs index 8f31663e0..8643600b0 100644 --- a/services/appflowy-collaborate/src/application.rs +++ b/services/appflowy-collaborate/src/application.rs @@ -31,11 +31,12 @@ use crate::collab::cache::CollabCache; use crate::collab::storage::CollabStorageImpl; use crate::command::{CLCommandReceiver, CLCommandSender}; use crate::config::{get_env_var, Config, DatabaseSetting, S3Setting}; -use crate::indexer::{IndexerConfiguration, IndexerProvider, IndexerScheduler}; use crate::pg_listener::PgListeners; use crate::snapshot::SnapshotControl; use crate::state::{AppMetrics, AppState, UserCache}; use crate::CollaborationServer; +use indexer::collab_indexer::IndexerProvider; +use indexer::scheduler::{IndexerConfiguration, IndexerScheduler}; pub struct Application { actix_server: Server, @@ -154,10 +155,13 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result() .unwrap_or(true), openai_api_key: get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""), + embedding_buffer_size: get_env_var("APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE", "2000") + .parse::() + .unwrap_or(2000), }; let indexer_scheduler = IndexerScheduler::new( IndexerProvider::new(), @@ -165,6 +169,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Result<(), AppError> { - spawn_blocking_validate_encode_collab( - &self.object_id, - &self.encoded_collab_v1, - &self.collab_type, - ) - .await - .map_err(|err| AppError::NoRequiredData(err.to_string())) + validate_encode_collab(&self.object_id, &self.encoded_collab_v1, &self.collab_type) + .await + .map_err(|err| AppError::NoRequiredData(err.to_string())) } } diff --git a/services/appflowy-collaborate/src/group/group_init.rs b/services/appflowy-collaborate/src/group/group_init.rs index a3d23039a..617ef20ec 100644 --- a/services/appflowy-collaborate/src/group/group_init.rs +++ b/services/appflowy-collaborate/src/group/group_init.rs @@ -1,5 +1,4 @@ use crate::error::RealtimeError; -use crate::indexer::IndexedCollab; use anyhow::anyhow; use app_error::AppError; use arc_swap::ArcSwap; @@ -18,9 +17,9 @@ use collab_rt_protocol::{Message, MessageReader, RTProtocolError, SyncMessage}; use collab_stream::client::CollabRedisStream; use collab_stream::collab_update_sink::{AwarenessUpdateSink, CollabUpdateSink}; -use crate::indexer::IndexerScheduler; use crate::metrics::CollabRealtimeMetrics; use bytes::Bytes; +use collab_document::document::DocumentBody; use collab_stream::error::StreamError; use collab_stream::model::{AwarenessStreamUpdate, CollabStreamUpdate, MessageId, UpdateFlags}; use dashmap::DashMap; @@ -28,12 +27,14 @@ use database::collab::{CollabStorage, GetCollabOrigin}; use database_entity::dto::{CollabParams, QueryCollabParams}; use futures::{pin_mut, Sink, Stream}; use futures_util::{SinkExt, StreamExt}; +use indexer::scheduler::{IndexerScheduler, UnindexedCollabTask, UnindexedData}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; use tokio::time::MissedTickBehavior; use tokio_util::sync::CancellationToken; use tracing::{error, info, trace}; +use uuid::Uuid; use yrs::updates::decoder::{Decode, DecoderV1}; use yrs::updates::encoder::{Encode, Encoder, EncoderV1}; use yrs::{ReadTxn, StateVector, Update}; @@ -366,7 +367,7 @@ impl CollabGroup { .state .persister .indexer_scheduler - .index_collab(workspace_id, object_id, &collab, collab_type) + .index_collab_immediately(&workspace_id, &object_id, &collab, &collab_type) .await } @@ -1142,6 +1143,20 @@ impl CollabPersister { let light_len = doc_state_light.len(); self.write_collab(doc_state_light).await?; + match self.collab_type { + CollabType::Document => { + let txn = collab.transact(); + if let Some(text) = DocumentBody::from_collab(collab) + .and_then(|body| body.to_plain_text(txn, false, true).ok()) + { + self.index_collab_content(text); + } + }, + _ => { + // TODO(nathan): support other collab type + }, + } + tracing::debug!( "persisted collab {} snapshot at {}: {} bytes", self.object_id, @@ -1176,36 +1191,34 @@ impl CollabPersister { .metrics .collab_size .observe(encoded_collab.len() as f64); - let params = CollabParams::new( - &self.object_id, - self.collab_type.clone(), - encoded_collab.clone(), - ); + let params = CollabParams::new(&self.object_id, self.collab_type.clone(), encoded_collab); self .storage .queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, true) .await .map_err(|err| RealtimeError::Internal(err.into()))?; - self.index_encoded_collab(encoded_collab); Ok(()) } - fn index_encoded_collab(&self, encoded_collab: Bytes) { - let indexed_collab = IndexedCollab { - object_id: self.object_id.clone(), - collab_type: self.collab_type.clone(), - encoded_collab, - }; - if let Err(err) = self - .indexer_scheduler - .index_encoded_collab_one(&self.workspace_id, indexed_collab) - { - tracing::warn!( - "failed to index collab `{}/{}`: {}", - self.workspace_id, - self.object_id, - err + fn index_collab_content(&self, text: String) { + if let Ok(workspace_id) = Uuid::parse_str(&self.workspace_id) { + let indexed_collab = UnindexedCollabTask::new( + workspace_id, + self.object_id.clone(), + self.collab_type.clone(), + UnindexedData::Text(text), ); + if let Err(err) = self + .indexer_scheduler + .index_pending_collab_one(indexed_collab, false) + { + tracing::warn!( + "failed to index collab `{}/{}`: {}", + self.workspace_id, + self.object_id, + err + ); + } } } diff --git a/services/appflowy-collaborate/src/group/manager.rs b/services/appflowy-collaborate/src/group/manager.rs index 9aa46840b..485caa0b0 100644 --- a/services/appflowy-collaborate/src/group/manager.rs +++ b/services/appflowy-collaborate/src/group/manager.rs @@ -20,8 +20,8 @@ use crate::client::client_msg_router::ClientMessageRouter; use crate::error::RealtimeError; use crate::group::group_init::CollabGroup; use crate::group::state::GroupManagementState; -use crate::indexer::IndexerScheduler; use crate::metrics::CollabRealtimeMetrics; +use indexer::scheduler::IndexerScheduler; pub struct GroupManager { state: GroupManagementState, diff --git a/services/appflowy-collaborate/src/group/persistence.rs b/services/appflowy-collaborate/src/group/persistence.rs new file mode 100644 index 000000000..2bce0308c --- /dev/null +++ b/services/appflowy-collaborate/src/group/persistence.rs @@ -0,0 +1,212 @@ +use std::sync::Arc; +use std::time::Duration; + +use crate::group::group_init::EditState; +use anyhow::anyhow; +use app_error::AppError; +use collab::lock::RwLock; +use collab::preclude::Collab; +use collab_document::document::DocumentBody; +use collab_entity::{validate_data_for_folder, CollabType}; +use database::collab::CollabStorage; +use database_entity::dto::CollabParams; +use indexer::scheduler::{IndexerScheduler, UnindexedCollabTask, UnindexedData}; +use tokio::time::interval; +use tokio_util::sync::CancellationToken; +use tracing::{trace, warn}; +use uuid::Uuid; + +pub(crate) struct GroupPersistence { + workspace_id: String, + object_id: String, + storage: Arc, + uid: i64, + edit_state: Arc, + /// Use Arc> instead of Weak> to make sure the collab is not dropped + /// when saving collab data to disk + collab: Arc>, + collab_type: CollabType, + persistence_interval: Duration, + indexer_scheduler: Arc, + cancel: CancellationToken, +} + +impl GroupPersistence +where + S: CollabStorage, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + workspace_id: String, + object_id: String, + uid: i64, + storage: Arc, + edit_state: Arc, + collab: Arc>, + collab_type: CollabType, + persistence_interval: Duration, + indexer_scheduler: Arc, + cancel: CancellationToken, + ) -> Self { + Self { + workspace_id, + object_id, + uid, + storage, + edit_state, + collab, + collab_type, + persistence_interval, + indexer_scheduler, + cancel, + } + } + + pub async fn run(self) { + let mut interval = interval(self.persistence_interval); + loop { + // delay 30 seconds before the first save. We don't want to save immediately after the collab is created + tokio::time::sleep(Duration::from_secs(30)).await; + tokio::select! { + _ = interval.tick() => { + if self.attempt_save().await.is_err() { + break; + } + }, + _ = self.cancel.cancelled() => { + self.force_save().await; + break; + } + } + } + } + + async fn force_save(&self) { + if self.edit_state.is_new_create() && self.save(true).await.is_ok() { + self.edit_state.set_is_new_create(false); + return; + } + + if !self.edit_state.is_edit() { + trace!("skip force save collab to disk: {}", self.object_id); + return; + } + + if let Err(err) = self.save(true).await { + warn!("fail to force save: {}:{:?}", self.object_id, err); + } + } + + /// return true if the collab has been dropped. Otherwise, return false + async fn attempt_save(&self) -> Result<(), AppError> { + trace!("collab:{} edit state: {}", self.object_id, self.edit_state); + + // Check if conditions for saving to disk are not met + let is_new = self.edit_state.is_new_create(); + if self.edit_state.should_save_to_disk() { + match self.save(is_new).await { + Ok(_) => { + if is_new { + self.edit_state.set_is_new_create(false); + } + }, + Err(err) => { + warn!("fail to write: {}:{}", self.object_id, err); + }, + } + } + Ok(()) + } + + async fn save(&self, flush_to_disk: bool) -> Result<(), AppError> { + let object_id = self.object_id.clone(); + let workspace_id = self.workspace_id.clone(); + let collab_type = self.collab_type.clone(); + + let cloned_collab = self.collab.clone(); + let indexer_scheduler = self.indexer_scheduler.clone(); + + let params = tokio::task::spawn_blocking(move || { + let collab = cloned_collab.blocking_read(); + let params = get_encode_collab(&workspace_id, &object_id, &collab, &collab_type)?; + match collab_type { + CollabType::Document => { + let txn = collab.transact(); + let text = DocumentBody::from_collab(&collab) + .and_then(|doc| doc.to_plain_text(txn, false, true).ok()); + + if let Some(text) = text { + let pending = UnindexedCollabTask::new( + Uuid::parse_str(&workspace_id)?, + object_id.clone(), + collab_type, + UnindexedData::UnindexedText(text), + ); + if let Err(err) = indexer_scheduler.index_pending_collab_one(pending, true) { + warn!("fail to index collab: {}:{}", object_id, err); + } + } + }, + _ => { + // TODO(nathan): support other collab types + }, + } + + Ok::<_, AppError>(params) + }) + .await??; + + self + .storage + .queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, flush_to_disk) + .await?; + // Update the edit state on successful save + self.edit_state.tick(); + Ok(()) + } +} + +/// Encodes collaboration parameters for a given workspace and object. +/// +/// This function attempts to encode collaboration details into a byte format based on the collaboration type. +/// It validates required data for the collaboration type before encoding. +/// If the collaboration type is `Folder`, it additionally checks for a workspace ID match. +/// +#[inline] +fn get_encode_collab( + workspace_id: &str, + object_id: &str, + collab: &Collab, + collab_type: &CollabType, +) -> Result { + // Attempt to encode collaboration data to version 1 bytes and validate required data. + let encoded_collab = collab + .encode_collab_v1(|c| collab_type.validate_require_data(c)) + .map_err(|err| { + AppError::Internal(anyhow!( + "Failed to encode collaboration to bytes: {:?}", + err + )) + })? + .encode_to_bytes() + .map_err(|err| { + AppError::Internal(anyhow!( + "Failed to serialize encoded collaboration to bytes: {:?}", + err + )) + })?; + + // Specific check for collaboration type 'Folder' to ensure workspace ID consistency. + if let CollabType::Folder = collab_type { + validate_data_for_folder(collab, workspace_id) + .map_err(|err| AppError::OverrideWithIncorrectData(err.to_string()))?; + } + + // Construct and return collaboration parameters. + let params = CollabParams { + object_id: object_id.to_string(), + encoded_collab_v1: encoded_collab.into(), + collab_type: collab_type.clone(), + }; + Ok(params) +} diff --git a/services/appflowy-collaborate/src/indexer/mod.rs b/services/appflowy-collaborate/src/indexer/mod.rs deleted file mode 100644 index 88b49af1d..000000000 --- a/services/appflowy-collaborate/src/indexer/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod document_indexer; -mod indexer_scheduler; -pub mod metrics; -mod open_ai; -mod provider; -mod vector; - -pub use document_indexer::DocumentIndexer; -pub use indexer_scheduler::*; -pub use provider::*; diff --git a/services/appflowy-collaborate/src/indexer/vector/open_ai.rs b/services/appflowy-collaborate/src/indexer/vector/open_ai.rs deleted file mode 100644 index 4f3efd6e2..000000000 --- a/services/appflowy-collaborate/src/indexer/vector/open_ai.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::indexer::vector::rest::check_response; -use anyhow::anyhow; -use app_error::AppError; -use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; -use serde::de::DeserializeOwned; -use std::time::Duration; - -pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; - -pub const REQUEST_PARALLELISM: usize = 40; - -#[derive(Debug, Clone)] -pub struct Embedder { - bearer: String, - client: ureq::Agent, -} - -impl Embedder { - pub fn new(api_key: String) -> Self { - let bearer = format!("Bearer {api_key}"); - let client = ureq::AgentBuilder::new() - .max_idle_connections(REQUEST_PARALLELISM * 2) - .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) - .build(); - - Self { bearer, client } - } - - pub fn embed(&self, params: EmbeddingRequest) -> Result { - for attempt in 0..3 { - let request = self - .client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .set("Content-Type", "application/json"); - - let result = check_response(request.send_json(¶ms)); - let retry_duration = match result { - Ok(response) => { - let data = from_response::(response)?; - return Ok(data); - }, - Err(retry) => retry.into_duration(attempt), - } - .map_err(|err| AppError::Internal(err.into()))?; - let retry_duration = retry_duration.min(Duration::from_secs(10)); - std::thread::sleep(retry_duration); - } - - Err(AppError::Internal(anyhow!( - "Failed to generate embeddings after 3 attempts" - ))) - } -} - -pub fn from_response(resp: ureq::Response) -> Result -where - T: DeserializeOwned, -{ - let status_code = resp.status(); - if status_code != 200 { - let body = resp.into_string()?; - anyhow::bail!("error code: {}, {}", status_code, body) - } - - let resp = resp.into_json()?; - Ok(resp) -} diff --git a/services/appflowy-collaborate/src/lib.rs b/services/appflowy-collaborate/src/lib.rs index f11dead2d..972fc075c 100644 --- a/services/appflowy-collaborate/src/lib.rs +++ b/services/appflowy-collaborate/src/lib.rs @@ -9,7 +9,6 @@ pub mod config; pub mod connect_state; pub mod error; pub mod group; -pub mod indexer; pub mod metrics; mod permission; mod pg_listener; diff --git a/services/appflowy-collaborate/src/rt_server.rs b/services/appflowy-collaborate/src/rt_server.rs index 0f4db262f..59f5917ad 100644 --- a/services/appflowy-collaborate/src/rt_server.rs +++ b/services/appflowy-collaborate/src/rt_server.rs @@ -18,8 +18,6 @@ use tracing::{error, info, trace, warn}; use yrs::updates::decoder::Decode; use yrs::StateVector; -use database::collab::CollabStorage; - use crate::client::client_msg_router::ClientMessageRouter; use crate::command::{spawn_collaboration_command, CLCommandReceiver}; use crate::config::get_env_var; @@ -27,8 +25,9 @@ use crate::connect_state::ConnectState; use crate::error::{CreateGroupFailedReason, RealtimeError}; use crate::group::cmd::{GroupCommand, GroupCommandRunner, GroupCommandSender}; use crate::group::manager::GroupManager; -use crate::indexer::IndexerScheduler; use crate::rt_server::collaboration_runtime::COLLAB_RUNTIME; +use database::collab::CollabStorage; +use indexer::scheduler::IndexerScheduler; use crate::actix_ws::entities::{ClientGenerateEmbeddingMessage, ClientHttpUpdateMessage}; use crate::{CollabRealtimeMetrics, RealtimeClientWebsocketSink}; diff --git a/services/appflowy-collaborate/src/state.rs b/services/appflowy-collaborate/src/state.rs index 400d44d6d..58eca5d22 100644 --- a/services/appflowy-collaborate/src/state.rs +++ b/services/appflowy-collaborate/src/state.rs @@ -6,18 +6,17 @@ use futures_util::StreamExt; use sqlx::PgPool; use uuid::Uuid; -use access_control::metrics::AccessControlMetrics; -use app_error::AppError; -use collab_stream::stream_router::StreamRouter; -use database::user::{select_all_uid_uuid, select_uid_from_uuid}; - use crate::collab::storage::CollabAccessControlStorage; use crate::config::Config; -use crate::indexer::metrics::EmbeddingMetrics; -use crate::indexer::IndexerScheduler; use crate::metrics::CollabMetrics; use crate::pg_listener::PgListeners; use crate::CollabRealtimeMetrics; +use access_control::metrics::AccessControlMetrics; +use app_error::AppError; +use collab_stream::stream_router::StreamRouter; +use database::user::{select_all_uid_uuid, select_uid_from_uuid}; +use indexer::metrics::EmbeddingMetrics; +use indexer::scheduler::IndexerScheduler; pub type RedisConnectionManager = redis::aio::ConnectionManager; diff --git a/services/appflowy-collaborate/src/telemetry.rs b/services/appflowy-collaborate/src/telemetry.rs index b1b632f98..0551bbe64 100644 --- a/services/appflowy-collaborate/src/telemetry.rs +++ b/services/appflowy-collaborate/src/telemetry.rs @@ -9,10 +9,18 @@ pub fn init_subscriber(app_env: &Environment) { START.call_once(|| { let level = std::env::var("RUST_LOG").unwrap_or("info".to_string()); let mut filters = vec![]; - filters.push(format!("appflowy_collaborate={}", level)); + filters.push(format!("actix_web={}", level)); filters.push(format!("collab={}", level)); + filters.push(format!("collab_sync={}", level)); + filters.push(format!("appflowy_cloud={}", level)); filters.push(format!("collab_plugins={}", level)); + filters.push(format!("realtime={}", level)); filters.push(format!("database={}", level)); + filters.push(format!("storage={}", level)); + filters.push(format!("gotrue={}", level)); + filters.push(format!("appflowy_collaborate={}", level)); + filters.push(format!("appflowy_ai_client={}", level)); + filters.push(format!("indexer={}", level)); let env_filter = EnvFilter::new(filters.join(",")); let builder = tracing_subscriber::fmt() diff --git a/services/appflowy-worker/Cargo.toml b/services/appflowy-worker/Cargo.toml index 63bbf8088..a448000b0 100644 --- a/services/appflowy-worker/Cargo.toml +++ b/services/appflowy-worker/Cargo.toml @@ -63,3 +63,9 @@ base64.workspace = true prometheus-client = "0.22.3" reqwest.workspace = true zstd.workspace = true +indexer.workspace = true +appflowy-collaborate = { path = "../appflowy-collaborate" } +rayon = "1.10.0" +app-error = { workspace = true, features = [ + "sqlx_error", +] } diff --git a/services/appflowy-worker/src/application.rs b/services/appflowy-worker/src/application.rs index b4f2642cc..923826656 100644 --- a/services/appflowy-worker/src/application.rs +++ b/services/appflowy-worker/src/application.rs @@ -15,10 +15,13 @@ use secrecy::ExposeSecret; use crate::mailer::AFWorkerMailer; use crate::metric::ImportMetrics; +use appflowy_worker::indexer_worker::{run_background_indexer, BackgroundIndexerConfig}; use axum::extract::State; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; +use indexer::metrics::EmbeddingMetrics; +use indexer::thread_pool::ThreadPoolNoAbortBuilder; use infra::env_util::get_env_var; use mailer::sender::Mailer; use std::sync::{Arc, Once}; @@ -124,6 +127,28 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err maximum_import_file_size, )); + let threads = Arc::new( + ThreadPoolNoAbortBuilder::new() + .num_threads(20) + .thread_name(|index| format!("background-embedding-thread-{index}")) + .build() + .unwrap(), + ); + + tokio::spawn(run_background_indexer( + state.pg_pool.clone(), + state.redis_client.clone(), + state.metrics.embedder_metrics.clone(), + threads.clone(), + BackgroundIndexerConfig { + enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true") + .parse::() + .unwrap_or(true), + open_api_key: appflowy_collaborate::config::get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""), + tick_interval_secs: 10, + }, + )); + let app = Router::new() .route("/metrics", get(metrics_handler)) .with_state(Arc::new(state)); @@ -212,15 +237,18 @@ pub struct AppMetrics { #[allow(dead_code)] registry: Arc, import_metrics: Arc, + embedder_metrics: Arc, } impl AppMetrics { pub fn new() -> Self { let mut registry = prometheus_client::registry::Registry::default(); let import_metrics = Arc::new(ImportMetrics::register(&mut registry)); + let embedder_metrics = Arc::new(EmbeddingMetrics::register(&mut registry)); Self { registry: Arc::new(registry), import_metrics, + embedder_metrics, } } } diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index b93a99e6d..3af4f7d68 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -59,7 +59,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::fs; use tokio::task::spawn_local; -use tokio::time::interval; +use tokio::time::{interval, MissedTickBehavior}; use tokio_util::compat::TokioAsyncReadCompatExt; use tracing::{error, info, trace, warn}; use uuid::Uuid; @@ -177,6 +177,7 @@ async fn process_upcoming_tasks( .group(group_name, consumer_name) .count(10); let mut interval = interval(Duration::from_secs(interval_secs)); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); interval.tick().await; loop { diff --git a/services/appflowy-worker/src/indexer_worker/mod.rs b/services/appflowy-worker/src/indexer_worker/mod.rs new file mode 100644 index 000000000..85d58e6b6 --- /dev/null +++ b/services/appflowy-worker/src/indexer_worker/mod.rs @@ -0,0 +1,2 @@ +mod worker; +pub use worker::*; diff --git a/services/appflowy-worker/src/indexer_worker/worker.rs b/services/appflowy-worker/src/indexer_worker/worker.rs new file mode 100644 index 000000000..fd0d20253 --- /dev/null +++ b/services/appflowy-worker/src/indexer_worker/worker.rs @@ -0,0 +1,247 @@ +use app_error::AppError; +use collab_entity::CollabType; +use database::index::get_collabs_indexed_at; +use indexer::collab_indexer::{Indexer, IndexerProvider}; +use indexer::entity::EmbeddingRecord; +use indexer::error::IndexerError; +use indexer::metrics::EmbeddingMetrics; +use indexer::queue::{ + ack_task, default_indexer_group_option, ensure_indexer_consumer_group, + read_background_embed_tasks, +}; +use indexer::scheduler::{spawn_pg_write_embeddings, UnindexedCollabTask, UnindexedData}; +use indexer::thread_pool::ThreadPoolNoAbort; +use indexer::vector::embedder::Embedder; +use indexer::vector::open_ai; +use rayon::prelude::*; +use redis::aio::ConnectionManager; +use sqlx::PgPool; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::RwLock; +use tokio::time::{interval, MissedTickBehavior}; +use tracing::{error, info, trace}; + +#[derive(Debug)] +pub struct BackgroundIndexerConfig { + pub enable: bool, + pub open_api_key: String, + pub tick_interval_secs: u64, +} + +pub async fn run_background_indexer( + pg_pool: PgPool, + mut redis_client: ConnectionManager, + embed_metrics: Arc, + threads: Arc, + config: BackgroundIndexerConfig, +) { + if !config.enable { + info!("Background indexer is disabled. Stop background indexer"); + return; + } + + if config.open_api_key.is_empty() { + error!("OpenAI API key is not set. Stop background indexer"); + return; + } + + let indexer_provider = IndexerProvider::new(); + info!("Starting background indexer..."); + if let Err(err) = ensure_indexer_consumer_group(&mut redis_client).await { + error!("Failed to ensure indexer consumer group: {:?}", err); + } + + let latest_write_embedding_err = Arc::new(RwLock::new(None)); + let (write_embedding_tx, write_embedding_rx) = unbounded_channel::(); + let write_embedding_task_fut = spawn_pg_write_embeddings( + write_embedding_rx, + pg_pool.clone(), + embed_metrics.clone(), + latest_write_embedding_err.clone(), + ); + + let process_tasks_task_fut = process_upcoming_tasks( + pg_pool, + &mut redis_client, + embed_metrics, + indexer_provider, + threads, + config, + write_embedding_tx, + latest_write_embedding_err, + ); + + tokio::select! { + _ = write_embedding_task_fut => { + error!("[Background Embedding] Write embedding task stopped"); + }, + _ = process_tasks_task_fut => { + error!("[Background Embedding] Process tasks task stopped"); + }, + } +} + +#[allow(clippy::too_many_arguments)] +async fn process_upcoming_tasks( + pg_pool: PgPool, + redis_client: &mut ConnectionManager, + metrics: Arc, + indexer_provider: Arc, + threads: Arc, + config: BackgroundIndexerConfig, + sender: UnboundedSender, + latest_write_embedding_err: Arc>>, +) { + let options = default_indexer_group_option(threads.current_num_threads()); + let mut interval = interval(Duration::from_secs(config.tick_interval_secs)); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + interval.tick().await; + + loop { + interval.tick().await; + + let latest_error = latest_write_embedding_err.write().await.take(); + if let Some(err) = latest_error { + if matches!(err, AppError::ActionTimeout(_)) { + info!( + "[Background Embedding] last write embedding task failed with timeout, waiting for 30s before retrying..." + ); + tokio::time::sleep(Duration::from_secs(15)).await; + } + } + + match read_background_embed_tasks(redis_client, &options).await { + Ok(replay) => { + let all_keys: Vec = replay + .keys + .iter() + .flat_map(|key| key.ids.iter().map(|stream_id| stream_id.id.clone())) + .collect(); + + for key in replay.keys { + info!( + "[Background Embedding] processing {} embedding tasks", + key.ids.len() + ); + + let mut tasks: Vec = key + .ids + .into_iter() + .filter_map(|stream_id| UnindexedCollabTask::try_from(&stream_id).ok()) + .collect(); + tasks.retain(|task| !task.data.is_empty()); + + let collab_ids: Vec<(String, CollabType)> = tasks + .iter() + .map(|task| (task.object_id.clone(), task.collab_type.clone())) + .collect(); + + let indexed_collabs = get_collabs_indexed_at(&pg_pool, collab_ids) + .await + .unwrap_or_default(); + + let all_tasks_len = tasks.len(); + if !indexed_collabs.is_empty() { + // Filter out tasks where `created_at` is less than `indexed_at` + tasks.retain(|task| { + indexed_collabs + .get(&task.object_id) + .map_or(true, |indexed_at| task.created_at > indexed_at.timestamp()) + }); + } + + if all_tasks_len != tasks.len() { + info!("[Background Embedding] filter out {} tasks where `created_at` is less than `indexed_at`", all_tasks_len - tasks.len()); + } + + let start = Instant::now(); + let num_tasks = tasks.len(); + tasks.into_par_iter().for_each(|task| { + let result = threads.install(|| { + if let Some(indexer) = indexer_provider.indexer_for(&task.collab_type) { + let embedder = create_embedder(&config); + let result = handle_task(embedder, indexer, task); + match result { + None => metrics.record_failed_embed_count(1), + Some(record) => { + metrics.record_embed_count(1); + trace!( + "[Background Embedding] send {} embedding record to write task", + record.object_id + ); + if let Err(err) = sender.send(record) { + trace!( + "[Background Embedding] failed to send embedding record to write task: {:?}", + err + ); + } + }, + } + } + }); + if let Err(err) = result { + error!( + "[Background Embedding] Failed to process embedder task: {:?}", + err + ); + } + }); + let cost = start.elapsed().as_millis(); + metrics.record_gen_embedding_time(num_tasks as u32, cost); + } + + if !all_keys.is_empty() { + match ack_task(redis_client, all_keys, true).await { + Ok(_) => trace!("[Background embedding]: delete tasks from stream"), + Err(err) => { + error!("[Background Embedding] Failed to ack tasks: {:?}", err); + }, + } + } + }, + Err(err) => { + error!("[Background Embedding] Failed to read tasks: {:?}", err); + if matches!(err, IndexerError::StreamGroupNotExist(_)) { + if let Err(err) = ensure_indexer_consumer_group(redis_client).await { + error!( + "[Background Embedding] Failed to ensure indexer consumer group: {:?}", + err + ); + } + } + }, + } + } +} + +fn handle_task( + embedder: Embedder, + indexer: Arc, + task: UnindexedCollabTask, +) -> Option { + trace!( + "[Background Embedding] processing task: {}, content:{:?}, collab_type: {}", + task.object_id, + task.data, + task.collab_type + ); + let chunks = match task.data { + UnindexedData::Text(text) => indexer + .create_embedded_chunks_from_text(task.object_id.clone(), text, embedder.model()) + .ok()?, + }; + let embeddings = indexer.embed(&embedder, chunks).ok()?; + embeddings.map(|embeddings| EmbeddingRecord { + workspace_id: task.workspace_id, + object_id: task.object_id, + collab_type: task.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }) +} + +fn create_embedder(config: &BackgroundIndexerConfig) -> Embedder { + Embedder::OpenAI(open_ai::Embedder::new(config.open_api_key.clone())) +} diff --git a/services/appflowy-worker/src/lib.rs b/services/appflowy-worker/src/lib.rs index b3f384e56..2dddf99e5 100644 --- a/services/appflowy-worker/src/lib.rs +++ b/services/appflowy-worker/src/lib.rs @@ -1,5 +1,6 @@ pub mod error; pub mod import_worker; +pub mod indexer_worker; mod mailer; pub mod metric; pub mod s3_client; diff --git a/src/api/util.rs b/src/api/util.rs index 3d48b1590..64bb6ea70 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use byteorder::{ByteOrder, LittleEndian}; use chrono::Utc; use collab_rt_entity::user::RealtimeUser; -use collab_rt_protocol::spawn_blocking_validate_encode_collab; +use collab_rt_protocol::validate_encode_collab; use database_entity::dto::CollabParams; use std::str::FromStr; use tokio_stream::StreamExt; @@ -119,13 +119,9 @@ pub trait CollabValidator { #[async_trait] impl CollabValidator for CollabParams { async fn check_encode_collab(&self) -> Result<(), AppError> { - spawn_blocking_validate_encode_collab( - &self.object_id, - &self.encoded_collab_v1, - &self.collab_type, - ) - .await - .map_err(|err| AppError::NoRequiredData(err.to_string())) + validate_encode_collab(&self.object_id, &self.encoded_collab_v1, &self.collab_type) + .await + .map_err(|err| AppError::NoRequiredData(err.to_string())) } } diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 6d6f8e776..a9193d52a 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -1,5 +1,5 @@ use crate::api::util::{client_version_from_headers, realtime_user_for_web_request, PayloadReader}; -use crate::api::util::{compress_type_from_header_value, device_id_from_headers, CollabValidator}; +use crate::api::util::{compress_type_from_header_value, device_id_from_headers}; use crate::api::ws::RealtimeServerAddr; use crate::biz; use crate::biz::collab::ops::{ @@ -32,23 +32,28 @@ use actix_web::{HttpRequest, Result}; use anyhow::{anyhow, Context}; use app_error::AppError; use appflowy_collaborate::actix_ws::entities::{ClientHttpStreamMessage, ClientHttpUpdateMessage}; -use appflowy_collaborate::indexer::IndexedCollab; use authentication::jwt::{Authorization, OptionalUserUuid, UserUuid}; use bytes::BytesMut; use chrono::{DateTime, Duration, Utc}; +use collab::core::collab::DataSource; +use collab::core::origin::CollabOrigin; +use collab::entity::EncodedCollab; +use collab::preclude::Collab; use collab_database::entity::FieldType; +use collab_document::document::Document; use collab_entity::CollabType; use collab_folder::timestamp; use collab_rt_entity::collab_proto::{CollabDocStateParams, PayloadCompressionType}; use collab_rt_entity::realtime_proto::HttpRealtimeMessage; use collab_rt_entity::user::RealtimeUser; use collab_rt_entity::RealtimeMessage; -use collab_rt_protocol::validate_encode_collab; +use collab_rt_protocol::collab_from_encode_collab; use database::collab::{CollabStorage, GetCollabOrigin}; use database::user::select_uid_from_email; use database_entity::dto::PublishCollabItem; use database_entity::dto::PublishInfo; use database_entity::dto::*; +use indexer::scheduler::{UnindexedCollabTask, UnindexedData}; use prost::Message as ProstMessage; use rayon::prelude::*; use sha2::{Digest, Sha256}; @@ -63,6 +68,7 @@ use tokio_tungstenite::tungstenite::Message; use tracing::{error, event, instrument, trace}; use uuid::Uuid; use validator::Validate; + pub const WORKSPACE_ID_PATH: &str = "workspace_id"; pub const COLLAB_OBJECT_ID_PATH: &str = "object_id"; @@ -706,7 +712,16 @@ async fn create_collab_handler( ); } - if let Err(err) = params.check_encode_collab().await { + let collab = collab_from_encode_collab(¶ms.object_id, ¶ms.encoded_collab_v1) + .await + .map_err(|err| { + AppError::NoRequiredData(format!( + "Failed to create collab from encoded collab: {}", + err + )) + })?; + + if let Err(err) = params.collab_type.validate_require_data(&collab) { return Err( AppError::NoRequiredData(format!( "collab doc state is not correct:{},{}", @@ -721,9 +736,19 @@ async fn create_collab_handler( .can_index_workspace(&workspace_id) .await? { - state - .indexer_scheduler - .index_encoded_collab_one(&workspace_id, IndexedCollab::from(¶ms))?; + if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) { + let workspace_id_uuid = + Uuid::parse_str(&workspace_id).map_err(|err| AppError::Internal(err.into()))?; + let pending = UnindexedCollabTask::new( + workspace_id_uuid, + params.object_id.clone(), + params.collab_type.clone(), + UnindexedData::Text(text), + ); + state + .indexer_scheduler + .index_pending_collab_one(pending, false)?; + } } let mut transaction = state @@ -759,7 +784,8 @@ async fn batch_create_collab_handler( req: HttpRequest, ) -> Result>> { let uid = state.user_cache.get_user_uid(&user_uuid).await?; - let workspace_id = workspace_id.into_inner().to_string(); + let workspace_id_uuid = workspace_id.into_inner(); + let workspace_id = workspace_id_uuid.to_string(); let compress_type = compress_type_from_header_value(req.headers())?; event!(tracing::Level::DEBUG, "start decompressing collab list"); @@ -791,7 +817,7 @@ async fn batch_create_collab_handler( } } // Perform decompression and processing in a Rayon thread pool - let collab_params_list = tokio::task::spawn_blocking(move || match compress_type { + let mut collab_params_list = tokio::task::spawn_blocking(move || match compress_type { CompressionType::Brotli { buffer_size } => offset_len_list .into_par_iter() .filter_map(|(offset, len)| { @@ -800,12 +826,31 @@ async fn batch_create_collab_handler( Ok(decompressed_data) => { let params = CollabParams::from_bytes(&decompressed_data).ok()?; if params.validate().is_ok() { - match validate_encode_collab( + let encoded_collab = + EncodedCollab::decode_from_bytes(¶ms.encoded_collab_v1).ok()?; + let collab = Collab::new_with_source( + CollabOrigin::Empty, ¶ms.object_id, - ¶ms.encoded_collab_v1, - ¶ms.collab_type, - ) { - Ok(_) => Some(params), + DataSource::DocStateV1(encoded_collab.doc_state.to_vec()), + vec![], + false, + ) + .ok()?; + + match params.collab_type.validate_require_data(&collab) { + Ok(_) => { + match params.collab_type { + CollabType::Document => { + let index_text = + Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)); + Some((Some(index_text), params)) + }, + _ => { + // TODO(nathan): support other types + Some((None, params)) + }, + } + }, Err(_) => None, } } else { @@ -829,30 +874,46 @@ async fn batch_create_collab_handler( let total_size = collab_params_list .iter() - .fold(0, |acc, x| acc + x.encoded_collab_v1.len()); + .fold(0, |acc, x| acc + x.1.encoded_collab_v1.len()); tracing::info!( "decompressed {} collab objects in {:?}", collab_params_list.len(), start.elapsed() ); - // if state - // .indexer_scheduler - // .can_index_workspace(&workspace_id) - // .await? - // { - // let indexed_collabs: Vec<_> = collab_params_list - // .iter() - // .filter(|p| state.indexer_scheduler.is_indexing_enabled(&p.collab_type)) - // .map(IndexedCollab::from) - // .collect(); - // - // if !indexed_collabs.is_empty() { - // state - // .indexer_scheduler - // .index_encoded_collabs(&workspace_id, indexed_collabs)?; - // } - // } + let mut pending_undexed_collabs = vec![]; + if state + .indexer_scheduler + .can_index_workspace(&workspace_id) + .await? + { + pending_undexed_collabs = collab_params_list + .iter_mut() + .filter(|p| { + state + .indexer_scheduler + .is_indexing_enabled(&p.1.collab_type) + }) + .flat_map(|value| match std::mem::take(&mut value.0) { + None => None, + Some(text) => text + .map(|text| { + UnindexedCollabTask::new( + workspace_id_uuid, + value.1.object_id.clone(), + value.1.collab_type.clone(), + UnindexedData::Text(text), + ) + }) + .ok(), + }) + .collect::>(); + } + + let collab_params_list = collab_params_list + .into_iter() + .map(|(_, params)| params) + .collect::>(); let start = Instant::now(); state @@ -866,6 +927,13 @@ async fn batch_create_collab_handler( total_size ); + // Must after batch_insert_new_collab + if !pending_undexed_collabs.is_empty() { + state + .indexer_scheduler + .index_pending_collabs(pending_undexed_collabs)?; + } + Ok(Json(AppResponse::Ok())) } @@ -1364,9 +1432,45 @@ async fn update_collab_handler( .can_index_workspace(&workspace_id) .await? { - state - .indexer_scheduler - .index_encoded_collab_one(&workspace_id, IndexedCollab::from(¶ms))?; + let workspace_id_uuid = + Uuid::parse_str(&workspace_id).map_err(|err| AppError::Internal(err.into()))?; + + match params.collab_type { + CollabType::Document => { + let collab = collab_from_encode_collab(¶ms.object_id, ¶ms.encoded_collab_v1) + .await + .map_err(|err| { + AppError::InvalidRequest(format!( + "Failed to create collab from encoded collab: {}", + err + )) + })?; + params + .collab_type + .validate_require_data(&collab) + .map_err(|err| { + AppError::NoRequiredData(format!( + "collab doc state is not correct:{},{}", + params.object_id, err + )) + })?; + + if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) { + let pending = UnindexedCollabTask::new( + workspace_id_uuid, + params.object_id.clone(), + params.collab_type.clone(), + UnindexedData::Text(text), + ); + state + .indexer_scheduler + .index_pending_collab_one(pending, false)?; + } + }, + _ => { + // TODO(nathan): support other collab type + }, + } } state diff --git a/src/application.rs b/src/application.rs index 2f30a12f4..8f76e2ec5 100644 --- a/src/application.rs +++ b/src/application.rs @@ -40,11 +40,12 @@ use appflowy_collaborate::actix_ws::server::RealtimeServerActor; use appflowy_collaborate::collab::cache::CollabCache; use appflowy_collaborate::collab::storage::CollabStorageImpl; use appflowy_collaborate::command::{CLCommandReceiver, CLCommandSender}; -use appflowy_collaborate::indexer::{IndexerConfiguration, IndexerProvider, IndexerScheduler}; use appflowy_collaborate::snapshot::SnapshotControl; use appflowy_collaborate::CollaborationServer; use collab_stream::stream_router::{StreamRouter, StreamRouterOptions}; use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage}; +use indexer::collab_indexer::IndexerProvider; +use indexer::scheduler::{IndexerConfiguration, IndexerScheduler}; use infra::env_util::get_env_var; use mailer::sender::Mailer; use snowflake::Snowflake; @@ -323,10 +324,16 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result() .unwrap_or(true), openai_api_key: get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""), + embedding_buffer_size: appflowy_collaborate::config::get_env_var( + "APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE", + "5000", + ) + .parse::() + .unwrap_or(5000), }; let indexer_scheduler = IndexerScheduler::new( IndexerProvider::new(), @@ -334,6 +341,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Result, AppResponseError> { - let embeddings = indexer_scheduler.embeddings(EmbeddingRequest { + let embeddings = indexer_scheduler.create_search_embeddings(EmbeddingRequest { input: EmbeddingInput::String(request.query.clone()), model: EmbeddingModel::TextEmbedding3Small.to_string(), encoding_format: EmbeddingEncodingFormat::Float, diff --git a/src/main.rs b/src/main.rs index 097647131..5b01c391a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ async fn main() -> anyhow::Result<()> { filters.push(format!("gotrue={}", level)); filters.push(format!("appflowy_collaborate={}", level)); filters.push(format!("appflowy_ai_client={}", level)); + filters.push(format!("indexer={}", level)); // Load environment variables from .env file dotenvy::dotenv().ok(); diff --git a/src/state.rs b/src/state.rs index ee365cca2..1659d43de 100644 --- a/src/state.rs +++ b/src/state.rs @@ -15,15 +15,14 @@ use app_error::AppError; use appflowy_ai_client::client::AppFlowyAIClient; use appflowy_collaborate::collab::cache::CollabCache; use appflowy_collaborate::collab::storage::CollabAccessControlStorage; -use appflowy_collaborate::indexer::metrics::EmbeddingMetrics; -use appflowy_collaborate::indexer::IndexerScheduler; use appflowy_collaborate::metrics::CollabMetrics; use appflowy_collaborate::CollabRealtimeMetrics; use collab_stream::stream_router::StreamRouter; use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage}; use database::user::{select_all_uid_uuid, select_uid_from_uuid}; use gotrue::grant::{Grant, PasswordGrant}; - +use indexer::metrics::EmbeddingMetrics; +use indexer::scheduler::IndexerScheduler; use snowflake::Snowflake; use tonic_proto::history::history_client::HistoryClient; diff --git a/tests/collab/collab_curd_test.rs b/tests/collab/collab_curd_test.rs index fd0435bcc..a0d334097 100644 --- a/tests/collab/collab_curd_test.rs +++ b/tests/collab/collab_curd_test.rs @@ -11,7 +11,7 @@ use reqwest::Method; use serde::Serialize; use serde_json::json; -use crate::collab::util::{generate_random_string, test_encode_collab_v1}; +use crate::collab::util::{empty_document_editor, generate_random_string, test_encode_collab_v1}; use client_api_test::TestClient; use shared_entity::response::AppResponse; use uuid::Uuid; @@ -50,77 +50,6 @@ async fn batch_insert_collab_with_empty_payload_test() { assert_eq!(error.code, ErrorCode::InvalidRequest); } -#[tokio::test] -async fn batch_insert_collab_success_test() { - let mut test_client = TestClient::new_user().await; - let workspace_id = test_client.workspace_id().await; - - let mut mock_encoded_collab = vec![]; - for _ in 0..200 { - let object_id = Uuid::new_v4().to_string(); - let encoded_collab_v1 = - test_encode_collab_v1(&object_id, "title", &generate_random_string(2 * 1024)); - mock_encoded_collab.push(encoded_collab_v1); - } - - for _ in 0..30 { - let object_id = Uuid::new_v4().to_string(); - let encoded_collab_v1 = - test_encode_collab_v1(&object_id, "title", &generate_random_string(10 * 1024)); - mock_encoded_collab.push(encoded_collab_v1); - } - - for _ in 0..10 { - let object_id = Uuid::new_v4().to_string(); - let encoded_collab_v1 = - test_encode_collab_v1(&object_id, "title", &generate_random_string(800 * 1024)); - mock_encoded_collab.push(encoded_collab_v1); - } - - let params_list = mock_encoded_collab - .iter() - .map(|encoded_collab_v1| CollabParams { - object_id: Uuid::new_v4().to_string(), - encoded_collab_v1: encoded_collab_v1.encode_to_bytes().unwrap().into(), - collab_type: CollabType::Unknown, - }) - .collect::>(); - - test_client - .create_collab_list(&workspace_id, params_list.clone()) - .await - .unwrap(); - - let params = params_list - .iter() - .map(|params| QueryCollab { - object_id: params.object_id.clone(), - collab_type: params.collab_type.clone(), - }) - .collect::>(); - - let result = test_client - .batch_get_collab(&workspace_id, params) - .await - .unwrap(); - - for params in params_list { - let encoded_collab = result.0.get(¶ms.object_id).unwrap(); - match encoded_collab { - QueryCollabResult::Success { encode_collab_v1 } => { - let actual = EncodedCollab::decode_from_bytes(encode_collab_v1.as_ref()).unwrap(); - let expected = EncodedCollab::decode_from_bytes(params.encoded_collab_v1.as_ref()).unwrap(); - assert_eq!(actual.doc_state, expected.doc_state); - }, - QueryCollabResult::Failed { error } => { - panic!("Failed to get collab: {:?}", error); - }, - } - } - - assert_eq!(result.0.values().len(), 240); -} - #[tokio::test] async fn create_collab_params_compatibility_serde_test() { // This test is to make sure that the CreateCollabParams is compatible with the old InsertCollabParams @@ -218,6 +147,69 @@ async fn create_collab_compatibility_with_json_params_test() { assert_eq!(encoded_collab, encoded_collab_from_server); } +#[tokio::test] +async fn batch_insert_document_collab_test() { + let mut test_client = TestClient::new_user().await; + let workspace_id = test_client.workspace_id().await; + + let num_collabs = 100; + let mut list = vec![]; + for _ in 0..num_collabs { + let object_id = Uuid::new_v4().to_string(); + let mut editor = empty_document_editor(&object_id); + let paragraphs = vec![ + generate_random_string(1), + generate_random_string(2), + generate_random_string(5), + ]; + editor.insert_paragraphs(paragraphs); + list.push((object_id, editor.encode_collab())); + } + + let params_list = list + .iter() + .map(|(object_id, encoded_collab_v1)| CollabParams { + object_id: object_id.clone(), + encoded_collab_v1: encoded_collab_v1.encode_to_bytes().unwrap().into(), + collab_type: CollabType::Document, + }) + .collect::>(); + + test_client + .create_collab_list(&workspace_id, params_list.clone()) + .await + .unwrap(); + + let params = params_list + .iter() + .map(|params| QueryCollab { + object_id: params.object_id.clone(), + collab_type: params.collab_type.clone(), + }) + .collect::>(); + + let result = test_client + .batch_get_collab(&workspace_id, params) + .await + .unwrap(); + + for params in params_list { + let encoded_collab = result.0.get(¶ms.object_id).unwrap(); + match encoded_collab { + QueryCollabResult::Success { encode_collab_v1 } => { + let actual = EncodedCollab::decode_from_bytes(encode_collab_v1.as_ref()).unwrap(); + let expected = EncodedCollab::decode_from_bytes(params.encoded_collab_v1.as_ref()).unwrap(); + assert_eq!(actual.doc_state, expected.doc_state); + }, + QueryCollabResult::Failed { error } => { + panic!("Failed to get collab: {:?}", error); + }, + } + } + + assert_eq!(result.0.values().len(), num_collabs); +} + #[derive(Debug, Clone, Serialize)] pub struct OldCreateCollabParams { #[serde(flatten)] diff --git a/tests/collab/storage_test.rs b/tests/collab/storage_test.rs index 56afea326..1bafc2160 100644 --- a/tests/collab/storage_test.rs +++ b/tests/collab/storage_test.rs @@ -199,7 +199,7 @@ async fn fail_insert_collab_with_empty_payload_test() { .create_collab(CreateCollabParams { object_id: Uuid::new_v4().to_string(), encoded_collab_v1: vec![], - collab_type: CollabType::Unknown, + collab_type: CollabType::Document, workspace_id: workspace_id.clone(), }) .await