diff --git a/src/backend/base/langflow/__main__.py b/src/backend/base/langflow/__main__.py index 95368d5b36ac..99fa7ceabea8 100644 --- a/src/backend/base/langflow/__main__.py +++ b/src/backend/base/langflow/__main__.py @@ -26,7 +26,7 @@ from langflow.logging.logger import configure, logger from langflow.main import setup_app from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist -from langflow.services.database.utils import session_getter +from langflow.services.database.utils import async_session_getter from langflow.services.deps import async_session_scope, get_db_service, get_settings_service from langflow.services.settings.constants import DEFAULT_SUPERUSER from langflow.services.utils import initialize_services @@ -419,30 +419,35 @@ def superuser( ) -> None: """Create a superuser.""" configure(log_level=log_level) - initialize_services() db_service = get_db_service() - with session_getter(db_service) as session: - from langflow.services.auth.utils import create_super_user - if create_super_user(db=session, username=username, password=password): - # Verify that the superuser was created - from langflow.services.database.models.user.model import User + async def _create_superuser(): + await initialize_services() + async with async_session_getter(db_service) as session: + from langflow.services.auth.utils import create_super_user + + if await create_super_user(db=session, username=username, password=password): + # Verify that the superuser was created + from langflow.services.database.models.user.model import User + + stmt = select(User).where(User.username == username) + user: User = (await session.exec(stmt)).first() + if user is None or not user.is_superuser: + typer.echo("Superuser creation failed.") + return + # Now create the first folder for the user + result = await create_default_folder_if_it_doesnt_exist(session, user.id) + if result: + typer.echo("Default folder created successfully.") + else: + msg = "Could not create default folder." + raise RuntimeError(msg) + typer.echo("Superuser created successfully.") - user: User = session.exec(select(User).where(User.username == username)).first() - if user is None or not user.is_superuser: - typer.echo("Superuser creation failed.") - return - # Now create the first folder for the user - result = create_default_folder_if_it_doesnt_exist(session, user.id) - if result: - typer.echo("Default folder created successfully.") else: - msg = "Could not create default folder." - raise RuntimeError(msg) - typer.echo("Superuser created successfully.") + typer.echo("Superuser creation failed.") - else: - typer.echo("Superuser creation failed.") + asyncio.run(_create_superuser()) # command to copy the langflow database from the cache to the current directory @@ -494,7 +499,7 @@ def migration( ): raise typer.Abort - initialize_services(fix_migration=fix) + asyncio.run(initialize_services(fix_migration=fix)) db_service = get_db_service() if not test: db_service.run_migrations() @@ -515,18 +520,20 @@ def api_key( None """ configure(log_level=log_level) - initialize_services() - settings_service = get_settings_service() - auth_settings = settings_service.auth_settings - if not auth_settings.AUTO_LOGIN: - typer.echo("Auto login is disabled. API keys cannot be created through the CLI.") - return async def aapi_key(): + await initialize_services() + settings_service = get_settings_service() + auth_settings = settings_service.auth_settings + if not auth_settings.AUTO_LOGIN: + typer.echo("Auto login is disabled. API keys cannot be created through the CLI.") + return None + async with async_session_scope() as session: from langflow.services.database.models.user.model import User - superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first() + stmt = select(User).where(User.username == DEFAULT_SUPERUSER) + superuser = (await session.exec(stmt)).first() if not superuser: typer.echo( "Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled." @@ -535,7 +542,8 @@ async def aapi_key(): from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate from langflow.services.database.models.api_key.crud import create_api_key, delete_api_key - api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first() + stmt = select(ApiKey).where(ApiKey.user_id == superuser.id) + api_key = (await session.exec(stmt)).first() if api_key: await delete_api_key(session, api_key.id) diff --git a/src/backend/base/langflow/api/v1/api_key.py b/src/backend/base/langflow/api/v1/api_key.py index 09b542fd09d7..5fa0d117ea74 100644 --- a/src/backend/base/langflow/api/v1/api_key.py +++ b/src/backend/base/langflow/api/v1/api_key.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, Response -from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession +from langflow.api.utils import AsyncDbSession, CurrentActiveUser from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse from langflow.services.auth import utils as auth_utils @@ -62,7 +62,7 @@ async def save_store_api_key( api_key_request: ApiKeyCreateRequest, response: Response, current_user: CurrentActiveUser, - db: DbSession, + db: AsyncDbSession, ): settings_service = get_settings_service() auth_settings = settings_service.auth_settings @@ -74,7 +74,7 @@ async def save_store_api_key( encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service) current_user.store_api_key = encrypted db.add(current_user) - db.commit() + await db.commit() response.set_cookie( "apikey_tkn_lflw", diff --git a/src/backend/base/langflow/api/v1/login.py b/src/backend/base/langflow/api/v1/login.py index 05d583e75067..e6c93dfe82b8 100644 --- a/src/backend/base/langflow/api/v1/login.py +++ b/src/backend/base/langflow/api/v1/login.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security import OAuth2PasswordRequestForm -from langflow.api.utils import DbSession +from langflow.api.utils import AsyncDbSession from langflow.api.v1.schemas import Token from langflow.services.auth.utils import ( authenticate_user, @@ -21,14 +21,14 @@ @router.post("/login", response_model=Token) -def login_to_get_access_token( +async def login_to_get_access_token( response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: DbSession, + db: AsyncDbSession, ): auth_settings = get_settings_service().auth_settings try: - user = authenticate_user(form_data.username, form_data.password, db) + user = await authenticate_user(form_data.username, form_data.password, db) except Exception as exc: if isinstance(exc, HTTPException): raise @@ -38,7 +38,7 @@ def login_to_get_access_token( ) from exc if user: - tokens = create_user_tokens(user_id=user.id, db=db, update_last_login=True) + tokens = await create_user_tokens(user_id=user.id, db=db, update_last_login=True) response.set_cookie( "refresh_token_lf", tokens["refresh_token"], @@ -66,9 +66,9 @@ def login_to_get_access_token( expires=None, # Set to None to make it a session cookie domain=auth_settings.COOKIE_DOMAIN, ) - get_variable_service().initialize_user_variables(user.id, db) + await get_variable_service().initialize_user_variables(user.id, db) # Create default folder for user if it doesn't exist - create_default_folder_if_it_doesnt_exist(db, user.id) + await create_default_folder_if_it_doesnt_exist(db, user.id) return tokens raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -78,11 +78,11 @@ def login_to_get_access_token( @router.get("/auto_login") -async def auto_login(response: Response, db: DbSession): +async def auto_login(response: Response, db: AsyncDbSession): auth_settings = get_settings_service().auth_settings if auth_settings.AUTO_LOGIN: - user_id, tokens = create_user_longterm_token(db) + user_id, tokens = await create_user_longterm_token(db) response.set_cookie( "access_token_lf", tokens["access_token"], @@ -93,7 +93,7 @@ async def auto_login(response: Response, db: DbSession): domain=auth_settings.COOKIE_DOMAIN, ) - user = get_user_by_id(db, user_id) + user = await get_user_by_id(db, user_id) if user: if user.store_api_key is None: @@ -124,14 +124,14 @@ async def auto_login(response: Response, db: DbSession): async def refresh_token( request: Request, response: Response, - db: DbSession, + db: AsyncDbSession, ): auth_settings = get_settings_service().auth_settings token = request.cookies.get("refresh_token_lf") if token: - tokens = create_refresh_token(token, db) + tokens = await create_refresh_token(token, db) response.set_cookie( "refresh_token_lf", tokens["refresh_token"], diff --git a/src/backend/base/langflow/api/v1/users.py b/src/backend/base/langflow/api/v1/users.py index a2fcbc42c1b0..8c43985b321b 100644 --- a/src/backend/base/langflow/api/v1/users.py +++ b/src/backend/base/langflow/api/v1/users.py @@ -7,7 +7,7 @@ from sqlmodel import select from sqlmodel.sql.expression import SelectOfScalar -from langflow.api.utils import CurrentActiveUser, DbSession +from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession from langflow.api.v1.schemas import UsersResponse from langflow.services.auth.utils import ( get_current_active_superuser, @@ -25,7 +25,7 @@ @router.post("/", response_model=UserRead, status_code=201) async def add_user( user: UserCreate, - session: DbSession, + session: AsyncDbSession, ) -> User: """Add a new user to the database.""" new_user = User.model_validate(user, from_attributes=True) @@ -33,13 +33,13 @@ async def add_user( new_user.password = get_password_hash(user.password) new_user.is_active = get_settings_service().auth_settings.NEW_USER_IS_ACTIVE session.add(new_user) - session.commit() - session.refresh(new_user) - folder = create_default_folder_if_it_doesnt_exist(session, new_user.id) + await session.commit() + await session.refresh(new_user) + folder = await create_default_folder_if_it_doesnt_exist(session, new_user.id) if not folder: raise HTTPException(status_code=500, detail="Error creating default folder") except IntegrityError as e: - session.rollback() + await session.rollback() raise HTTPException(status_code=400, detail="This username is unavailable.") from e return new_user @@ -58,14 +58,14 @@ async def read_all_users( *, skip: int = 0, limit: int = 10, - session: DbSession, + session: AsyncDbSession, ) -> UsersResponse: """Retrieve a list of users from the database with pagination.""" query: SelectOfScalar = select(User).offset(skip).limit(limit) - users = session.exec(query).fetchall() + users = (await session.exec(query)).fetchall() count_query = select(func.count()).select_from(User) - total_count = session.exec(count_query).first() + total_count = (await session.exec(count_query)).first() return UsersResponse( total_count=total_count, @@ -78,7 +78,7 @@ async def patch_user( user_id: UUID, user_update: UserUpdate, user: CurrentActiveUser, - session: DbSession, + session: AsyncDbSession, ) -> User: """Update an existing user's data.""" update_password = bool(user_update.password) @@ -93,10 +93,10 @@ async def patch_user( raise HTTPException(status_code=400, detail="You can't change your password here") user_update.password = get_password_hash(user_update.password) - if user_db := get_user_by_id(session, user_id): + if user_db := await get_user_by_id(session, user_id): if not update_password: user_update.password = user_db.password - return update_user(user_db, user_update, session) + return await update_user(user_db, user_update, session) raise HTTPException(status_code=404, detail="User not found") @@ -105,7 +105,7 @@ async def reset_password( user_id: UUID, user_update: UserUpdate, user: CurrentActiveUser, - session: DbSession, + session: AsyncDbSession, ) -> User: """Reset a user's password.""" if user_id != user.id: @@ -117,8 +117,8 @@ async def reset_password( raise HTTPException(status_code=400, detail="You can't use your current password") new_password = get_password_hash(user_update.password) user.password = new_password - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) return user diff --git a/src/backend/base/langflow/api/v1/variable.py b/src/backend/base/langflow/api/v1/variable.py index 5b3e3e6e839d..5030c6bbf3c5 100644 --- a/src/backend/base/langflow/api/v1/variable.py +++ b/src/backend/base/langflow/api/v1/variable.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException from sqlalchemy.exc import NoResultFound -from langflow.api.utils import CurrentActiveUser, DbSession +from langflow.api.utils import AsyncDbSession, CurrentActiveUser from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate from langflow.services.deps import get_variable_service from langflow.services.variable.constants import GENERIC_TYPE @@ -15,7 +15,7 @@ @router.post("/", response_model=VariableRead, status_code=201) async def create_variable( *, - session: DbSession, + session: AsyncDbSession, variable: VariableCreate, current_user: CurrentActiveUser, ): @@ -30,10 +30,10 @@ async def create_variable( if not variable.value: raise HTTPException(status_code=400, detail="Variable value cannot be empty") - if variable.name in variable_service.list_variables(user_id=current_user.id, session=session): + if variable.name in await variable_service.list_variables(user_id=current_user.id, session=session): raise HTTPException(status_code=400, detail="Variable name already exists") try: - return variable_service.create_variable( + return await variable_service.create_variable( user_id=current_user.id, name=variable.name, value=variable.value, @@ -50,7 +50,7 @@ async def create_variable( @router.get("/", response_model=list[VariableRead], status_code=200) async def read_variables( *, - session: DbSession, + session: AsyncDbSession, current_user: CurrentActiveUser, ): """Read all variables.""" @@ -59,7 +59,7 @@ async def read_variables( msg = "Variable service is not an instance of DatabaseVariableService" raise TypeError(msg) try: - return variable_service.get_all(user_id=current_user.id, session=session) + return await variable_service.get_all(user_id=current_user.id, session=session) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e @@ -67,7 +67,7 @@ async def read_variables( @router.patch("/{variable_id}", response_model=VariableRead, status_code=200) async def update_variable( *, - session: DbSession, + session: AsyncDbSession, variable_id: UUID, variable: VariableUpdate, current_user: CurrentActiveUser, @@ -78,7 +78,7 @@ async def update_variable( msg = "Variable service is not an instance of DatabaseVariableService" raise TypeError(msg) try: - return variable_service.update_variable_fields( + return await variable_service.update_variable_fields( user_id=current_user.id, variable_id=variable_id, variable=variable, @@ -94,13 +94,13 @@ async def update_variable( @router.delete("/{variable_id}", status_code=204) async def delete_variable( *, - session: DbSession, + session: AsyncDbSession, variable_id: UUID, current_user: CurrentActiveUser, ) -> None: """Delete a variable.""" variable_service = get_variable_service() try: - variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session) + await variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/backend/base/langflow/base/agents/events.py b/src/backend/base/langflow/base/agents/events.py index 0c409ab6543d..44f9a9d25f2d 100644 --- a/src/backend/base/langflow/base/agents/events.py +++ b/src/backend/base/langflow/base/agents/events.py @@ -1,4 +1,5 @@ # Add helper functions for each event type +import asyncio from collections.abc import AsyncIterator from time import perf_counter from typing import Any, Protocol @@ -249,7 +250,7 @@ async def process_agent_events( agent_message.properties.icon = "Bot" agent_message.properties.state = "partial" # Store the initial message - agent_message = send_message_method(message=agent_message) + agent_message = await asyncio.to_thread(send_message_method, message=agent_message) try: # Create a mapping of run_ids to tool contents tool_blocks_map: dict[str, ToolContent] = {} diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 88ec6cdc600b..5a939ebcb333 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -448,7 +448,7 @@ def list_key_names(self): variable_service = get_variable_service() with session_scope() as session: - return variable_service.list_variables(user_id=self.user_id, session=session) + return variable_service.list_variables_sync(user_id=self.user_id, session=session) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable. diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 594b9046cbf8..8c88cc672aca 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -28,6 +28,7 @@ ) from langflow.services.database.models.user.crud import get_user_by_username from langflow.services.deps import ( + async_session_scope, get_settings_service, get_storage_service, get_variable_service, @@ -519,7 +520,7 @@ def _is_valid_uuid(val): return str(uuid_obj) == val -def load_flows_from_directory() -> None: +async def load_flows_from_directory() -> None: """On langflow startup, this loads all flows from the directory specified in the settings. All flows are uploaded into the default folder for the superuser. @@ -533,8 +534,8 @@ def load_flows_from_directory() -> None: logger.warning("AUTO_LOGIN is disabled, not loading flows from directory") return - with session_scope() as session: - user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER) + async with async_session_scope() as session: + user = await get_user_by_username(session, settings_service.auth_settings.SUPERUSER) if user is None: msg = "Superuser not found in the database" raise NoResultFound(msg) @@ -553,7 +554,7 @@ def load_flows_from_directory() -> None: flow["id"] = no_json_name flow_id = flow.get("id") - existing = find_existing_flow(session, flow_id, flow_endpoint_name) + existing = await find_existing_flow(session, flow_id, flow_endpoint_name) if existing: logger.debug(f"Found existing flow: {existing.name}") logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}") @@ -585,15 +586,15 @@ def load_flows_from_directory() -> None: session.add(flow) -def find_existing_flow(session, flow_id, flow_endpoint_name): +async def find_existing_flow(session, flow_id, flow_endpoint_name): if flow_endpoint_name: logger.debug(f"flow_endpoint_name: {flow_endpoint_name}") stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name) - if existing := session.exec(stmt).first(): + if existing := (await session.exec(stmt)).first(): logger.debug(f"Found existing flow by endpoint name: {existing.name}") return existing stmt = select(Flow).where(Flow.id == flow_id) - if existing := session.exec(stmt).first(): + if existing := (await session.exec(stmt)).first(): logger.debug(f"Found existing flow by id: {flow_id}") return existing return None @@ -645,7 +646,7 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None: ) -def initialize_super_user_if_needed() -> None: +async def initialize_super_user_if_needed() -> None: settings_service = get_settings_service() if not settings_service.auth_settings.AUTO_LOGIN: return @@ -655,8 +656,8 @@ def initialize_super_user_if_needed() -> None: msg = "SUPERUSER and SUPERUSER_PASSWORD must be set in the settings if AUTO_LOGIN is true." raise ValueError(msg) - with session_scope() as session: - super_user = create_super_user(db=session, username=username, password=password) - get_variable_service().initialize_user_variables(super_user.id, session) - create_default_folder_if_it_doesnt_exist(session, super_user.id) - logger.info("Super user initialized") + async with async_session_scope() as async_session: + super_user = await create_super_user(db=async_session, username=username, password=password) + await get_variable_service().initialize_user_variables(super_user.id, async_session) + await create_default_folder_if_it_doesnt_exist(async_session, super_user.id) + logger.info("Super user initialized") diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 6edd353f4e27..20fe1a1d4979 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -89,11 +89,6 @@ async def dispatch(self, request: Request, call_next): def get_lifespan(*, fix_migration=False, version=None): telemetry_service = get_telemetry_service() - def _initialize(): - initialize_services(fix_migration=fix_migration) - setup_llm_caching() - initialize_super_user_if_needed() - @asynccontextmanager async def lifespan(_app: FastAPI): configure(async_file=True) @@ -104,12 +99,13 @@ async def lifespan(_app: FastAPI): else: rprint("[bold green]Starting Langflow...[/bold green]") try: - await asyncio.to_thread(_initialize) + await initialize_services(fix_migration=fix_migration) + await asyncio.to_thread(setup_llm_caching) + await initialize_super_user_if_needed() all_types_dict = await get_and_cache_all_types_dict(get_settings_service()) await asyncio.to_thread(create_or_update_starter_projects, all_types_dict) telemetry_service.start() - await asyncio.to_thread(load_flows_from_directory) - + await load_flows_from_directory() yield except Exception as exc: diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 71dec85ec74d..ce23a568090a 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -1,10 +1,9 @@ -import asyncio import base64 import random import warnings from collections.abc import Coroutine from datetime import datetime, timedelta, timezone -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from uuid import UUID from cryptography.fernet import Fernet @@ -12,16 +11,18 @@ from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer from jose import JWTError, jwt from loguru import logger -from sqlmodel import Session +from sqlmodel.ext.asyncio.session import AsyncSession from starlette.websockets import WebSocket from langflow.services.database.models.api_key.crud import check_key -from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at from langflow.services.database.models.user.model import User, UserRead -from langflow.services.deps import get_db_service, get_session, get_settings_service +from langflow.services.deps import get_async_session, get_db_service, get_settings_service from langflow.services.settings.service import SettingsService +if TYPE_CHECKING: + from langflow.services.database.models.api_key.model import ApiKey + oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) API_KEY_NAME = "x-api-key" @@ -33,14 +34,14 @@ # Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py -def api_key_security( +async def api_key_security( query_param: Annotated[str, Security(api_key_query)], header_param: Annotated[str, Security(api_key_header)], ) -> UserRead | None: settings_service = get_settings_service() - result: ApiKey | User | None = None + result: ApiKey | User | None - with get_db_service().with_session() as db: + async with get_db_service().with_async_session() as db: if settings_service.auth_settings.AUTO_LOGIN: # Get the first user if not settings_service.auth_settings.SUPERUSER: @@ -49,7 +50,7 @@ def api_key_security( detail="Missing first superuser credentials", ) - result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER) + result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER) elif not query_param and not header_param: raise HTTPException( @@ -58,18 +59,16 @@ def api_key_security( ) elif query_param: - result = check_key(db, query_param) + result = await check_key(db, query_param) else: - result = check_key(db, header_param) + result = await check_key(db, header_param) if not result: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid or missing API key", ) - if isinstance(result, ApiKey): - return UserRead.model_validate(result.user, from_attributes=True) if isinstance(result, User): return UserRead.model_validate(result, from_attributes=True) msg = "Invalid result type" @@ -80,11 +79,11 @@ async def get_current_user( token: Annotated[str, Security(oauth2_login)], query_param: Annotated[str, Security(api_key_query)], header_param: Annotated[str, Security(api_key_header)], - db: Annotated[Session, Depends(get_session)], + db: Annotated[AsyncSession, Depends(get_async_session)], ) -> User: if token: return await get_current_user_by_jwt(token, db) - user = await asyncio.to_thread(api_key_security, query_param, header_param) + user = await api_key_security(query_param, header_param) if user: return user @@ -95,8 +94,8 @@ async def get_current_user( async def get_current_user_by_jwt( - token: Annotated[str, Depends(oauth2_login)], - db: Annotated[Session, Depends(get_session)], + token: str, + db: AsyncSession, ) -> User: settings_service = get_settings_service() @@ -144,7 +143,7 @@ async def get_current_user_by_jwt( headers={"WWW-Authenticate": "Bearer"}, ) from e - user = get_user_by_id(db, user_id) + user = await get_user_by_id(db, user_id) if user is None or not user.is_active: logger.info("User not found or inactive.") raise HTTPException( @@ -157,7 +156,7 @@ async def get_current_user_by_jwt( async def get_current_user_for_websocket( websocket: WebSocket, - db: Annotated[Session, Depends(get_session)], + db: Annotated[AsyncSession, Depends(get_async_session)], query_param: Annotated[str, Security(api_key_query)], ) -> User | None: token = websocket.query_params.get("token") @@ -165,7 +164,7 @@ async def get_current_user_for_websocket( if token: return await get_current_user_by_jwt(token, db) if api_key: - return await asyncio.to_thread(api_key_security, api_key, query_param) + return await api_key_security(api_key, query_param) return None @@ -207,12 +206,12 @@ def create_token(data: dict, expires_delta: timedelta): ) -def create_super_user( +async def create_super_user( username: str, password: str, - db: Session, + db: AsyncSession, ) -> User: - super_user = get_user_by_username(db, username) + super_user = await get_user_by_username(db, username) if not super_user: super_user = User( @@ -224,17 +223,17 @@ def create_super_user( ) db.add(super_user) - db.commit() - db.refresh(super_user) + await db.commit() + await db.refresh(super_user) return super_user -def create_user_longterm_token(db: Session) -> tuple[UUID, dict]: +async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]: settings_service = get_settings_service() username = settings_service.auth_settings.SUPERUSER - super_user = get_user_by_username(db, username) + super_user = await get_user_by_username(db, username) if not super_user: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created") access_token_expires_longterm = timedelta(days=365) @@ -244,7 +243,7 @@ def create_user_longterm_token(db: Session) -> tuple[UUID, dict]: ) # Update: last_login_at - update_user_last_login_at(super_user.id, db) + await update_user_last_login_at(super_user.id, db) return super_user.id, { "access_token": access_token, @@ -270,7 +269,7 @@ def get_user_id_from_token(token: str) -> UUID: return UUID(int=0) -def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = False) -> dict: +async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict: settings_service = get_settings_service() access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS) @@ -287,7 +286,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = # Update: last_login_at if update_last_login: - update_user_last_login_at(user_id, db) + await update_user_last_login_at(user_id, db) return { "access_token": access_token, @@ -296,7 +295,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = } -def create_refresh_token(refresh_token: str, db: Session): +async def create_refresh_token(refresh_token: str, db: AsyncSession): settings_service = get_settings_service() try: @@ -314,12 +313,12 @@ def create_refresh_token(refresh_token: str, db: Session): if user_id is None or token_type == "": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") - user_exists = get_user_by_id(db, user_id) + user_exists = await get_user_by_id(db, user_id) if user_exists is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") - return create_user_tokens(user_id, db) + return await create_user_tokens(user_id, db) except JWTError as e: logger.exception("JWT decoding error") @@ -329,8 +328,8 @@ def create_refresh_token(refresh_token: str, db: Session): ) from e -def authenticate_user(username: str, password: str, db: Session) -> User | None: - user = get_user_by_username(db, username) +async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None: + user = await get_user_by_username(db, username) if not user: return None diff --git a/src/backend/base/langflow/services/database/models/api_key/crud.py b/src/backend/base/langflow/services/database/models/api_key/crud.py index faa210deaf66..7c306335748e 100644 --- a/src/backend/base/langflow/services/database/models/api_key/crud.py +++ b/src/backend/base/langflow/services/database/models/api_key/crud.py @@ -1,13 +1,17 @@ +import asyncio import datetime import secrets -import threading from typing import TYPE_CHECKING from uuid import UUID -from sqlmodel import Session, select +from sqlalchemy.orm import selectinload +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from langflow.services.database.models import User from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead +from langflow.services.database.utils import async_session_getter +from langflow.services.deps import get_db_service if TYPE_CHECKING: from sqlmodel.sql.expression import SelectOfScalar @@ -47,33 +51,29 @@ async def delete_api_key(session: AsyncSession, api_key_id: UUID) -> None: await session.commit() -def check_key(session: Session, api_key: str) -> ApiKey | None: +update_total_uses_tasks: set[asyncio.Task] = set() + + +async def check_key(session: AsyncSession, api_key: str) -> User | None: """Check if the API key is valid.""" - query: SelectOfScalar = select(ApiKey).where(ApiKey.api_key == api_key) - api_key_object: ApiKey | None = session.exec(query).first() + query: SelectOfScalar = select(ApiKey).options(selectinload(ApiKey.user)).where(ApiKey.api_key == api_key) + api_key_object: ApiKey | None = (await session.exec(query)).first() if api_key_object is not None: - threading.Thread( - target=update_total_uses, - args=( - session, - api_key_object, - ), - ).start() - return api_key_object + task = asyncio.create_task(update_total_uses(api_key_object.id)) + task.add_done_callback(update_total_uses_tasks.discard) + update_total_uses_tasks.add(task) + return api_key_object.user + return None -def update_total_uses(session, api_key: ApiKey): +async def update_total_uses(api_key_id: UUID): """Update the total uses and last used at.""" - # This is running in a separate thread to avoid slowing down the request - # but session is not thread safe so we need to create a new session - - with Session(session.get_bind()) as new_session: - new_api_key = new_session.get(ApiKey, api_key.id) + async with async_session_getter(get_db_service()) as session: + new_api_key = await session.get(ApiKey, api_key_id) if new_api_key is None: msg = "API Key not found" raise ValueError(msg) new_api_key.total_uses += 1 new_api_key.last_used_at = datetime.datetime.now(datetime.timezone.utc) - new_session.add(new_api_key) - new_session.commit() - return new_api_key + session.add(new_api_key) + await session.commit() diff --git a/src/backend/base/langflow/services/database/models/folder/utils.py b/src/backend/base/langflow/services/database/models/folder/utils.py index 342920123864..c7f8f4aeedd3 100644 --- a/src/backend/base/langflow/services/database/models/folder/utils.py +++ b/src/backend/base/langflow/services/database/models/folder/utils.py @@ -1,6 +1,7 @@ from uuid import UUID -from sqlmodel import Session, and_, select, update +from sqlmodel import and_, select, update +from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.database.models.flow.model import Flow @@ -8,8 +9,9 @@ from .model import Folder -def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID): - folder = session.exec(select(Folder).where(Folder.user_id == user_id)).first() +async def create_default_folder_if_it_doesnt_exist(session: AsyncSession, user_id: UUID): + stmt = select(Folder).where(Folder.user_id == user_id) + folder = (await session.exec(stmt)).first() if not folder: folder = Folder( name=DEFAULT_FOLDER_NAME, @@ -17,9 +19,9 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID): description=DEFAULT_FOLDER_DESCRIPTION, ) session.add(folder) - session.commit() - session.refresh(folder) - session.exec( + await session.commit() + await session.refresh(folder) + await session.exec( update(Flow) .where( and_( @@ -29,12 +31,14 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID): ) .values(folder_id=folder.id) ) - session.commit() + await session.commit() return folder -def get_default_folder_id(session: Session, user_id: UUID): - folder = session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id)).first() +async def get_default_folder_id(session: AsyncSession, user_id: UUID): + folder = ( + await session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id)) + ).first() if not folder: - folder = create_default_folder_if_it_doesnt_exist(session, user_id) + folder = await create_default_folder_if_it_doesnt_exist(session, user_id) return folder.id diff --git a/src/backend/base/langflow/services/database/models/user/crud.py b/src/backend/base/langflow/services/database/models/user/crud.py index f5b4f74f9d50..b78b706b0e77 100644 --- a/src/backend/base/langflow/services/database/models/user/crud.py +++ b/src/backend/base/langflow/services/database/models/user/crud.py @@ -5,20 +5,23 @@ from loguru import logger from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.attributes import flag_modified -from sqlmodel import Session, select +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.database.models.user.model import User, UserUpdate -def get_user_by_username(db: Session, username: str) -> User | None: - return db.exec(select(User).where(User.username == username)).first() +async def get_user_by_username(db: AsyncSession, username: str) -> User | None: + stmt = select(User).where(User.username == username) + return (await db.exec(stmt)).first() -def get_user_by_id(db: Session, user_id: UUID) -> User | None: - return db.exec(select(User).where(User.id == user_id)).first() +async def get_user_by_id(db: AsyncSession, user_id: UUID) -> User | None: + stmt = select(User).where(User.id == user_id) + return (await db.exec(stmt)).first() -def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User: +async def update_user(user_db: User | None, user: UserUpdate, db: AsyncSession) -> User: if not user_db: raise HTTPException(status_code=404, detail="User not found") @@ -40,18 +43,18 @@ def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User: flag_modified(user_db, "updated_at") try: - db.commit() + await db.commit() except IntegrityError as e: - db.rollback() + await db.rollback() raise HTTPException(status_code=400, detail=str(e)) from e return user_db -def update_user_last_login_at(user_id: UUID, db: Session): +async def update_user_last_login_at(user_id: UUID, db: AsyncSession): try: user_data = UserUpdate(last_login_at=datetime.now(timezone.utc)) - user = get_user_by_id(db, user_id) - return update_user(user, user_data, db) + user = await get_user_by_id(db, user_id) + return await update_user(user, user_data, db) except Exception: # noqa: BLE001 logger.opt(exception=True).debug("Error updating user last login at") diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 5107038977bc..91bc9f2c1392 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -142,28 +142,29 @@ def with_session(self): @asynccontextmanager async def with_async_session(self): - async with AsyncSession(self.async_engine) as session: + async with AsyncSession(self.async_engine, expire_on_commit=False) as session: yield session - def migrate_flows_if_auto_login(self) -> None: + async def migrate_flows_if_auto_login(self) -> None: # if auto_login is enabled, we need to migrate the flows # to the default superuser if they don't have a user id # associated with them settings_service = get_settings_service() if settings_service.auth_settings.AUTO_LOGIN: - with self.with_session() as session: - flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all() + async with self.with_async_session() as session: + stmt = select(models.Flow).where(models.Flow.user_id is None) + flows = (await session.exec(stmt)).all() if flows: logger.debug("Migrating flows to default superuser") username = settings_service.auth_settings.SUPERUSER - user = get_user_by_username(session, username) + user = await get_user_by_username(session, username) if not user: logger.error("Default superuser not found") msg = "Default superuser not found" raise RuntimeError(msg) for flow in flows: flow.user_id = user.id - session.commit() + await session.commit() logger.debug("Flows migrated successfully") def check_schema_health(self) -> bool: @@ -346,20 +347,15 @@ def create_db_and_tables(self) -> None: logger.debug("Database and tables created successfully") - def _teardown(self) -> None: + async def teardown(self) -> None: logger.debug("Tearing down database") try: settings_service = get_settings_service() # remove the default superuser if auto_login is enabled # using the SUPERUSER to get the user - with self.with_session() as session: - teardown_superuser(settings_service, session) - + async with self.with_async_session() as session: + await teardown_superuser(settings_service, session) except Exception: # noqa: BLE001 logger.exception("Error tearing down database") - - self.engine.dispose() - - async def teardown(self) -> None: - await asyncio.to_thread(self._teardown) await self.async_engine.dispose() + await asyncio.to_thread(self.engine.dispose) diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index 7337c5778d86..576ec2140101 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -1,12 +1,13 @@ from __future__ import annotations -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING from alembic.util.exc import CommandError from loguru import logger from sqlmodel import Session, text +from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from langflow.services.database.service import DatabaseService @@ -70,6 +71,19 @@ def session_getter(db_service: DatabaseService): session.close() +@asynccontextmanager +async def async_session_getter(db_service: DatabaseService): + try: + session = AsyncSession(db_service.async_engine) + yield session + except Exception: + logger.exception("Session rollback because of exception") + await session.rollback() + raise + finally: + await session.close() + + @dataclass class Result: name: str diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index a017e79349e4..d857c6e9efab 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -1,7 +1,8 @@ import asyncio from loguru import logger -from sqlmodel import Session, select +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.auth.utils import create_super_user, verify_password from langflow.services.cache.factory import CacheServiceFactory @@ -9,13 +10,14 @@ from langflow.services.schema import ServiceType from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD -from .deps import get_db_service, get_service, get_session, get_settings_service +from .deps import get_db_service, get_service, get_settings_service -def get_or_create_super_user(session: Session, username, password, is_default): +async def get_or_create_super_user(session: AsyncSession, username, password, is_default): from langflow.services.database.models.user.model import User - user = session.exec(select(User).where(User.username == username)).first() + stmt = select(User).where(User.username == username) + user = (await session.exec(stmt)).first() if user and user.is_superuser: return None # Superuser already exists @@ -51,7 +53,7 @@ def get_or_create_super_user(session: Session, username, password, is_default): else: logger.debug("Creating superuser.") try: - return create_super_user(username, password, db=session) + return await create_super_user(username, password, db=session) except Exception as exc: # noqa: BLE001 if "UNIQUE constraint failed: user.username" in str(exc): # This is to deal with workers running this @@ -62,12 +64,12 @@ def get_or_create_super_user(session: Session, username, password, is_default): logger.opt(exception=True).debug("Error creating superuser.") -def setup_superuser(settings_service, session: Session) -> None: +async def setup_superuser(settings_service, session: AsyncSession) -> None: if settings_service.auth_settings.AUTO_LOGIN: logger.debug("AUTO_LOGIN is set to True. Creating default superuser.") else: # Remove the default superuser if it exists - teardown_superuser(settings_service, session) + await teardown_superuser(settings_service, session) username = settings_service.auth_settings.SUPERUSER password = settings_service.auth_settings.SUPERUSER_PASSWORD @@ -75,7 +77,9 @@ def setup_superuser(settings_service, session: Session) -> None: is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD) try: - user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default) + user = await get_or_create_super_user( + session=session, username=username, password=password, is_default=is_default + ) if user is not None: logger.debug("Superuser created successfully.") except Exception as exc: @@ -86,7 +90,7 @@ def setup_superuser(settings_service, session: Session) -> None: settings_service.auth_settings.reset_credentials() -def teardown_superuser(settings_service, session) -> None: +async def teardown_superuser(settings_service, session: AsyncSession) -> None: """Teardown the superuser.""" # If AUTO_LOGIN is True, we will remove the default superuser # from the database. @@ -97,30 +101,27 @@ def teardown_superuser(settings_service, session) -> None: username = DEFAULT_SUPERUSER from langflow.services.database.models.user.model import User - user = session.exec(select(User).where(User.username == username)).first() + stmt = select(User).where(User.username == username) + user = (await session.exec(stmt)).first() # Check if super was ever logged in, if not delete it # if it has logged in, it means the user is using it to login if user and user.is_superuser is True and not user.last_login_at: - session.delete(user) - session.commit() + await session.delete(user) + await session.commit() logger.debug("Default superuser removed successfully.") except Exception as exc: logger.exception(exc) - session.rollback() + await session.rollback() msg = "Could not remove default superuser." raise RuntimeError(msg) from exc -def _teardown_superuser(): - with get_db_service().with_session() as session: - teardown_superuser(get_settings_service(), session) - - async def teardown_services() -> None: """Teardown all the services.""" try: - await asyncio.to_thread(_teardown_superuser) + async with get_db_service().with_async_session() as session: + await teardown_superuser(get_settings_service(), session) except Exception as exc: # noqa: BLE001 logger.exception(exc) try: @@ -156,15 +157,16 @@ def initialize_session_service() -> None: ) -def initialize_services(*, fix_migration: bool = False) -> None: +async def initialize_services(*, fix_migration: bool = False) -> None: """Initialize all the services needed.""" # Test cache connection get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory()) # Setup the superuser - initialize_database(fix_migration=fix_migration) - setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), next(get_session())) + await asyncio.to_thread(initialize_database, fix_migration=fix_migration) + async with get_db_service().with_async_session() as session: + await setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), session) try: - get_db_service().migrate_flows_if_auto_login() + await get_db_service().migrate_flows_if_auto_login() except Exception as exc: msg = "Error migrating flows" logger.exception(msg) diff --git a/src/backend/base/langflow/services/variable/base.py b/src/backend/base/langflow/services/variable/base.py index 7cbb91637ed9..482759890aa6 100644 --- a/src/backend/base/langflow/services/variable/base.py +++ b/src/backend/base/langflow/services/variable/base.py @@ -2,6 +2,7 @@ from uuid import UUID from sqlmodel import Session +from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.base import Service from langflow.services.database.models.variable.model import Variable @@ -13,7 +14,7 @@ class VariableService(Service): name = "variable_service" @abc.abstractmethod - def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None: + async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None: """Initialize user variables. Args: @@ -36,7 +37,7 @@ def get_variable(self, user_id: UUID | str, name: str, field: str, session: Sess """ @abc.abstractmethod - def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]: + def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: """List all variables. Args: @@ -48,7 +49,19 @@ def list_variables(self, user_id: UUID | str, session: Session) -> list[str | No """ @abc.abstractmethod - def update_variable(self, user_id: UUID | str, name: str, value: str, session: Session) -> Variable: + async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: + """List all variables. + + Args: + user_id: The user ID. + session: The database session. + + Returns: + A list of variable names. + """ + + @abc.abstractmethod + async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable: """Update a variable. Args: @@ -62,7 +75,7 @@ def update_variable(self, user_id: UUID | str, name: str, value: str, session: S """ @abc.abstractmethod - def delete_variable(self, user_id: UUID | str, name: str, session: Session) -> None: + async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None: """Delete a variable. Args: @@ -75,7 +88,7 @@ def delete_variable(self, user_id: UUID | str, name: str, session: Session) -> N """ @abc.abstractmethod - def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None: + async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None: """Delete a variable by ID. Args: @@ -85,7 +98,7 @@ def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: """ @abc.abstractmethod - def create_variable( + async def create_variable( self, user_id: UUID | str, name: str, @@ -93,7 +106,7 @@ def create_variable( *, default_fields: list[str], _type: str, - session: Session, + session: AsyncSession, ) -> Variable: """Create a variable. diff --git a/src/backend/base/langflow/services/variable/kubernetes.py b/src/backend/base/langflow/services/variable/kubernetes.py index b5206f2333fc..d37f596d38a3 100644 --- a/src/backend/base/langflow/services/variable/kubernetes.py +++ b/src/backend/base/langflow/services/variable/kubernetes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os from typing import TYPE_CHECKING @@ -17,6 +18,7 @@ from uuid import UUID from sqlmodel import Session + from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.settings.service import SettingsService @@ -28,7 +30,7 @@ def __init__(self, settings_service: SettingsService): self.kubernetes_secrets = KubernetesSecretManager() @override - def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None: + async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None: # Check for environment variables that should be stored in the database should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not" logger.info(f"{should_or_should_not} store environment variables in the kubernetes.") @@ -45,7 +47,8 @@ def initialize_user_variables(self, user_id: UUID | str, session: Session) -> No try: secret_name = encode_user_id(user_id) - self.kubernetes_secrets.create_secret( + await asyncio.to_thread( + self.kubernetes_secrets.create_secret, name=secret_name, data=variables, ) @@ -75,12 +78,13 @@ def resolve_variable( msg = f"user_id {user_id} variable name {name} not found." raise ValueError(msg) + @override def get_variable( self, user_id: UUID | str, name: str, field: str, - _session: Session, + session: Session, ) -> str: secret_name = encode_user_id(user_id) key, value = self.resolve_variable(secret_name, user_id, name) @@ -92,10 +96,11 @@ def get_variable( raise TypeError(msg) return value - def list_variables( + @override + def list_variables_sync( self, user_id: UUID | str, - _session: Session, + session: Session, ) -> list[str | None]: variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id)) if not variables: @@ -109,28 +114,49 @@ def list_variables( names.append(key) return names - def update_variable( + @override + async def list_variables( + self, + user_id: UUID | str, + session: AsyncSession, + ) -> list[str | None]: + return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session) + + def _update_variable( self, user_id: UUID | str, name: str, value: str, - _session: Session, ): secret_name = encode_user_id(user_id) secret_key, _ = self.resolve_variable(secret_name, user_id, name) return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value}) - def delete_variable(self, user_id: UUID | str, name: str, _session: Session) -> None: - secret_name = encode_user_id(user_id) + @override + async def update_variable( + self, + user_id: UUID | str, + name: str, + value: str, + session: AsyncSession, + ): + return await asyncio.to_thread(self._update_variable, user_id, name, value) + def _delete_variable(self, user_id: UUID | str, name: str) -> None: + secret_name = encode_user_id(user_id) secret_key, _ = self.resolve_variable(secret_name, user_id, name) self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key) - def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, _session: Session) -> None: - self.delete_variable(user_id, _session, str(variable_id)) + @override + async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None: + await asyncio.to_thread(self._delete_variable, user_id, name) + + @override + async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, session: AsyncSession) -> None: + await self.delete_variable(user_id, str(variable_id), session) @override - def create_variable( + async def create_variable( self, user_id: UUID | str, name: str, @@ -138,7 +164,7 @@ def create_variable( *, default_fields: list[str], _type: str, - session: Session, + session: AsyncSession, ) -> Variable: secret_name = encode_user_id(user_id) secret_key = name @@ -147,7 +173,9 @@ def create_variable( else: _type = GENERIC_TYPE - self.kubernetes_secrets.upsert_secret(secret_name=secret_name, data={secret_key: value}) + await asyncio.to_thread( + self.kubernetes_secrets.upsert_secret, secret_name=secret_name, data={secret_key: value} + ) variable_base = VariableCreate( name=name, diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index 11eb086218d5..f5cf85d0b3a6 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -17,6 +17,8 @@ from collections.abc import Sequence from uuid import UUID + from sqlmodel.ext.asyncio.session import AsyncSession + from langflow.services.settings.service import SettingsService @@ -24,7 +26,7 @@ class DatabaseVariableService(VariableService, Service): def __init__(self, settings_service: SettingsService): self.settings_service = settings_service - def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None: + async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None: if not self.settings_service.settings.store_environment_variables: logger.info("Skipping environment variable storage.") return @@ -34,12 +36,12 @@ def initialize_user_variables(self, user_id: UUID | str, session: Session) -> No if var_name in os.environ and os.environ[var_name].strip(): value = os.environ[var_name].strip() query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name) - existing = session.exec(query).first() + existing = (await session.exec(query)).first() try: if existing: - self.update_variable(user_id, var_name, value, session) + await self.update_variable(user_id, var_name, value, session) else: - self.create_variable( + await self.create_variable( user_id=user_id, name=var_name, value=value, @@ -76,40 +78,46 @@ def get_variable( # we decrypt the value return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service) - def get_all(self, user_id: UUID | str, session: Session) -> list[Variable | None]: - return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all()) + async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]: + stmt = select(Variable).where(Variable.user_id == user_id) + return list((await session.exec(stmt)).all()) + + def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: + variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all() + return [variable.name for variable in variables if variable] - def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]: - variables = self.get_all(user_id=user_id, session=session) + async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: + variables = await self.get_all(user_id=user_id, session=session) return [variable.name for variable in variables if variable] - def update_variable( + async def update_variable( self, user_id: UUID | str, name: str, value: str, - session: Session, + session: AsyncSession, ): - variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first() + stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name) + variable = (await session.exec(stmt)).first() if not variable: msg = f"{name} variable not found." raise ValueError(msg) encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service) variable.value = encrypted session.add(variable) - session.commit() - session.refresh(variable) + await session.commit() + await session.refresh(variable) return variable - def update_variable_fields( + async def update_variable_fields( self, user_id: UUID | str, variable_id: UUID | str, variable: VariableUpdate, - session: Session, + session: AsyncSession, ): query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id) - db_variable = session.exec(query).one() + db_variable = (await session.exec(query)).one() db_variable.updated_at = datetime.now(timezone.utc) variable.value = variable.value or "" @@ -121,33 +129,34 @@ def update_variable_fields( setattr(db_variable, key, value) session.add(db_variable) - session.commit() - session.refresh(db_variable) + await session.commit() + await session.refresh(db_variable) return db_variable - def delete_variable( + async def delete_variable( self, user_id: UUID | str, name: str, - session: Session, + session: AsyncSession, ) -> None: stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name) - variable = session.exec(stmt).first() + variable = (await session.exec(stmt)).first() if not variable: msg = f"{name} variable not found." raise ValueError(msg) - session.delete(variable) - session.commit() + await session.delete(variable) + await session.commit() - def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None: - variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first() + async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None: + stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id) + variable = (await session.exec(stmt)).first() if not variable: msg = f"{variable_id} variable not found." raise ValueError(msg) - session.delete(variable) - session.commit() + await session.delete(variable) + await session.commit() - def create_variable( + async def create_variable( self, user_id: UUID | str, name: str, @@ -155,7 +164,7 @@ def create_variable( *, default_fields: Sequence[str] = (), _type: str = GENERIC_TYPE, - session: Session, + session: AsyncSession, ): variable_base = VariableCreate( name=name, @@ -165,6 +174,6 @@ def create_variable( ) variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id}) session.add(variable) - session.commit() - session.refresh(variable) + await session.commit() + await session.refresh(variable) return variable diff --git a/src/backend/tests/blockbuster.py b/src/backend/tests/blockbuster.py index 3a11201eb3f4..c3b3a5f9d497 100644 --- a/src/backend/tests/blockbuster.py +++ b/src/backend/tests/blockbuster.py @@ -90,6 +90,8 @@ def file_op(self, *args, **kwargs): "_read_pyc", }: return func(self, *args, **kwargs) + if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize": + return func(self, *args, **kwargs) raise _blocking_error(func) return file_op @@ -104,6 +106,8 @@ def file_op(self, *args, **kwargs): for frame_info in inspect.stack(): if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_write_pyc": return func(self, *args, **kwargs) + if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize": + return func(self, *args, **kwargs) if self not in {sys.stdout, sys.stderr}: raise _blocking_error(func) return func(self, *args, **kwargs) diff --git a/src/backend/tests/performance/test_server_init.py b/src/backend/tests/performance/test_server_init.py index 630efd17b0e9..a08165ed4950 100644 --- a/src/backend/tests/performance/test_server_init.py +++ b/src/backend/tests/performance/test_server_init.py @@ -24,7 +24,7 @@ async def test_initialize_services(): """Benchmark the initialization of services.""" from langflow.services.utils import initialize_services - await asyncio.to_thread(initialize_services, fix_migration=False) + await initialize_services(fix_migration=False) settings_service = await asyncio.to_thread(get_settings_service) assert "test_performance.db" in settings_service.settings.database_url @@ -45,8 +45,8 @@ async def test_initialize_super_user(): from langflow.initial_setup.setup import initialize_super_user_if_needed from langflow.services.utils import initialize_services - await asyncio.to_thread(initialize_services, fix_migration=False) - await asyncio.to_thread(initialize_super_user_if_needed) + await initialize_services(fix_migration=False) + await initialize_super_user_if_needed() settings_service = await asyncio.to_thread(get_settings_service) assert "test_performance.db" in settings_service.settings.database_url @@ -69,7 +69,7 @@ async def test_create_starter_projects(): from langflow.interface.types import get_and_cache_all_types_dict from langflow.services.utils import initialize_services - await asyncio.to_thread(initialize_services, fix_migration=False) + await initialize_services(fix_migration=False) settings_service = await asyncio.to_thread(get_settings_service) types_dict = await get_and_cache_all_types_dict(settings_service) await asyncio.to_thread(create_or_update_starter_projects, types_dict) @@ -81,6 +81,6 @@ async def test_load_flows(): """Benchmark loading flows from directory.""" from langflow.initial_setup.setup import load_flows_from_directory - await asyncio.to_thread(load_flows_from_directory) + await load_flows_from_directory() settings_service = await asyncio.to_thread(get_settings_service) assert "test_performance.db" in settings_service.settings.database_url diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index 081c7147f29b..f66da3a04ce0 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -1,6 +1,6 @@ from datetime import datetime from unittest.mock import patch -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest from langflow.services.database.models.variable.model import VariableUpdate @@ -8,7 +8,9 @@ from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE from langflow.services.variable.service import DatabaseVariableService -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import Session, SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession @pytest.fixture @@ -18,114 +20,125 @@ def service(): @pytest.fixture -def session(): - engine = create_engine("sqlite:///:memory:") - SQLModel.metadata.create_all(engine) - with Session(engine) as session: +async def session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with AsyncSession(engine) as session: yield session -def test_initialize_user_variables__create_and_update(service, session): +def _get_variable( + session: Session, + service, + user_id: UUID | str, + name: str, + field: str, +): + return service.get_variable(user_id, name, field, session=session) + + +async def test_initialize_user_variables__create_and_update(service, session: AsyncSession): user_id = uuid4() field = "" good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)} bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"} env_vars = {**good_vars, **bad_vars} - service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session) + await service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session) env_vars["OPENAI_API_KEY"] = "updated_value" with patch.dict("os.environ", env_vars, clear=True): - service.initialize_user_variables(user_id=user_id, session=session) + await service.initialize_user_variables(user_id=user_id, session=session) - variables = service.list_variables(user_id, session=session) + variables = await service.list_variables(user_id, session=session) for name in variables: - value = service.get_variable(user_id, name, field, session=session) + value = await session.run_sync(_get_variable, service, user_id, name, field) assert value == env_vars[name] assert all(i in variables for i in good_vars) assert all(i not in variables for i in bad_vars) -def test_initialize_user_variables__not_found_variable(service, session): +async def test_initialize_user_variables__not_found_variable(service, session: AsyncSession): with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m: m.side_effect = Exception() - service.initialize_user_variables(uuid4(), session=session) + await service.initialize_user_variables(uuid4(), session=session) assert True -def test_initialize_user_variables__skipping_environment_variable_storage(service, session): +async def test_initialize_user_variables__skipping_environment_variable_storage(service, session: AsyncSession): service.settings_service.settings.store_environment_variables = False - service.initialize_user_variables(uuid4(), session=session) + await service.initialize_user_variables(uuid4(), session=session) assert True -def test_get_variable(service, session): +async def test_get_variable(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" field = "" - service.create_variable(user_id, name, value, session=session) + await service.create_variable(user_id, name, value, session=session) - result = service.get_variable(user_id, name, field, session=session) + result = await session.run_sync(_get_variable, service, user_id, name, field) assert result == value -def test_get_variable__valueerror(service, session): +async def test_get_variable__valueerror(service, session: AsyncSession): user_id = uuid4() name = "name" field = "" with pytest.raises(ValueError, match=f"{name} variable not found."): - service.get_variable(user_id, name, field, session) + await session.run_sync(_get_variable, service, user_id, name, field) -def test_get_variable__typeerror(service, session): +async def test_get_variable__typeerror(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" field = "session_id" _type = CREDENTIAL_TYPE - service.create_variable(user_id, name, value, _type=_type, session=session) + await service.create_variable(user_id, name, value, _type=_type, session=session) with pytest.raises(TypeError) as exc: - service.get_variable(user_id, name, field, session) + await session.run_sync(_get_variable, service, user_id, name, field) assert name in str(exc.value) assert "purpose is to prevent the exposure of value" in str(exc.value) -def test_list_variables(service, session): +async def test_list_variables(service, session: AsyncSession): user_id = uuid4() names = ["name1", "name2", "name3"] value = "value" for name in names: - service.create_variable(user_id, name, value, session=session) + await service.create_variable(user_id, name, value, session=session) - result = service.list_variables(user_id, session=session) + result = await service.list_variables(user_id, session=session) assert all(name in result for name in names) -def test_list_variables__empty(service, session): - result = service.list_variables(uuid4(), session=session) +async def test_list_variables__empty(service, session: AsyncSession): + result = await service.list_variables(uuid4(), session=session) assert not result assert isinstance(result, list) -def test_update_variable(service, session): +async def test_update_variable(service, session: AsyncSession): user_id = uuid4() name = "name" old_value = "old_value" new_value = "new_value" field = "" - service.create_variable(user_id, name, old_value, session=session) + await service.create_variable(user_id, name, old_value, session=session) - old_recovered = service.get_variable(user_id, name, field, session=session) - result = service.update_variable(user_id, name, new_value, session=session) - new_recovered = service.get_variable(user_id, name, field, session=session) + old_recovered = await session.run_sync(_get_variable, service, user_id, name, field) + result = await service.update_variable(user_id, name, new_value, session=session) + new_recovered = await session.run_sync(_get_variable, service, user_id, name, field) assert old_value == old_recovered assert new_value == new_recovered @@ -139,26 +152,26 @@ def test_update_variable(service, session): assert isinstance(result.updated_at, datetime) -def test_update_variable__valueerror(service, session): +async def test_update_variable__valueerror(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" with pytest.raises(ValueError, match=f"{name} variable not found."): - service.update_variable(user_id, name, value, session=session) + await service.update_variable(user_id, name, value, session=session) -def test_update_variable_fields(service, session): +async def test_update_variable_fields(service, session: AsyncSession): user_id = uuid4() new_name = new_value = "donkey" - variable = service.create_variable(user_id, "old_name", "old_value", session=session) + variable = await service.create_variable(user_id, "old_name", "old_value", session=session) saved = variable.model_dump() variable = VariableUpdate(**saved) variable.name = new_name variable.value = new_value variable.default_fields = ["new_field"] - result = service.update_variable_fields( + result = await service.update_variable_fields( user_id=user_id, variable_id=saved.get("id"), variable=variable, @@ -177,58 +190,58 @@ def test_update_variable_fields(service, session): assert saved.get("updated_at") != result.updated_at -def test_delete_variable(service, session): +async def test_delete_variable(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" field = "" - service.create_variable(user_id, name, value, session=session) - recovered = service.get_variable(user_id, name, field, session=session) - service.delete_variable(user_id, name, session=session) + await service.create_variable(user_id, name, value, session=session) + recovered = await session.run_sync(_get_variable, service, user_id, name, field) + await service.delete_variable(user_id, name, session=session) with pytest.raises(ValueError, match=f"{name} variable not found."): - service.get_variable(user_id, name, field, session) + await session.run_sync(_get_variable, service, user_id, name, field) assert recovered == value -def test_delete_variable__valueerror(service, session): +async def test_delete_variable__valueerror(service, session: AsyncSession): user_id = uuid4() name = "name" with pytest.raises(ValueError, match=f"{name} variable not found."): - service.delete_variable(user_id, name, session=session) + await service.delete_variable(user_id, name, session=session) -def test_delete_variable_by_id(service, session): +async def test_delete_variable_by_id(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" field = "field" - saved = service.create_variable(user_id, name, value, session=session) - recovered = service.get_variable(user_id, name, field, session=session) - service.delete_variable_by_id(user_id, saved.id, session=session) + saved = await service.create_variable(user_id, name, value, session=session) + recovered = await session.run_sync(_get_variable, service, user_id, name, field) + await service.delete_variable_by_id(user_id, saved.id, session=session) with pytest.raises(ValueError, match=f"{name} variable not found."): - service.get_variable(user_id, name, field, session) + await session.run_sync(_get_variable, service, user_id, name, field) assert recovered == value -def test_delete_variable_by_id__valueerror(service, session): +async def test_delete_variable_by_id__valueerror(service, session: AsyncSession): user_id = uuid4() variable_id = uuid4() with pytest.raises(ValueError, match=f"{variable_id} variable not found."): - service.delete_variable_by_id(user_id, variable_id, session=session) + await service.delete_variable_by_id(user_id, variable_id, session=session) -def test_create_variable(service, session): +async def test_create_variable(service, session: AsyncSession): user_id = uuid4() name = "name" value = "value" - result = service.create_variable(user_id, name, value, session=session) + result = await service.create_variable(user_id, name, value, session=session) assert result.user_id == user_id assert result.name == name diff --git a/src/backend/tests/unit/test_setup_superuser.py b/src/backend/tests/unit/test_setup_superuser.py index b8fb1cbd1309..b56b815ad090 100644 --- a/src/backend/tests/unit/test_setup_superuser.py +++ b/src/backend/tests/unit/test_setup_superuser.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock, patch +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch from langflow.services.settings.constants import ( DEFAULT_SUPERUSER, @@ -91,7 +92,7 @@ @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") -def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service): +async def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service): mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = True mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER @@ -104,29 +105,28 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting mock_session.query.return_value.filter.return_value.first.return_value = mock_user mock_get_session.return_value = iter([mock_session]) - teardown_superuser(mock_settings_service, mock_session) + await teardown_superuser(mock_settings_service, mock_session) mock_session.query.assert_not_called() -@patch("langflow.services.deps.get_settings_service") -@patch("langflow.services.deps.get_session") -def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service): +async def test_teardown_superuser_no_default_superuser(): admin_user_name = "admin_user" mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = False mock_settings_service.auth_settings.SUPERUSER = admin_user_name mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105 - mock_get_settings_service.return_value = mock_settings_service - mock_session = MagicMock() + mock_session = AsyncMock(return_value=asyncio.Future()) mock_user = MagicMock() mock_user.is_superuser = False - mock_session.query.return_value.filter.return_value.first.return_value = mock_user - mock_get_session.return_value = [mock_session] + mock_user.last_login_at = None - teardown_superuser(mock_settings_service, mock_session) + mock_result = MagicMock() + mock_result.first.return_value = mock_user + mock_session.exec.return_value = mock_result - mock_session.query.assert_not_called() - mock_session.delete.assert_not_called() - mock_session.commit.assert_not_called() + await teardown_superuser(mock_settings_service, mock_session) + + mock_session.delete.assert_not_awaited() + mock_session.commit.assert_not_awaited() diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index 9184a6567afb..6caec318059d 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -5,18 +5,18 @@ from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user import UserUpdate from langflow.services.database.models.user.model import User -from langflow.services.database.utils import session_getter +from langflow.services.database.utils import async_session_getter, session_getter from langflow.services.deps import get_db_service, get_settings_service from sqlmodel import select @pytest.fixture -def super_user(client): # noqa: ARG001 +async def super_user(client): # noqa: ARG001 settings_manager = get_settings_service() auth_settings = settings_manager.auth_settings - with session_getter(get_db_service()) as session: - return create_super_user( - db=session, + async with async_session_getter(get_db_service()) as db: + return await create_super_user( + db=db, username=auth_settings.SUPERUSER, password=auth_settings.SUPERUSER_PASSWORD, )