Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use AsyncSession for user management #4491

Merged
merged 7 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from langflow.logging.logger import configure, logger
from langflow.main import setup_app
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
from langflow.services.settings.constants import DEFAULT_SUPERUSER
from langflow.services.utils import initialize_services
Expand Down Expand Up @@ -419,30 +419,35 @@ def superuser(
) -> None:
"""Create a superuser."""
configure(log_level=log_level)
initialize_services()
db_service = get_db_service()
with session_getter(db_service) as session:
from langflow.services.auth.utils import create_super_user

if create_super_user(db=session, username=username, password=password):
# Verify that the superuser was created
from langflow.services.database.models.user.model import User
async def _create_superuser():
await initialize_services()
async with async_session_getter(db_service) as session:
from langflow.services.auth.utils import create_super_user

if await create_super_user(db=session, username=username, password=password):
# Verify that the superuser was created
from langflow.services.database.models.user.model import User

stmt = select(User).where(User.username == username)
user: User = (await session.exec(stmt)).first()
if user is None or not user.is_superuser:
typer.echo("Superuser creation failed.")
return
# Now create the first folder for the user
result = await create_default_folder_if_it_doesnt_exist(session, user.id)
if result:
typer.echo("Default folder created successfully.")
else:
msg = "Could not create default folder."
raise RuntimeError(msg)
typer.echo("Superuser created successfully.")

user: User = session.exec(select(User).where(User.username == username)).first()
if user is None or not user.is_superuser:
typer.echo("Superuser creation failed.")
return
# Now create the first folder for the user
result = create_default_folder_if_it_doesnt_exist(session, user.id)
if result:
typer.echo("Default folder created successfully.")
else:
msg = "Could not create default folder."
raise RuntimeError(msg)
typer.echo("Superuser created successfully.")
typer.echo("Superuser creation failed.")

else:
typer.echo("Superuser creation failed.")
asyncio.run(_create_superuser())


# command to copy the langflow database from the cache to the current directory
Expand Down Expand Up @@ -494,7 +499,7 @@ def migration(
):
raise typer.Abort

