diff --git a/schemas/openapi.json b/schemas/openapi.json index 58fbddeac9..2ea71f4091 100644 --- a/schemas/openapi.json +++ b/schemas/openapi.json @@ -475,7 +475,6 @@ "/v1/datasets/{id}/examples": { "get": { "tags": [ - "datasets", "datasets" ], "summary": "Get examples from a dataset", diff --git a/src/phoenix/server/api/routers/v1/dataset_examples.py b/src/phoenix/server/api/routers/v1/dataset_examples.py deleted file mode 100644 index 6cc4457911..0000000000 --- a/src/phoenix/server/api/routers/v1/dataset_examples.py +++ /dev/null @@ -1,157 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, HTTPException, Path, Query -from sqlalchemy import and_, func, select -from starlette.requests import Request -from starlette.status import HTTP_404_NOT_FOUND -from strawberry.relay import GlobalID - -from phoenix.db.models import ( - Dataset as ORMDataset, -) -from phoenix.db.models import ( - DatasetExample as ORMDatasetExample, -) -from phoenix.db.models import ( - DatasetExampleRevision as ORMDatasetExampleRevision, -) -from phoenix.db.models import ( - DatasetVersion as ORMDatasetVersion, -) - -from .pydantic_compat import V1RoutesBaseModel -from .utils import ResponseBody, add_errors_to_responses - -router = APIRouter(tags=["datasets"]) - - -class DatasetExample(V1RoutesBaseModel): - id: str - input: Dict[str, Any] - output: Dict[str, Any] - metadata: Dict[str, Any] - updated_at: datetime - - -class ListDatasetExamplesData(V1RoutesBaseModel): - dataset_id: str - version_id: str - examples: List[DatasetExample] - - -class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]): - pass - - -@router.get( - "/datasets/{id}/examples", - operation_id="getDatasetExamples", - summary="Get examples from a dataset", - responses=add_errors_to_responses([HTTP_404_NOT_FOUND]), -) -async def get_dataset_examples( - request: Request, - id: str = Path(description="The ID of the dataset"), - version_id: Optional[str] = Query( - default=None, - description=( - "The ID of the dataset version " "(if omitted, returns data from the latest version)" - ), - ), -) -> ListDatasetExamplesResponseBody: - dataset_gid = GlobalID.from_id(id) - version_gid = GlobalID.from_id(version_id) if version_id else None - - if (dataset_type := dataset_gid.type_name) != "Dataset": - raise HTTPException( - detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND - ) - - if version_gid and (version_type := version_gid.type_name) != "DatasetVersion": - raise HTTPException( - detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND - ) - - async with request.app.state.db() as session: - if ( - resolved_dataset_id := await session.scalar( - select(ORMDataset.id).where(ORMDataset.id == int(dataset_gid.node_id)) - ) - ) is None: - raise HTTPException( - detail=f"No dataset with id {dataset_gid} can be found.", - status_code=HTTP_404_NOT_FOUND, - ) - - # Subquery to find the maximum created_at for each dataset_example_id - # timestamp tiebreaks are resolved by the largest id - partial_subquery = select( - func.max(ORMDatasetExampleRevision.id).label("max_id"), - ).group_by(ORMDatasetExampleRevision.dataset_example_id) - - if version_gid: - if ( - resolved_version_id := await session.scalar( - select(ORMDatasetVersion.id).where( - and_( - ORMDatasetVersion.dataset_id == resolved_dataset_id, - ORMDatasetVersion.id == int(version_gid.node_id), - ) - ) - ) - ) is None: - raise HTTPException( - detail=f"No dataset version with id {version_id} can be found.", - status_code=HTTP_404_NOT_FOUND, - ) - # if a version_id is provided, filter the subquery to only include revisions from that - partial_subquery = partial_subquery.filter( - ORMDatasetExampleRevision.dataset_version_id <= resolved_version_id - ) - else: - if ( - resolved_version_id := await session.scalar( - select(func.max(ORMDatasetVersion.id)).where( - ORMDatasetVersion.dataset_id == resolved_dataset_id - ) - ) - ) is None: - raise HTTPException( - detail="Dataset has no versions.", - status_code=HTTP_404_NOT_FOUND, - ) - - subquery = partial_subquery.subquery() - # Query for the most recent example revisions that are not deleted - query = ( - select(ORMDatasetExample, ORMDatasetExampleRevision) - .join( - ORMDatasetExampleRevision, - ORMDatasetExample.id == ORMDatasetExampleRevision.dataset_example_id, - ) - .join( - subquery, - (subquery.c.max_id == ORMDatasetExampleRevision.id), - ) - .filter(ORMDatasetExample.dataset_id == resolved_dataset_id) - .filter(ORMDatasetExampleRevision.revision_kind != "DELETE") - .order_by(ORMDatasetExample.id.asc()) - ) - examples = [ - DatasetExample( - id=str(GlobalID("DatasetExample", str(example.id))), - input=revision.input, - output=revision.output, - metadata=revision.metadata_, - updated_at=revision.created_at, - ) - async for example, revision in await session.stream(query) - ] - return ListDatasetExamplesResponseBody( - data=ListDatasetExamplesData( - dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))), - version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))), - examples=examples, - ) - ) diff --git a/src/phoenix/server/api/routers/v1/datasets.py b/src/phoenix/server/api/routers/v1/datasets.py index 265097bd69..59d9d5ad70 100644 --- a/src/phoenix/server/api/routers/v1/datasets.py +++ b/src/phoenix/server/api/routers/v1/datasets.py @@ -56,12 +56,11 @@ add_dataset_examples, ) from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType -from phoenix.server.api.types.DatasetExample import DatasetExample +from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.api.utils import delete_projects, delete_traces -from .dataset_examples import router as dataset_examples_router from .pydantic_compat import V1RoutesBaseModel from .utils import ( PaginatedResponseBody, @@ -669,12 +668,135 @@ async def _parse_form_data( ) -# including the dataset examples router here ensures the dataset example routes -# are included in a natural order in the openapi schema and the swagger ui -# -# todo: move the dataset examples routes here and remove the dataset_examples -# sub-module -router.include_router(dataset_examples_router) +class DatasetExample(V1RoutesBaseModel): + id: str + input: Dict[str, Any] + output: Dict[str, Any] + metadata: Dict[str, Any] + updated_at: datetime + + +class ListDatasetExamplesData(V1RoutesBaseModel): + dataset_id: str + version_id: str + examples: List[DatasetExample] + + +class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]): + pass + + +@router.get( + "/datasets/{id}/examples", + operation_id="getDatasetExamples", + summary="Get examples from a dataset", + responses=add_errors_to_responses([HTTP_404_NOT_FOUND]), +) +async def get_dataset_examples( + request: Request, + id: str = Path(description="The ID of the dataset"), + version_id: Optional[str] = Query( + default=None, + description=( + "The ID of the dataset version " "(if omitted, returns data from the latest version)" + ), + ), +) -> ListDatasetExamplesResponseBody: + dataset_gid = GlobalID.from_id(id) + version_gid = GlobalID.from_id(version_id) if version_id else None + + if (dataset_type := dataset_gid.type_name) != "Dataset": + raise HTTPException( + detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND + ) + + if version_gid and (version_type := version_gid.type_name) != "DatasetVersion": + raise HTTPException( + detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND + ) + + async with request.app.state.db() as session: + if ( + resolved_dataset_id := await session.scalar( + select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id)) + ) + ) is None: + raise HTTPException( + detail=f"No dataset with id {dataset_gid} can be found.", + status_code=HTTP_404_NOT_FOUND, + ) + + # Subquery to find the maximum created_at for each dataset_example_id + # timestamp tiebreaks are resolved by the largest id + partial_subquery = select( + func.max(models.DatasetExampleRevision.id).label("max_id"), + ).group_by(models.DatasetExampleRevision.dataset_example_id) + + if version_gid: + if ( + resolved_version_id := await session.scalar( + select(models.DatasetVersion.id).where( + and_( + models.DatasetVersion.dataset_id == resolved_dataset_id, + models.DatasetVersion.id == int(version_gid.node_id), + ) + ) + ) + ) is None: + raise HTTPException( + detail=f"No dataset version with id {version_id} can be found.", + status_code=HTTP_404_NOT_FOUND, + ) + # if a version_id is provided, filter the subquery to only include revisions from that + partial_subquery = partial_subquery.filter( + models.DatasetExampleRevision.dataset_version_id <= resolved_version_id + ) + else: + if ( + resolved_version_id := await session.scalar( + select(func.max(models.DatasetVersion.id)).where( + models.DatasetVersion.dataset_id == resolved_dataset_id + ) + ) + ) is None: + raise HTTPException( + detail="Dataset has no versions.", + status_code=HTTP_404_NOT_FOUND, + ) + + subquery = partial_subquery.subquery() + # Query for the most recent example revisions that are not deleted + query = ( + select(models.DatasetExample, models.DatasetExampleRevision) + .join( + models.DatasetExampleRevision, + models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id, + ) + .join( + subquery, + (subquery.c.max_id == models.DatasetExampleRevision.id), + ) + .filter(models.DatasetExample.dataset_id == resolved_dataset_id) + .filter(models.DatasetExampleRevision.revision_kind != "DELETE") + .order_by(models.DatasetExample.id.asc()) + ) + examples = [ + DatasetExample( + id=str(GlobalID("DatasetExample", str(example.id))), + input=revision.input, + output=revision.output, + metadata=revision.metadata_, + updated_at=revision.created_at, + ) + async for example, revision in await session.stream(query) + ] + return ListDatasetExamplesResponseBody( + data=ListDatasetExamplesData( + dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))), + version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))), + examples=examples, + ) + ) @router.get( @@ -794,7 +916,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes: records = [ { "example_id": GlobalID( - type_name=DatasetExample.__name__, + type_name=DatasetExampleNodeType.__name__, node_id=str(ex.dataset_example_id), ), **{f"input_{k}": v for k, v in ex.input.items()}, diff --git a/tests/server/api/routers/v1/test_dataset_examples.py b/tests/server/api/routers/v1/test_dataset_examples.py deleted file mode 100644 index 86a15d8312..0000000000 --- a/tests/server/api/routers/v1/test_dataset_examples.py +++ /dev/null @@ -1,220 +0,0 @@ -from strawberry.relay import GlobalID - - -async def test_get_dataset_examples_404s_with_nonexistent_dataset_id(test_client): - global_id = GlobalID("Dataset", str(0)) - response = await test_client.get(f"/v1/datasets/{global_id}/examples") - assert response.status_code == 404 - assert response.content.decode() == f"No dataset with id {global_id} can be found." - - -async def test_get_dataset_examples_404s_with_invalid_global_id(test_client, simple_dataset): - global_id = GlobalID("InvalidDataset", str(0)) - response = await test_client.get(f"/v1/datasets/{global_id}/examples") - assert response.status_code == 404 - assert "refers to a InvalidDataset" in response.content.decode() - - -async def test_get_dataset_examples_404s_with_nonexistent_version_id(test_client, simple_dataset): - global_id = GlobalID("Dataset", str(0)) - version_id = GlobalID("DatasetVersion", str(99)) - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)} - ) - assert response.status_code == 404 - assert response.content.decode() == f"No dataset version with id {version_id} can be found." - - -async def test_get_dataset_examples_404s_with_invalid_version_global_id( - test_client, simple_dataset -): - global_id = GlobalID("Dataset", str(0)) - version_id = GlobalID("InvalidDatasetVersion", str(0)) - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)} - ) - assert response.status_code == 404 - assert "refers to a InvalidDatasetVersion" in response.content.decode() - - -async def test_get_simple_dataset_examples(test_client, simple_dataset): - global_id = GlobalID("Dataset", str(0)) - response = await test_client.get(f"/v1/datasets/{global_id}/examples") - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert data["dataset_id"] == str(GlobalID("Dataset", str(0))) - assert data["version_id"] == str(GlobalID("DatasetVersion", str(0))) - examples = data["examples"] - assert len(examples) == 1 - expected_examples = [ - { - "id": str(GlobalID("DatasetExample", str(0))), - "input": {"in": "foo"}, - "output": {"out": "bar"}, - "metadata": {"info": "the first reivision"}, - } - ] - for example, expected in zip(examples, expected_examples): - assert "updated_at" in example - example_subset = {k: v for k, v in example.items() if k in expected} - assert example_subset == expected - - -async def test_list_simple_dataset_examples_at_each_version(test_client, simple_dataset): - global_id = GlobalID("Dataset", str(0)) - v0 = GlobalID("DatasetVersion", str(0)) - - # one example is created in version 0 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v0)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 1 - - -async def test_list_empty_dataset_examples(test_client, empty_dataset): - global_id = GlobalID("Dataset", str(1)) - response = await test_client.get(f"/v1/datasets/{global_id}/examples") - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 0 - - -async def test_list_empty_dataset_examples_at_each_version(test_client, empty_dataset): - global_id = GlobalID("Dataset", str(1)) - v1 = GlobalID("DatasetVersion", str(1)) - v2 = GlobalID("DatasetVersion", str(2)) - v3 = GlobalID("DatasetVersion", str(3)) - - # two examples are created in version 1 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v1)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 2 - - # two examples are patched in version 2 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v2)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 2 - - # two examples are deleted in version 3 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v3)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 0 - - -async def test_list_dataset_with_revisions_examples(test_client, dataset_with_revisions): - global_id = GlobalID("Dataset", str(2)) - response = await test_client.get(f"/v1/datasets/{global_id}/examples") - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert data["dataset_id"] == str(GlobalID("Dataset", str(2))) - assert data["version_id"] == str(GlobalID("DatasetVersion", str(9))) - examples = data["examples"] - assert len(examples) == 3 - expected_values = [ - { - "id": str(GlobalID("DatasetExample", str(3))), - "input": {"in": "foo"}, - "output": {"out": "bar"}, - "metadata": {"info": "first revision"}, - }, - { - "id": str(GlobalID("DatasetExample", str(4))), - "input": {"in": "updated foofoo"}, - "output": {"out": "updated barbar"}, - "metadata": {"info": "updating revision"}, - }, - { - "id": str(GlobalID("DatasetExample", str(5))), - "input": {"in": "look at me"}, - "output": {"out": "i have all the answers"}, - "metadata": {"info": "a new example"}, - }, - ] - for example, expected in zip(examples, expected_values): - assert "updated_at" in example - example_subset = {k: v for k, v in example.items() if k in expected} - assert example_subset == expected - - -async def test_list_dataset_with_revisions_examples_at_each_version( - test_client, dataset_with_revisions -): - global_id = GlobalID("Dataset", str(2)) - v4 = GlobalID("DatasetVersion", str(4)) - v5 = GlobalID("DatasetVersion", str(5)) - v6 = GlobalID("DatasetVersion", str(6)) - v7 = GlobalID("DatasetVersion", str(7)) - v8 = GlobalID("DatasetVersion", str(8)) - v9 = GlobalID("DatasetVersion", str(9)) - - # two examples are created in version 4 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v4)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 2 - - # two examples are patched in version 5 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v5)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 3 - - # one example is added in version 6 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v6)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 4 - - # one example is deleted in version 7 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v7)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 3 - - # one example is added in version 8 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v8)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 4 - - # one example is deleted in version 9 - response = await test_client.get( - f"/v1/datasets/{global_id}/examples", params={"version_id": str(v9)} - ) - assert response.status_code == 200 - result = response.json() - data = result["data"] - assert len(data["examples"]) == 3 diff --git a/tests/server/api/routers/v1/test_datasets.py b/tests/server/api/routers/v1/test_datasets.py index 060bc3abaa..eba74eb9ef 100644 --- a/tests/server/api/routers/v1/test_datasets.py +++ b/tests/server/api/routers/v1/test_datasets.py @@ -546,3 +546,222 @@ async def test_delete_dataset(test_client, empty_dataset) -> None: assert len((await test_client.get("v1/datasets")).json()["data"]) == 0 with pytest.raises(HTTPStatusError): (await test_client.delete(url)).raise_for_status() + + +async def test_get_dataset_examples_404s_with_nonexistent_dataset_id(test_client): + global_id = GlobalID("Dataset", str(0)) + response = await test_client.get(f"/v1/datasets/{global_id}/examples") + assert response.status_code == 404 + assert response.content.decode() == f"No dataset with id {global_id} can be found." + + +async def test_get_dataset_examples_404s_with_invalid_global_id(test_client, simple_dataset): + global_id = GlobalID("InvalidDataset", str(0)) + response = await test_client.get(f"/v1/datasets/{global_id}/examples") + assert response.status_code == 404 + assert "refers to a InvalidDataset" in response.content.decode() + + +async def test_get_dataset_examples_404s_with_nonexistent_version_id(test_client, simple_dataset): + global_id = GlobalID("Dataset", str(0)) + version_id = GlobalID("DatasetVersion", str(99)) + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)} + ) + assert response.status_code == 404 + assert response.content.decode() == f"No dataset version with id {version_id} can be found." + + +async def test_get_dataset_examples_404s_with_invalid_version_global_id( + test_client, simple_dataset +): + global_id = GlobalID("Dataset", str(0)) + version_id = GlobalID("InvalidDatasetVersion", str(0)) + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)} + ) + assert response.status_code == 404 + assert "refers to a InvalidDatasetVersion" in response.content.decode() + + +async def test_get_simple_dataset_examples(test_client, simple_dataset): + global_id = GlobalID("Dataset", str(0)) + response = await test_client.get(f"/v1/datasets/{global_id}/examples") + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert data["dataset_id"] == str(GlobalID("Dataset", str(0))) + assert data["version_id"] == str(GlobalID("DatasetVersion", str(0))) + examples = data["examples"] + assert len(examples) == 1 + expected_examples = [ + { + "id": str(GlobalID("DatasetExample", str(0))), + "input": {"in": "foo"}, + "output": {"out": "bar"}, + "metadata": {"info": "the first reivision"}, + } + ] + for example, expected in zip(examples, expected_examples): + assert "updated_at" in example + example_subset = {k: v for k, v in example.items() if k in expected} + assert example_subset == expected + + +async def test_list_simple_dataset_examples_at_each_version(test_client, simple_dataset): + global_id = GlobalID("Dataset", str(0)) + v0 = GlobalID("DatasetVersion", str(0)) + + # one example is created in version 0 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v0)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 1 + + +async def test_list_empty_dataset_examples(test_client, empty_dataset): + global_id = GlobalID("Dataset", str(1)) + response = await test_client.get(f"/v1/datasets/{global_id}/examples") + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 0 + + +async def test_list_empty_dataset_examples_at_each_version(test_client, empty_dataset): + global_id = GlobalID("Dataset", str(1)) + v1 = GlobalID("DatasetVersion", str(1)) + v2 = GlobalID("DatasetVersion", str(2)) + v3 = GlobalID("DatasetVersion", str(3)) + + # two examples are created in version 1 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v1)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 2 + + # two examples are patched in version 2 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v2)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 2 + + # two examples are deleted in version 3 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v3)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 0 + + +async def test_list_dataset_with_revisions_examples(test_client, dataset_with_revisions): + global_id = GlobalID("Dataset", str(2)) + response = await test_client.get(f"/v1/datasets/{global_id}/examples") + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert data["dataset_id"] == str(GlobalID("Dataset", str(2))) + assert data["version_id"] == str(GlobalID("DatasetVersion", str(9))) + examples = data["examples"] + assert len(examples) == 3 + expected_values = [ + { + "id": str(GlobalID("DatasetExample", str(3))), + "input": {"in": "foo"}, + "output": {"out": "bar"}, + "metadata": {"info": "first revision"}, + }, + { + "id": str(GlobalID("DatasetExample", str(4))), + "input": {"in": "updated foofoo"}, + "output": {"out": "updated barbar"}, + "metadata": {"info": "updating revision"}, + }, + { + "id": str(GlobalID("DatasetExample", str(5))), + "input": {"in": "look at me"}, + "output": {"out": "i have all the answers"}, + "metadata": {"info": "a new example"}, + }, + ] + for example, expected in zip(examples, expected_values): + assert "updated_at" in example + example_subset = {k: v for k, v in example.items() if k in expected} + assert example_subset == expected + + +async def test_list_dataset_with_revisions_examples_at_each_version( + test_client, dataset_with_revisions +): + global_id = GlobalID("Dataset", str(2)) + v4 = GlobalID("DatasetVersion", str(4)) + v5 = GlobalID("DatasetVersion", str(5)) + v6 = GlobalID("DatasetVersion", str(6)) + v7 = GlobalID("DatasetVersion", str(7)) + v8 = GlobalID("DatasetVersion", str(8)) + v9 = GlobalID("DatasetVersion", str(9)) + + # two examples are created in version 4 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v4)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 2 + + # two examples are patched in version 5 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v5)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 3 + + # one example is added in version 6 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v6)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 4 + + # one example is deleted in version 7 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v7)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 3 + + # one example is added in version 8 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v8)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 4 + + # one example is deleted in version 9 + response = await test_client.get( + f"/v1/datasets/{global_id}/examples", params={"version_id": str(v9)} + ) + assert response.status_code == 200 + result = response.json() + data = result["data"] + assert len(data["examples"]) == 3