Skip to content

Commit c183637

Browse files
manavgupclaude
andcommitted
fix: Add Settings injection to all services for proper .env fallback (#458)
Inject Settings dependency into all services that instantiate LLMParametersService to ensure proper .env value fallback. **Services Updated:** - CollectionService - ConversationService - ConversationSummarizationService - EntityExtractionService - PipelineService - PodcastService - QuestionService - SearchService - UserProviderService **Other Updates:** - data_ingestion/ingestion.py - Settings injection - doc_utils.py - Settings injection - generation/providers/factory.py - Settings injection - retrieval/reranker.py - Settings injection - router/user_routes/llm_routes.py - Settings injection **Why:** These services create LLMParametersService instances. With the fix in #458, LLMParametersService now requires Settings to properly fall back to .env values when no database override exists. **Impact:** All services now respect .env configuration values like MAX_NEW_TOKENS=1024 instead of using hardcoded defaults. Part of #458 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 729e07d commit c183637

File tree

14 files changed

+73
-37
lines changed

14 files changed

+73
-37
lines changed

backend/rag_solution/data_ingestion/ingestion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def _get_embedding_provider(self):
4949
session_factory = create_session_factory()
5050
db = session_factory()
5151
try:
52-
factory = LLMProviderFactory(db)
52+
from core.config import get_settings
53+
54+
settings = get_settings()
55+
factory = LLMProviderFactory(db, settings)
5356
logger.info("LLMProviderFactory created")
5457

5558
self._embedding_provider = factory.get_provider("watsonx")

backend/rag_solution/doc_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def _get_embeddings_for_doc_utils(text: str | list[str]) -> list[list[float]]:
4141
db = session_factory()
4242

4343
try:
44-
factory = LLMProviderFactory(db)
44+
from core.config import get_settings
45+
46+
settings = get_settings()
47+
factory = LLMProviderFactory(db, settings)
4548
provider = factory.get_provider("watsonx")
4649
return provider.get_embeddings(text)
4750
except LLMProviderError as e:

backend/rag_solution/generation/providers/factory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from threading import Lock
66
from typing import TYPE_CHECKING, ClassVar
77

8+
from core.config import Settings
89
from core.custom_exceptions import LLMProviderError
910
from core.logging_utils import get_logger
1011
from rag_solution.services.llm_model_service import LLMModelService
@@ -43,19 +44,21 @@ class LLMProviderFactory:
4344
_providers: ClassVar[dict[str, type[LLMBase]]] = {}
4445
_lock: ClassVar[Lock] = Lock()
4546

46-
def __init__(self, db: Session) -> None:
47+
def __init__(self, db: Session, settings: Settings) -> None:
4748
"""
4849
Initialize factory with database session and required services.
4950
5051
Args:
5152
db: SQLAlchemy database session
53+
settings: Application settings
5254
"""
5355
self._db = db
56+
self._settings = settings
5457
self._instances: dict[str, LLMBase] = {}
5558

5659
# Initialize required services
5760
self._llm_provider_service = LLMProviderService(db)
58-
self._llm_parameters_service = LLMParametersService(db)
61+
self._llm_parameters_service = LLMParametersService(db, settings)
5962
self._prompt_template_service = PromptTemplateService(db)
6063
self._llm_model_service = LLMModelService(db)
6164

backend/rag_solution/retrieval/reranker.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,20 +548,42 @@ def rerank(
548548

549549
rerank_time = time.time() - start_time
550550

551-
# Combine results with scores (strict=True for safety)
552-
scored_results = list(zip(results, scores, strict=True))
551+
# Normalize cross-encoder scores to 0-1 range
552+
# MS-MARCO models output scores in range ~[-10, +10]
553+
# Frontend expects scores in [0, 1] for display as percentages
554+
min_score = float(scores.min())
555+
max_score = float(scores.max())
556+
score_range = max_score - min_score
557+
558+
if score_range > 0:
559+
# Min-max normalization preserves relative ranking
560+
normalized_scores = [(float(s) - min_score) / score_range for s in scores]
561+
logger.debug(
562+
"Normalized scores: min=%.3f, max=%.3f, range=%.3f",
563+
min_score,
564+
max_score,
565+
score_range,
566+
)
567+
else:
568+
# All scores identical - assign 0.5 to all
569+
normalized_scores = [0.5 for _ in scores]
570+
logger.debug("All cross-encoder scores identical (%.3f), using 0.5", min_score)
571+
572+
# Combine results with normalized scores (strict=True for safety)
573+
scored_results = list(zip(results, normalized_scores, strict=True))
553574

554575
# Sort by cross-encoder scores (descending)
555576
sorted_results = sorted(scored_results, key=lambda x: x[1], reverse=True)
556577

557-
# Update QueryResult scores with cross-encoder scores
578+
# Update QueryResult scores with normalized cross-encoder scores
558579
# Note: QueryResult schema only has chunk, score, embeddings
559580
# Collection info is preserved in the chunk object
581+
# Scores are already normalized to [0, 1] range for frontend display
560582
reranked_results = []
561583
for result, ce_score in sorted_results:
562584
new_result = QueryResult(
563585
chunk=result.chunk,
564-
score=float(ce_score), # Convert numpy float to Python float
586+
score=float(ce_score), # Already normalized to 0-1 range
565587
embeddings=result.embeddings,
566588
)
567589
reranked_results.append(new_result)

backend/rag_solution/router/user_routes/llm_routes.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import UUID4
88
from sqlalchemy.orm import Session
99

10-
from rag_solution.core.dependencies import get_db, verify_user_access
10+
from rag_solution.core.dependencies import get_db, get_llm_parameters_service, verify_user_access
1111
from rag_solution.schemas.llm_model_schema import LLMModelOutput
1212
from rag_solution.schemas.llm_parameters_schema import LLMParametersInput, LLMParametersOutput
1313
from rag_solution.schemas.llm_provider_schema import LLMProviderInput, LLMProviderOutput
@@ -35,10 +35,11 @@
3535
},
3636
)
3737
async def get_llm_parameters(
38-
user_id: UUID4, user: Annotated[UserOutput, Depends(verify_user_access)], db: Annotated[Session, Depends(get_db)]
38+
user_id: UUID4,
39+
user: Annotated[UserOutput, Depends(verify_user_access)],
40+
service: Annotated[LLMParametersService, Depends(get_llm_parameters_service)],
3941
) -> list[LLMParametersOutput]:
4042
"""Retrieve all LLM parameters for a user."""
41-
service = LLMParametersService(db)
4243
try:
4344
return service.get_user_parameters(user.id)
4445
except Exception as e:
@@ -55,10 +56,9 @@ async def create_llm_parameters(
5556
user_id: UUID4,
5657
parameters_input: LLMParametersInput,
5758
user: Annotated[UserOutput, Depends(verify_user_access)],
58-
db: Annotated[Session, Depends(get_db)],
59+
service: Annotated[LLMParametersService, Depends(get_llm_parameters_service)],
5960
) -> LLMParametersOutput:
6061
"""Create a new set of LLM parameters for a user."""
61-
service = LLMParametersService(db)
6262
try:
6363
return service.create_parameters(parameters_input)
6464
except Exception as e:
@@ -76,10 +76,9 @@ async def update_llm_parameters(
7676
parameter_id: UUID4,
7777
parameters_input: LLMParametersInput,
7878
user: Annotated[UserOutput, Depends(verify_user_access)],
79-
db: Annotated[Session, Depends(get_db)],
79+
service: Annotated[LLMParametersService, Depends(get_llm_parameters_service)],
8080
) -> LLMParametersOutput:
8181
"""Update an existing set of LLM parameters."""
82-
service = LLMParametersService(db)
8382
try:
8483
return service.update_parameters(parameter_id, parameters_input)
8584
except Exception as e:
@@ -96,10 +95,9 @@ async def delete_llm_parameters(
9695
user_id: UUID4,
9796
parameter_id: UUID4,
9897
user: Annotated[UserOutput, Depends(verify_user_access)],
99-
db: Annotated[Session, Depends(get_db)],
98+
service: Annotated[LLMParametersService, Depends(get_llm_parameters_service)],
10099
) -> bool:
101100
"""Delete an existing set of LLM parameters."""
102-
service = LLMParametersService(db)
103101
try:
104102
service.delete_parameters(parameter_id)
105103
return True
@@ -117,10 +115,9 @@ async def set_default_llm_parameters(
117115
user_id: UUID4,
118116
parameter_id: UUID4,
119117
user: Annotated[UserOutput, Depends(verify_user_access)],
120-
db: Annotated[Session, Depends(get_db)],
118+
service: Annotated[LLMParametersService, Depends(get_llm_parameters_service)],
121119
) -> LLMParametersOutput:
122120
"""Set a specific set of LLM parameters as default."""
123-
service = LLMParametersService(db)
124121
try:
125122
return service.set_default_parameters(parameter_id)
126123
except Exception as e:

