diff --git a/service/app/core/chat/langchain.py b/service/app/core/chat/langchain.py index 7df80af1..3b2d93b1 100644 --- a/service/app/core/chat/langchain.py +++ b/service/app/core/chat/langchain.py @@ -26,6 +26,7 @@ GeneratedFileHandler, StreamContext, StreamingEventHandler, + ThinkingEventHandler, TokenStreamProcessor, ToolEventHandler, ) @@ -301,7 +302,7 @@ async def _handle_updates_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[ async def _handle_messages_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[StreamingEvent, None]: - """Handle 'messages' mode events (token streaming).""" + """Handle 'messages' mode events (token streaming and thinking content).""" if not isinstance(data, tuple): return @@ -331,7 +332,27 @@ async def _handle_messages_mode(data: Any, ctx: StreamContext) -> AsyncGenerator if node and node not in ("model", "agent"): return - # Extract and emit token + # Check for thinking content first (from reasoning models like Claude, DeepSeek R1, Gemini 3) + thinking_content = ThinkingEventHandler.extract_thinking_content(message_chunk) + + if thinking_content: + # Start thinking if not already + if not ctx.is_thinking: + logger.debug("Emitting thinking_start for stream_id=%s", ctx.stream_id) + ctx.is_thinking = True + yield ThinkingEventHandler.create_thinking_start(ctx.stream_id) + + ctx.thinking_buffer.append(thinking_content) + yield ThinkingEventHandler.create_thinking_chunk(ctx.stream_id, thinking_content) + return + + # If we were thinking but now have regular content, end thinking first + if ctx.is_thinking: + logger.debug("Emitting thinking_end for stream_id=%s", ctx.stream_id) + ctx.is_thinking = False + yield ThinkingEventHandler.create_thinking_end(ctx.stream_id) + + # Extract and emit token for regular streaming token_text = TokenStreamProcessor.extract_token_text(message_chunk) if not token_text: return @@ -347,6 +368,12 @@ async def _handle_messages_mode(data: Any, ctx: StreamContext) -> AsyncGenerator async def _finalize_streaming(ctx: StreamContext) -> AsyncGenerator[StreamingEvent, None]: """Finalize the streaming session.""" + # If still thinking when finalizing, emit thinking_end + if ctx.is_thinking: + logger.debug("Emitting thinking_end (in finalize) for stream_id=%s", ctx.stream_id) + ctx.is_thinking = False + yield ThinkingEventHandler.create_thinking_end(ctx.stream_id) + if ctx.is_streaming: logger.debug( "Emitting streaming_end for stream_id=%s (total tokens: %d)", diff --git a/service/app/core/chat/stream_handlers.py b/service/app/core/chat/stream_handlers.py index 304cf583..44e93762 100644 --- a/service/app/core/chat/stream_handlers.py +++ b/service/app/core/chat/stream_handlers.py @@ -25,6 +25,9 @@ StreamingEndData, StreamingEvent, StreamingStartData, + ThinkingChunkData, + ThinkingEndData, + ThinkingStartData, TokenUsageData, ToolCallRequestData, ToolCallResponseData, @@ -55,6 +58,9 @@ class StreamContext: total_input_tokens: int = 0 total_output_tokens: int = 0 total_tokens: int = 0 + # Thinking/reasoning content state + is_thinking: bool = False + thinking_buffer: list[str] = field(default_factory=list) class ToolEventHandler: @@ -79,7 +85,7 @@ def create_tool_request_event(tool_call: dict[str, Any]) -> StreamingEvent: "status": ToolCallStatus.EXECUTING, "timestamp": asyncio.get_event_loop().time(), } - return {"type": ChatEventType.TOOL_CALL_REQUEST, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.TOOL_CALL_REQUEST, "data": data} @staticmethod def create_tool_response_event( @@ -101,7 +107,7 @@ def create_tool_response_event( "status": status, "result": result, } - return {"type": ChatEventType.TOOL_CALL_RESPONSE, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.TOOL_CALL_RESPONSE, "data": data} class StreamingEventHandler: @@ -111,13 +117,13 @@ class StreamingEventHandler: def create_streaming_start(stream_id: str) -> StreamingEvent: """Create streaming start event.""" data: StreamingStartData = {"id": stream_id} - return {"type": ChatEventType.STREAMING_START, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.STREAMING_START, "data": data} @staticmethod def create_streaming_chunk(stream_id: str, content: str) -> StreamingEvent: """Create streaming chunk event.""" data: StreamingChunkData = {"id": stream_id, "content": content} - return {"type": ChatEventType.STREAMING_CHUNK, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.STREAMING_CHUNK, "data": data} @staticmethod def create_streaming_end(stream_id: str) -> StreamingEvent: @@ -126,7 +132,7 @@ def create_streaming_end(stream_id: str) -> StreamingEvent: "id": stream_id, "created_at": asyncio.get_event_loop().time(), } - return {"type": ChatEventType.STREAMING_END, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.STREAMING_END, "data": data} @staticmethod def create_token_usage_event(input_tokens: int, output_tokens: int, total_tokens: int) -> StreamingEvent: @@ -136,17 +142,102 @@ def create_token_usage_event(input_tokens: int, output_tokens: int, total_tokens "output_tokens": output_tokens, "total_tokens": total_tokens, } - return {"type": ChatEventType.TOKEN_USAGE, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.TOKEN_USAGE, "data": data} @staticmethod def create_processing_event(status: str = ProcessingStatus.PREPARING_REQUEST) -> StreamingEvent: """Create processing status event.""" - return {"type": ChatEventType.PROCESSING, "data": {"status": status}} # type: ignore[return-value] + return {"type": ChatEventType.PROCESSING, "data": {"status": status}} @staticmethod def create_error_event(error: str) -> StreamingEvent: """Create error event.""" - return {"type": ChatEventType.ERROR, "data": {"error": error}} # type: ignore[return-value] + return {"type": ChatEventType.ERROR, "data": {"error": error}} + + +class ThinkingEventHandler: + """Handle thinking/reasoning content streaming events.""" + + @staticmethod + def create_thinking_start(stream_id: str) -> StreamingEvent: + """Create thinking start event.""" + data: ThinkingStartData = {"id": stream_id} + return {"type": ChatEventType.THINKING_START, "data": data} + + @staticmethod + def create_thinking_chunk(stream_id: str, content: str) -> StreamingEvent: + """Create thinking chunk event.""" + data: ThinkingChunkData = {"id": stream_id, "content": content} + return {"type": ChatEventType.THINKING_CHUNK, "data": data} + + @staticmethod + def create_thinking_end(stream_id: str) -> StreamingEvent: + """Create thinking end event.""" + data: ThinkingEndData = {"id": stream_id} + return {"type": ChatEventType.THINKING_END, "data": data} + + @staticmethod + def extract_thinking_content(message_chunk: Any) -> str | None: + """ + Extract thinking/reasoning content from message chunk. + + Checks various provider-specific locations: + - Anthropic Claude: content blocks with type="thinking" + - DeepSeek R1: additional_kwargs.reasoning_content + - Gemini 3: content blocks with type="thought" or response_metadata.reasoning + - Generic: response_metadata.reasoning_content or thinking + + Args: + message_chunk: Message chunk from LLM streaming + + Returns: + Extracted thinking content or None + """ + # Check for DeepSeek/OpenAI style reasoning_content in additional_kwargs + if hasattr(message_chunk, "additional_kwargs"): + additional_kwargs = message_chunk.additional_kwargs + if isinstance(additional_kwargs, dict): + reasoning = additional_kwargs.get("reasoning_content") + if reasoning: + logger.debug("Found thinking in additional_kwargs.reasoning_content") + return reasoning + + # Check for thinking/thought blocks in content (Anthropic, Gemini 3) + if hasattr(message_chunk, "content"): + content = message_chunk.content + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + block_type = block.get("type", "") + # Anthropic Claude uses "thinking" type + if block_type == "thinking": + thinking_text = block.get("thinking", "") + if thinking_text: + logger.debug("Found thinking in content block type='thinking'") + return thinking_text + # Gemini 3 uses "thought" type + elif block_type == "thought": + thought_text = block.get("thought", "") or block.get("text", "") + if thought_text: + logger.debug("Found thinking in content block type='thought'") + return thought_text + + # Check response_metadata for thinking content + if hasattr(message_chunk, "response_metadata"): + metadata = message_chunk.response_metadata + if isinstance(metadata, dict): + # Gemini 3 uses "reasoning" key + thinking = ( + metadata.get("thinking") + or metadata.get("reasoning_content") + or metadata.get("reasoning") + or metadata.get("thoughts") + ) + if thinking: + logger.debug("Found thinking in response_metadata: %s", list(metadata.keys())) + return thinking + + return None class CitationExtractor: @@ -264,7 +355,7 @@ def _deduplicate_citations(citations: list[CitationData]) -> list[CitationData]: def create_citations_event(citations: list[CitationData]) -> StreamingEvent: """Create search citations event.""" data: SearchCitationsData = {"citations": citations} - return {"type": ChatEventType.SEARCH_CITATIONS, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.SEARCH_CITATIONS, "data": data} class GeneratedFileHandler: @@ -397,7 +488,7 @@ async def process_generated_content( def create_generated_files_event(files: list[GeneratedFileInfo]) -> StreamingEvent: """Create generated files event.""" data: GeneratedFilesData = {"files": files} - return {"type": ChatEventType.GENERATED_FILES, "data": data} # type: ignore[return-value] + return {"type": ChatEventType.GENERATED_FILES, "data": data} class TokenStreamProcessor: diff --git a/service/app/core/providers/factory.py b/service/app/core/providers/factory.py index 4802647d..eab6bbf8 100644 --- a/service/app/core/providers/factory.py +++ b/service/app/core/providers/factory.py @@ -104,7 +104,6 @@ def _create_google(self, model: str, credentials: LLMCredentials, runtime_kwargs # Extract google_search_enabled from runtime_kwargs google_search_enabled = runtime_kwargs.pop("google_search_enabled", False) - # Create the base model llm = ChatGoogleGenerativeAI( model=model, google_api_key=credentials["api_key"], diff --git a/service/app/models/message.py b/service/app/models/message.py index e3a8f62e..0a6be71d 100644 --- a/service/app/models/message.py +++ b/service/app/models/message.py @@ -16,6 +16,8 @@ class MessageBase(SQLModel): role: str content: str topic_id: UUID = Field(index=True) + # Thinking/reasoning content from models like Claude, DeepSeek R1, Gemini 3 + thinking_content: str | None = None class Message(MessageBase, table=True): @@ -63,3 +65,4 @@ class MessageReadWithFilesAndCitations(MessageBase): class MessageUpdate(SQLModel): role: str | None = None content: str | None = None + thinking_content: str | None = None diff --git a/service/app/repos/message.py b/service/app/repos/message.py index 4a706e0a..f8350e74 100644 --- a/service/app/repos/message.py +++ b/service/app/repos/message.py @@ -395,6 +395,7 @@ async def get_messages_with_files_and_citations( created_at=message.created_at, attachments=file_reads_with_urls, citations=citations, + thinking_content=message.thinking_content, ) messages_with_files_and_citations.append(message_with_files_and_citations) diff --git a/service/app/schemas/chat_event_types.py b/service/app/schemas/chat_event_types.py index a0ceb2f0..2023d889 100644 --- a/service/app/schemas/chat_event_types.py +++ b/service/app/schemas/chat_event_types.py @@ -150,6 +150,25 @@ class InsufficientBalanceData(TypedDict): action_required: str +class ThinkingStartData(TypedDict): + """Data payload for THINKING_START event.""" + + id: str + + +class ThinkingChunkData(TypedDict): + """Data payload for THINKING_CHUNK event.""" + + id: str + content: str + + +class ThinkingEndData(TypedDict): + """Data payload for THINKING_END event.""" + + id: str + + # ============================================================================= # Full Event Structures (type + data) # ============================================================================= @@ -253,6 +272,27 @@ class InsufficientBalanceEvent(TypedDict): data: InsufficientBalanceData +class ThinkingStartEvent(TypedDict): + """Full event structure for thinking start.""" + + type: Literal[ChatEventType.THINKING_START] + data: ThinkingStartData + + +class ThinkingChunkEvent(TypedDict): + """Full event structure for thinking chunk.""" + + type: Literal[ChatEventType.THINKING_CHUNK] + data: ThinkingChunkData + + +class ThinkingEndEvent(TypedDict): + """Full event structure for thinking end.""" + + type: Literal[ChatEventType.THINKING_END] + data: ThinkingEndData + + # ============================================================================= # Union type for generic event handling # ============================================================================= @@ -273,6 +313,9 @@ class InsufficientBalanceEvent(TypedDict): | MessageSavedEvent | MessageEvent | InsufficientBalanceEvent + | ThinkingStartEvent + | ThinkingChunkEvent + | ThinkingEndEvent ) @@ -294,6 +337,9 @@ class InsufficientBalanceEvent(TypedDict): "MessageSavedData", "MessageData", "InsufficientBalanceData", + "ThinkingStartData", + "ThinkingChunkData", + "ThinkingEndData", # Event types "StreamingStartEvent", "StreamingChunkEvent", @@ -309,6 +355,9 @@ class InsufficientBalanceEvent(TypedDict): "MessageSavedEvent", "MessageEvent", "InsufficientBalanceEvent", + "ThinkingStartEvent", + "ThinkingChunkEvent", + "ThinkingEndEvent", # Union "StreamingEvent", ] diff --git a/service/app/schemas/chat_events.py b/service/app/schemas/chat_events.py index ad5f3adf..5207ee45 100644 --- a/service/app/schemas/chat_events.py +++ b/service/app/schemas/chat_events.py @@ -47,6 +47,11 @@ class ChatEventType(StrEnum): # Balance/billing events INSUFFICIENT_BALANCE = "insufficient_balance" + # Thinking/reasoning content (for models like Claude, DeepSeek R1, OpenAI o1) + THINKING_START = "thinking_start" + THINKING_CHUNK = "thinking_chunk" + THINKING_END = "thinking_end" + class ChatClientEventType(StrEnum): """Client -> Server event types (messages coming from the frontend).""" diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index a5be161f..7daf5677 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -145,6 +145,7 @@ async def _process_chat_message_async( ai_message_id = None ai_message_obj: Message | None = None full_content = "" + full_thinking_content = "" # Track thinking content for persistence citations_data: List[CitationData] = [] generated_files_count = 0 @@ -286,6 +287,23 @@ async def _process_chat_message_async( elif stream_event["type"] == ChatEventType.ERROR: await publisher.publish(json.dumps(stream_event)) break + + # Handle thinking events + elif stream_event["type"] == ChatEventType.THINKING_START: + # Create message object if not exists + if not ai_message_obj: + ai_message_create = MessageCreate(role="assistant", content="", topic_id=topic_id) + ai_message_obj = await message_repo.create_message(ai_message_create) + await publisher.publish(json.dumps(stream_event)) + + elif stream_event["type"] == ChatEventType.THINKING_CHUNK: + chunk_content = stream_event["data"].get("content", "") + full_thinking_content += chunk_content + await publisher.publish(json.dumps(stream_event)) + + elif stream_event["type"] == ChatEventType.THINKING_END: + await publisher.publish(json.dumps(stream_event)) + else: await publisher.publish(json.dumps(stream_event)) @@ -296,6 +314,11 @@ async def _process_chat_message_async( ai_message_obj.content = full_content db.add(ai_message_obj) + # Update thinking content + if full_thinking_content: + ai_message_obj.thinking_content = full_thinking_content + db.add(ai_message_obj) + # Save citations if citations_data: try: diff --git a/service/migrations/versions/03630403f8c2_add_thinking_content_to_message.py b/service/migrations/versions/03630403f8c2_add_thinking_content_to_message.py new file mode 100644 index 00000000..6a9b0356 --- /dev/null +++ b/service/migrations/versions/03630403f8c2_add_thinking_content_to_message.py @@ -0,0 +1,34 @@ +"""add_thinking_content_to_message + +Revision ID: 03630403f8c2 +Revises: 70ee7fb4d40b +Create Date: 2026-01-04 15:46:02.555769 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "03630403f8c2" +down_revision: Union[str, Sequence[str], None] = "70ee7fb4d40b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message", sa.Column("thinking_content", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "thinking_content") + # ### end Alembic commands ### diff --git a/service/tests/factories/file.py b/service/tests/factories/file.py new file mode 100644 index 00000000..1b586d15 --- /dev/null +++ b/service/tests/factories/file.py @@ -0,0 +1,24 @@ +from polyfactory.factories.pydantic_factory import ModelFactory + +from app.models.file import File, FileCreate + + +class FileFactory(ModelFactory[File]): + """Factory for File model.""" + + __model__ = File + + +class FileCreateFactory(ModelFactory[FileCreate]): + """Factory for FileCreate schema.""" + + __model__ = FileCreate + + scope = "private" + category = "documents" + is_deleted = False + status = "pending" + message_id = None + folder_id = None + metainfo = None + file_hash = None diff --git a/service/tests/factories/message.py b/service/tests/factories/message.py new file mode 100644 index 00000000..84613278 --- /dev/null +++ b/service/tests/factories/message.py @@ -0,0 +1,22 @@ +from uuid import uuid4 + +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory + +from app.models.message import Message, MessageCreate + + +class MessageFactory(ModelFactory[Message]): + """Factory for Message model.""" + + __model__ = Message + + +class MessageCreateFactory(ModelFactory[MessageCreate]): + """Factory for MessageCreate schema.""" + + __model__ = MessageCreate + + role = "user" + topic_id = Use(uuid4) + thinking_content = None diff --git a/service/tests/factories/session.py b/service/tests/factories/session.py new file mode 100644 index 00000000..08b5f30d --- /dev/null +++ b/service/tests/factories/session.py @@ -0,0 +1,21 @@ +from polyfactory.factories.pydantic_factory import ModelFactory + +from app.models.sessions import Session, SessionCreate + + +class SessionFactory(ModelFactory[Session]): + """Factory for Session model.""" + + __model__ = Session + + +class SessionCreateFactory(ModelFactory[SessionCreate]): + """Factory for SessionCreate schema.""" + + __model__ = SessionCreate + + is_active = True + agent_id = None + provider_id = None + model = None + google_search_enabled = False diff --git a/service/tests/factories/topic.py b/service/tests/factories/topic.py new file mode 100644 index 00000000..3a9a9769 --- /dev/null +++ b/service/tests/factories/topic.py @@ -0,0 +1,21 @@ +from uuid import uuid4 + +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory + +from app.models.topic import Topic, TopicCreate + + +class TopicFactory(ModelFactory[Topic]): + """Factory for Topic model.""" + + __model__ = Topic + + +class TopicCreateFactory(ModelFactory[TopicCreate]): + """Factory for TopicCreate schema.""" + + __model__ = TopicCreate + + is_active = True + session_id = Use(uuid4) diff --git a/service/tests/integration/test_repo/test_file_repo.py b/service/tests/integration/test_repo/test_file_repo.py new file mode 100644 index 00000000..1c145032 --- /dev/null +++ b/service/tests/integration/test_repo/test_file_repo.py @@ -0,0 +1,301 @@ +from uuid import uuid4 + +import pytest +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.repos.file import FileRepository +from tests.factories.file import FileCreateFactory + + +@pytest.mark.integration +class TestFileRepository: + """Integration tests for FileRepository.""" + + @pytest.fixture + def file_repo(self, db_session: AsyncSession) -> FileRepository: + return FileRepository(db_session) + + def _make_unique_storage_key(self, prefix: str = "test") -> str: + """Generate a unique storage key for tests.""" + return f"{prefix}/{uuid4().hex[:8]}/file.txt" + + async def test_create_and_get_file(self, file_repo: FileRepository): + """Test creating a file and retrieving it.""" + user_id = "test-user-file-create" + storage_key = self._make_unique_storage_key() + file_create = FileCreateFactory.build( + user_id=user_id, + storage_key=storage_key, + original_filename="test.txt", + content_type="text/plain", + file_size=1024, + ) + + # Create + created_file = await file_repo.create_file(file_create) + assert created_file.id is not None + assert created_file.user_id == user_id + assert created_file.original_filename == "test.txt" + assert created_file.storage_key == storage_key + + # Get by ID + fetched_file = await file_repo.get_file_by_id(created_file.id) + assert fetched_file is not None + assert fetched_file.id == created_file.id + + async def test_get_file_by_storage_key(self, file_repo: FileRepository): + """Test retrieving file by storage key.""" + user_id = "test-user-file-key" + storage_key = self._make_unique_storage_key("key-test") + file_create = FileCreateFactory.build( + user_id=user_id, + storage_key=storage_key, + ) + + await file_repo.create_file(file_create) + + fetched = await file_repo.get_file_by_storage_key(storage_key) + assert fetched is not None + assert fetched.storage_key == storage_key + + # Non-existent key + not_found = await file_repo.get_file_by_storage_key("non/existent/key") + assert not_found is None + + async def test_get_files_by_user(self, file_repo: FileRepository): + """Test listing files for a user.""" + user_id = "test-user-file-list" + + # Create 3 files + for i in range(3): + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key(f"list-{i}"), + ) + ) + + # Create file for another user + await file_repo.create_file( + FileCreateFactory.build( + user_id="other-user", + storage_key=self._make_unique_storage_key("other"), + ) + ) + + files = await file_repo.get_files_by_user(user_id) + assert len(files) == 3 + for f in files: + assert f.user_id == user_id + + async def test_get_files_by_user_with_scope_filter(self, file_repo: FileRepository): + """Test filtering files by scope.""" + user_id = "test-user-file-scope" + + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("public"), + scope="public", + ) + ) + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("private"), + scope="private", + ) + ) + + public_files = await file_repo.get_files_by_user(user_id, scope="public") + assert len(public_files) == 1 + assert public_files[0].scope == "public" + + async def test_update_file(self, file_repo: FileRepository): + """Test updating a file.""" + user_id = "test-user-file-update" + created = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("update"), + original_filename="old_name.txt", + ) + ) + + from app.models.file import FileUpdate + + update_data = FileUpdate(original_filename="new_name.txt") + updated = await file_repo.update_file(created.id, update_data) + + assert updated is not None + assert updated.original_filename == "new_name.txt" + + # Verify persistence + fetched = await file_repo.get_file_by_id(created.id) + assert fetched is not None + assert fetched.original_filename == "new_name.txt" + + async def test_soft_delete_and_restore_file(self, file_repo: FileRepository): + """Test soft delete and restore functionality.""" + user_id = "test-user-file-soft-delete" + created = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("soft-del"), + ) + ) + + # Soft delete + success = await file_repo.soft_delete_file(created.id) + assert success is True + + fetched = await file_repo.get_file_by_id(created.id) + assert fetched is not None + assert fetched.is_deleted is True + assert fetched.deleted_at is not None + + # Restore + restored = await file_repo.restore_file(created.id) + assert restored is True + + fetched = await file_repo.get_file_by_id(created.id) + assert fetched is not None + assert fetched.is_deleted is False + assert fetched.deleted_at is None + + async def test_hard_delete_file(self, file_repo: FileRepository): + """Test permanent file deletion.""" + user_id = "test-user-file-hard-delete" + created = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("hard-del"), + ) + ) + + success = await file_repo.hard_delete_file(created.id) + assert success is True + + fetched = await file_repo.get_file_by_id(created.id) + assert fetched is None + + async def test_get_files_by_hash(self, file_repo: FileRepository): + """Test deduplication lookup by file hash.""" + user_id = "test-user-file-hash" + file_hash = "abc123def456" + + # Create 2 files with same hash + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("hash1"), + file_hash=file_hash, + ) + ) + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("hash2"), + file_hash=file_hash, + ) + ) + + files = await file_repo.get_files_by_hash(file_hash, user_id) + assert len(files) == 2 + + async def test_get_total_size_by_user(self, file_repo: FileRepository): + """Test calculating total file size for a user.""" + user_id = "test-user-file-size" + + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("size1"), + file_size=1000, + ) + ) + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("size2"), + file_size=2000, + ) + ) + + total_size = await file_repo.get_total_size_by_user(user_id) + assert total_size == 3000 + + async def test_get_file_count_by_user(self, file_repo: FileRepository): + """Test counting files for a user.""" + user_id = "test-user-file-count" + + for i in range(4): + await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key(f"count-{i}"), + ) + ) + + count = await file_repo.get_file_count_by_user(user_id) + assert count == 4 + + async def test_bulk_soft_delete_by_user(self, file_repo: FileRepository): + """Test bulk soft delete with user validation.""" + user_id = "test-user-file-bulk-del" + + file1 = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("bulk1"), + ) + ) + file2 = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("bulk2"), + ) + ) + file3 = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("bulk3"), + ) + ) + + count = await file_repo.bulk_soft_delete_by_user(user_id, [file1.id, file2.id]) + assert count == 2 + + # file3 should not be deleted + fetched3 = await file_repo.get_file_by_id(file3.id) + assert fetched3 is not None + assert fetched3.is_deleted is False + + async def test_update_files_message_id(self, file_repo: FileRepository): + """Test linking files to a message.""" + user_id = "test-user-file-msg-link" + message_id = uuid4() + + file1 = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("msg1"), + status="pending", + ) + ) + file2 = await file_repo.create_file( + FileCreateFactory.build( + user_id=user_id, + storage_key=self._make_unique_storage_key("msg2"), + status="pending", + ) + ) + + count = await file_repo.update_files_message_id([file1.id, file2.id], message_id, user_id) + assert count == 2 + + # Verify files are linked and confirmed + fetched1 = await file_repo.get_file_by_id(file1.id) + assert fetched1 is not None + assert fetched1.message_id == message_id + assert fetched1.status == "confirmed" diff --git a/service/tests/integration/test_repo/test_message_repo.py b/service/tests/integration/test_repo/test_message_repo.py new file mode 100644 index 00000000..ee08161e --- /dev/null +++ b/service/tests/integration/test_repo/test_message_repo.py @@ -0,0 +1,146 @@ +import pytest +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.topic import Topic +from app.repos.message import MessageRepository +from app.repos.session import SessionRepository +from app.repos.topic import TopicRepository +from tests.factories.message import MessageCreateFactory +from tests.factories.session import SessionCreateFactory +from tests.factories.topic import TopicCreateFactory + + +@pytest.mark.integration +class TestMessageRepository: + """Integration tests for MessageRepository.""" + + @pytest.fixture + def message_repo(self, db_session: AsyncSession) -> MessageRepository: + return MessageRepository(db_session) + + @pytest.fixture + def session_repo(self, db_session: AsyncSession) -> SessionRepository: + return SessionRepository(db_session) + + @pytest.fixture + def topic_repo(self, db_session: AsyncSession) -> TopicRepository: + return TopicRepository(db_session) + + @pytest.fixture + async def test_topic(self, session_repo: SessionRepository, topic_repo: TopicRepository): + """Create a test session and topic for message tests.""" + session = await session_repo.create_session(SessionCreateFactory.build(), "test-user-message") + topic = await topic_repo.create_topic(TopicCreateFactory.build(session_id=session.id)) + return topic + + async def test_create_and_get_message(self, message_repo: MessageRepository, test_topic: Topic): + """Test creating a message and retrieving it.""" + message_create = MessageCreateFactory.build(topic_id=test_topic.id, role="user", content="Hello!") + + # Create + created_message = await message_repo.create_message(message_create) + assert created_message.id is not None + assert created_message.content == "Hello!" + assert created_message.role == "user" + assert created_message.topic_id == test_topic.id + + # Get by ID + fetched_message = await message_repo.get_message_by_id(created_message.id) + assert fetched_message is not None + assert fetched_message.id == created_message.id + + async def test_get_messages_by_topic(self, message_repo: MessageRepository, test_topic: Topic): + """Test listing messages for a topic.""" + # Create 3 messages + await message_repo.create_message( + MessageCreateFactory.build(topic_id=test_topic.id, role="user", content="First") + ) + await message_repo.create_message( + MessageCreateFactory.build(topic_id=test_topic.id, role="assistant", content="Second") + ) + await message_repo.create_message( + MessageCreateFactory.build(topic_id=test_topic.id, role="user", content="Third") + ) + + messages = await message_repo.get_messages_by_topic(test_topic.id) + assert len(messages) == 3 + for msg in messages: + assert msg.topic_id == test_topic.id + + async def test_get_messages_by_topic_ordered(self, message_repo: MessageRepository, test_topic: Topic): + """Test messages are ordered by created_at ascending.""" + msg1 = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id, content="First")) + msg2 = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id, content="Second")) + + messages = await message_repo.get_messages_by_topic(test_topic.id, order_by_created=True) + assert len(messages) == 2 + assert messages[0].id == msg1.id + assert messages[1].id == msg2.id + + async def test_get_messages_by_topic_with_limit(self, message_repo: MessageRepository, test_topic: Topic): + """Test limiting number of messages returned.""" + for i in range(5): + await message_repo.create_message( + MessageCreateFactory.build(topic_id=test_topic.id, content=f"Message {i}") + ) + + messages = await message_repo.get_messages_by_topic(test_topic.id, limit=3) + assert len(messages) == 3 + + async def test_delete_message(self, message_repo: MessageRepository, test_topic: Topic): + """Test deleting a single message.""" + created = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id)) + + # Delete without cascade (no files to delete) + success = await message_repo.delete_message(created.id, cascade_files=False) + assert success is True + + fetched = await message_repo.get_message_by_id(created.id) + assert fetched is None + + async def test_delete_messages_by_topic(self, message_repo: MessageRepository, test_topic: Topic): + """Test deleting all messages for a topic.""" + # Create 3 messages + for _ in range(3): + await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id)) + + count = await message_repo.delete_messages_by_topic(test_topic.id, cascade_files=False) + assert count == 3 + + messages = await message_repo.get_messages_by_topic(test_topic.id) + assert len(messages) == 0 + + async def test_bulk_delete_messages(self, message_repo: MessageRepository, test_topic: Topic): + """Test deleting multiple messages by ID.""" + msg1 = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id)) + msg2 = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id)) + msg3 = await message_repo.create_message(MessageCreateFactory.build(topic_id=test_topic.id)) + + count = await message_repo.bulk_delete_messages([msg1.id, msg2.id], cascade_files=False) + assert count == 2 + + # msg3 should still exist + assert await message_repo.get_message_by_id(msg3.id) is not None + assert await message_repo.get_message_by_id(msg1.id) is None + + @pytest.mark.parametrize("role", ["user", "assistant", "system", "tool"]) + async def test_create_message_with_different_roles( + self, message_repo: MessageRepository, test_topic: Topic, role: str + ): + """Test creating messages with different roles.""" + created = await message_repo.create_message( + MessageCreateFactory.build(topic_id=test_topic.id, role=role, content="Test") + ) + assert created.role == role + + async def test_create_message_with_thinking_content(self, message_repo: MessageRepository, test_topic: Topic): + """Test creating a message with thinking content (AI reasoning).""" + created = await message_repo.create_message( + MessageCreateFactory.build( + topic_id=test_topic.id, + role="assistant", + content="Final answer", + thinking_content="Let me think about this...", + ) + ) + assert created.thinking_content == "Let me think about this..." diff --git a/service/tests/integration/test_repo/test_session_repo.py b/service/tests/integration/test_repo/test_session_repo.py new file mode 100644 index 00000000..eb5769a3 --- /dev/null +++ b/service/tests/integration/test_repo/test_session_repo.py @@ -0,0 +1,130 @@ +import pytest +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.repos.session import SessionRepository +from tests.factories.session import SessionCreateFactory + + +@pytest.mark.integration +class TestSessionRepository: + """Integration tests for SessionRepository.""" + + @pytest.fixture + def session_repo(self, db_session: AsyncSession) -> SessionRepository: + return SessionRepository(db_session) + + async def test_create_and_get_session(self, session_repo: SessionRepository): + """Test creating a session and retrieving it.""" + user_id = "test-user-session-create" + session_create = SessionCreateFactory.build() + + # Create + created_session = await session_repo.create_session(session_create, user_id) + assert created_session.id is not None + assert created_session.name == session_create.name + assert created_session.user_id == user_id + + # Get by ID + fetched_session = await session_repo.get_session_by_id(created_session.id) + assert fetched_session is not None + assert fetched_session.id == created_session.id + assert fetched_session.name == created_session.name + + async def test_get_sessions_by_user(self, session_repo: SessionRepository): + """Test listing sessions for a user.""" + user_id = "test-user-session-list" + + # Create 2 sessions for the user + await session_repo.create_session(SessionCreateFactory.build(), user_id) + await session_repo.create_session(SessionCreateFactory.build(), user_id) + + # Create session for another user + await session_repo.create_session(SessionCreateFactory.build(), "other-user") + + sessions = await session_repo.get_sessions_by_user(user_id) + assert len(sessions) == 2 + for session in sessions: + assert session.user_id == user_id + + async def test_get_session_by_user_and_agent(self, session_repo: SessionRepository): + """Test fetching session by user and agent combination.""" + user_id = "test-user-session-agent" + session_create = SessionCreateFactory.build(agent_id=None) + + await session_repo.create_session(session_create, user_id) + + # Find session with no agent + found = await session_repo.get_session_by_user_and_agent(user_id, None) + assert found is not None + assert found.user_id == user_id + assert found.agent_id is None + + async def test_update_session(self, session_repo: SessionRepository): + """Test updating a session.""" + user_id = "test-user-session-update" + created = await session_repo.create_session(SessionCreateFactory.build(), user_id) + + from app.models.sessions import SessionUpdate + + update_data = SessionUpdate(name="Updated Session Name", is_active=False) + updated = await session_repo.update_session(created.id, update_data) + + assert updated is not None + assert updated.name == "Updated Session Name" + assert updated.is_active is False + + # Verify persistence + fetched = await session_repo.get_session_by_id(created.id) + assert fetched is not None + assert fetched.name == "Updated Session Name" + + async def test_delete_session(self, session_repo: SessionRepository): + """Test deleting a session.""" + user_id = "test-user-session-delete" + created = await session_repo.create_session(SessionCreateFactory.build(), user_id) + + success = await session_repo.delete_session(created.id) + assert success is True + + fetched = await session_repo.get_session_by_id(created.id) + assert fetched is None + + async def test_delete_session_not_found(self, session_repo: SessionRepository): + """Test deleting a non-existent session.""" + from uuid import uuid4 + + success = await session_repo.delete_session(uuid4()) + assert success is False + + async def test_get_sessions_ordered_by_activity(self, session_repo: SessionRepository, db_session: AsyncSession): + """Test fetching sessions ordered by recent topic activity.""" + import asyncio + + from app.repos.topic import TopicRepository + from tests.factories.topic import TopicCreateFactory + + user_id = "test-user-session-ordered" + topic_repo = TopicRepository(db_session) + + # Create 3 sessions + session1 = await session_repo.create_session(SessionCreateFactory.build(name="Session 1"), user_id) + session2 = await session_repo.create_session(SessionCreateFactory.build(name="Session 2"), user_id) + session3 = await session_repo.create_session(SessionCreateFactory.build(name="Session 3"), user_id) + + topic1 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=session1.id, name="Topic 1")) + await asyncio.sleep(0.01) + + _topic2 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=session2.id, name="Topic 2")) + await asyncio.sleep(0.01) + + _topic3 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=session3.id, name="Topic 3")) + + await asyncio.sleep(0.01) + await topic_repo.update_topic_timestamp(topic1.id) + + sessions = await session_repo.get_sessions_by_user_ordered_by_activity(user_id) + + assert len(sessions) == 3 + assert sessions[0].id == session1.id + assert sessions[1].id == session3.id + assert sessions[2].id == session2.id diff --git a/service/tests/integration/test_repo/test_topic_repo.py b/service/tests/integration/test_repo/test_topic_repo.py new file mode 100644 index 00000000..5ab316af --- /dev/null +++ b/service/tests/integration/test_repo/test_topic_repo.py @@ -0,0 +1,136 @@ +import pytest +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.sessions import Session +from app.repos.session import SessionRepository +from app.repos.topic import TopicRepository +from tests.factories.session import SessionCreateFactory +from tests.factories.topic import TopicCreateFactory + + +@pytest.mark.integration +class TestTopicRepository: + """Integration tests for TopicRepository.""" + + @pytest.fixture + def topic_repo(self, db_session: AsyncSession) -> TopicRepository: + return TopicRepository(db_session) + + @pytest.fixture + def session_repo(self, db_session: AsyncSession) -> SessionRepository: + return SessionRepository(db_session) + + @pytest.fixture + async def test_session(self, session_repo: SessionRepository): + """Create a test session for topic tests.""" + return await session_repo.create_session(SessionCreateFactory.build(), "test-user-topic") + + async def test_create_and_get_topic(self, topic_repo: TopicRepository, test_session: Session): + """Test creating a topic and retrieving it.""" + topic_create = TopicCreateFactory.build(session_id=test_session.id) + + # Create + created_topic = await topic_repo.create_topic(topic_create) + assert created_topic.id is not None + assert created_topic.name == topic_create.name + assert created_topic.session_id == test_session.id + + # Get by ID + fetched_topic = await topic_repo.get_topic_by_id(created_topic.id) + assert fetched_topic is not None + assert fetched_topic.id == created_topic.id + + async def test_get_topics_by_session(self, topic_repo: TopicRepository, test_session: Session): + """Test listing topics for a session.""" + # Create 2 topics + await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + + topics = await topic_repo.get_topics_by_session(test_session.id) + assert len(topics) == 2 + for topic in topics: + assert topic.session_id == test_session.id + + async def test_get_topics_by_session_ordered(self, topic_repo: TopicRepository, test_session: Session): + """Test listing topics ordered by updated_at.""" + await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id, name="Topic 1")) + await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id, name="Topic 2")) + + topics = await topic_repo.get_topics_by_session(test_session.id, order_by_updated=True) + assert len(topics) == 2 + # Most recently updated should be first (descending order) + assert topics[0].name == "Topic 2" + + async def test_update_topic(self, topic_repo: TopicRepository, test_session: Session): + """Test updating a topic.""" + created = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + + from app.models.topic import TopicUpdate + + update_data = TopicUpdate(name="Updated Topic Name", is_active=False) + updated = await topic_repo.update_topic(created.id, update_data) + + assert updated is not None + assert updated.name == "Updated Topic Name" + assert updated.is_active is False + + # Verify persistence + fetched = await topic_repo.get_topic_by_id(created.id) + assert fetched is not None + assert fetched.name == "Updated Topic Name" + + async def test_delete_topic(self, topic_repo: TopicRepository, test_session: Session): + """Test deleting a topic.""" + created = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + + success = await topic_repo.delete_topic(created.id) + assert success is True + + fetched = await topic_repo.get_topic_by_id(created.id) + assert fetched is None + + async def test_delete_topic_not_found(self, topic_repo: TopicRepository): + """Test deleting a non-existent topic.""" + from uuid import uuid4 + + success = await topic_repo.delete_topic(uuid4()) + assert success is False + + async def test_bulk_delete_topics(self, topic_repo: TopicRepository, test_session: Session): + """Test deleting multiple topics at once.""" + topic1 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + topic2 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + topic3 = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + + count = await topic_repo.bulk_delete_topics([topic1.id, topic2.id]) + assert count == 2 + + # topic3 should still exist + assert await topic_repo.get_topic_by_id(topic3.id) is not None + assert await topic_repo.get_topic_by_id(topic1.id) is None + assert await topic_repo.get_topic_by_id(topic2.id) is None + + async def test_update_topic_timestamp(self, topic_repo: TopicRepository, test_session: Session): + """Test updating a topic's timestamp.""" + created = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + original_updated_at = created.updated_at + + # Small delay to ensure timestamp difference + import asyncio + + await asyncio.sleep(0.01) + + updated = await topic_repo.update_topic_timestamp(created.id) + assert updated is not None + # Verify the updated_at was changed (comparing without timezone awareness) + assert updated.updated_at is not None + # The timestamp should have been updated (either same or later) + assert updated.updated_at.replace(tzinfo=None) >= original_updated_at.replace(tzinfo=None) + + async def test_get_topic_with_details(self, topic_repo: TopicRepository, test_session: Session): + """Test get_topic_with_details (alias for get_topic_by_id in no-FK architecture).""" + created = await topic_repo.create_topic(TopicCreateFactory.build(session_id=test_session.id)) + + fetched = await topic_repo.get_topic_with_details(created.id) + assert fetched is not None + assert fetched.id == created.id diff --git a/service/tests/unit/test_core/test_thinking_events.py b/service/tests/unit/test_core/test_thinking_events.py new file mode 100644 index 00000000..a4b5bb51 --- /dev/null +++ b/service/tests/unit/test_core/test_thinking_events.py @@ -0,0 +1,146 @@ +""" +Unit tests for thinking event handling. + +Tests the ThinkingEventHandler class and thinking content extraction +from various provider formats (Anthropic, DeepSeek, etc.). +""" + +from app.core.chat.stream_handlers import ThinkingEventHandler +from app.schemas.chat_events import ChatEventType + + +class MockMessageChunk: + """Mock message chunk for testing thinking content extraction.""" + + def __init__( + self, + content: str | list = "", + additional_kwargs: dict | None = None, + response_metadata: dict | None = None, + ): + self.content = content + self.additional_kwargs = additional_kwargs or {} + self.response_metadata = response_metadata or {} + + +class TestThinkingEventHandler: + """Tests for ThinkingEventHandler event creation methods.""" + + def test_create_thinking_start_event(self) -> None: + """Verify thinking_start event has correct structure.""" + event = ThinkingEventHandler.create_thinking_start("stream_123") + + assert event["type"] == ChatEventType.THINKING_START + assert event["data"]["id"] == "stream_123" + + def test_create_thinking_chunk_event(self) -> None: + """Verify thinking_chunk event has correct structure.""" + event = ThinkingEventHandler.create_thinking_chunk("stream_123", "Let me think...") + + assert event["type"] == ChatEventType.THINKING_CHUNK + assert event["data"]["id"] == "stream_123" + assert event["data"]["content"] == "Let me think..." + + def test_create_thinking_end_event(self) -> None: + """Verify thinking_end event has correct structure.""" + event = ThinkingEventHandler.create_thinking_end("stream_123") + + assert event["type"] == ChatEventType.THINKING_END + assert event["data"]["id"] == "stream_123" + + +class TestExtractThinkingContent: + """Tests for ThinkingEventHandler.extract_thinking_content method.""" + + def test_extract_deepseek_reasoning_content(self) -> None: + """Extract thinking from DeepSeek R1 style additional_kwargs.reasoning_content.""" + chunk = MockMessageChunk( + content="", + additional_kwargs={"reasoning_content": "Step 1: Analyze the problem..."}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result == "Step 1: Analyze the problem..." + + def test_extract_anthropic_thinking_block(self) -> None: + """Extract thinking from Anthropic Claude style content blocks.""" + chunk = MockMessageChunk( + content=[ + {"type": "thinking", "thinking": "Let me reason through this..."}, + {"type": "text", "text": "The answer is 42."}, + ] + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result == "Let me reason through this..." + + def test_extract_thinking_from_response_metadata(self) -> None: + """Extract thinking from response_metadata.thinking.""" + chunk = MockMessageChunk( + content="", + response_metadata={"thinking": "I need to consider all factors..."}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result == "I need to consider all factors..." + + def test_extract_reasoning_content_from_response_metadata(self) -> None: + """Extract from response_metadata.reasoning_content (alternative key).""" + chunk = MockMessageChunk( + content="", + response_metadata={"reasoning_content": "Analyzing the data..."}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result == "Analyzing the data..." + + def test_no_thinking_content_returns_none(self) -> None: + """Return None when no thinking content is present.""" + chunk = MockMessageChunk( + content="Hello, world!", + additional_kwargs={}, + response_metadata={}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result is None + + def test_empty_thinking_returns_none(self) -> None: + """Return None for empty reasoning_content string.""" + chunk = MockMessageChunk( + content="", + additional_kwargs={"reasoning_content": ""}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result is None + + def test_deepseek_takes_priority_over_response_metadata(self) -> None: + """DeepSeek additional_kwargs should be checked first.""" + chunk = MockMessageChunk( + content="", + additional_kwargs={"reasoning_content": "From additional_kwargs"}, + response_metadata={"thinking": "From response_metadata"}, + ) + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result == "From additional_kwargs" + + def test_handles_missing_attributes_gracefully(self) -> None: + """Handle objects without expected attributes.""" + + class MinimalChunk: + pass + + chunk = MinimalChunk() + + result = ThinkingEventHandler.extract_thinking_content(chunk) + + assert result is None diff --git a/web/package.json b/web/package.json index f8fdd58b..8702f7a8 100644 --- a/web/package.json +++ b/web/package.json @@ -152,14 +152,12 @@ "react-dnd": "^16.0.1", "react-dnd-html5-backend": "^16.0.1", "react-i18next": "^16.5.0", - "react-image-crop": "^11.0.10", "react-lite-youtube-embed": "^3.3.3", "react-markdown": "^10.1.0", "react-player": "3.3.1", "react-textarea-autosize": "^8.5.9", "react-tweet": "^3.3.0", "react-use-measure": "^2.1.7", - "react-use-websocket": "^4.13.0", "rehype-katex": "^7.0.1", "remark-gfm": "^4.0.1", "remark-math": "^6.0.0", diff --git a/web/src/app/marketplace/AgentMarketplaceDetail.tsx b/web/src/app/marketplace/AgentMarketplaceDetail.tsx index 6521a2e2..08f83baa 100644 --- a/web/src/app/marketplace/AgentMarketplaceDetail.tsx +++ b/web/src/app/marketplace/AgentMarketplaceDetail.tsx @@ -6,6 +6,7 @@ import { useMarketplaceRequirements, useToggleLike, } from "@/hooks/useMarketplace"; +import Markdown from "@/lib/Markdown"; import { useIsMarketplaceOwner } from "@/utils/marketplace"; import { ArrowLeftIcon, @@ -20,8 +21,7 @@ import { } from "@heroicons/react/24/outline"; import { HeartIcon as HeartSolidIcon } from "@heroicons/react/24/solid"; import { useState } from "react"; -import ReactMarkdown from "react-markdown"; -import remarkGfm from "remark-gfm"; +import { useTranslation } from "react-i18next"; interface AgentMarketplaceDetailProps { marketplaceId: string; @@ -39,6 +39,7 @@ export default function AgentMarketplaceDetail({ onBack, onManage, }: AgentMarketplaceDetailProps) { + const { t } = useTranslation(); const [showForkModal, setShowForkModal] = useState(false); const [activeTab, setActiveTab] = useState< "readme" | "config" | "requirements" @@ -89,7 +90,7 @@ export default function AgentMarketplaceDetail({ {/* Loading Icon */}

- Loading agent details... + {t("marketplace.detail.loading")}

@@ -103,9 +104,7 @@ export default function AgentMarketplaceDetail({
-
- Failed to load agent details. Please try again. -
+
{t("marketplace.detail.error")}
@@ -113,7 +112,7 @@ export default function AgentMarketplaceDetail({ } return ( -
+
{/* Header with back button */}
@@ -122,14 +121,14 @@ export default function AgentMarketplaceDetail({ className="group mb-4 flex items-center gap-2 rounded-lg border border-neutral-200 bg-white px-4 py-2 text-sm font-medium text-neutral-700 shadow-sm transition-all hover:border-neutral-300 hover:shadow dark:border-neutral-800 dark:bg-neutral-900 dark:text-neutral-300 dark:hover:border-neutral-700" > - Back to Marketplace + {t("marketplace.detail.back")}
{/* Main Content */}
{/* Left Column - Agent Info */} -
+
{/* Agent Header */}
{/* Gradient background */} @@ -153,13 +152,14 @@ export default function AgentMarketplaceDetail({ {listing.name}

- Published by{" "} + {t("marketplace.detail.publishedBy")}{" "} {listing.user_id.split("@")[0] || listing.user_id}

- {listing.description || "No description provided"} + {listing.description || + t("marketplace.detail.noDescription")}

{listing.tags.map((tag, index) => ( @@ -185,7 +185,9 @@ export default function AgentMarketplaceDetail({
{listing.likes_count}
-
Likes
+
+ {t("marketplace.detail.stats.likes")} +
@@ -208,7 +210,9 @@ export default function AgentMarketplaceDetail({
{listing.forks_count}
-
Forks
+
+ {t("marketplace.detail.stats.forks")} +
@@ -219,7 +223,9 @@ export default function AgentMarketplaceDetail({
{listing.views_count}
-
Views
+
+ {t("marketplace.detail.stats.views")} +
@@ -229,39 +235,39 @@ export default function AgentMarketplaceDetail({ {/* Tabbed Content Section */}
{/* Tab Bar */} -
+
@@ -269,21 +275,22 @@ export default function AgentMarketplaceDetail({
{/* README Tab */} {activeTab === "readme" && ( -
+
{listing.readme ? ( - - {listing.readme} - + ) : (
-

No README provided for this agent.

+

{t("marketplace.detail.readme.empty")}

{isOwner && onManage && ( )}
@@ -309,7 +316,7 @@ export default function AgentMarketplaceDetail({ {listing.snapshot.configuration.model && (

- Model + {t("marketplace.detail.config.model")}

{listing.snapshot.configuration.model} @@ -321,7 +328,7 @@ export default function AgentMarketplaceDetail({ {listing.snapshot.configuration.prompt && (

- System Prompt + {t("marketplace.detail.config.systemPrompt")}

@@ -336,8 +343,10 @@ export default function AgentMarketplaceDetail({
                           listing.snapshot.mcp_server_configs.length > 0 && (
                             

- MCP Servers ( - {listing.snapshot.mcp_server_configs.length}) + {t("marketplace.detail.config.mcpServers", { + count: + listing.snapshot.mcp_server_configs.length, + })}

{listing.snapshot.mcp_server_configs.map( @@ -357,7 +366,7 @@ export default function AgentMarketplaceDetail({ ) : (
-

No configuration available.

+

{t("marketplace.detail.config.empty")}

)}
@@ -374,9 +383,14 @@ export default function AgentMarketplaceDetail({
- LLM Provider Required: You'll - need to configure an AI provider (OpenAI, - Anthropic, etc.) to use this agent. + + {t( + "marketplace.detail.requirements.provider.title", + )} + {" "} + {t( + "marketplace.detail.requirements.provider.description", + )}
@@ -386,7 +400,9 @@ export default function AgentMarketplaceDetail({ {requirements.mcp_servers.length > 0 && (

- MCP Servers ({requirements.mcp_servers.length}) + {t("marketplace.detail.requirements.mcpServers", { + count: requirements.mcp_servers.length, + })}

{requirements.mcp_servers.map((mcp, index) => ( @@ -400,7 +416,10 @@ export default function AgentMarketplaceDetail({ {mcp.name} - ✅ Auto-configured + ✅{" "} + {t( + "marketplace.detail.requirements.autoConfigured", + )}
{mcp.description && ( @@ -421,11 +440,18 @@ export default function AgentMarketplaceDetail({
- Knowledge Base: The original - agent uses{" "} - {requirements.knowledge_base.file_count} files. - These files will be copied to your workspace - when you fork this agent. + + {t( + "marketplace.detail.requirements.knowledgeBase.title", + )} + {" "} + {t( + "marketplace.detail.requirements.knowledgeBase.description", + { + count: + requirements.knowledge_base.file_count, + }, + )}
@@ -439,8 +465,7 @@ export default function AgentMarketplaceDetail({
- No special requirements! This agent is ready - to use after forking. + {t("marketplace.detail.requirements.none")}
@@ -449,7 +474,7 @@ export default function AgentMarketplaceDetail({ ) : (
-

Loading requirements...

+

{t("marketplace.detail.requirements.loading")}

)}
@@ -459,12 +484,12 @@ export default function AgentMarketplaceDetail({
{/* Right Column - Actions */} -
+
{/* Action Card */} -
+

- Actions + {t("marketplace.detail.actions.title")}

@@ -518,7 +545,7 @@ export default function AgentMarketplaceDetail({ >
- Manage Agent + {t("marketplace.detail.actions.manage")}
)} @@ -529,7 +556,7 @@ export default function AgentMarketplaceDetail({ {/* Author Info */}

- Published By + {t("marketplace.detail.meta.publishedBy")}

{listing.user_id} @@ -543,14 +570,18 @@ export default function AgentMarketplaceDetail({

{listing.first_published_at && (
- First Published:{" "} + + {t("marketplace.detail.meta.firstPublished")} + {" "} {new Date( listing.first_published_at, ).toLocaleDateString()}
)}
- Last Updated:{" "} + + {t("marketplace.detail.meta.lastUpdated")} + {" "} {new Date(listing.updated_at).toLocaleDateString()}
@@ -565,11 +596,10 @@ export default function AgentMarketplaceDetail({

- About Forking + {t("marketplace.detail.aboutForking.title")}

- Forking creates your own independent copy. Changes won't - affect the original agent. + {t("marketplace.detail.aboutForking.description")}

diff --git a/web/src/components/features/CheckInCalendar.tsx b/web/src/components/features/CheckInCalendar.tsx index 7b6de1d8..7cd0b104 100644 --- a/web/src/components/features/CheckInCalendar.tsx +++ b/web/src/components/features/CheckInCalendar.tsx @@ -552,7 +552,7 @@ export function CheckInCalendar({ onCheckInSuccess }: CheckInCalendarProps) {
{checkInRecord && (
-
+
@@ -577,7 +577,7 @@ export function CheckInCalendar({ onCheckInSuccess }: CheckInCalendarProps) { {consumption && (
-
+
使用统计 diff --git a/web/src/components/layouts/components/ChatBubble.tsx b/web/src/components/layouts/components/ChatBubble.tsx index 9e6c20c3..c102215a 100644 --- a/web/src/components/layouts/components/ChatBubble.tsx +++ b/web/src/components/layouts/components/ChatBubble.tsx @@ -9,6 +9,7 @@ import { useEffect, useMemo, useRef } from "react"; import LoadingMessage from "./LoadingMessage"; import MessageAttachments from "./MessageAttachments"; import { SearchCitations } from "./SearchCitations"; +import ThinkingBubble from "./ThinkingBubble"; import ToolCallCard from "./ToolCallCard"; interface ChatBubbleProps { @@ -30,6 +31,8 @@ function ChatBubble({ message }: ChatBubbleProps) { toolCalls, attachments, citations, + isThinking, + thinkingContent, } = message; // 流式消息打字效果 @@ -204,6 +207,14 @@ function ChatBubble({ message }: ChatBubbleProps) { : "text-sm text-neutral-700 dark:text-neutral-300" }`} > + {/* Thinking content - shown before main response for assistant messages */} + {!isUserMessage && thinkingContent && ( + + )} + {isLoading ? ( ) : ( diff --git a/web/src/components/layouts/components/MessageAttachments.tsx b/web/src/components/layouts/components/MessageAttachments.tsx index 3b006c3a..18878bc5 100644 --- a/web/src/components/layouts/components/MessageAttachments.tsx +++ b/web/src/components/layouts/components/MessageAttachments.tsx @@ -494,7 +494,7 @@ export default function MessageAttachments({ onClick={(e) => e.stopPropagation()} > {/* Header */} -
+
(null); + + // Split content into lines for display + const lines = useMemo(() => { + return content.split("\n").filter((line) => line.trim()); + }, [content]); + + // Get last 5 lines for active thinking display + const visibleLines = useMemo(() => { + return lines.slice(-5); + }, [lines]); + + // Auto-scroll to bottom during active thinking + useEffect(() => { + if (isThinking && scrollRef.current) { + scrollRef.current.scrollTop = scrollRef.current.scrollHeight; + } + }, [content, isThinking]); + + // Don't render if no content + if (!content) { + return null; + } + + return ( +
+ + {isThinking ? ( + // Active thinking state - animated scrolling view + + {/* Subtle shimmer effect */} + + + {/* Header with animated icon */} +
+ + + + + {t("app.chat.thinking.label")} + + + + +
+ + {/* Scrolling content - max 5 lines visible */} +
+ {visibleLines.map((line, index) => ( + + {line} + + ))} + {/* Blinking cursor */} + +
+ + {/* Fade overlay at top when more content */} + {lines.length > 5 && ( +
+ )} + + ) : ( + // Collapsed state - expandable accordion + + {/* Collapsible header */} + + + {/* Expanded content */} + + {isExpanded && ( + +
+
+ +
+
+
+ )} +
+
+ )} + +
+ ); +} diff --git a/web/src/components/ui/3d-pin.tsx b/web/src/components/ui/3d-pin.tsx index 9de190f1..fe3736bc 100644 --- a/web/src/components/ui/3d-pin.tsx +++ b/web/src/components/ui/3d-pin.tsx @@ -164,8 +164,8 @@ export const PinPerspective = ({
<> - - + + diff --git a/web/src/i18n/locales/en/translation.json b/web/src/i18n/locales/en/translation.json index 3396b8ee..276abf1b 100644 --- a/web/src/i18n/locales/en/translation.json +++ b/web/src/i18n/locales/en/translation.json @@ -8,7 +8,12 @@ "chat": { "assistantsTitle": "Assistants", "chooseAgentHint": "Choose an agent to start", - "chatLabel": "Chat" + "chatLabel": "Chat", + "thinking": { + "label": "Thinking...", + "showThinking": "Show thinking", + "hideThinking": "Hide thinking" + } } }, "common": { @@ -284,6 +289,63 @@ "by": "by {{author}}", "noDescription": "No description provided", "tagsMore": "+{{count}} more" + }, + "detail": { + "loading": "Loading agent details...", + "error": "Failed to load agent details. Please try again.", + "back": "Back to Marketplace", + "publishedBy": "Published by", + "noDescription": "No description provided", + "stats": { + "likes": "Likes", + "forks": "Forks", + "views": "Views" + }, + "tabs": { + "readme": "README", + "config": "Configuration", + "requirements": "Requirements" + }, + "readme": { + "empty": "No README provided for this agent.", + "manage": "Manage to add a README" + }, + "config": { + "model": "Model", + "systemPrompt": "System Prompt", + "mcpServers": "MCP Servers ({{count}})", + "empty": "No configuration available." + }, + "requirements": { + "provider": { + "title": "LLM Provider Required:", + "description": "You'll need to configure an AI provider (OpenAI, Anthropic, etc.) to use this agent." + }, + "mcpServers": "MCP Servers ({{count}})", + "autoConfigured": "Auto-configured", + "knowledgeBase": { + "title": "Knowledge Base:", + "description": "The original agent uses {{count}} files. These files will be copied to your workspace when you fork this agent." + }, + "none": "No special requirements! This agent is ready to use after forking.", + "loading": "Loading requirements..." + }, + "actions": { + "title": "Actions", + "fork": "Fork This Agent", + "liked": "Liked", + "like": "Like This Agent", + "manage": "Manage Agent" + }, + "meta": { + "publishedBy": "Published By", + "firstPublished": "First Published:", + "lastUpdated": "Last Updated:" + }, + "aboutForking": { + "title": "About Forking", + "description": "Forking creates your own independent copy. Changes won't affect the original agent." + } } }, "knowledge": { diff --git a/web/src/i18n/locales/zh/translation.json b/web/src/i18n/locales/zh/translation.json index 1cf6c44c..3c9317af 100644 --- a/web/src/i18n/locales/zh/translation.json +++ b/web/src/i18n/locales/zh/translation.json @@ -8,7 +8,12 @@ "chat": { "assistantsTitle": "助手", "chooseAgentHint": "选择一个助手开始", - "chatLabel": "聊天" + "chatLabel": "聊天", + "thinking": { + "label": "思考中...", + "showThinking": "显示思考过程", + "hideThinking": "隐藏思考过程" + } } }, "common": { @@ -263,205 +268,63 @@ "by": "来自 {{author}}", "noDescription": "暂无描述", "tagsMore": "+{{count}} 更多" - } - }, - "knowledge": { - "titles": { - "recents": "最近", - "allFiles": "全部文件", - "myKnowledge": "我的知识库", - "knowledgeBase": "知识库", - "trash": "回收站" - }, - "a11y": { - "navTitle": "导航菜单", - "navDescription": "在你的知识库中进行导航" - }, - "prompts": { - "folderName": "请输入文件夹名称:", - "knowledgeSetName": "请输入知识集名称:", - "knowledgeSetDescription": "请输入描述(可选):" - }, - "errors": { - "createFolderFailed": "创建文件夹失败", - "createKnowledgeSetFailed": "创建知识集失败" - }, - "upload": { - "errors": { - "fileTooLarge": "文件“{{name}}”({{fileSizeMB}}MB)超过最大大小限制 {{maxSizeMB}}MB", - "notEnoughStorage": "存储空间不足。文件大小:{{fileSizeMB}}MB,可用:{{availableMB}}MB。请先删除一些文件。", - "uploadFailed": "上传失败" - } }, - "toolbar": { - "home": "主页", - "searchFilesPlaceholder": "搜索文件...", - "searchPlaceholder": "搜索", - "listView": "列表视图", - "gridView": "网格视图", - "newFolder": "新建文件夹", - "emptyTrash": "清空回收站", - "empty": "清空", - "uploadFile": "上传文件", - "upload": "上传", - "refresh": "刷新" - }, - "sidebar": { - "groups": { - "favorites": "常用", - "media": "媒体", - "locations": "位置" - }, - "items": { - "images": "图片", - "documents": "文档" - }, - "newKnowledgeSet": "新建知识集", - "noKnowledgeSets": "暂无知识集" - }, - "status": { - "items": "{{count}} 项", - "usedOfTotal": "已用 {{used}} / {{total}}", - "available": "可用 {{available}}" - }, - "contextMenu": { - "download": "下载", - "rename": "重命名", - "moveTo": "移动到...", - "addToKnowledgeSet": "添加到知识集", - "removeFromKnowledgeSet": "从知识集中移除", - "delete": "删除" - }, - "moveModal": { - "title": "移动 \"{{name}}\"", - "home": "主页", - "noSubfolders": "没有子文件夹", - "moveHere": "移动到这里" - }, - "fileList": { - "itemTypes": { - "file": "文件", - "folder": "文件夹" - }, - "columns": { - "name": "名称", - "size": "大小", - "dateModified": "修改时间" - }, - "empty": { - "trash": "回收站为空", - "noItems": "暂无内容" - }, - "actions": { - "preview": "预览", - "download": "下载", - "restore": "还原", - "delete": "删除", - "deleteForever": "永久删除", - "moveToTrash": "移到回收站", - "deleteImmediately": "立即删除", - "deleteFailed": "删除失败", - "restoreFailed": "还原失败", - "downloadFailed": "下载失败" - }, - "deleteItem": { - "title": "删除{{itemType}}", - "message": "确定要删除这个{{itemType}}吗?" - }, - "moveToTrash": { - "message": "确定要将此文件移到回收站吗?" - }, - "deleteForever": { - "message": "确定要永久删除此文件吗?此操作无法撤销。" + "detail": { + "loading": "正在加载助手详情...", + "error": "加载助手详情失败,请重试。", + "back": "返回市场", + "publishedBy": "发布者", + "noDescription": "暂无描述", + "stats": { + "likes": "点赞", + "forks": "复刻", + "views": "浏览" }, - "emptyTrash": { - "title": "清空回收站", - "message": "确定要永久删除 {{count}} 个项目吗?此操作无法撤销。", - "confirm": "清空回收站", - "failed": "清空回收站失败,可能仍有部分项目未删除。" + "tabs": { + "readme": "说明文档", + "config": "配置信息", + "requirements": "运行要求" }, - "rename": { - "titleFile": "重命名文件", - "titleFolder": "重命名文件夹", - "placeholder": "输入新名称", - "confirm": "重命名", - "failed": "重命名失败" + "readme": { + "empty": "该助手暂无说明文档。", + "manage": "去添加说明文档" }, - "move": { - "failed": "移动失败" + "config": { + "model": "模型", + "systemPrompt": "系统提示词", + "mcpServers": "MCP 服务 ({{count}})", + "empty": "暂无配置信息。" }, - "knowledgeSet": { - "add": { - "title": "添加到知识集", - "subtitle": "选择一个知识集来添加 \"{{name}}\"" + "requirements": { + "provider": { + "title": "需要 LLM 服务商:", + "description": "你需要配置 AI 服务商(OpenAI, Anthropic 等)才能使用此助手。" }, - "none": "暂无可用知识集,请先创建一个。", - "fileCount": "{{count}} 个文件", - "added": "已成功添加到知识集", - "alreadyInSet": "该文件已在此知识集中。", - "addFailed": "添加到知识集失败", - "remove": { - "title": "从知识集中移除", - "message": "确定要将 \"{{name}}\" 从此知识集中移除吗?", - "confirm": "移除", - "failed": "从知识集中移除失败" - } - }, - "notifications": { - "successTitle": "成功", - "noticeTitle": "提示", - "errorTitle": "错误" - } - }, - "createKnowledgeSetModal": { - "title": "创建知识集", - "fields": { - "name": { - "label": "名称", - "placeholder": "请输入名称" + "mcpServers": "MCP 服务 ({{count}})", + "autoConfigured": "自动配置", + "knowledgeBase": { + "title": "知识库:", + "description": "原助手使用了 {{count}} 个文件。复刻此助手时,这些文件将被复制到你的工作区。" }, - "description": { - "label": "描述(可选)", - "placeholder": "添加简短描述" - } + "none": "无特殊要求!复刻后即可直接使用。", + "loading": "正在加载运行要求..." }, "actions": { - "create": "创建", - "creating": "创建中..." + "title": "操作", + "fork": "复刻此助手", + "liked": "已点赞", + "like": "点赞", + "manage": "管理助手" }, - "validation": { - "nameRequired": "名称不能为空" + "meta": { + "publishedBy": "发布者", + "firstPublished": "首次发布:", + "lastUpdated": "最后更新:" }, - "errors": { - "createFailed": "创建知识集失败" + "aboutForking": { + "title": "关于复刻", + "description": "复刻会创建一个独立的副本。你的修改不会影响原始助手。" } } - }, - "mcp": { - "title": "我的服务", - "subtitle": "浏览市场并管理你的服务", - "refresh": "刷新", - "addCustom": "添加自定义", - "market": { - "title": "MCP 市场", - "quickAdd": "快速添加", - "close": "关闭" - }, - "added": { - "title": "已添加服务", - "online": "在线", - "offline": "离线", - "tools": "工具", - "noDescription": "暂无描述", - "edit": "编辑服务", - "remove": "移除服务", - "test": "测试工具", - "empty": { - "title": "暂无已添加服务", - "description": "在左侧浏览市场添加服务,或创建自定义连接。", - "button": "添加自定义服务" - }, - "loading": "加载 MCP 服务中..." - } } } diff --git a/web/src/service/xyzenService.ts b/web/src/service/xyzenService.ts index e6df76c2..80dff70e 100644 --- a/web/src/service/xyzenService.ts +++ b/web/src/service/xyzenService.ts @@ -19,7 +19,10 @@ interface MessageEvent { | "tool_call_response" | "insufficient_balance" | "error" - | "topic_updated"; + | "topic_updated" + | "thinking_start" + | "thinking_chunk" + | "thinking_end"; data: | Message | { diff --git a/web/src/store/slices/chatSlice.ts b/web/src/store/slices/chatSlice.ts index 2cc388a5..57cd9f64 100644 --- a/web/src/store/slices/chatSlice.ts +++ b/web/src/store/slices/chatSlice.ts @@ -35,14 +35,22 @@ function groupToolMessagesWithAssistant(messages: Message[]): Message[] { arguments: { ...(toolCall.arguments || {}) }, }); - const cloneMessage = (message: Message): Message => ({ - ...message, - toolCalls: message.toolCalls - ? message.toolCalls.map((toolCall) => cloneToolCall(toolCall)) - : undefined, - attachments: message.attachments ? [...message.attachments] : undefined, - citations: message.citations ? [...message.citations] : undefined, - }); + const cloneMessage = (message: Message): Message => { + const backendThinkingContent = ( + message as Message & { thinking_content?: string } + ).thinking_content; + + return { + ...message, + toolCalls: message.toolCalls + ? message.toolCalls.map((toolCall) => cloneToolCall(toolCall)) + : undefined, + attachments: message.attachments ? [...message.attachments] : undefined, + citations: message.citations ? [...message.citations] : undefined, + // Map thinking_content from backend to thinkingContent for frontend + thinkingContent: backendThinkingContent ?? message.thinkingContent, + }; + }; for (const msg of messages) { if (msg.role !== "tool") { @@ -581,12 +589,14 @@ export const createChatSlice: StateCreator< } case "streaming_start": { - // Convert loading message to streaming message + // Convert loading or thinking message to streaming message channel.responding = true; + const eventData = event.data as { id: string }; + + // First check for loading message const loadingIndex = channel.messages.findIndex( (m) => m.isLoading, ); - const eventData = event.data as { id: string }; if (loadingIndex !== -1) { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { isLoading: _, ...messageWithoutLoading } = @@ -597,18 +607,33 @@ export const createChatSlice: StateCreator< isStreaming: true, content: "", }; - } else { - // No loading present (backend may skip sending "loading"). Create a streaming message now. - channel.messages.push({ - id: eventData.id, - clientId: generateClientId(), - role: "assistant" as const, - content: "", - isNewMessage: true, - created_at: new Date().toISOString(), + break; + } + + // Check for existing message with same ID (e.g., after thinking_end set isThinking=false) + const existingIndex = channel.messages.findIndex( + (m) => m.id === eventData.id, + ); + if (existingIndex !== -1) { + // Convert existing message to streaming - keep thinking content if present + channel.messages[existingIndex] = { + ...channel.messages[existingIndex], + isThinking: false, isStreaming: true, - }); + }; + break; } + + // No loading or existing message found, create a streaming message now + channel.messages.push({ + id: eventData.id, + clientId: generateClientId(), + role: "assistant" as const, + content: "", + isNewMessage: true, + created_at: new Date().toISOString(), + isStreaming: true, + }); break; } @@ -981,6 +1006,68 @@ export const createChatSlice: StateCreator< break; } + case "thinking_start": { + // Start thinking mode - find or create the assistant message + channel.responding = true; + const eventData = event.data as { id: string }; + const loadingIndex = channel.messages.findIndex( + (m) => m.isLoading, + ); + if (loadingIndex !== -1) { + // Convert loading message to thinking message + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { isLoading: _, ...messageWithoutLoading } = + channel.messages[loadingIndex]; + channel.messages[loadingIndex] = { + ...messageWithoutLoading, + id: eventData.id, + isThinking: true, + thinkingContent: "", + content: "", + }; + } else { + // No loading present, create a thinking message + channel.messages.push({ + id: eventData.id, + clientId: `thinking-${Date.now()}`, + role: "assistant" as const, + content: "", + isNewMessage: true, + created_at: new Date().toISOString(), + isThinking: true, + thinkingContent: "", + }); + } + break; + } + + case "thinking_chunk": { + // Append to thinking content + const eventData = event.data as { id: string; content: string }; + const thinkingIndex = channel.messages.findIndex( + (m) => m.id === eventData.id, + ); + if (thinkingIndex !== -1) { + const currentThinking = + channel.messages[thinkingIndex].thinkingContent ?? ""; + channel.messages[thinkingIndex].thinkingContent = + currentThinking + eventData.content; + } + break; + } + + case "thinking_end": { + // End thinking mode + const eventData = event.data as { id: string }; + const endThinkingIndex = channel.messages.findIndex( + (m) => m.id === eventData.id, + ); + if (endThinkingIndex !== -1) { + channel.messages[endThinkingIndex].isThinking = false; + } + break; + } + case "topic_updated": { const eventData = event.data as { id: string; diff --git a/web/src/store/types.ts b/web/src/store/types.ts index 8effeb8c..ae9f5be8 100644 --- a/web/src/store/types.ts +++ b/web/src/store/types.ts @@ -67,6 +67,9 @@ export interface Message { attachments?: MessageAttachment[]; // Search citations from built-in search citations?: SearchCitation[]; + // Thinking/reasoning content from models like Claude, DeepSeek R1, OpenAI o1 + isThinking?: boolean; + thinkingContent?: string; } export interface KnowledgeContext { diff --git a/web/yarn.lock b/web/yarn.lock index ed19142e..bb444f4f 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -5343,14 +5343,12 @@ __metadata: react-dnd-html5-backend: "npm:^16.0.1" react-dom: "npm:^19.1.0" react-i18next: "npm:^16.5.0" - react-image-crop: "npm:^11.0.10" react-lite-youtube-embed: "npm:^3.3.3" react-markdown: "npm:^10.1.0" react-player: "npm:3.3.1" react-textarea-autosize: "npm:^8.5.9" react-tweet: "npm:^3.3.0" react-use-measure: "npm:^2.1.7" - react-use-websocket: "npm:^4.13.0" rehype-katex: "npm:^7.0.1" remark-gfm: "npm:^4.0.1" remark-math: "npm:^6.0.0" @@ -13385,15 +13383,6 @@ __metadata: languageName: node linkType: hard -"react-image-crop@npm:^11.0.10": - version: 11.0.10 - resolution: "react-image-crop@npm:11.0.10" - peerDependencies: - react: ">=16.13.1" - checksum: 10c0/2c2c7066c3a51838fc074a4a125ee2063170e01c26e3b1c8fc8b66155cf7ef576b5747089b4f365ce9dd2728794c39f57643b37eb78116e9db3f4b6299d3f682 - languageName: node - linkType: hard - "react-is@npm:^16.13.1, react-is@npm:^16.7.0": version: 16.13.1 resolution: "react-is@npm:16.13.1" @@ -13566,13 +13555,6 @@ __metadata: languageName: node linkType: hard -"react-use-websocket@npm:^4.13.0": - version: 4.13.0 - resolution: "react-use-websocket@npm:4.13.0" - checksum: 10c0/92f0941c67984f3b43979a2e5aa9a358d1e2b01591575b09fdcdd638d0c4275f6c5e180d1173632bc0fd564458afc3583643635eaeb4e0b8ce059555576661f3 - languageName: node - linkType: hard - "react@npm:^19.1.0": version: 19.2.3 resolution: "react@npm:19.2.3"