Skip to content

Commit

Permalink
Fix db session used in different threads
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 4, 2024
1 parent f66139a commit 5a668d4
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/backend/base/langflow/api/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@router.post("/login", response_model=Token)
async def login_to_get_access_token(
def login_to_get_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: DbSession,
Expand Down
47 changes: 24 additions & 23 deletions src/backend/base/langflow/services/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_session, get_settings_service
from langflow.services.deps import get_db_service, get_session, get_settings_service
from langflow.services.settings.service import SettingsService

oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
Expand All @@ -36,31 +36,32 @@
def api_key_security(
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[Session, Depends(get_session)],
) -> UserRead | None:
settings_service = get_settings_service()
result: ApiKey | User | None = None
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing first superuser credentials",
)

result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
with get_db_service().with_session() as db:
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing first superuser credentials",
)

elif not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)

elif not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)

elif query_param:
result = check_key(db, query_param)
elif query_param:
result = check_key(db, query_param)

else:
result = check_key(db, header_param)
else:
result = check_key(db, header_param)

if not result:
raise HTTPException(
Expand All @@ -83,7 +84,7 @@ async def get_current_user(
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
user = await asyncio.to_thread(api_key_security, query_param, header_param, db)
user = await asyncio.to_thread(api_key_security, query_param, header_param)
if user:
return user

Expand Down Expand Up @@ -164,17 +165,17 @@ async def get_current_user_for_websocket(
if token:
return await get_current_user_by_jwt(token, db)
if api_key:
return await asyncio.to_thread(api_key_security, api_key, query_param, db)
return await asyncio.to_thread(api_key_security, api_key, query_param)
return None


def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
return current_user


def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
if not current_user.is_superuser:
Expand Down
6 changes: 5 additions & 1 deletion src/backend/base/langflow/services/database/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import time
from contextlib import contextmanager
from datetime import datetime, timezone
Expand Down Expand Up @@ -317,7 +318,7 @@ def create_db_and_tables(self) -> None:

logger.debug("Database and tables created successfully")

async def teardown(self) -> None:
def _teardown(self) -> None:
logger.debug("Tearing down database")
try:
settings_service = get_settings_service()
Expand All @@ -330,3 +331,6 @@ async def teardown(self) -> None:
logger.exception("Error tearing down database")

self.engine.dispose()

async def teardown(self) -> None:
await asyncio.to_thread(self._teardown)
9 changes: 8 additions & 1 deletion src/backend/base/langflow/services/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from loguru import logger
from sqlmodel import Session, select

Expand Down Expand Up @@ -110,10 +112,15 @@ def teardown_superuser(settings_service, session) -> None:
raise RuntimeError(msg) from exc


def _teardown_superuser():
with get_session() as session:
teardown_superuser(get_settings_service(), session)


async def teardown_services() -> None:
"""Teardown all the services."""
try:
teardown_superuser(get_settings_service(), next(get_session()))
await asyncio.to_thread(_teardown_superuser)
except Exception as exc: # noqa: BLE001
logger.exception(exc)
try:
Expand Down

0 comments on commit 5a668d4

Please sign in to comment.