Skip to content

Commit

Permalink
fix: obtain unsearchable views before searching for embeddings (#1140)
Browse files Browse the repository at this point in the history
  • Loading branch information
khorshuheng authored Jan 8, 2025
1 parent b47a635 commit 0ca22f7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 42 deletions.
18 changes: 9 additions & 9 deletions libs/database/src/index/search_ops.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::ops::DerefMut;

use chrono::{DateTime, Utc};
use pgvector::Vector;
use sqlx::Transaction;
use sqlx::{Executor, Postgres};
use uuid::Uuid;

/// Logs each search request to track usage by workspace. It either inserts a new record or updates
Expand All @@ -12,8 +10,8 @@ use uuid::Uuid;
/// Searches and retrieves documents based on their similarity to a given search embedding.
/// It filters by workspace, user access, and document status, and returns a limited number
/// of the most relevant documents, sorted by similarity score.
pub async fn search_documents(
tx: &mut Transaction<'_, sqlx::Postgres>,
pub async fn search_documents<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
params: SearchDocumentParams,
tokens_used: u32,
) -> Result<Vec<SearchDocumentItem>, sqlx::Error> {
Expand All @@ -38,9 +36,8 @@ pub async fn search_documents(
em.embedding <=> $3 AS score
FROM af_collab_embeddings em
JOIN af_collab collab ON em.oid = collab.oid AND em.partition_key = collab.partition_key
JOIN af_workspace_member member ON collab.workspace_id = member.workspace_id
JOIN af_user u ON collab.owner_uid = u.uid
WHERE member.uid = $1 AND collab.workspace_id = $2 AND collab.deleted_at IS NULL
WHERE collab.workspace_id = $2 AND NOT(collab.oid = ANY($7::text[]))
ORDER BY em.embedding <=> $3
LIMIT $5
"#,
Expand All @@ -50,8 +47,9 @@ pub async fn search_documents(
.bind(Vector::from(params.embedding))
.bind(params.preview)
.bind(params.limit)
.bind(tokens_used as i64);
let rows = query.fetch_all(tx.deref_mut()).await?;
.bind(tokens_used as i64)
.bind(params.non_viewable_view_ids);
let rows = query.fetch_all(executor).await?;
Ok(rows)
}

Expand All @@ -67,6 +65,8 @@ pub struct SearchDocumentParams {
pub preview: i32,
/// Embedding of the query - generated by OpenAI embedder.
pub embedding: Vec<f32>,
/// List of view ids which is not supposed to be returned in the search results.
pub non_viewable_view_ids: Vec<String>,
}

#[derive(Debug, Clone, sqlx::FromRow)]
Expand Down
14 changes: 13 additions & 1 deletion src/biz/collab/folder_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,23 @@ pub fn private_and_nonviewable_view_ids(folder: &Folder) -> PrivateAndNonviewabl
if check_if_view_is_space(&private_view) && !my_private_view_ids.contains(&private_section.id)
{
nonviewable_view_ids.insert(private_section.id);
let private_view_ids_in_space: HashSet<String> = folder
.get_views_belong_to(&private_view.id)
.iter()
.map(|v| v.id.clone())
.collect();
nonviewable_view_ids.extend(private_view_ids_in_space);
}
}
}
for trash_view in folder.get_all_trash_sections() {
nonviewable_view_ids.insert(trash_view.id);
nonviewable_view_ids.insert(trash_view.id.clone());
let child_views_for_trash: HashSet<String> = folder
.get_views_belong_to(&trash_view.id)
.iter()
.map(|v| v.id.clone())
.collect();
nonviewable_view_ids.extend(child_views_for_trash);
}
PrivateAndNonviewableViews {
my_private_view_ids,
Expand Down
64 changes: 32 additions & 32 deletions src/biz/search/ops.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use crate::api::metrics::RequestMetrics;
use crate::biz::collab::folder_view::{
check_if_view_ancestors_fulfil_condition, private_and_nonviewable_view_ids,
};
use crate::biz::collab::folder_view::private_and_nonviewable_view_ids;
use crate::biz::collab::utils::get_latest_collab_folder;
use app_error::ErrorCode;
use app_error::AppError;
use appflowy_ai_client::dto::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest,
};
use appflowy_collaborate::collab::storage::CollabAccessControlStorage;
use database::collab::GetCollabOrigin;
use itertools::Itertools;
use std::collections::HashSet;
use std::sync::Arc;

use database::index::{search_documents, SearchDocumentParams};
use shared_entity::dto::search_dto::{
SearchContentType, SearchDocumentRequest, SearchDocumentResponseItem,
};
use shared_entity::response::AppResponseError;
use sqlx::PgPool;

use indexer::scheduler::IndexerScheduler;
Expand All @@ -29,7 +28,7 @@ pub async fn search_document(
workspace_id: Uuid,
request: SearchDocumentRequest,
metrics: &RequestMetrics,
) -> Result<Vec<SearchDocumentResponseItem>, AppResponseError> {
) -> Result<Vec<SearchDocumentResponseItem>, AppError> {
let embeddings = indexer_scheduler
.create_search_embeddings(EmbeddingRequest {
input: EmbeddingInput::String(request.query.clone()),
Expand All @@ -49,34 +48,49 @@ pub async fn search_document(
let embedding = embeddings
.data
.first()
.ok_or_else(|| AppResponseError::new(ErrorCode::Internal, "OpenAI returned no embeddings"))?;
.ok_or_else(|| AppError::Internal(anyhow::anyhow!("OpenAI returned no embeddings")))?;
let embedding = match &embedding.embedding {
EmbeddingOutput::Float(vector) => vector.iter().map(|&v| v as f32).collect(),
EmbeddingOutput::Base64(_) => {
return Err(AppResponseError::new(
ErrorCode::Internal,
"OpenAI returned embeddings in unsupported format",
))
return Err(AppError::Internal(anyhow::anyhow!(
"OpenAI returned embeddings in unsupported format"
)))
},
};

let mut tx = pg_pool
.begin()
.await
.map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?;
let folder = get_latest_collab_folder(
collab_storage,
GetCollabOrigin::User { uid },
&workspace_id.to_string(),
)
.await?;
let private_and_nonviewable_views = private_and_nonviewable_view_ids(&folder);
let space_ids: HashSet<String> = folder
.get_view(&workspace_id.to_string())
.ok_or_else(|| AppError::Internal(anyhow::anyhow!("Workspace view not found in folder")))?
.children
.iter()
.map(|c| c.id.clone())
.collect();

let mut non_searchable_view_ids = private_and_nonviewable_views.nonviewable_view_ids;
non_searchable_view_ids.extend(space_ids);
let results = search_documents(
&mut tx,
pg_pool,
SearchDocumentParams {
user_id: uid,
workspace_id,
limit: request.limit.unwrap_or(10) as i32,
preview: request.preview_size.unwrap_or(500) as i32,
embedding,
non_viewable_view_ids: non_searchable_view_ids
.iter()
.map(|uuid| uuid.to_string())
.collect_vec(),
},
total_tokens,
)
.await?;
tx.commit().await?;
tracing::trace!(
"user {} search request in workspace {} returned {} results for query: `{}`",
uid,
Expand All @@ -85,22 +99,8 @@ pub async fn search_document(
request.query
);

let folder = get_latest_collab_folder(
collab_storage,
GetCollabOrigin::User { uid },
&workspace_id.to_string(),
)
.await?;
let private_and_nonviewable_views = private_and_nonviewable_view_ids(&folder);
let non_searchable_view_ids = private_and_nonviewable_views.nonviewable_view_ids;
let filtered_results = results.into_iter().filter(|item| {
!check_if_view_ancestors_fulfil_condition(&item.object_id, &folder, |view| {
non_searchable_view_ids.contains(&view.id)
})
});

Ok(
filtered_results
results
.into_iter()
.map(|item| SearchDocumentResponseItem {
object_id: item.object_id,
Expand Down

0 comments on commit 0ca22f7

Please sign in to comment.