Skip to content

Commit 5debb76

Browse files
manavgupclaude
andauthored
feat: Add production-grade cross-encoder reranking (#548)
* fix: Remove arbitrary 10K char limit + fix datetime.utcnow() deprecation **Critical Fix - Message Content Length**: - Increased ConversationMessageInput.content max_length from 10,000 to 100,000 characters - **Problem**: LLM responses frequently exceed 10K chars, especially with: - Chain of Thought reasoning (adds 8K-16K chars) - Code examples and technical documentation - Long document summaries - Claude can output ~32,000 chars, GPT-4 ~16,000 chars - **Impact**: Users getting 404 errors with "string_too_long" validation failures - **Solution**: Raised limit to 100,000 chars (safe for all LLM use cases) **Deprecation Fix - datetime.utcnow()**: - Replaced all datetime.utcnow() with datetime.now(UTC) - **Files**: conversation_schema.py (9 occurrences), conversation_service.py (4 occurrences) - **Reason**: datetime.utcnow() deprecated in Python 3.12+ - **Migration**: Added UTC import, changed: - datetime.utcnow() → datetime.now(UTC) - default_factory=datetime.utcnow → default_factory=lambda: datetime.now(UTC) **Error Resolved**: ``` ValidationError: 1 validation error for ConversationMessageInput content String should have at most 10000 characters [type=string_too_long] ``` **Testing**: ✅ Schema validation works with 50,000+ char content ✅ datetime.now(UTC) produces timezone-aware timestamps ✅ No breaking changes to API **Files Changed**: - backend/rag_solution/schemas/conversation_schema.py - backend/rag_solution/services/conversation_service.py Fixes: User-reported runtime error in conversation service Related: Python 3.12 deprecation warnings (Issue #520) Signed-off-by: manavgup <manavg@gmail.com> * feat: Add production-grade cross-encoder reranking Implements fast, high-quality document reranking using cross-encoder models from sentence-transformers, replacing slow LLM-based reranking. Also fixes LLM hallucination bug in non-CoT path. ## Performance Improvements ### Reranking Speed (250x faster) - Before: 20-30s (LLM-based reranking) - After: 80ms (cross-encoder) - Model: cross-encoder/ms-marco-MiniLM-L-6-v2 ### End-to-End Query Speed (12.5x faster) - Before: 100s (broken LLM hallucination) - After stop sequences: 35s (still using LLM reranking) - After cross-encoder: 8-22s ✅ ## Quality Improvements - Precision-focused scoring (0-1 relevance scores) - Trained on MS MARCO dataset (530K query-document pairs) - Industry-standard approach (used by Cohere, Pinecone, Weaviate) - Maintains quality while achieving 250x speedup ## Changes Made 1. **reranker.py**: Added CrossEncoderReranker class - Uses sentence-transformers library - Batch processing for efficiency - Comprehensive logging and error handling - Model caching (7s first load, 1s subsequent) 2. **pipeline_service.py**: Integrated cross-encoder into pipeline - Added cross-encoder branch in get_reranker() - Fallback to SimpleReranker on errors - User-level reranker selection 3. **config.py**: Added cross-encoder configuration - RERANKER_TYPE=cross-encoder option - CROSS_ENCODER_MODEL setting (default: ms-marco-MiniLM-L-6-v2) 4. **watsonx.py**: Fixed LLM hallucination bug - Added stop_sequences: ["##", "\n\nQuestion:", "\n\n##"] - Prevents LLM from generating extra unwanted Q&A pairs 5. **user_provider_service.py**: Enhanced system prompt - Explicit instructions to answer only user's question - Prevents multi-question generation 6. **pyproject.toml**: Added sentence-transformers dependency - Version: ^5.1.2 ## Configuration Add to .env: ```bash ENABLE_RERANKING=true RERANKER_TYPE=cross-encoder CROSS_ENCODER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2 ``` ## Testing Results ✅ No-CoT + top_k=20: 8s (80ms reranking) ✅ No-CoT + top_k=5: 22s (includes 7s model load on first request) ✅ CoT + top_k=5: 27s (70ms reranking) All queries return correct, concise answers with proper source attribution. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: manavgup <manavg@gmail.com> * fix: Address PR #548 review - fix linting error and failing test - Fix B905: Add strict=False to zip() in CrossEncoderReranker - Fix failing test: Update test_message_content_validation_max_length to use correct max_length (100000) - Add integration strategy to MASTER_ISSUES_ROADMAP.md The test was expecting max_length=10000 but the actual schema allows max_length=100000. Signed-off-by: manavgup <manavg@gmail.com> * fix: Address PR #548 review comments - critical fixes 1. ✅ Remove debug print() statements from watsonx.py - Removed 18 lines of print() debug output (lines 384-401) - Changed INFO logging to DEBUG level with guard check - Truncate text preview to 100 chars for readability 2. ✅ Change zip() strict parameter to True - Changed strict=False to strict=True in CrossEncoderReranker.rerank() - Safer failure mode - will raise if lists are misaligned 3. ✅ Add comprehensive unit tests for CrossEncoderReranker - Created tests/unit/retrieval/test_cross_encoder_reranker.py - 35 unit tests covering all functionality - Tests initialization, reranking, top-k, empty input, async, errors - All tests passing Addresses issues from: - PR review comment #3470377166 - pr-reviewer agent findings (confidence 95, 88, 85) Remaining items (lower priority): - Move sentence_transformers import to module level - Add type hint to __init__ - Fix stop sequences specificity - Add model security validation - Add documentation Signed-off-by: manavgup <manavg@gmail.com> * fix: Address CRITICAL schema issues and async deprecations 1. ✅ Fix CRITICAL QueryResult schema mismatch - Removed collection_id and collection_name from QueryResult creation - QueryResult schema only has: chunk, score, embeddings - Collection info is preserved in chunk.metadata - Fixes Pydantic validation error that would occur at runtime 2. ✅ Fix async deprecation warnings - Changed get_event_loop() → get_running_loop() - Removed unnecessary lambda wrapper in executor call - Python 3.10+ compatible 3. ✅ Add error handling to cross-encoder prediction - Wrap model.predict() in try-except - Raise ValueError with context on failure - Better error messages for debugging 4. ✅ Optimize debug logging in watsonx.py - Limit to first 5 texts (was unlimited) - Add summary for remaining texts - Prevents expensive logging loops for large batches These fixes address the most critical issues found in PR review: - Schema mismatch would cause runtime errors (confidence 100%) - Async deprecation causes warnings in Python 3.10+ (confidence 95%) - Missing error handling could cause cryptic failures (confidence 85%) - Debug logging could impact performance (confidence 80%) Signed-off-by: manavgup <manavg@gmail.com> * fix: Final fixes for PR #548 - tests, linting, and documentation 1. ✅ Fix 3 failing error handling tests - Updated tests to expect ValueError (wrapped exceptions) - Tests now match new error handling behavior - All 35 tests passing 2. ✅ Add comprehensive MkDocs documentation - Created docs/features/cross-encoder-reranking.md (640 lines) - Added to docs/features/index.md navigation - Complete usage guide, API reference, troubleshooting - Performance comparisons and migration guide 3. ✅ Fix all linting issues - Fixed import ordering (removed blank lines, unused imports) - Added missing logging import in watsonx.py - Fixed MyPy type issues (None checks for result.chunk) - Removed unused loop variable in test - All ruff/mypy checks pass Changes: - backend/rag_solution/generation/providers/watsonx.py: Added logging import - backend/rag_solution/retrieval/reranker.py: Fixed None check for result.chunk.text - tests/unit/retrieval/test_cross_encoder_reranker.py: Fixed test expectations + imports - docs/features/cross-encoder-reranking.md: NEW - comprehensive documentation - docs/features/index.md: Added cross-encoder reranking to features list All files now pass: ✅ Ruff formatting and linting ✅ MyPy type checking ✅ Import ordering (isort) ✅ All 35 unit tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: manavgup <manavg@gmail.com> --------- Signed-off-by: manavgup <manavg@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 9fcb102 commit 5debb76

File tree

13 files changed

+1700
-47
lines changed

13 files changed

+1700
-47
lines changed

backend/core/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class Settings(BaseSettings):
153153

154154
# Reranking settings
155155
enable_reranking: Annotated[bool, Field(default=True, alias="ENABLE_RERANKING")]
156-
reranker_type: Annotated[str, Field(default="llm", alias="RERANKER_TYPE")] # Options: llm, simple
156+
reranker_type: Annotated[str, Field(default="llm", alias="RERANKER_TYPE")] # Options: llm, simple, cross-encoder
157157
reranker_top_k: Annotated[
158158
int | None, Field(default=5, alias="RERANKER_TOP_K")
159159
] # Default 5 for optimal quality/speed
@@ -162,6 +162,10 @@ class Settings(BaseSettings):
162162
reranker_prompt_template_name: Annotated[
163163
str, Field(default="reranking", alias="RERANKER_PROMPT_TEMPLATE_NAME")
164164
] # Template name for reranking prompts
165+
# Cross-encoder reranker settings (production-grade, ~100ms vs 20-30s for LLM)
166+
cross_encoder_model: Annotated[
167+
str, Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2", alias="CROSS_ENCODER_MODEL")
168+
] # Fast cross-encoder for reranking
165169

166170
# Podcast Generation settings
167171
# Environment: "development" uses FastAPI BackgroundTasks + local filesystem

backend/rag_solution/generation/providers/watsonx.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import logging
56
import time
67
from collections.abc import Generator, Sequence
78
from typing import Any
@@ -182,13 +183,14 @@ def _get_generation_params(
182183
if params is None:
183184
raise ValueError("No LLM parameters found for user")
184185

185-
# Convert to WatsonX format
186+
# Convert to WatsonX format with stop sequences
186187
return {
187188
GenParams.DECODING_METHOD: "sample",
188189
GenParams.MAX_NEW_TOKENS: params.max_new_tokens,
189190
GenParams.TEMPERATURE: params.temperature,
190191
GenParams.TOP_K: params.top_k,
191192
GenParams.TOP_P: params.top_p,
193+
GenParams.STOP_SEQUENCES: ["##", "\n\nQuestion:", "\n\n##"], # Stop at markdown headers or new questions
192194
}
193195

194196
def generate_text(
@@ -379,8 +381,13 @@ def get_embeddings(self, texts: str | Sequence[str]) -> EmbeddingsList:
379381
if isinstance(texts, str):
380382
texts = [texts]
381383

382-
logger.debug("Generating embeddings for %d texts", len(texts))
383-
logger.debug("Embeddings client: %s", self.embeddings_client)
384+
# Debug logging for embeddings generation (limited to first 5 for performance)
385+
if logger.isEnabledFor(logging.DEBUG):
386+
logger.debug("Generating embeddings for %d texts", len(texts))
387+
for idx, text in enumerate(texts[:5], 1):
388+
logger.debug("Text %d (length: %d chars): %s", idx, len(text), text[:100])
389+
if len(texts) > 5:
390+
logger.debug("... and %d more texts", len(texts) - 5)
384391

385392
# Add a configurable delay to prevent rate limiting
386393
settings = get_settings()

backend/rag_solution/retrieval/reranker.py

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,17 @@ async def _score_batch_async(self, query: str, batch: list[QueryResult]) -> list
275275
formatted_prompts = self._create_reranking_prompts(query, batch)
276276

277277
try:
278-
# Call LLM provider asynchronously
279-
responses = await self.llm_provider.generate_text(
280-
user_id=self.user_id,
281-
prompt=formatted_prompts,
282-
template=None,
278+
# Call LLM provider (synchronous - run in executor to avoid blocking)
279+
import asyncio
280+
281+
loop = asyncio.get_event_loop()
282+
responses = await loop.run_in_executor(
283+
None,
284+
lambda: self.llm_provider.generate_text(
285+
user_id=self.user_id,
286+
prompt=formatted_prompts,
287+
template=None,
288+
),
283289
)
284290

285291
# Extract scores from responses
@@ -461,3 +467,131 @@ async def rerank_async(
461467
Async version of rerank - SimpleReranker doesn't need concurrency, just wraps sync method.
462468
"""
463469
return self.rerank(query, results, top_k)
470+
471+
472+
class CrossEncoderReranker(BaseReranker):
473+
"""Fast cross-encoder reranker using sentence-transformers.
474+
475+
Production-grade reranker that uses a cross-encoder model to score
476+
query-document pairs. Much faster than LLM-based reranking (~100ms vs 20-30s).
477+
478+
Models:
479+
- cross-encoder/ms-marco-MiniLM-L-12-v2: Best accuracy (12 layers)
480+
- cross-encoder/ms-marco-MiniLM-L-6-v2: Faster, good accuracy (6 layers)
481+
- cross-encoder/ms-marco-TinyBERT-L-2-v2: Fastest, decent accuracy (2 layers)
482+
"""
483+
484+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
485+
"""
486+
Initialize cross-encoder reranker.
487+
488+
Args:
489+
model_name: HuggingFace model name for cross-encoder
490+
"""
491+
from sentence_transformers import CrossEncoder
492+
493+
self.model_name = model_name
494+
logger.info("Loading cross-encoder model: %s", model_name)
495+
start_time = time.time()
496+
self.model = CrossEncoder(model_name)
497+
load_time = time.time() - start_time
498+
logger.info("Cross-encoder loaded in %.2fs", load_time)
499+
500+
def rerank(
501+
self,
502+
query: str,
503+
results: list[QueryResult],
504+
top_k: int | None = None,
505+
) -> list[QueryResult]:
506+
"""
507+
Rerank results using cross-encoder model.
508+
509+
Cross-encoders score query-document pairs directly, providing more accurate
510+
relevance scoring than bi-encoder cosine similarity. This is the industry
511+
standard for production reranking (used by OpenAI, Anthropic, Cohere, etc.).
512+
513+
Args:
514+
query: The search query
515+
results: List of QueryResult objects to rerank
516+
top_k: Optional number of top results to return (defaults to len(results))
517+
518+
Returns:
519+
Reranked list of QueryResult objects with updated scores
520+
521+
Raises:
522+
ValueError: If model prediction fails
523+
"""
524+
if not results:
525+
logger.debug("No results to rerank")
526+
return []
527+
528+
if top_k is None:
529+
top_k = len(results)
530+
531+
logger.debug(
532+
"Reranking %d results with cross-encoder (top_k=%d, model=%s)",
533+
len(results),
534+
top_k,
535+
self.model_name,
536+
)
537+
538+
# Create query-document pairs for cross-encoder
539+
start_time = time.time()
540+
pairs = [[query, result.chunk.text if result.chunk and result.chunk.text else ""] for result in results]
541+
542+
# Score all pairs with cross-encoder (fast: ~100ms for 20 docs)
543+
try:
544+
scores = self.model.predict(pairs)
545+
except Exception as e:
546+
logger.error("Cross-encoder prediction failed: %s", e)
547+
raise ValueError(f"Reranking failed for model {self.model_name}: {e}") from e
548+
549+
rerank_time = time.time() - start_time
550+
551+
# Combine results with scores (strict=True for safety)
552+
scored_results = list(zip(results, scores, strict=True))
553+
554+
# Sort by cross-encoder scores (descending)
555+
sorted_results = sorted(scored_results, key=lambda x: x[1], reverse=True)
556+
557+
# Update QueryResult scores with cross-encoder scores
558+
# Note: QueryResult schema only has chunk, score, embeddings
559+
# Collection info is preserved in the chunk object
560+
reranked_results = []
561+
for result, ce_score in sorted_results:
562+
new_result = QueryResult(
563+
chunk=result.chunk,
564+
score=float(ce_score), # Convert numpy float to Python float
565+
embeddings=result.embeddings,
566+
)
567+
reranked_results.append(new_result)
568+
569+
# Return top_k results
570+
final_results = reranked_results[:top_k]
571+
572+
logger.info(
573+
"Reranked %d results → %d results in %.3fs (model=%s)",
574+
len(results),
575+
len(final_results),
576+
rerank_time,
577+
self.model_name,
578+
)
579+
580+
return final_results
581+
582+
async def rerank_async(
583+
self,
584+
query: str,
585+
results: list[QueryResult],
586+
top_k: int | None = None,
587+
) -> list[QueryResult]:
588+
"""
589+
Async version of rerank.
590+
591+
Cross-encoder inference is CPU-bound and relatively fast (~100ms),
592+
so we run it in an executor to avoid blocking the event loop.
593+
"""
594+
import asyncio
595+
596+
loop = asyncio.get_running_loop()
597+
return await loop.run_in_executor(None, self.rerank, query, results, top_k)

backend/rag_solution/schemas/conversation_schema.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
context management, and question suggestions.
55
"""
66

7-
from datetime import datetime
7+
from datetime import UTC, datetime
88
from enum import Enum
99
from typing import Any
1010
from uuid import uuid4
@@ -132,9 +132,9 @@ def to_output( # pylint: disable=too-many-arguments,too-many-positional-argumen
132132
) -> "ConversationSessionOutput":
133133
"""Convert input to output schema using Pydantic 2+ model validation."""
134134
if created_at is None:
135-
created_at = datetime.utcnow()
135+
created_at = datetime.now(UTC)
136136
if updated_at is None:
137-
updated_at = datetime.utcnow()
137+
updated_at = datetime.now(UTC)
138138

139139
# Use model_dump() to get all input data, then update with additional fields
140140
data = self.model_dump()
@@ -163,8 +163,8 @@ class ConversationSessionOutput(BaseModel):
163163
max_messages: int = Field(..., description="Maximum number of messages")
164164
is_archived: bool = Field(default=False, description="Whether the session is archived")
165165
is_pinned: bool = Field(default=False, description="Whether the session is pinned")
166-
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation timestamp")
167-
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update timestamp")
166+
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Creation timestamp")
167+
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Last update timestamp")
168168
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
169169
message_count: int = Field(default=0, description="Number of messages in the session")
170170

@@ -234,7 +234,7 @@ class ConversationMessageInput(BaseModel):
234234
"""Input schema for conversation messages."""
235235

236236
session_id: UUID4 = Field(..., description="ID of the session")
237-
content: str = Field(..., min_length=1, max_length=10000, description="Message content")
237+
content: str = Field(..., min_length=1, max_length=100000, description="Message content")
238238
role: MessageRole = Field(..., description="Role of the message sender")
239239
message_type: MessageType = Field(..., description="Type of message")
240240
metadata: MessageMetadata | dict[str, Any] | None = Field(default=None, description="Message metadata")
@@ -246,7 +246,7 @@ class ConversationMessageInput(BaseModel):
246246
def to_output(self, message_id: UUID4, created_at: datetime | None = None) -> "ConversationMessageOutput":
247247
"""Convert input to output schema using Pydantic 2+ model validation."""
248248
if created_at is None:
249-
created_at = datetime.utcnow()
249+
created_at = datetime.now(UTC)
250250

251251
# Use model_dump() to get all input data, then update with additional fields
252252
data = self.model_dump()
@@ -263,7 +263,7 @@ class ConversationMessageOutput(BaseModel):
263263
content: str = Field(..., description="Message content")
264264
role: MessageRole = Field(..., description="Role of the message sender")
265265
message_type: MessageType = Field(..., description="Type of message")
266-
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation timestamp")
266+
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Creation timestamp")
267267
metadata: MessageMetadata | None = Field(default=None, description="Message metadata")
268268
token_count: int | None = Field(default=None, description="Token count for this message")
269269
execution_time: float | None = Field(default=None, description="Execution time in seconds")
@@ -406,7 +406,7 @@ class ExportOutput(BaseModel):
406406
session_data: ConversationSessionOutput = Field(..., description="Session information")
407407
messages: list[ConversationMessageOutput] = Field(..., description="All messages in session")
408408
export_format: ExportFormat = Field(..., description="Format of the export")
409-
export_timestamp: datetime = Field(default_factory=datetime.utcnow, description="Export timestamp")
409+
export_timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Export timestamp")
410410
metadata: dict[str, Any] = Field(default_factory=dict, description="Export metadata")
411411

412412

@@ -463,7 +463,7 @@ class ConversationSummaryOutput(BaseModel):
463463
important_decisions: list[str] = Field(default_factory=list, description="Important decisions made")
464464
unresolved_questions: list[str] = Field(default_factory=list, description="Questions still unresolved")
465465
summary_strategy: SummarizationStrategy = Field(..., description="Strategy used for summarization")
466-
created_at: datetime = Field(default_factory=datetime.utcnow, description="Summary creation timestamp")
466+
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Summary creation timestamp")
467467
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional summary metadata")
468468

469469
@classmethod
@@ -609,7 +609,7 @@ class ConversationExportOutput(BaseModel):
609609
messages: list[ConversationMessageOutput] = Field(..., description="Exported messages")
610610
summaries: list[ConversationSummaryOutput] = Field(default_factory=list, description="Conversation summaries")
611611
export_format: ExportFormat = Field(..., description="Format of the export")
612-
export_timestamp: datetime = Field(default_factory=datetime.utcnow, description="Export timestamp")
612+
export_timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Export timestamp")
613613
total_messages: int = Field(..., ge=0, description="Total number of messages exported")
614614
total_tokens: int = Field(default=0, ge=0, description="Total tokens in exported content")
615615
file_size_bytes: int = Field(default=0, ge=0, description="Size of exported file in bytes")

backend/rag_solution/services/conversation_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
import re
9-
from datetime import datetime, timedelta
9+
from datetime import UTC, datetime, timedelta
1010
from typing import Any
1111
from uuid import UUID
1212

@@ -788,7 +788,7 @@ async def get_session_statistics(self, session_id: UUID, user_id: UUID) -> Sessi
788788
cot_usage_count=cot_usage_count,
789789
context_enhancement_count=context_enhancement_count,
790790
created_at=session.created_at,
791-
last_activity=datetime.utcnow(),
791+
last_activity=datetime.now(UTC),
792792
metadata={
793793
"total_llm_calls": total_llm_calls,
794794
"cot_token_count": cot_token_count,
@@ -811,7 +811,7 @@ async def export_session(self, session_id: UUID, user_id: UUID, export_format: s
811811
"session_data": session,
812812
"messages": messages,
813813
"export_format": export_format,
814-
"export_timestamp": datetime.utcnow(),
814+
"export_timestamp": datetime.now(UTC),
815815
"metadata": {"cot_integration": True, "context_enhancement": True},
816816
}
817817

@@ -1198,7 +1198,7 @@ def cleanup_expired_sessions(self) -> int:
11981198
"""Clean up expired sessions and return count of cleaned sessions."""
11991199

12001200
# Sessions expire after 7 days of inactivity
1201-
expiry_date = datetime.utcnow() - timedelta(days=7)
1201+
expiry_date = datetime.now(UTC) - timedelta(days=7)
12021202

12031203
expired_sessions = (
12041204
self.db.query(ConversationSession)
@@ -1382,7 +1382,7 @@ async def generate_conversation_summary(self, session_id: UUID, user_id: UUID, s
13821382
"topics": list(topics)[:10], # Limit to top 10 topics
13831383
"total_tokens": stats.total_tokens,
13841384
"cot_usage_count": stats.cot_usage_count,
1385-
"generated_at": datetime.utcnow().isoformat(),
1385+
"generated_at": datetime.now(UTC).isoformat(),
13861386
}
13871387

13881388
def _generate_brief_summary(

backend/rag_solution/services/pipeline_service.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def get_reranker(self, user_id: UUID4) -> BaseReranker | None:
148148
user_id: User UUID for creating LLM-based reranker
149149
150150
Returns:
151-
Reranker instance (LLMReranker or SimpleReranker), or None if disabled
151+
Reranker instance (CrossEncoderReranker, LLMReranker or SimpleReranker), or None if disabled
152152
"""
153153
if not self.settings.enable_reranking:
154154
return None
@@ -157,10 +157,22 @@ def get_reranker(self, user_id: UUID4) -> BaseReranker | None:
157157

158158
# pylint: disable=import-outside-toplevel
159159
# Justification: Lazy import to avoid circular dependency
160-
from rag_solution.retrieval.reranker import LLMReranker, SimpleReranker
160+
from rag_solution.retrieval.reranker import CrossEncoderReranker, LLMReranker, SimpleReranker
161161
from rag_solution.schemas.prompt_template_schema import PromptTemplateType
162162

163-
if self.settings.reranker_type == "llm":
163+
if self.settings.reranker_type == "cross-encoder":
164+
try:
165+
logger.debug("Creating cross-encoder reranker for user %s", user_id)
166+
reranker = CrossEncoderReranker(model_name=self.settings.cross_encoder_model)
167+
logger.debug("Cross-encoder reranker created successfully for user %s", user_id)
168+
return reranker
169+
except Exception as e: # pylint: disable=broad-exception-caught
170+
# Justification: Fallback to simple reranker for any initialization error
171+
logger.warning(
172+
"Failed to create cross-encoder reranker for user %s: %s, using simple reranker", user_id, e
173+
)
174+
return SimpleReranker()
175+
elif self.settings.reranker_type == "llm":
164176
try:
165177
# Get LLM provider
166178
provider_config = self.llm_provider_service.get_default_provider()

backend/rag_solution/services/user_provider_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def _create_default_rag_template(self, user_id: UUID4) -> PromptTemplateOutput:
124124
template_type=PromptTemplateType.RAG_QUERY,
125125
system_prompt=(
126126
"You are a helpful AI assistant specializing in answering questions based on the given context. "
127+
"Answer ONLY the user's question that is provided. "
128+
"Do not generate additional questions or topics. "
129+
"Provide a single, focused, concise answer based on the context.\n\n"
127130
"Format your responses using Markdown for better readability:\n"
128131
"- Use **bold** for emphasis on key points\n"
129132
"- Use bullet points (- or *) for lists\n"

0 commit comments

Comments
 (0)