Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion service/app/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def create_chat_agent(

# Create LangChain model WITH provider-side web search binding.
# This ensures OpenAI gets `web_search_preview` and Gemini/Vertex gets `google_search`.
llm: BaseChatModel = user_provider_manager.create_langchain_model(
llm: BaseChatModel = await user_provider_manager.create_langchain_model(
provider_id,
model=model_name,
google_search_enabled=google_search_enabled,
Expand Down
13 changes: 13 additions & 0 deletions service/app/api/v1/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app.middleware.auth import get_current_user
from app.models.file import FileCreate, FileRead, FileReadWithUrl, FileUpdate
from app.repos.file import FileRepository
from app.repos.knowledge_set import KnowledgeSetRepository

logger = logging.getLogger(__name__)

Expand All @@ -41,6 +42,7 @@ async def upload_file(
scope: str = Form(FileScope.PRIVATE),
category: str | None = Form(None),
folder_id: UUID | None = Form(None),
knowledge_set_id: UUID | None = Form(None),
user_id: str = Depends(get_current_user),
storage: StorageServiceProto = Depends(get_storage_service),
db: AsyncSession = Depends(get_session),
Expand Down Expand Up @@ -133,6 +135,17 @@ async def upload_file(
)

file_record = await file_repo.create_file(file_create)

# Link to knowledge set if provided
if knowledge_set_id:
try:
ks_repo = KnowledgeSetRepository(db)
await ks_repo.validate_access(user_id, knowledge_set_id)
await ks_repo.link_file_to_knowledge_set(file_record.id, knowledge_set_id)
except ValueError as e:
logger.warning(f"Failed to link file to knowledge set during upload: {e}")
# Don't fail the whole upload if linking fails due to access/existence

await db.commit()
await db.refresh(file_record)

Expand Down
13 changes: 8 additions & 5 deletions service/app/api/v1/knowledge_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
KnowledgeSetRead,
KnowledgeSetUpdate,
KnowledgeSetWithFileCount,
BulkLinkFilesRequest,
)
from app.repos.file import FileRepository
from app.repos.knowledge_set import KnowledgeSetRepository
Expand Down Expand Up @@ -333,7 +334,7 @@ async def get_files_in_knowledge_set(
@router.post("/{knowledge_set_id}/files/bulk-link")
async def bulk_link_files_to_knowledge_set(
knowledge_set_id: UUID,
file_ids: list[UUID],
request: BulkLinkFilesRequest,
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
Expand All @@ -352,15 +353,17 @@ async def bulk_link_files_to_knowledge_set(
raise ErrCode.KNOWLEDGE_SET_ACCESS_DENIED.with_messages("Access denied")

# Verify all files exist and user has access
for file_id in file_ids:
for file_id in request.file_ids:
file = await file_repo.get_file_by_id(file_id)
if not file:
raise ErrCode.FILE_NOT_FOUND.with_messages(f"File {file_id} not found")
if file.user_id != user_id:
raise ErrCode.FILE_ACCESS_DENIED.with_messages(f"File {file_id} access denied")

# Bulk link
successful, skipped = await knowledge_set_repo.bulk_link_files_to_knowledge_set(file_ids, knowledge_set_id)
successful, skipped = await knowledge_set_repo.bulk_link_files_to_knowledge_set(
request.file_ids, knowledge_set_id
)
await db.commit()

return {
Expand All @@ -382,7 +385,7 @@ async def bulk_link_files_to_knowledge_set(
@router.post("/{knowledge_set_id}/files/bulk-unlink")
async def bulk_unlink_files_from_knowledge_set(
knowledge_set_id: UUID,
file_ids: list[UUID],
request: BulkLinkFilesRequest,
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
Expand All @@ -400,7 +403,7 @@ async def bulk_unlink_files_from_knowledge_set(
raise ErrCode.KNOWLEDGE_SET_ACCESS_DENIED.with_messages("Access denied")

# Bulk unlink
count = await knowledge_set_repo.bulk_unlink_files_from_knowledge_set(file_ids, knowledge_set_id)
count = await knowledge_set_repo.bulk_unlink_files_from_knowledge_set(request.file_ids, knowledge_set_id)
await db.commit()

return {
Expand Down
65 changes: 42 additions & 23 deletions service/app/api/v1/providers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Any, NotRequired, cast
from typing import Any
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from litellm.types.utils import ModelInfo
from sqlmodel.ext.asyncio.session import AsyncSession

from app.common.code.error_code import ErrCodeError, handle_auth_error
from app.configs import configs
from app.core.auth import AuthorizationService, get_auth_service
from app.core.llm.service import LiteLLMService
from app.core.model_registry import ModelInfo, ModelsDevService
from app.infra.database import get_session
from app.middleware.auth import get_current_user
from app.models.provider import ProviderCreate, ProviderRead, ProviderUpdate
Expand All @@ -28,8 +27,10 @@ def _sanitize_provider_read(provider: Any) -> ProviderRead:
return ProviderRead(**provider_dict)


class DefaultModelInfo(ModelInfo, total=False):
provider_type: NotRequired[str]
class DefaultModelInfo(ModelInfo):
"""Extended ModelInfo with provider_type for default model endpoint."""

provider_type: str | None = None


@router.get("/default-model", response_model=DefaultModelInfo)
Expand All @@ -52,40 +53,58 @@ async def get_default_model_config() -> DefaultModelInfo:
model = default_cfg.model
provider_type = default_provider.value

model_info = LiteLLMService.get_model_info(model)
model_info = await ModelsDevService.get_model_info_for_key(model)

if not model_info:
raise HTTPException(
status_code=500,
detail=f"Failed to get model info for default model: {model}",
# Return a basic ModelInfo if not found in models.dev
return DefaultModelInfo(
key=model,
provider_type=provider_type,
)

# Add key and provider_type to ModelInfo
result: dict[str, Any] = dict(model_info)
result["key"] = model
result["provider_type"] = provider_type
# Add supported_openai_params if missing (required by ModelInfo)
if "supported_openai_params" not in result:
result["supported_openai_params"] = None
return cast(DefaultModelInfo, result)
# Create DefaultModelInfo with provider_type
return DefaultModelInfo(
key=model,
name=model_info.name,
max_tokens=model_info.max_tokens,
max_input_tokens=model_info.max_input_tokens,
max_output_tokens=model_info.max_output_tokens,
input_cost_per_token=model_info.input_cost_per_token,
output_cost_per_token=model_info.output_cost_per_token,
litellm_provider=model_info.litellm_provider,
mode=model_info.mode,
supports_function_calling=model_info.supports_function_calling,
supports_parallel_function_calling=model_info.supports_parallel_function_calling,
supports_vision=model_info.supports_vision,
supports_audio_input=model_info.supports_audio_input,
supports_audio_output=model_info.supports_audio_output,
supports_reasoning=model_info.supports_reasoning,
supports_structured_output=model_info.supports_structured_output,
supports_web_search=model_info.supports_web_search,
model_family=model_info.model_family,
knowledge_cutoff=model_info.knowledge_cutoff,
release_date=model_info.release_date,
open_weights=model_info.open_weights,
provider_type=provider_type,
)


@router.get("/templates", response_model=dict[str, list[ModelInfo]])
async def get_provider_templates() -> dict[str, list[ModelInfo]]:
"""
Get available provider templates with metadata for the UI.
Returns configuration templates for all supported LLM providers.
Dynamically fetches models from LiteLLM instead of using hardcoded config.
Dynamically fetches models from models.dev.
"""
return LiteLLMService.get_all_providers_with_models()
return await ModelsDevService.get_all_providers_with_models()


@router.get("/models", response_model=list[str])
async def get_supported_models() -> list[str]:
"""
Get a list of all models supported by the system (via LiteLLM).
Get a list of all models supported by the system (via models.dev).
"""
return LiteLLMService.list_supported_models()
return await ModelsDevService.list_all_models()


@router.get("/available-models", response_model=dict[str, list[ModelInfo]])
Expand All @@ -112,7 +131,7 @@ async def get_available_models_for_user(
result: dict[str, list[ModelInfo]] = {}

for provider in providers:
models = LiteLLMService.get_models_by_provider(provider.provider_type)
models = await ModelsDevService.get_models_by_provider_type(provider.provider_type)
if models:
result[str(provider.id)] = models

Expand Down Expand Up @@ -144,7 +163,7 @@ async def get_provider_available_models(
except ErrCodeError as e:
raise handle_auth_error(e)

models = LiteLLMService.get_models_by_provider(str(provider.provider_type))
models = await ModelsDevService.get_models_by_provider_type(str(provider.provider_type))
return models


Expand Down
2 changes: 1 addition & 1 deletion service/app/core/chat/topic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def generate_and_update_topic_title(
f"{message_text}"
)

llm = user_provider_manager.create_langchain_model(provider_id, model_name)
llm = await user_provider_manager.create_langchain_model(provider_id, model_name)
response = await llm.ainvoke([HumanMessage(content=prompt)])
logger.debug(f"LLM response: {response}")

Expand Down
Loading
Loading