initialize_services(fix_migration=fix)
asyncio.run(initialize_services(fix_migration=fix))
db_service = get_db_service()
if not test:
db_service.run_migrations()
Expand All @@ -515,18 +520,20 @@ def api_key(
None
"""
configure(log_level=log_level)
initialize_services()
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return

async def aapi_key():
await initialize_services()
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return None

async with async_session_scope() as session:
from langflow.services.database.models.user.model import User

superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first()
stmt = select(User).where(User.username == DEFAULT_SUPERUSER)
superuser = (await session.exec(stmt)).first()
if not superuser:
typer.echo(
"Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled."
Expand All @@ -535,7 +542,8 @@ async def aapi_key():
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import create_api_key, delete_api_key

api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first()
stmt = select(ApiKey).where(ApiKey.user_id == superuser.id)
api_key = (await session.exec(stmt)).first()
if api_key:
await delete_api_key(session, api_key.id)

Expand Down
6 changes: 3 additions & 3 deletions src/backend/base/langflow/api/v1/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import APIRouter, Depends, HTTPException, Response

from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils

Expand Down Expand Up @@ -62,7 +62,7 @@ async def save_store_api_key(
api_key_request: ApiKeyCreateRequest,
response: Response,
current_user: CurrentActiveUser,
db: DbSession,
db: AsyncDbSession,
):
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
Expand All @@ -74,7 +74,7 @@ async def save_store_api_key(
encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service)
current_user.store_api_key = encrypted
db.add(current_user)
db.commit()
await db.commit()

response.set_cookie(
"apikey_tkn_lflw",
Expand Down
24 changes: 12 additions & 12 deletions src/backend/base/langflow/api/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm

from langflow.api.utils import DbSession
from langflow.api.utils import AsyncDbSession
from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
Expand All @@ -21,14 +21,14 @@


@router.post("/login", response_model=Token)
def login_to_get_access_token(
async def login_to_get_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: DbSession,
db: AsyncDbSession,
):
auth_settings = get_settings_service().auth_settings
try:
user = authenticate_user(form_data.username, form_data.password, db)
user = await authenticate_user(form_data.username, form_data.password, db)
except Exception as exc:
if isinstance(exc, HTTPException):
raise
Expand All @@ -38,7 +38,7 @@ def login_to_get_access_token(
) from exc

if user:
tokens = create_user_tokens(user_id=user.id, db=db, update_last_login=True)
tokens = await create_user_tokens(user_id=user.id, db=db, update_last_login=True)
response.set_cookie(
"refresh_token_lf",
tokens["refresh_token"],
Expand Down Expand Up @@ -66,9 +66,9 @@ def login_to_get_access_token(
expires=None, # Set to None to make it a session cookie
domain=auth_settings.COOKIE_DOMAIN,
)
get_variable_service().initialize_user_variables(user.id, db)
await get_variable_service().initialize_user_variables(user.id, db)
# Create default folder for user if it doesn't exist
create_default_folder_if_it_doesnt_exist(db, user.id)
await create_default_folder_if_it_doesnt_exist(db, user.id)
return tokens
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -78,11 +78,11 @@ def login_to_get_access_token(


@router.get("/auto_login")
async def auto_login(response: Response, db: DbSession):
async def auto_login(response: Response, db: AsyncDbSession):
auth_settings = get_settings_service().auth_settings

if auth_settings.AUTO_LOGIN:
user_id, tokens = create_user_longterm_token(db)
user_id, tokens = await create_user_longterm_token(db)
response.set_cookie(
"access_token_lf",
tokens["access_token"],
Expand All @@ -93,7 +93,7 @@ async def auto_login(response: Response, db: DbSession):
domain=auth_settings.COOKIE_DOMAIN,
)

user = get_user_by_id(db, user_id)
user = await get_user_by_id(db, user_id)

if user:
if user.store_api_key is None:
Expand Down Expand Up @@ -124,14 +124,14 @@ async def auto_login(response: Response, db: DbSession):
async def refresh_token(
request: Request,
response: Response,
db: DbSession,
db: AsyncDbSession,
):
auth_settings = get_settings_service().auth_settings

token = request.cookies.get("refresh_token_lf")

if token:
tokens = create_refresh_token(token, db)
tokens = await create_refresh_token(token, db)
response.set_cookie(
"refresh_token_lf",
tokens["refresh_token"],
Expand Down
30 changes: 15 additions & 15 deletions src/backend/base/langflow/api/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar

from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
get_current_active_superuser,
Expand All @@ -25,21 +25,21 @@
@router.post("/", response_model=UserRead, status_code=201)
async def add_user(
user: UserCreate,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Add a new user to the database."""
new_user = User.model_validate(user, from_attributes=True)
try:
new_user.password = get_password_hash(user.password)
new_user.is_active = get_settings_service().auth_settings.NEW_USER_IS_ACTIVE
session.add(new_user)
session.commit()
session.refresh(new_user)
folder = create_default_folder_if_it_doesnt_exist(session, new_user.id)
await session.commit()
await session.refresh(new_user)
folder = await create_default_folder_if_it_doesnt_exist(session, new_user.id)
if not folder:
raise HTTPException(status_code=500, detail="Error creating default folder")
except IntegrityError as e:
session.rollback()
await session.rollback()
raise HTTPException(status_code=400, detail="This username is unavailable.") from e

return new_user
Expand All @@ -58,14 +58,14 @@ async def read_all_users(
*,
skip: int = 0,
limit: int = 10,
session: DbSession,
session: AsyncDbSession,
) -> UsersResponse:
"""Retrieve a list of users from the database with pagination."""
query: SelectOfScalar = select(User).offset(skip).limit(limit)
users = session.exec(query).fetchall()
users = (await session.exec(query)).fetchall()

count_query = select(func.count()).select_from(User)
total_count = session.exec(count_query).first()
total_count = (await session.exec(count_query)).first()

return UsersResponse(
total_count=total_count,
Expand All @@ -78,7 +78,7 @@ async def patch_user(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Update an existing user's data."""
update_password = bool(user_update.password)
Expand All @@ -93,10 +93,10 @@ async def patch_user(
raise HTTPException(status_code=400, detail="You can't change your password here")
user_update.password = get_password_hash(user_update.password)

if user_db := get_user_by_id(session, user_id):
if user_db := await get_user_by_id(session, user_id):
if not update_password:
user_update.password = user_db.password
return update_user(user_db, user_update, session)
return await update_user(user_db, user_update, session)
raise HTTPException(status_code=404, detail="User not found")


Expand All @@ -105,7 +105,7 @@ async def reset_password(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Reset a user's password."""
if user_id != user.id:
Expand All @@ -117,8 +117,8 @@ async def reset_password(
raise HTTPException(status_code=400, detail="You can't use your current password")
new_password = get_password_hash(user_update.password)
user.password = new_password
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)

return user

Expand Down
Loading