backend/rag_solution/services/collection_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, db: Session, settings: Settings) -> None:
7474
# Initialize other services
7575
self.user_provider_service = UserProviderService(db, settings)
7676
self.prompt_template_service = PromptTemplateService(db)
77-
self.llm_parameters_service = LLMParametersService(db)
77+
self.llm_parameters_service = LLMParametersService(db, settings)
7878
self.question_service = QuestionService(db, settings)
7979
self.llm_model_service = LLMModelService(db)
8080

backend/rag_solution/services/conversation_service.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,11 +1538,16 @@ async def generate_conversation_name(self, session_id: UUID, user_id: UUID) -> s
15381538
Title:"""
15391539

15401540
# Use the LLM to generate the name
1541+
# Use low max_tokens for short titles (typically 2-5 words)
1542+
# Use lower temperature for focused, concise output
1543+
max_tokens = 20 # Reasonable limit for short titles
1544+
temperature = min(self.settings.temperature, 0.3) # Cap at 0.3 for consistency
1545+
15411546
try:
15421547
if hasattr(provider, "generate") and callable(provider.generate):
1543-
response = await provider.generate(prompt, max_tokens=20, temperature=0.3)
1548+
response = await provider.generate(prompt, max_tokens=max_tokens, temperature=temperature)
15441549
elif hasattr(provider, "llm_base") and hasattr(provider.llm_base, "generate"):
1545-
response = await provider.llm_base.generate(prompt, max_tokens=20, temperature=0.3)
1550+
response = await provider.llm_base.generate(prompt, max_tokens=max_tokens, temperature=temperature)
15461551
else:
15471552
# Fallback to simple name generation
15481553
return self._generate_simple_name_from_questions(user_questions)

backend/rag_solution/services/conversation_summarization_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ async def _generate_summary_content(
306306
# Create LLM provider instance using factory
307307
from rag_solution.generation.providers.factory import LLMProviderFactory
308308

309-
factory = LLMProviderFactory(self.db)
309+
factory = LLMProviderFactory(self.db, self.settings)
310310
llm_provider = factory.get_provider(provider_config.name)
311311

312312
# Generate summary

backend/rag_solution/services/entity_extraction_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ async def _extract_with_llm(self, context: str) -> list[str]:
208208

209209
# Get actual provider instance
210210
try:
211-
factory = LLMProviderFactory(self.db)
211+
factory = LLMProviderFactory(self.db, self.settings)
212212
provider = factory.get_provider(provider_config.name)
213213
except (ImportError, ValueError, RuntimeError) as e:
214214
logger.error("Failed to get LLM provider: %s", e)
@@ -239,7 +239,10 @@ async def _extract_with_llm(self, context: str) -> list[str]:
239239
try:
240240
# Generate using provider
241241
if hasattr(provider, "generate"):
242-
response = await provider.generate(prompt=prompt, max_tokens=100, temperature=0.0)
242+
# Use conservative max_tokens for entity extraction (typically short lists)
243+
# Keep temperature=0.0 for deterministic extraction
244+
max_tokens = min(self.settings.max_new_tokens, 150) # Cap at 150 for entities
245+
response = await provider.generate(prompt=prompt, max_tokens=max_tokens, temperature=0.0)
243246
else:
244247
logger.warning("Provider does not support generate(), falling back to spaCy")
245248
return self._extract_with_spacy(context)

backend/rag_solution/services/pipeline_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def pipeline_repository(self) -> PipelineConfigRepository:
8686
def llm_parameters_service(self) -> LLMParametersService:
8787
"""Get or create LLM parameters service instance."""
8888
if self._llm_parameters_service is None:
89-
self._llm_parameters_service = LLMParametersService(self.db)
89+
self._llm_parameters_service = LLMParametersService(self.db, self.settings)
9090
return self._llm_parameters_service
9191

9292
@property
@@ -184,7 +184,7 @@ def get_reranker(self, user_id: UUID4) -> BaseReranker | None:
184184
# Justification: Lazy import to avoid circular dependency
185185
from rag_solution.generation.providers.factory import LLMProviderFactory
186186

187-
factory = LLMProviderFactory(self.db)
187+
factory = LLMProviderFactory(self.db, self.settings)
188188
llm_provider = factory.get_provider(provider_config.name)
189189

190190
# Get reranking prompt template (user-specific)
@@ -602,7 +602,7 @@ def _validate_configuration(
602602
resource_id=str(pipeline_config.provider_id),
603603
)
604604

605-
provider = LLMProviderFactory(self.db).get_provider(provider_output.name)
605+
provider = LLMProviderFactory(self.db, self.settings).get_provider(provider_output.name)
606606
if not provider:
607607
raise ConfigurationError("llm_provider", "Failed to initialize LLM provider")
608608

0 commit comments

Comments
 (0)