From e103dcdc9aaf9d1d40cf7a72ad4dd0743cdcd7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Tue, 14 May 2024 16:45:45 +0200 Subject: [PATCH] feat: move api v0 functionality to v1 (#168) # Description Feature branch for migrating API endpoints from v0 to v1. Refs https://github.com/argilla-io/argilla/issues/4773 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** - [x] Adding new tests. **Checklist** - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- CHANGELOG.md | 13 + src/argilla_server/apis/routes.py | 4 + src/argilla_server/apis/v0/handlers/users.py | 11 +- .../apis/v0/handlers/workspaces.py | 17 +- .../apis/v1/handlers/authentication.py | 39 +++ src/argilla_server/apis/v1/handlers/info.py | 35 ++ src/argilla_server/apis/v1/handlers/users.py | 93 +++++- .../apis/v1/handlers/workspaces.py | 108 ++++++- src/argilla_server/contexts/accounts.py | 79 +++-- src/argilla_server/contexts/info.py | 55 ++++ .../errors/future/base_errors.py | 8 +- src/argilla_server/policies.py | 45 +++ src/argilla_server/schemas/v0/workspaces.py | 5 - src/argilla_server/schemas/v1/info.py | 25 ++ src/argilla_server/schemas/v1/users.py | 52 +++ src/argilla_server/schemas/v1/workspaces.py | 13 +- src/argilla_server/search_engine/base.py | 4 + .../search_engine/elasticsearch.py | 3 + .../search_engine/opensearch.py | 3 + src/argilla_server/security/model.py | 5 - src/argilla_server/services/info.py | 2 + tests/unit/api/v0/test_authentication.py | 4 +- tests/unit/api/v0/test_users.py | 47 +++ tests/unit/api/v1/authentication/__init__.py | 13 + .../v1/authentication/test_create_token.py | 66 ++++ tests/unit/api/v1/info/__init__.py | 13 + tests/unit/api/v1/info/test_get_status.py | 36 +++ tests/unit/api/v1/info/test_get_version.py | 29 ++ tests/unit/api/v1/users/__init__.py | 13 + tests/unit/api/v1/users/test_create_user.py | 306 ++++++++++++++++++ tests/unit/api/v1/users/test_delete_user.py | 82 +++++ .../api/v1/users/test_get_current_user.py | 43 +++ tests/unit/api/v1/users/test_get_user.py | 71 ++++ tests/unit/api/v1/users/test_list_users.py | 81 +++++ tests/unit/api/v1/workspaces/__init__.py | 13 + .../v1/workspaces/test_create_workspace.py | 112 +++++++ .../workspaces/test_create_workspace_user.py | 140 ++++++++ .../workspaces/test_delete_workspace_user.py | 150 +++++++++ .../workspaces/test_list_workspace_users.py | 130 ++++++++ 39 files changed, 1909 insertions(+), 59 deletions(-) create mode 100644 src/argilla_server/apis/v1/handlers/authentication.py create mode 100644 src/argilla_server/apis/v1/handlers/info.py create mode 100644 src/argilla_server/contexts/info.py create mode 100644 src/argilla_server/schemas/v1/info.py create mode 100644 src/argilla_server/schemas/v1/users.py create mode 100644 tests/unit/api/v1/authentication/__init__.py create mode 100644 tests/unit/api/v1/authentication/test_create_token.py create mode 100644 tests/unit/api/v1/info/__init__.py create mode 100644 tests/unit/api/v1/info/test_get_status.py create mode 100644 tests/unit/api/v1/info/test_get_version.py create mode 100644 tests/unit/api/v1/users/__init__.py create mode 100644 tests/unit/api/v1/users/test_create_user.py create mode 100644 tests/unit/api/v1/users/test_delete_user.py create mode 100644 tests/unit/api/v1/users/test_get_current_user.py create mode 100644 tests/unit/api/v1/users/test_get_user.py create mode 100644 tests/unit/api/v1/users/test_list_users.py create mode 100644 tests/unit/api/v1/workspaces/__init__.py create mode 100644 tests/unit/api/v1/workspaces/test_create_workspace.py create mode 100644 tests/unit/api/v1/workspaces/test_create_workspace_user.py create mode 100644 tests/unit/api/v1/workspaces/test_delete_workspace_user.py create mode 100644 tests/unit/api/v1/workspaces/test_list_workspace_users.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 273c242d..a3679be8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,19 @@ These are the section headers that we use: ## [Unreleased]() +- Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138)) +- Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140)) +- Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142)) +- Added `GET /api/v1/users/:user_id` endpoint to get a specific user. ([#166](https://github.com/argilla-io/argilla-server/pull/166)) +- Added `POST /api/v1/users` endpoint to create a new user. ([#146](https://github.com/argilla-io/argilla-server/pull/146)) +- Added `DELETE /api/v1/users` endpoint to delete a user. ([#148](https://github.com/argilla-io/argilla-server/pull/148)) +- Added `POST /api/v1/workspaces` endpoint to create a new workspace. ([#150](https://github.com/argilla-io/argilla-server/pull/150)) +- Added `GET /api/v1/workspaces/:workspace_id/users` endpoint to get the users of a workspace. ([#153](https://github.com/argilla-io/argilla-server/pull/153)) +- Added `POST /api/v1/workspaces/:workspace_id/users` endpoind to add a user to a workspace. ([#156](https://github.com/argilla-io/argilla-server/pull/156)) +- Added `DELETE /api/v1/workspaces/:workspace_id/users/:user_id` endpoint to remove a user from a workspace. ([#158](https://github.com/argilla-io/argilla-server/pull/158)) +- Added `GET /api/v1/version` endpoint to get the current Argilla version. ([#162](https://github.com/argilla-io/argilla-server/pull/162)) +- Added `GET /api/v1/status` endpoint to get Argilla service status. ([#165](https://github.com/argilla-io/argilla-server/pull/165)) + ## [1.28.0](https://github.com/argilla-io/argilla-server/compare/v1.27.0...v1.28.0) ### Added diff --git a/src/argilla_server/apis/routes.py b/src/argilla_server/apis/routes.py index deae7719..e075e58d 100644 --- a/src/argilla_server/apis/routes.py +++ b/src/argilla_server/apis/routes.py @@ -35,12 +35,14 @@ users, workspaces, ) +from argilla_server.apis.v1.handlers import authentication as authentication_v1 from argilla_server.apis.v1.handlers import ( datasets as datasets_v1, ) from argilla_server.apis.v1.handlers import ( fields as fields_v1, ) +from argilla_server.apis.v1.handlers import info as info_v1 from argilla_server.apis.v1.handlers import ( metadata_properties as metadata_properties_v1, ) @@ -113,6 +115,8 @@ def create_api_v1(): APIErrorHandler.configure_app(api_v1) for router in [ + info_v1.router, + authentication_v1.router, datasets_v1.router, fields_v1.router, questions_v1.router, diff --git a/src/argilla_server/apis/v0/handlers/users.py b/src/argilla_server/apis/v0/handlers/users.py index 1fa3b45a..645dd3cd 100644 --- a/src/argilla_server/apis/v0/handlers/users.py +++ b/src/argilla_server/apis/v0/handlers/users.py @@ -23,6 +23,7 @@ from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError +from argilla_server.errors.future import NotUniqueError from argilla_server.policies import UserPolicy, authorize from argilla_server.pydantic_v1 import parse_obj_as from argilla_server.schemas.v0.users import User, UserCreate @@ -90,17 +91,17 @@ async def create_user( ): await authorize(current_user, UserPolicy.create) - user = await accounts.get_user_by_username(db, user_create.username) - if user is not None: - raise EntityAlreadyExistsError(name=user_create.username, type=User) - try: - user = await accounts.create_user(db, user_create) + user = await accounts.create_user(db, user_create.dict(), user_create.workspaces) + telemetry.track_user_created(user) + except NotUniqueError: + raise EntityAlreadyExistsError(name=user_create.username, type=User) except Exception as e: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) await user.awaitable_attrs.workspaces + return User.from_orm(user) diff --git a/src/argilla_server/apis/v0/handlers/workspaces.py b/src/argilla_server/apis/v0/handlers/workspaces.py index 65c4af74..b23e2aba 100644 --- a/src/argilla_server/apis/v0/handlers/workspaces.py +++ b/src/argilla_server/apis/v0/handlers/workspaces.py @@ -21,10 +21,11 @@ from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError +from argilla_server.errors.future import NotUniqueError from argilla_server.policies import WorkspacePolicy, WorkspaceUserPolicy, authorize from argilla_server.pydantic_v1 import parse_obj_as from argilla_server.schemas.v0.users import User -from argilla_server.schemas.v0.workspaces import Workspace, WorkspaceCreate, WorkspaceUserCreate +from argilla_server.schemas.v0.workspaces import Workspace, WorkspaceCreate from argilla_server.security import auth router = APIRouter(tags=["workspaces"]) @@ -39,11 +40,11 @@ async def create_workspace( ): await authorize(current_user, WorkspacePolicy.create) - if await accounts.get_workspace_by_name(db, workspace_create.name): + try: + workspace = await accounts.create_workspace(db, workspace_create.dict()) + except NotUniqueError: raise EntityAlreadyExistsError(name=workspace_create.name, type=Workspace) - workspace = await accounts.create_workspace(db, workspace_create) - return Workspace.from_orm(workspace) @@ -84,13 +85,11 @@ async def create_workspace_user( if not user: raise EntityNotFoundError(name=str(user_id), type=User) - workspace_user = await accounts.get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id) - if workspace_user is not None: + try: + workspace_user = await accounts.create_workspace_user(db, {"workspace_id": workspace_id, "user_id": user_id}) + except NotUniqueError: raise EntityAlreadyExistsError(name=str(user_id), type=User) - workspace_user = await accounts.create_workspace_user( - db, WorkspaceUserCreate(workspace_id=workspace_id, user_id=user_id) - ) await db.refresh(user, attribute_names=["workspaces"]) return User.from_orm(workspace_user.user) diff --git a/src/argilla_server/apis/v1/handlers/authentication.py b/src/argilla_server/apis/v1/handlers/authentication.py new file mode 100644 index 00000000..7834bc99 --- /dev/null +++ b/src/argilla_server/apis/v1/handlers/authentication.py @@ -0,0 +1,39 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Annotated + +from fastapi import APIRouter, Depends, Form, status +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.contexts import accounts +from argilla_server.database import get_async_db +from argilla_server.errors import UnauthorizedError +from argilla_server.schemas.v1.oauth2 import Token + +router = APIRouter(tags=["Authentication"]) + + +@router.post("/token", status_code=status.HTTP_201_CREATED, response_model=Token) +async def create_token( + *, + db: AsyncSession = Depends(get_async_db), + username: Annotated[str, Form()], + password: Annotated[str, Form()], +): + user = await accounts.authenticate_user(db, username, password) + if not user: + raise UnauthorizedError() + + return Token(access_token=accounts.generate_user_token(user)) diff --git a/src/argilla_server/apis/v1/handlers/info.py b/src/argilla_server/apis/v1/handlers/info.py new file mode 100644 index 00000000..7a1f1d27 --- /dev/null +++ b/src/argilla_server/apis/v1/handlers/info.py @@ -0,0 +1,35 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import APIRouter, Depends + +from argilla_server.contexts import info +from argilla_server.schemas.v1.info import Status, Version +from argilla_server.search_engine import SearchEngine, get_search_engine + +router = APIRouter(tags=["info"]) + + +@router.get("/version", response_model=Version) +async def get_version(): + return Version(version=info.argilla_version()) + + +@router.get("/status", response_model=Status) +async def get_status(search_engine: SearchEngine = Depends(get_search_engine)): + return Status( + version=info.argilla_version(), + search_engine=await search_engine.info(), + memory=info.memory_status(), + ) diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index f0578f65..23c53b52 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -12,24 +12,111 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Security, status +from fastapi import APIRouter, Depends, HTTPException, Request, Security, status from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server import models, telemetry from argilla_server.contexts import accounts from argilla_server.database import get_async_db -from argilla_server.models import User +from argilla_server.errors.future import NotUniqueError from argilla_server.policies import UserPolicyV1, authorize +from argilla_server.schemas.v1.users import User, UserCreate, Users from argilla_server.schemas.v1.workspaces import Workspaces from argilla_server.security import auth router = APIRouter(tags=["users"]) +@router.get("/me", response_model=User) +async def get_current_user(request: Request, current_user: models.User = Security(auth.get_current_user)): + await telemetry.track_login(request, current_user) + + return current_user + + +@router.get("/users/{user_id}", response_model=User) +async def get_user( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.get) + + user = await accounts.get_user_by_id(db, user_id) + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with id `{user_id}` not found", + ) + + return user + + +@router.get("/users", response_model=Users) +async def list_users( + *, + db: AsyncSession = Depends(get_async_db), + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.list) + + users = await accounts.list_users(db) + + return Users(items=users) + + +@router.post("/users", status_code=status.HTTP_201_CREATED, response_model=User) +async def create_user( + *, + db: AsyncSession = Depends(get_async_db), + user_create: UserCreate, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.create) + + try: + user = await accounts.create_user(db, user_create.dict()) + + telemetry.track_user_created(user) + except NotUniqueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + + return user + + +@router.delete("/users/{user_id}", response_model=User) +async def delete_user( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), +): + user = await accounts.get_user_by_id(db, user_id) + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with id `{user_id}` not found", + ) + + await authorize(current_user, UserPolicyV1.delete) + + await accounts.delete_user(db, user) + + return user + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( - *, db: AsyncSession = Depends(get_async_db), user_id: UUID, current_user: User = Security(auth.get_current_user) + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), ): await authorize(current_user, UserPolicyV1.list_workspaces) diff --git a/src/argilla_server/apis/v1/handlers/workspaces.py b/src/argilla_server/apis/v1/handlers/workspaces.py index 6d809334..fd3adea4 100644 --- a/src/argilla_server/apis/v1/handlers/workspaces.py +++ b/src/argilla_server/apis/v1/handlers/workspaces.py @@ -17,11 +17,15 @@ from fastapi import APIRouter, Depends, HTTPException, Security, status from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server import models from argilla_server.contexts import accounts, datasets from argilla_server.database import get_async_db +from argilla_server.errors import EntityAlreadyExistsError +from argilla_server.errors.future import NotUniqueError from argilla_server.models import User -from argilla_server.policies import WorkspacePolicyV1, authorize -from argilla_server.schemas.v1.workspaces import Workspace, Workspaces +from argilla_server.policies import WorkspacePolicyV1, WorkspaceUserPolicyV1, authorize +from argilla_server.schemas.v1.users import User, Users +from argilla_server.schemas.v1.workspaces import Workspace, WorkspaceCreate, Workspaces, WorkspaceUserCreate from argilla_server.security import auth from argilla_server.services.datasets import DatasetsService @@ -33,7 +37,7 @@ async def get_workspace( *, db: AsyncSession = Depends(get_async_db), workspace_id: UUID, - current_user: User = Security(auth.get_current_user), + current_user: models.User = Security(auth.get_current_user), ): await authorize(current_user, WorkspacePolicyV1.get(workspace_id)) @@ -47,13 +51,30 @@ async def get_workspace( return workspace +@router.post("/workspaces", status_code=status.HTTP_201_CREATED, response_model=Workspace) +async def create_workspace( + *, + db: AsyncSession = Depends(get_async_db), + workspace_create: WorkspaceCreate, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, WorkspacePolicyV1.create) + + try: + workspace = await accounts.create_workspace(db, workspace_create.dict()) + except NotUniqueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + + return workspace + + @router.delete("/workspaces/{workspace_id}", response_model=Workspace) async def delete_workspace( *, db: AsyncSession = Depends(get_async_db), datasets_service: DatasetsService = Depends(DatasetsService.get_instance), workspace_id: UUID, - current_user: User = Security(auth.get_current_user), + current_user: models.User = Security(auth.get_current_user), ): await authorize(current_user, WorkspacePolicyV1.delete) @@ -81,7 +102,9 @@ async def delete_workspace( @router.get("/me/workspaces", response_model=Workspaces) async def list_workspaces_me( - *, db: AsyncSession = Depends(get_async_db), current_user: User = Security(auth.get_current_user) + *, + db: AsyncSession = Depends(get_async_db), + current_user: models.User = Security(auth.get_current_user), ) -> Workspaces: await authorize(current_user, WorkspacePolicyV1.list_workspaces_me) @@ -91,3 +114,78 @@ async def list_workspaces_me( workspaces = await accounts.list_workspaces_by_user_id(db, current_user.id) return Workspaces(items=workspaces) + + +@router.get("/workspaces/{workspace_id}/users", response_model=Users) +async def list_workspace_users( + *, + db: AsyncSession = Depends(get_async_db), + workspace_id: UUID, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, WorkspaceUserPolicyV1.list(workspace_id)) + + workspace = await accounts.get_workspace_by_id(db, workspace_id) + if workspace is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with id `{workspace_id}` not found", + ) + + await workspace.awaitable_attrs.users + + return Users(items=workspace.users) + + +@router.post("/workspaces/{workspace_id}/users", status_code=status.HTTP_201_CREATED, response_model=User) +async def create_workspace_user( + *, + db: AsyncSession = Depends(get_async_db), + workspace_id: UUID, + workspace_user_create: WorkspaceUserCreate, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, WorkspaceUserPolicyV1.create) + + workspace = await accounts.get_workspace_by_id(db, workspace_id) + if workspace is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with id `{workspace_id}` not found", + ) + + user = await accounts.get_user_by_id(db, workspace_user_create.user_id) + if user is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"User with id `{workspace_user_create.user_id}` not found", + ) + + try: + workspace_user = await accounts.create_workspace_user(db, {"workspace_id": workspace.id, "user_id": user.id}) + except NotUniqueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + + return workspace_user.user + + +@router.delete("/workspaces/{workspace_id}/users/{user_id}", response_model=User) +async def delete_workspace_user( + *, + db: AsyncSession = Depends(get_async_db), + workspace_id: UUID, + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), +): + workspace_user = await accounts.get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id) + if workspace_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with id `{user_id}` not found in workspace with id `{workspace_id}`", + ) + + await authorize(current_user, WorkspaceUserPolicyV1.delete(workspace_user)) + + await accounts.delete_workspace_user(db, workspace_user) + + return await workspace_user.awaitable_attrs.user diff --git a/src/argilla_server/contexts/accounts.py b/src/argilla_server/contexts/accounts.py index e09a20c7..c3af30ce 100644 --- a/src/argilla_server/contexts/accounts.py +++ b/src/argilla_server/contexts/accounts.py @@ -21,9 +21,12 @@ from sqlalchemy.orm import Session, selectinload from argilla_server.enums import UserRole +from argilla_server.errors.future import NotUniqueError from argilla_server.models import User, Workspace, WorkspaceUser from argilla_server.schemas.v0.users import UserCreate -from argilla_server.schemas.v0.workspaces import WorkspaceCreate, WorkspaceUserCreate +from argilla_server.schemas.v0.workspaces import WorkspaceCreate +from argilla_server.security.authentication.jwt import JWT +from argilla_server.security.authentication.userinfo import UserInfo _CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -35,13 +38,18 @@ async def get_workspace_user_by_workspace_id_and_user_id( return result.scalar_one_or_none() -async def create_workspace_user(db: AsyncSession, workspace_user_create: WorkspaceUserCreate) -> WorkspaceUser: - workspace_user = await WorkspaceUser.create( - db, - workspace_id=workspace_user_create.workspace_id, - user_id=workspace_user_create.user_id, - ) +async def create_workspace_user(db: AsyncSession, workspace_user_attrs: dict) -> WorkspaceUser: + workspace_id = workspace_user_attrs["workspace_id"] + user_id = workspace_user_attrs["user_id"] + + if (await get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id)) is not None: + raise NotUniqueError(f"Workspace user with workspace_id `{workspace_id}` and user_id `{user_id}` is not unique") + + workspace_user = await WorkspaceUser.create(db, workspace_id=workspace_id, user_id=user_id) + + # TODO: Once we delete API v0 endpoint we can reduce this to refresh only the user. await db.refresh(workspace_user, attribute_names=["workspace", "user"]) + return workspace_user @@ -73,8 +81,11 @@ async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> List[Wo return result.scalars().all() -async def create_workspace(db: AsyncSession, workspace_create: WorkspaceCreate) -> Workspace: - return await Workspace.create(db, schema=workspace_create) +async def create_workspace(db: AsyncSession, workspace_attrs: dict) -> Workspace: + if (await get_workspace_by_name(db, workspace_attrs["name"])) is not None: + raise NotUniqueError(f"Workspace name `{workspace_attrs['name']}` is not unique") + + return await Workspace.create(db, name=workspace_attrs["name"]) async def delete_workspace(db: AsyncSession, workspace: Workspace): @@ -108,6 +119,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non async def list_users(db: "AsyncSession") -> Sequence[User]: + # TODO: After removing API v0 implementation we can remove the workspaces eager loading + # because is not used in the new API v1 endpoints. result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces))) return result.scalars().all() @@ -117,23 +130,29 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U return result.scalars().all() -async def create_user(db: "AsyncSession", user_create: UserCreate) -> User: +# TODO: After removing API v0 implementation we can remove the workspaces attribute. +# With API v1 the workspaces will be created doing additional requests to other endpoints for it. +async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User: + if (await get_user_by_username(db, user_attrs["username"])) is not None: + raise NotUniqueError(f"Username `{user_attrs['username']}` is not unique") + async with db.begin_nested(): user = await User.create( db, - first_name=user_create.first_name, - last_name=user_create.last_name, - username=user_create.username, - role=user_create.role, - password_hash=hash_password(user_create.password), + first_name=user_attrs["first_name"], + last_name=user_attrs["last_name"], + username=user_attrs["username"], + role=user_attrs["role"], + password_hash=hash_password(user_attrs["password"]), autocommit=False, ) - if user_create.workspaces: - for workspace_name in user_create.workspaces: + if workspaces is not None: + for workspace_name in workspaces: workspace = await get_workspace_by_name(db, workspace_name) if not workspace: raise ValueError(f"Workspace '{workspace_name}' does not exist") + await WorkspaceUser.create( db, workspace_id=workspace.id, @@ -150,15 +169,18 @@ async def create_user_with_random_password( db, username: str, first_name: str, - workspaces: List[str] = None, role: UserRole = UserRole.annotator, + workspaces: Union[List[str], None] = None, ) -> User: - password = _generate_random_password() + user_attrs = { + "first_name": first_name, + "last_name": None, + "username": username, + "role": role, + "password": _generate_random_password(), + } - user_create = UserCreate( - first_name=first_name, username=username, role=role, password=password, workspaces=workspaces - ) - return await create_user(db, user_create) + return await create_user(db, user_attrs, workspaces) async def delete_user(db: AsyncSession, user: User) -> User: @@ -188,6 +210,17 @@ def _generate_random_password() -> str: return secrets.token_urlsafe() +def generate_user_token(user: User) -> str: + return JWT.create( + UserInfo( + identity=str(user.id), + name=user.first_name, + username=user.username, + role=user.role, + ), + ) + + async def fetch_users_by_ids_as_dict(db: "AsyncSession", user_ids: List[UUID]) -> Dict[UUID, User]: users = await list_users_by_ids(db, set(user_ids)) return {user.id: user for user in users} diff --git a/src/argilla_server/contexts/info.py b/src/argilla_server/contexts/info.py new file mode 100644 index 00000000..704cc929 --- /dev/null +++ b/src/argilla_server/contexts/info.py @@ -0,0 +1,55 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import psutil + +from argilla_server._version import __version__ + + +def argilla_version() -> str: + return __version__ + + +def memory_status() -> dict: + process = psutil.Process(os.getpid()) + + return {k: _memory_size(v) for k, v in process.memory_info()._asdict().items()} + + +def _memory_size(bytes) -> str: + system = [ + (1024**5, "P"), + (1024**4, "T"), + (1024**3, "G"), + (1024**2, "M"), + (1024**1, "K"), + (1024**0, "B"), + ] + + factor, suffix = None, None + for factor, suffix in system: + if bytes >= factor: + break + + amount = int(bytes / factor) + if isinstance(suffix, tuple): + singular, multiple = suffix + if amount == 1: + suffix = singular + else: + suffix = multiple + + return str(amount) + suffix diff --git a/src/argilla_server/errors/future/base_errors.py b/src/argilla_server/errors/future/base_errors.py index 3d017f14..1e3e7081 100644 --- a/src/argilla_server/errors/future/base_errors.py +++ b/src/argilla_server/errors/future/base_errors.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["NotFoundError", "AuthenticationError"] +__all__ = ["NotFoundError", "NotUniqueError", "AuthenticationError"] class NotFoundError(Exception): @@ -21,6 +21,12 @@ class NotFoundError(Exception): pass +class NotUniqueError(Exception): + """Custom Argilla not unique error. Use it for situations where an Argilla domain entity already exists violating a constraint.""" + + pass + + class AuthenticationError(Exception): """Custom Argilla unauthorized error. Use it for situations where an request is not authorized to perform an action.""" diff --git a/src/argilla_server/policies.py b/src/argilla_server/policies.py index 196ef159..6164ee72 100644 --- a/src/argilla_server/policies.py +++ b/src/argilla_server/policies.py @@ -77,6 +77,31 @@ async def is_allowed(actor: User) -> bool: return is_allowed +class WorkspaceUserPolicyV1: + @classmethod + def list(cls, workspace_id: UUID) -> PolicyAction: + async def is_allowed(actor: User) -> bool: + return actor.is_owner or ( + actor.is_admin and await _exists_workspace_user_by_user_and_workspace_id(actor, workspace_id) + ) + + return is_allowed + + @classmethod + async def create(cls, actor: User) -> bool: + return actor.is_owner + + @classmethod + def delete(cls, workspace_user: WorkspaceUser) -> PolicyAction: + async def is_allowed(actor: User) -> bool: + return actor.is_owner or ( + actor.is_admin + and await _exists_workspace_user_by_user_and_workspace_id(actor, workspace_user.workspace_id) + ) + + return is_allowed + + class WorkspacePolicy: @classmethod async def list(cls, actor: User) -> bool: @@ -102,6 +127,10 @@ async def is_allowed(actor: User) -> bool: return is_allowed + @classmethod + async def create(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def delete(cls, actor: User) -> bool: return actor.is_owner @@ -129,6 +158,22 @@ async def is_allowed(actor: User) -> bool: class UserPolicyV1: + @classmethod + async def get(cls, actor: User) -> bool: + return actor.is_owner + + @classmethod + async def list(cls, actor: User) -> bool: + return actor.is_owner + + @classmethod + async def create(cls, actor: User) -> bool: + return actor.is_owner + + @classmethod + async def delete(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def list_workspaces(cls, actor: User) -> bool: return actor.is_owner diff --git a/src/argilla_server/schemas/v0/workspaces.py b/src/argilla_server/schemas/v0/workspaces.py index b403fe3b..4abf8aaf 100644 --- a/src/argilla_server/schemas/v0/workspaces.py +++ b/src/argilla_server/schemas/v0/workspaces.py @@ -31,10 +31,5 @@ class Config: orm_mode = True -class WorkspaceUserCreate(BaseModel): - user_id: UUID - workspace_id: UUID - - class WorkspaceCreate(BaseModel): name: str = Field(..., regex=WORKSPACE_NAME_REGEX, min_length=1) diff --git a/src/argilla_server/schemas/v1/info.py b/src/argilla_server/schemas/v1/info.py new file mode 100644 index 00000000..8ec7da13 --- /dev/null +++ b/src/argilla_server/schemas/v1/info.py @@ -0,0 +1,25 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argilla_server.pydantic_v1 import BaseModel + + +class Version(BaseModel): + version: str + + +class Status(BaseModel): + version: str + search_engine: dict + memory: dict diff --git a/src/argilla_server/schemas/v1/users.py b/src/argilla_server/schemas/v1/users.py new file mode 100644 index 00000000..1b93d7f6 --- /dev/null +++ b/src/argilla_server/schemas/v1/users.py @@ -0,0 +1,52 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from argilla_server.enums import UserRole +from argilla_server.pydantic_v1 import BaseModel, Field, constr + +USER_USERNAME_REGEX = "^(?!-|_)[A-za-z0-9-_]+$" +USER_PASSWORD_MIN_LENGTH = 8 +USER_PASSWORD_MAX_LENGTH = 100 + + +class User(BaseModel): + id: UUID + first_name: str + last_name: Optional[str] + username: str + role: UserRole + # TODO: We need to move `api_key` outside of this schema and think about a more + # secure way to expose it, along with ways to expire it and create new API keys. + api_key: str + inserted_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + + +class UserCreate(BaseModel): + first_name: constr(min_length=1, strip_whitespace=True) + last_name: Optional[constr(min_length=1, strip_whitespace=True)] + username: str = Field(regex=USER_USERNAME_REGEX, min_length=1) + role: Optional[UserRole] + password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH) + + +class Users(BaseModel): + items: List[User] diff --git a/src/argilla_server/schemas/v1/workspaces.py b/src/argilla_server/schemas/v1/workspaces.py index 1a0f4dcd..071aeaf7 100644 --- a/src/argilla_server/schemas/v1/workspaces.py +++ b/src/argilla_server/schemas/v1/workspaces.py @@ -16,7 +16,10 @@ from typing import List from uuid import UUID -from argilla_server.pydantic_v1 import BaseModel +from argilla_server.constants import ES_INDEX_REGEX_PATTERN +from argilla_server.pydantic_v1 import BaseModel, Field + +WORKSPACE_NAME_REGEX = ES_INDEX_REGEX_PATTERN class Workspace(BaseModel): @@ -29,5 +32,13 @@ class Config: orm_mode = True +class WorkspaceCreate(BaseModel): + name: str = Field(regex=WORKSPACE_NAME_REGEX, min_length=1) + + class Workspaces(BaseModel): items: List[Workspace] + + +class WorkspaceUserCreate(BaseModel): + user_id: UUID diff --git a/src/argilla_server/search_engine/base.py b/src/argilla_server/search_engine/base.py index 2e0f402e..7396c3da 100644 --- a/src/argilla_server/search_engine/base.py +++ b/src/argilla_server/search_engine/base.py @@ -264,6 +264,10 @@ async def new_instance(cls) -> "SearchEngine": async def close(self): pass + @abstractmethod + async def info(self) -> dict: + pass + @classmethod def register(cls, engine_name: str): def decorator(engine_class): diff --git a/src/argilla_server/search_engine/elasticsearch.py b/src/argilla_server/search_engine/elasticsearch.py index c945e471..44b2d5c4 100644 --- a/src/argilla_server/search_engine/elasticsearch.py +++ b/src/argilla_server/search_engine/elasticsearch.py @@ -64,6 +64,9 @@ async def new_instance(cls) -> "ElasticSearchEngine": async def close(self): await self.client.close() + async def info(self) -> dict: + return await self.client.info() + def _mapping_for_vector_settings(self, vector_settings: VectorSettings) -> dict: return { es_field_for_vector_settings(vector_settings): { diff --git a/src/argilla_server/search_engine/opensearch.py b/src/argilla_server/search_engine/opensearch.py index 078ae051..6333dd64 100644 --- a/src/argilla_server/search_engine/opensearch.py +++ b/src/argilla_server/search_engine/opensearch.py @@ -56,6 +56,9 @@ async def new_instance(cls) -> "OpenSearchEngine": async def close(self): await self.client.close() + async def info(self) -> dict: + return await self.client.info() + def _configure_index_settings(self): base_settings = super()._configure_index_settings() return {**base_settings, "index.knn": False} diff --git a/src/argilla_server/security/model.py b/src/argilla_server/security/model.py index 2b6d175a..f62f59c0 100644 --- a/src/argilla_server/security/model.py +++ b/src/argilla_server/security/model.py @@ -28,11 +28,6 @@ USER_PASSWORD_MAX_LENGTH = 100 -class WorkspaceUserCreate(BaseModel): - user_id: UUID - workspace_id: UUID - - class Workspace(BaseModel): id: UUID name: str diff --git a/src/argilla_server/services/info.py b/src/argilla_server/services/info.py index 496f9a75..79fcd4fc 100644 --- a/src/argilla_server/services/info.py +++ b/src/argilla_server/services/info.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: Once we delete all API v0 endpoints, we can remove this file. + import os from typing import Any, Dict diff --git a/tests/unit/api/v0/test_authentication.py b/tests/unit/api/v0/test_authentication.py index 8e22b4e8..c6eb4d23 100644 --- a/tests/unit/api/v0/test_authentication.py +++ b/tests/unit/api/v0/test_authentication.py @@ -21,12 +21,12 @@ @pytest.mark.asyncio class TestsAuthentication: - async def authenticate(self, async_client: AsyncClient): + async def test_authenticate(self, async_client: AsyncClient): user = await UserFactory.create() response = await async_client.post( "/api/security/token", - data={"username": user.username, "password": "12345678"}, + data={"username": user.username, "password": "1234"}, ) assert response.status_code == 200 assert response.json()["access_token"] diff --git a/tests/unit/api/v0/test_users.py b/tests/unit/api/v0/test_users.py index 2c1bd6c6..5a196b3f 100644 --- a/tests/unit/api/v0/test_users.py +++ b/tests/unit/api/v0/test_users.py @@ -217,6 +217,53 @@ async def test_create_user_with_non_default_role( assert response_body["role"] == UserRole.owner.value +@pytest.mark.asyncio +async def test_create_user_with_first_name_including_leading_and_trailing_spaces( + async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict +): + response = await async_client.post( + "/api/users", + headers=owner_auth_header, + json={ + "first_name": " First name ", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 200 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["first_name"] == "First name" + assert user.first_name == "First name" + + +@pytest.mark.asyncio +async def test_create_user_with_last_name_including_leading_and_trailing_spaces( + async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict +): + response = await async_client.post( + "/api/users", + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": " Last name ", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 200 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["last_name"] == "Last name" + assert user.last_name == "Last name" + + @pytest.mark.asyncio async def test_create_user_without_authentication(async_client: "AsyncClient", db: "AsyncSession"): user = {"first_name": "first-name", "username": "username", "password": "12345678"} diff --git a/tests/unit/api/v1/authentication/__init__.py b/tests/unit/api/v1/authentication/__init__.py new file mode 100644 index 00000000..55be4179 --- /dev/null +++ b/tests/unit/api/v1/authentication/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/api/v1/authentication/test_create_token.py b/tests/unit/api/v1/authentication/test_create_token.py new file mode 100644 index 00000000..38e835aa --- /dev/null +++ b/tests/unit/api/v1/authentication/test_create_token.py @@ -0,0 +1,66 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.models import User +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestsCreateToken: + def url(self) -> str: + return "/api/v1/token" + + async def test_create_token(self, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.post( + self.url(), + data={ + "username": user.username, + "password": "1234", + }, + ) + + assert response.status_code == 201 + assert response.json()["access_token"] + assert response.json()["token_type"] == "bearer" + + async def test_create_token_with_invalid_username(self, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.post( + self.url(), + data={ + "username": "invalid-username", + "password": "1234", + }, + ) + + assert response.status_code == 401 + + async def test_create_token_with_invalid_password(self, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.post( + self.url(), + data={ + "username": user.username, + "password": "invalid-password", + }, + ) + + assert response.status_code == 401 diff --git a/tests/unit/api/v1/info/__init__.py b/tests/unit/api/v1/info/__init__.py new file mode 100644 index 00000000..55be4179 --- /dev/null +++ b/tests/unit/api/v1/info/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/api/v1/info/test_get_status.py b/tests/unit/api/v1/info/test_get_status.py new file mode 100644 index 00000000..e337bf20 --- /dev/null +++ b/tests/unit/api/v1/info/test_get_status.py @@ -0,0 +1,36 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server._version import __version__ +from argilla_server.search_engine import SearchEngine +from httpx import AsyncClient + + +@pytest.mark.asyncio +class TestGetStatus: + def url(self) -> str: + return "/api/v1/status" + + async def test_get_status(self, async_client: AsyncClient, mock_search_engine: SearchEngine): + mock_search_engine.info.return_value = {} + + response = await async_client.get(self.url()) + + assert response.status_code == 200 + + response_json = response.json() + assert response_json["version"] == __version__ + assert "search_engine" in response_json + assert "memory" in response_json diff --git a/tests/unit/api/v1/info/test_get_version.py b/tests/unit/api/v1/info/test_get_version.py new file mode 100644 index 00000000..c00f576f --- /dev/null +++ b/tests/unit/api/v1/info/test_get_version.py @@ -0,0 +1,29 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server._version import __version__ +from httpx import AsyncClient + + +@pytest.mark.asyncio +class TestGetVersion: + def url(self) -> str: + return "/api/v1/version" + + async def test_get_version(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 200 + assert response.json() == {"version": __version__} diff --git a/tests/unit/api/v1/users/__init__.py b/tests/unit/api/v1/users/__init__.py new file mode 100644 index 00000000..55be4179 --- /dev/null +++ b/tests/unit/api/v1/users/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/api/v1/users/test_create_user.py b/tests/unit/api/v1/users/test_create_user.py new file mode 100644 index 00000000..62a32d95 --- /dev/null +++ b/tests/unit/api/v1/users/test_create_user.py @@ -0,0 +1,306 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestCreateUser: + def url(self) -> str: + return "/api/v1/users" + + async def test_create_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + response_json = response.json() + assert response_json == { + "id": str(user.id), + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "role": UserRole.annotator, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + async def test_create_user_with_first_name_including_leading_and_trailing_spaces( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": " First name ", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["first_name"] == "First name" + assert user.first_name == "First name" + + async def test_create_user_with_last_name_including_leading_and_trailing_spaces( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": " Last name ", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["last_name"] == "Last name" + assert user.last_name == "Last name" + + async def test_create_user_without_last_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["last_name"] == None + assert user.last_name == None + + async def test_create_user_with_non_default_role( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + "role": UserRole.owner, + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + + assert response.json()["role"] == UserRole.owner + assert user.role == UserRole.owner + + async def test_create_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + response = await async_client.post( + self.url(), + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 401 + assert (await db.execute(select(func.count(User.id)))).scalar() == 0 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_create_user_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + user = await UserFactory.create(role=user_role) + + response = await async_client.post( + self.url(), + headers={API_KEY_HEADER_NAME: user.api_key}, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_existent_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + await UserFactory.create(username="username") + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 409 + assert response.json() == {"detail": "Username `username` is not unique"} + + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + + async def test_create_user_with_invalid_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "invalid username", + "password": "12345678", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_invalid_min_length_first_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_invalid_min_length_last_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_invalid_min_length_password( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "1234", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_invalid_max_length_password( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "p" * 101, + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_invalid_role( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + "role": "invalid role", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 diff --git a/tests/unit/api/v1/users/test_delete_user.py b/tests/unit/api/v1/users/test_delete_user.py new file mode 100644 index 00000000..0c3b6b58 --- /dev/null +++ b/tests/unit/api/v1/users/test_delete_user.py @@ -0,0 +1,82 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestDeleteUser: + def url(self, user_id: UUID) -> str: + return f"/api/v1/users/{user_id}" + + async def test_delete_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + + response = await async_client.delete(self.url(user.id), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "first_name": user.first_name, + "last_name": user.last_name, + "username": user.username, + "role": user.role, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_delete_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.delete(self.url(user.id)) + + assert response.status_code == 401 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_delete_user_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + user = await UserFactory.create() + user_with_unauthorized_role = await UserFactory.create(role=user_role) + + response = await async_client.delete( + self.url(user.id), + headers={API_KEY_HEADER_NAME: user_with_unauthorized_role.api_key}, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + + async def test_delete_user_with_nonexistent_user_id(self, async_client: AsyncClient, owner_auth_header: dict): + user_id = uuid4() + + response = await async_client.delete(self.url(user_id), headers=owner_auth_header) + + assert response.status_code == 404 + assert response.json() == {"detail": f"User with id `{user_id}` not found"} diff --git a/tests/unit/api/v1/users/test_get_current_user.py b/tests/unit/api/v1/users/test_get_current_user.py new file mode 100644 index 00000000..e0d29f93 --- /dev/null +++ b/tests/unit/api/v1/users/test_get_current_user.py @@ -0,0 +1,43 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.models import User +from httpx import AsyncClient + + +@pytest.mark.asyncio +class TestGetCurrentUser: + def url(self) -> str: + return "/api/v1/me" + + async def test_get_current_user(self, async_client: AsyncClient, owner: User, owner_auth_header: dict): + response = await async_client.get(self.url(), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "id": str(owner.id), + "first_name": owner.first_name, + "last_name": owner.last_name, + "username": owner.username, + "role": owner.role, + "api_key": owner.api_key, + "inserted_at": owner.inserted_at.isoformat(), + "updated_at": owner.updated_at.isoformat(), + } + + async def test_get_current_user_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 diff --git a/tests/unit/api/v1/users/test_get_user.py b/tests/unit/api/v1/users/test_get_user.py new file mode 100644 index 00000000..da509891 --- /dev/null +++ b/tests/unit/api/v1/users/test_get_user.py @@ -0,0 +1,71 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestGetUser: + def url(self, user_id: UUID) -> str: + return f"/api/v1/users/{user_id}" + + async def test_get_user(self, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + + response = await async_client.get(self.url(user.id), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "first_name": user.first_name, + "last_name": user.last_name, + "username": user.username, + "role": UserRole.annotator, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + async def test_get_user_without_authentication(self, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.get(self.url(user.id)) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_get_user_with_unauthorized_role(self, async_client: AsyncClient, user_role: UserRole): + user = await UserFactory.create(role=user_role) + + response = await async_client.get( + self.url(user.id), + headers={API_KEY_HEADER_NAME: user.api_key}, + ) + + assert response.status_code == 403 + + async def test_get_user_with_nonexistent_user_id(self, async_client: AsyncClient, owner_auth_header: dict): + user_id = uuid4() + + response = await async_client.get(self.url(user_id), headers=owner_auth_header) + + assert response.status_code == 404 + assert response.json() == {"detail": f"User with id `{user_id}` not found"} diff --git a/tests/unit/api/v1/users/test_list_users.py b/tests/unit/api/v1/users/test_list_users.py new file mode 100644 index 00000000..358d710d --- /dev/null +++ b/tests/unit/api/v1/users/test_list_users.py @@ -0,0 +1,81 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestListUsers: + def url(self) -> str: + return "/api/v1/users" + + async def test_list_users(self, async_client: AsyncClient, owner: User, owner_auth_header: dict): + user_a, user_b = await UserFactory.create_batch(2) + + response = await async_client.get(self.url(), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(owner.id), + "first_name": owner.first_name, + "last_name": owner.last_name, + "username": owner.username, + "role": owner.role, + "api_key": owner.api_key, + "inserted_at": owner.inserted_at.isoformat(), + "updated_at": owner.updated_at.isoformat(), + }, + { + "id": str(user_a.id), + "first_name": user_a.first_name, + "last_name": user_a.last_name, + "username": user_a.username, + "role": user_a.role, + "api_key": user_a.api_key, + "inserted_at": user_a.inserted_at.isoformat(), + "updated_at": user_a.updated_at.isoformat(), + }, + { + "id": str(user_b.id), + "first_name": user_b.first_name, + "last_name": user_b.last_name, + "username": user_b.username, + "role": user_b.role, + "api_key": user_b.api_key, + "inserted_at": user_b.inserted_at.isoformat(), + "updated_at": user_b.updated_at.isoformat(), + }, + ] + } + + async def test_list_users_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_list_users_with_unauthorized_role(self, async_client: AsyncClient, user_role: UserRole): + user = await UserFactory.create(role=user_role) + + response = await async_client.get(self.url(), headers={API_KEY_HEADER_NAME: user.api_key}) + + assert response.status_code == 403 diff --git a/tests/unit/api/v1/workspaces/__init__.py b/tests/unit/api/v1/workspaces/__init__.py new file mode 100644 index 00000000..55be4179 --- /dev/null +++ b/tests/unit/api/v1/workspaces/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/api/v1/workspaces/test_create_workspace.py b/tests/unit/api/v1/workspaces/test_create_workspace.py new file mode 100644 index 00000000..c8de03b8 --- /dev/null +++ b/tests/unit/api/v1/workspaces/test_create_workspace.py @@ -0,0 +1,112 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import Workspace +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory, WorkspaceFactory + + +@pytest.mark.asyncio +class TestCreateWorkspace: + def url(self) -> str: + return "/api/v1/workspaces" + + async def test_create_workspace(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"name": "workspace"}, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 1 + workspace = (await db.execute(select(Workspace).filter_by(name="workspace"))).scalar_one() + + assert response.json() == { + "id": str(workspace.id), + "name": "workspace", + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + + async def test_create_workspace_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + response = await async_client.post( + self.url(), + json={"name": "workspace"}, + ) + + assert response.status_code == 401 + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 0 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_create_workspace_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + user = await UserFactory.create(role=user_role) + + response = await async_client.post( + self.url(), + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"name": "workspace"}, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 0 + + async def test_create_workspace_with_existent_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + await WorkspaceFactory.create(name="workspace") + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"name": "workspace"}, + ) + + assert response.status_code == 409 + assert response.json() == {"detail": "Workspace name `workspace` is not unique"} + + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 1 + + async def test_create_workspace_with_invalid_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"name": "invalid name"}, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 0 + + async def test_create_workspace_with_invalid_min_length_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"name": ""}, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 0 diff --git a/tests/unit/api/v1/workspaces/test_create_workspace_user.py b/tests/unit/api/v1/workspaces/test_create_workspace_user.py new file mode 100644 index 00000000..07bd2b6b --- /dev/null +++ b/tests/unit/api/v1/workspaces/test_create_workspace_user.py @@ -0,0 +1,140 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import WorkspaceUser +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory, WorkspaceFactory, WorkspaceUserFactory + + +@pytest.mark.asyncio +class TestCreateWorkspaceUser: + def url(self, workspace_id: UUID) -> str: + return f"/api/v1/workspaces/{workspace_id}/users" + + async def test_create_workspace_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + + response = await async_client.post( + self.url(workspace.id), + headers=owner_auth_header, + json={"user_id": str(user.id)}, + ) + + assert response.status_code == 201 + assert response.json() == { + "id": str(user.id), + "first_name": user.first_name, + "last_name": user.last_name, + "username": user.username, + "role": UserRole.annotator, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 + assert ( + await db.execute(select(WorkspaceUser).filter_by(workspace_id=workspace.id, user_id=user.id)) + ).scalar_one() + + async def test_create_workspace_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + + response = await async_client.post( + self.url(workspace.id), + json={"user_id": str(user.id)}, + ) + + assert response.status_code == 401 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_create_workspace_user_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create(role=user_role) + + response = await async_client.post( + self.url(workspace.id), + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"user_id": str(user.id)}, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + async def test_create_workspace_user_with_nonexistent_workspace_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace_id = uuid4() + user = await UserFactory.create() + + response = await async_client.post( + self.url(workspace_id), + headers=owner_auth_header, + json={"user_id": str(user.id)}, + ) + + assert response.status_code == 404 + assert response.json() == {"detail": f"Workspace with id `{workspace_id}` not found"} + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + async def test_create_workspace_user_with_nonexistent_user_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + user_id = uuid4() + + response = await async_client.post( + self.url(workspace.id), + headers=owner_auth_header, + json={"user_id": str(user_id)}, + ) + + assert response.status_code == 422 + assert response.json() == {"detail": f"User with id `{user_id}` not found"} + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + async def test_create_workspace_user_with_existent_workspace_id_and_user_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + response = await async_client.post( + self.url(workspace.id), + headers=owner_auth_header, + json={"user_id": str(user.id)}, + ) + + assert response.status_code == 409 + assert response.json() == { + "detail": f"Workspace user with workspace_id `{workspace.id}` and user_id `{user.id}` is not unique", + } + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 diff --git a/tests/unit/api/v1/workspaces/test_delete_workspace_user.py b/tests/unit/api/v1/workspaces/test_delete_workspace_user.py new file mode 100644 index 00000000..f0a78d6e --- /dev/null +++ b/tests/unit/api/v1/workspaces/test_delete_workspace_user.py @@ -0,0 +1,150 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import WorkspaceUser +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import AdminFactory, AnnotatorFactory, UserFactory, WorkspaceFactory, WorkspaceUserFactory + + +@pytest.mark.asyncio +class TestDeleteWorkspaceUser: + def url(self, workspace_id: UUID, user_id: UUID) -> str: + return f"/api/v1/workspaces/{workspace_id}/users/{user_id}" + + async def test_delete_workspace_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + response = await async_client.delete( + self.url(workspace.id, user.id), + headers=owner_auth_header, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "first_name": user.first_name, + "last_name": user.last_name, + "username": user.username, + "role": UserRole.annotator, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + async def test_delete_workspace_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + response = await async_client.delete(self.url(workspace.id, user.id)) + + assert response.status_code == 401 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 + + async def test_delete_workspace_user_as_admin(self, db: AsyncSession, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + admin = await AdminFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=admin.id) + + response = await async_client.delete( + self.url(workspace.id, admin.id), + headers={API_KEY_HEADER_NAME: admin.api_key}, + ) + + assert response.status_code == 200 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 0 + + async def test_delete_workspace_user_as_admin_from_different_workspace( + self, db: AsyncSession, async_client: AsyncClient + ): + workspace = await WorkspaceFactory.create() + user = await AdminFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + other_workspace = await WorkspaceFactory.create() + admin = await AdminFactory.create() + await WorkspaceUserFactory.create(workspace_id=other_workspace.id, user_id=admin.id) + + response = await async_client.delete( + self.url(workspace.id, user.id), + headers={API_KEY_HEADER_NAME: admin.api_key}, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 2 + + async def test_delete_workspace_user_as_annotator(self, db: AsyncSession, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + annotator = await AnnotatorFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=annotator.id) + + response = await async_client.delete( + self.url(workspace.id, annotator.id), + headers={API_KEY_HEADER_NAME: annotator.api_key}, + ) + + assert response.status_code == 403 + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 + + async def test_delete_workspace_user_with_nonexistent_workspace_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + non_existent_workspace_id = uuid4() + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + response = await async_client.delete( + self.url(non_existent_workspace_id, user.id), + headers=owner_auth_header, + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": f"User with id `{user.id}` not found in workspace with id `{non_existent_workspace_id}`" + } + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 + + async def test_delete_workspace_user_with_nonexistent_user_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + non_existent_user_id = uuid4() + workspace = await WorkspaceFactory.create() + user = await UserFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) + + response = await async_client.delete( + self.url(workspace.id, non_existent_user_id), + headers=owner_auth_header, + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": f"User with id `{non_existent_user_id}` not found in workspace with id `{workspace.id}`" + } + + assert (await db.execute(select(func.count(WorkspaceUser.id)))).scalar() == 1 diff --git a/tests/unit/api/v1/workspaces/test_list_workspace_users.py b/tests/unit/api/v1/workspaces/test_list_workspace_users.py new file mode 100644 index 00000000..95de444e --- /dev/null +++ b/tests/unit/api/v1/workspaces/test_list_workspace_users.py @@ -0,0 +1,130 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from httpx import AsyncClient + +from tests.factories import AdminFactory, AnnotatorFactory, UserFactory, WorkspaceFactory, WorkspaceUserFactory + + +@pytest.mark.asyncio +class TestListWorkspaceUsers: + def url(self, workspace_id: UUID) -> str: + return f"/api/v1/workspaces/{workspace_id}/users" + + async def test_list_workspace_users(self, async_client: AsyncClient, owner_auth_header: dict): + workspace = await WorkspaceFactory.create() + users = await UserFactory.create_batch(3) + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[0].id) + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[1].id) + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[2].id) + + other_workspace = await WorkspaceFactory.create() + other_users = await UserFactory.create_batch(2) + await WorkspaceUserFactory.create(workspace_id=other_workspace.id, user_id=other_users[0].id) + await WorkspaceUserFactory.create(workspace_id=other_workspace.id, user_id=other_users[1].id) + + response = await async_client.get(self.url(workspace.id), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(users[0].id), + "first_name": users[0].first_name, + "last_name": users[0].last_name, + "username": users[0].username, + "role": UserRole.annotator, + "api_key": users[0].api_key, + "inserted_at": users[0].inserted_at.isoformat(), + "updated_at": users[0].updated_at.isoformat(), + }, + { + "id": str(users[1].id), + "first_name": users[1].first_name, + "last_name": users[1].last_name, + "username": users[1].username, + "role": UserRole.annotator, + "api_key": users[1].api_key, + "inserted_at": users[1].inserted_at.isoformat(), + "updated_at": users[1].updated_at.isoformat(), + }, + { + "id": str(users[2].id), + "first_name": users[2].first_name, + "last_name": users[2].last_name, + "username": users[2].username, + "role": UserRole.annotator, + "api_key": users[2].api_key, + "inserted_at": users[2].inserted_at.isoformat(), + "updated_at": users[2].updated_at.isoformat(), + }, + ], + } + + async def test_list_workspace_users_without_authentication(self, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + + response = await async_client.get(self.url(workspace.id)) + + assert response.status_code == 401 + + async def test_list_workspace_users_as_admin(self, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + admin = await AdminFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=admin.id) + + response = await async_client.get( + self.url(workspace.id), + headers={API_KEY_HEADER_NAME: admin.api_key}, + ) + + assert response.status_code == 200 + + async def test_list_workspace_users_as_admin_from_different_workspace(self, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + admin = await AdminFactory.create() + + response = await async_client.get( + self.url(workspace.id), + headers={API_KEY_HEADER_NAME: admin.api_key}, + ) + + assert response.status_code == 403 + + async def test_list_workspace_users_as_annotator(self, async_client: AsyncClient): + workspace = await WorkspaceFactory.create() + annotator = await AnnotatorFactory.create() + await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=annotator.id) + + response = await async_client.get( + self.url(workspace.id), + headers={API_KEY_HEADER_NAME: annotator.api_key}, + ) + + assert response.status_code == 403 + + async def test_list_workspace_with_nonexistent_workspace_id( + self, async_client: AsyncClient, owner_auth_header: dict + ): + workspace_id = uuid4() + + response = await async_client.get(self.url(workspace_id), headers=owner_auth_header) + + assert response.status_code == 404 + assert response.json() == {"detail": f"Workspace with id `{workspace_id}` not found"}