diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index d1393bc7c8..e5b07493b0 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -56,6 +56,9 @@ def __new__(cls, *args: Any, **kwargs: Any): deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) search_vector: Optional[str] = Field(sa_column=sa.Column(pg.TSVECTOR(), nullable=True)) + search_vector_update_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, index=True) + ) review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) review_result: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=True)) diff --git a/backend/oasst_backend/models/message_revision.py b/backend/oasst_backend/models/message_revision.py index 18edc9f28c..fbd1656676 100644 --- a/backend/oasst_backend/models/message_revision.py +++ b/backend/oasst_backend/models/message_revision.py @@ -22,7 +22,9 @@ class MessageRevision(SQLModel, table=True): message_id: UUID = Field(sa_column=sa.Column(sa.ForeignKey("message.id"), nullable=False, index=True)) user_id: Optional[UUID] = Field(sa_column=sa.Column(sa.ForeignKey("user.id"), nullable=True)) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp()) + sa_column=sa.Column( + sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp(), index=True + ) ) _user_is_author: Optional[bool] = PrivateAttr(default=None) diff --git a/backend/oasst_backend/scheduled_tasks.py b/backend/oasst_backend/scheduled_tasks.py index 4ae1ea1264..2fe2c44528 100644 --- a/backend/oasst_backend/scheduled_tasks.py +++ b/backend/oasst_backend/scheduled_tasks.py @@ -7,14 +7,14 @@ from celery import shared_task from loguru import logger from oasst_backend.celery_worker import app -from oasst_backend.models import ApiClient, Message, User +from oasst_backend.models import ApiClient, Message, MessageRevision, User from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils.database_utils import db_lang_to_postgres_ts_lang, default_session_factory from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.utils import log_timing, utcnow from sqlalchemy import func -from sqlmodel import update +from sqlmodel import and_, or_, update async def useHFApi(text, url, model_name): @@ -73,7 +73,29 @@ def update_search_vectors(batch_size: int) -> None: with default_session_factory() as session: while True: to_update: list[Message] = ( - session.query(Message).filter(Message.search_vector.is_(None)).limit(batch_size).all() + session.query(Message) + .outerjoin( + MessageRevision, + and_( + Message.id == MessageRevision.message_id, + MessageRevision.created_date + == session.query(func.max(MessageRevision.created_date)) + .filter(MessageRevision.message_id == Message.id) + .as_scalar(), + ), + ) + .filter( + or_( + Message.search_vector.is_(None), + MessageRevision.created_date > Message.search_vector_update_date, + and_( + Message.search_vector_update_date.is_(None), + MessageRevision.created_date.isnot(None), + ), + ) + ) + .limit(batch_size) + .all() ) if not to_update: @@ -83,6 +105,7 @@ def update_search_vectors(batch_size: int) -> None: message_payload: MessagePayload = message.payload.payload message_lang: str = db_lang_to_postgres_ts_lang(message.lang) message.search_vector = func.to_tsvector(message_lang, message_payload.text) + message.search_vector_update_date = utcnow() session.commit() except Exception as e: