From f5c433cda946b8ebbfafad54539ec22d1327e3ce Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Sat, 22 Apr 2023 21:12:11 +0200 Subject: [PATCH 1/6] Add `Response.status` column --- .../versions/e402e9d9245e_create_responses_table.py | 1 + src/argilla/server/models.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/argilla/server/alembic/versions/e402e9d9245e_create_responses_table.py b/src/argilla/server/alembic/versions/e402e9d9245e_create_responses_table.py index 97beb5e1ff..b9339a40e0 100644 --- a/src/argilla/server/alembic/versions/e402e9d9245e_create_responses_table.py +++ b/src/argilla/server/alembic/versions/e402e9d9245e_create_responses_table.py @@ -36,6 +36,7 @@ def upgrade() -> None: sa.Column("values", sa.JSON, nullable=False), sa.Column("record_id", sa.Uuid, sa.ForeignKey("records.id", ondelete="CASCADE"), nullable=False, index=True), sa.Column("user_id", sa.Uuid, sa.ForeignKey("users.id", ondelete="SET NULL"), index=True), + sa.Column("status", sa.String, nullable=False, index=True), sa.Column("inserted_at", sa.DateTime, nullable=False), sa.Column("updated_at", sa.DateTime, nullable=False), sa.UniqueConstraint("record_id", "user_id", name="response_record_id_user_id_uq"), diff --git a/src/argilla/server/models.py b/src/argilla/server/models.py index a487385ef7..42b9536d27 100644 --- a/src/argilla/server/models.py +++ b/src/argilla/server/models.py @@ -49,6 +49,12 @@ class UserRole(str, Enum): annotator = "annotator" +class ResponseStatus(str, Enum): + pending = "pending" + submitted = "submitted" + discarded = "discarded" + + class Annotation(Base): __tablename__ = "annotations" @@ -79,6 +85,7 @@ class Response(Base): values: Mapped[dict] = mapped_column(JSON) record_id: Mapped[UUID] = mapped_column(ForeignKey("records.id")) user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id")) + status: Mapped[ResponseStatus] = mapped_column(default=ResponseStatus.pending, index=True, nullable=False) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=default_inserted_at, onupdate=datetime.utcnow) @@ -89,7 +96,7 @@ class Response(Base): def __repr__(self): return ( f"Response(id={str(self.id)!r}, record_id={str(self.record_id)!r}, user_id={str(self.user_id)!r}, " - f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" + f"status={self.status!r}, inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) From 18951a8c4c6cedeb946b70bee5d11d239711183f Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Sat, 22 Apr 2023 22:17:00 +0200 Subject: [PATCH 2/6] Add `PUT /api/v1/responses/{response_id}/status` endpoint --- .../server/apis/v1/handlers/responses.py | 21 ++++++++++++++++++- src/argilla/server/contexts/datasets.py | 11 +++++++++- src/argilla/server/schemas/v1/responses.py | 7 +++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/responses.py b/src/argilla/server/apis/v1/handlers/responses.py index 3bee332d23..c35956f9ac 100644 --- a/src/argilla/server/apis/v1/handlers/responses.py +++ b/src/argilla/server/apis/v1/handlers/responses.py @@ -21,7 +21,11 @@ from argilla.server.database import get_db from argilla.server.models import User from argilla.server.policies import ResponsePolicyV1, authorize -from argilla.server.schemas.v1.responses import Response, ResponseUpdate +from argilla.server.schemas.v1.responses import ( + Response, + ResponseStatusUpdate, + ResponseUpdate, +) from argilla.server.security import auth router = APIRouter(tags=["responses"]) @@ -53,6 +57,21 @@ def update_response( return datasets.update_response(db, response, response_update) +@router.put("/responses/{response_id}/status", response_model=Response) +def update_response_status( + *, + db: Session = Depends(get_db), + response_id: UUID, + response_status_update: ResponseStatusUpdate, + current_user: User = Security(auth.get_current_user), +): + response = _get_response(db, response_id) + + authorize(current_user, ResponsePolicyV1.update(response)) + + return datasets.update_response_status(db, response, response_status_update) + + @router.delete("/responses/{response_id}", response_model=Response) def delete_response( *, diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 4885795a1e..9992752f28 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -24,7 +24,7 @@ RecordsCreate, ) from argilla.server.schemas.v1.records import ResponseCreate -from argilla.server.schemas.v1.responses import ResponseUpdate +from argilla.server.schemas.v1.responses import ResponseStatusUpdate, ResponseUpdate from argilla.server.security.model import User from sqlalchemy import and_, func from sqlalchemy.orm import Session, contains_eager, joinedload @@ -230,6 +230,15 @@ def update_response(db: Session, response: Response, response_update: ResponseUp return response +def update_response_status(db: Session, response: Response, response_status_update: ResponseStatusUpdate): + response.status = response_status_update.status + + db.commit() + db.refresh(response) + + return response + + def delete_response(db: Session, response: Response): db.delete(response) db.commit() diff --git a/src/argilla/server/schemas/v1/responses.py b/src/argilla/server/schemas/v1/responses.py index 851a424857..f7149e454a 100644 --- a/src/argilla/server/schemas/v1/responses.py +++ b/src/argilla/server/schemas/v1/responses.py @@ -18,12 +18,15 @@ from pydantic import BaseModel +from argilla.server.models import ResponseStatus + class Response(BaseModel): id: UUID values: Dict[str, Any] record_id: UUID user_id: UUID + status: ResponseStatus inserted_at: datetime updated_at: datetime @@ -33,3 +36,7 @@ class Config: class ResponseUpdate(BaseModel): values: Dict[str, Any] + + +class ResponseStatusUpdate(BaseModel): + status: ResponseStatus From d56b875856f5e00c97ee2d94eb711115e35f8f5d Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Sun, 23 Apr 2023 10:40:01 +0200 Subject: [PATCH 3/6] Add unit tests for `PUT /api/v1/response/{response_id}/status` endpoint --- tests/server/api/v1/test_responses.py | 56 ++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/server/api/v1/test_responses.py b/tests/server/api/v1/test_responses.py index 37d18a5b03..3925a18ba7 100644 --- a/tests/server/api/v1/test_responses.py +++ b/tests/server/api/v1/test_responses.py @@ -16,7 +16,7 @@ from uuid import uuid4 from argilla._constants import API_KEY_HEADER_NAME -from argilla.server.models import Response +from argilla.server.models import Response, ResponseStatus from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -54,6 +54,7 @@ def test_update_response(client: TestClient, db: Session, admin_auth_header: dic }, "record_id": str(response.record_id), "user_id": str(response.user_id), + "status": ResponseStatus.pending, "inserted_at": response.inserted_at.isoformat(), "updated_at": datetime.fromisoformat(resp_body["updated_at"]).isoformat(), } @@ -117,6 +118,7 @@ def test_update_response_as_annotator(client: TestClient, db: Session): }, "record_id": str(response.record_id), "user_id": str(response.user_id), + "status": ResponseStatus.pending, "inserted_at": response.inserted_at.isoformat(), "updated_at": datetime.fromisoformat(resp_body["updated_at"]).isoformat(), } @@ -171,6 +173,58 @@ def test_update_response_with_nonexistent_response_id(client: TestClient, db: Se } +def test_update_response_status(client: TestClient, db: Session, admin_auth_header: dict): + response = ResponseFactory.create() + response_json = {"status": "submitted"} + + resp = client.put(f"/api/v1/responses/{response.id}/status", headers=admin_auth_header, json=response_json) + + assert resp.status_code == 200 + assert db.get(Response, response.id).status == "submitted" + + +def test_update_response_status_without_authentication(client: TestClient, db: Session): + response = ResponseFactory.create() + response_json = {"status": "submitted"} + + resp = client.put(f"/api/v1/responses/{response.id}/status", json=response_json) + + assert resp.status_code == 401 + assert db.get(Response, response.id).status == "pending" + + +def test_update_response_status_as_annotator(client: TestClient, db: Session): + annotator = AnnotatorFactory.create() + response = ResponseFactory.create(user=annotator) + response_json = {"status": "submitted"} + + resp = client.put( + f"/api/v1/responses/{response.id}/status", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=response_json + ) + + assert resp.status_code == 200 + assert db.get(Response, response.id).status == "submitted" + + +def test_update_response_status_as_annotator_for_different_user_response(client: TestClient, db: Session): + annotator = AnnotatorFactory.create() + response = ResponseFactory.create() + response_json = {"status": "submitted"} + + resp = client.put( + f"/api/v1/responses/{response.id}/status", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=response_json + ) + + assert resp.status_code == 403 + assert db.get(Response, response.id).status == "pending" + + +def test_update_response_status_with_nonexistent_response_id(client: TestClient, db: Session, admin_auth_header: dict): + resp = client.put(f"/api/v1/responses/{uuid4()}/status", headers=admin_auth_header, json={"status": "submitted"}) + + assert resp.status_code == 404 + + def test_delete_response(client: TestClient, db: Session, admin_auth_header: dict): response = ResponseFactory.create() From 2977dcaf3a63490417b9554ac4b26e8f923e2566 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 24 Apr 2023 13:38:05 +0200 Subject: [PATCH 4/6] Update to display `ResponseStatus` enum `str` value --- src/argilla/server/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/models.py b/src/argilla/server/models.py index 4eb2446f4a..8651fe1047 100644 --- a/src/argilla/server/models.py +++ b/src/argilla/server/models.py @@ -123,7 +123,7 @@ class Response(Base): def __repr__(self): return ( f"Response(id={str(self.id)!r}, record_id={str(self.record_id)!r}, user_id={str(self.user_id)!r}, " - f"status={self.status!r}, inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" + f"status={self.status}, inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) From 6ba0060734dbbd5ad0607cdfeeead14cfd223410 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 24 Apr 2023 17:35:41 +0200 Subject: [PATCH 5/6] Add setting `Response` invalid status unit test --- tests/server/api/v1/test_responses.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/server/api/v1/test_responses.py b/tests/server/api/v1/test_responses.py index 3925a18ba7..b20d30e5c0 100644 --- a/tests/server/api/v1/test_responses.py +++ b/tests/server/api/v1/test_responses.py @@ -225,6 +225,16 @@ def test_update_response_status_with_nonexistent_response_id(client: TestClient, assert resp.status_code == 404 +def test_update_response_status_with_invalid_status(client: TestClient, db: Session, admin_auth_header: dict): + response = ResponseFactory.create() + response_json = {"status": "invalid"} + + resp = client.put(f"/api/v1/responses/{response.id}/status", headers=admin_auth_header, json=response_json) + + assert resp.status_code == 422 + assert db.get(Response, response.id).status == "pending" + + def test_delete_response(client: TestClient, db: Session, admin_auth_header: dict): response = ResponseFactory.create() From 2366e8381b02f6bc8a3c5711254ff9b526a03240 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 24 Apr 2023 17:35:52 +0200 Subject: [PATCH 6/6] Update unit tests to check response body --- tests/server/api/v1/test_responses.py | 39 +++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/tests/server/api/v1/test_responses.py b/tests/server/api/v1/test_responses.py index b20d30e5c0..2e75e78279 100644 --- a/tests/server/api/v1/test_responses.py +++ b/tests/server/api/v1/test_responses.py @@ -174,7 +174,12 @@ def test_update_response_with_nonexistent_response_id(client: TestClient, db: Se def test_update_response_status(client: TestClient, db: Session, admin_auth_header: dict): - response = ResponseFactory.create() + response = ResponseFactory.create( + values={ + "input_ok": {"value": "no"}, + "output_ok": {"value": "no"}, + } + ) response_json = {"status": "submitted"} resp = client.put(f"/api/v1/responses/{response.id}/status", headers=admin_auth_header, json=response_json) @@ -182,6 +187,20 @@ def test_update_response_status(client: TestClient, db: Session, admin_auth_head assert resp.status_code == 200 assert db.get(Response, response.id).status == "submitted" + resp_body = resp.json() + assert resp_body == { + "id": str(response.id), + "values": { + "input_ok": {"value": "no"}, + "output_ok": {"value": "no"}, + }, + "record_id": str(response.record_id), + "user_id": str(response.user_id), + "status": ResponseStatus.submitted, + "inserted_at": response.inserted_at.isoformat(), + "updated_at": datetime.fromisoformat(resp_body["updated_at"]).isoformat(), + } + def test_update_response_status_without_authentication(client: TestClient, db: Session): response = ResponseFactory.create() @@ -195,7 +214,9 @@ def test_update_response_status_without_authentication(client: TestClient, db: S def test_update_response_status_as_annotator(client: TestClient, db: Session): annotator = AnnotatorFactory.create() - response = ResponseFactory.create(user=annotator) + response = ResponseFactory.create( + user=annotator, values={"input_ok": {"value": "no"}, "output_ok": {"value": "no"}} + ) response_json = {"status": "submitted"} resp = client.put( @@ -205,6 +226,20 @@ def test_update_response_status_as_annotator(client: TestClient, db: Session): assert resp.status_code == 200 assert db.get(Response, response.id).status == "submitted" + resp_body = resp.json() + assert resp_body == { + "id": str(response.id), + "values": { + "input_ok": {"value": "no"}, + "output_ok": {"value": "no"}, + }, + "record_id": str(response.record_id), + "user_id": str(response.user_id), + "status": ResponseStatus.submitted, + "inserted_at": response.inserted_at.isoformat(), + "updated_at": datetime.fromisoformat(resp_body["updated_at"]).isoformat(), + } + def test_update_response_status_as_annotator_for_different_user_response(client: TestClient, db: Session): annotator = AnnotatorFactory.create()