Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sort param to get test run by id #106

Merged
merged 6 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions arthur_bench/client/bench_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
TestSuiteRequest,
PaginatedTestSuite,
TestSuiteSummary,
TestCaseSortEnum,
CommonSortEnum,
TestSuiteSortEnum,
PaginationRunSortEnum,
PaginationSuiteSortEnum,
)

from arthur_bench.exceptions import ArthurInternalError
Expand All @@ -26,7 +31,7 @@ class BenchClient(ABC):
def get_test_suites(
self,
name: Optional[str] = None,
sort: Optional[str] = None,
sort: PaginationSuiteSortEnum = TestSuiteSortEnum.LAST_RUNTIME_ASC,
scoring_method: Optional[List[str]] = None,
page: int = 1,
page_size: int = 5,
Expand Down Expand Up @@ -72,7 +77,7 @@ def get_test_suite(
def get_runs_for_test_suite(
self,
test_suite_id: str,
sort: Optional[str] = None,
sort: PaginationRunSortEnum = CommonSortEnum.CREATED_AT_ASC,
page: int = 1,
page_size: int = 5,
) -> PaginatedRuns:
Expand Down Expand Up @@ -107,7 +112,7 @@ def get_test_run(
test_run_id: str,
page: int = 1,
page_size: int = 5,
sort: Optional[bool] = None,
sort: Optional[TestCaseSortEnum] = None,
) -> PaginatedRun:
"""
Get a test run by id.
Expand All @@ -116,6 +121,7 @@ def get_test_run(
:param test_run_id: uuid of the test run
:param page: the page to fetch, pagination refers to the test cases
:param page_size: page size to fetch, pagination refers to the test cases
:param sort: sort key to sort the retrieved results
"""
raise NotImplementedError

Expand Down
50 changes: 27 additions & 23 deletions arthur_bench/client/local/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
RunResult,
ScoringMethod,
ScorerOutputType,
PaginationSuiteSortEnum,
PaginationRunSortEnum,
PaginationSortEnum,
CommonSortEnum,
TestCaseSortEnum,
TestRunSortEnum,
TestSuiteSortEnum,
)

from arthur_bench.utils.loaders import load_suite_from_json, get_file_extension
Expand All @@ -48,20 +55,20 @@
NUM_BINS = 20

SORT_QUERY_TO_FUNC = {
"last_run_time": lambda x: x.last_run_time
TestSuiteSortEnum.LAST_RUNTIME_ASC: lambda x: x.last_run_time
if x.last_run_time is not None
else x.created_at,
"name": lambda x: x.name,
"created_at": lambda x: x.created_at,
"avg_score": lambda x: x.avg_score,
"-last_run_time": lambda x: x.last_run_time
CommonSortEnum.NAME_ASC: lambda x: x.name,
CommonSortEnum.CREATED_AT_ASC: lambda x: x.created_at,
TestRunSortEnum.AVG_SCORE_ASC: lambda x: x.avg_score,
TestSuiteSortEnum.LAST_RUNTIME_DESC: lambda x: x.last_run_time
if x.last_run_time is not None
else x.created_at,
"-name": lambda x: x.name,
"-created_at": lambda x: x.created_at,
"-avg_score": lambda x: x.avg_score,
"score": lambda x: x.score,
"id": lambda x: x.id,
CommonSortEnum.NAME_DESC: lambda x: x.name,
CommonSortEnum.CREATED_AT_DESC: lambda x: x.created_at,
TestRunSortEnum.AVG_SCORE_DESC: lambda x: x.avg_score,
TestCaseSortEnum.SCORE_ASC: lambda x: x.score,
TestCaseSortEnum.SCORE_DESC: lambda x: x.score,
}


Expand Down Expand Up @@ -146,12 +153,12 @@ class PageInfo:


def _paginate(
objs: List, page: int, page_size: int, sort_key: Optional[str] = None
objs: List, page: int, page_size: int, sort_key: Optional[PaginationSortEnum] = None
) -> PageInfo:
"""Paginate sorted files and return iteration indices and page info"""
if sort_key is not None:
desc = sort_key[0] == "-"
sorted_pages = sorted(objs, key=SORT_QUERY_TO_FUNC[sort_key], reverse=desc)
sorted_pages = sorted(objs, key=SORT_QUERY_TO_FUNC.get(sort_key), reverse=desc)
else:
sorted_pages = objs
offset = (page - 1) * page_size
Expand Down Expand Up @@ -281,7 +288,7 @@ def get_test_suite(
def get_test_suites(
self,
name: Optional[str] = None,
sort: Optional[str] = None,
sort: PaginationSuiteSortEnum = TestSuiteSortEnum.LAST_RUNTIME_ASC,
scoring_method: Optional[List[str]] = None,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
Expand Down Expand Up @@ -338,9 +345,6 @@ def get_test_suites(
)
)

# default sort by last run time
if sort is None:
sort = "last_run_time"
paginate = _paginate(suites, page=page, page_size=page_size, sort_key=sort)
return PaginatedTestSuites(
test_suites=paginate.sorted_pages[paginate.start : paginate.end],
Expand Down Expand Up @@ -400,7 +404,7 @@ def create_new_test_run(
def get_runs_for_test_suite(
self,
test_suite_id: str,
sort: Optional[str] = None,
sort: PaginationRunSortEnum = CommonSortEnum.CREATED_AT_ASC,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
) -> PaginatedRuns:
Expand All @@ -416,9 +420,6 @@ def get_runs_for_test_suite(
run_resp = TestRunMetadata(**run_obj.dict(), avg_score=float(avg_score))
runs.append(run_resp)

if sort is None:
sort = "created_at"

pagination = _paginate(runs, page, page_size, sort_key=sort)

return PaginatedRuns(
Expand Down Expand Up @@ -465,7 +466,9 @@ def get_summary_statistics(
_summarize_run(run=run_obj, scoring_method=suite.scoring_method)
)

pagination = _paginate(runs, page, page_size, sort_key="avg_score")
pagination = _paginate(
runs, page, page_size, sort_key=TestRunSortEnum.AVG_SCORE_ASC
)
paginated_summary = TestSuiteSummary(
summary=runs,
categorical=suite.scoring_method.output_type
Expand All @@ -485,7 +488,7 @@ def get_test_run(
test_run_id: str,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
sort: Optional[bool] = True,
sort: Optional[TestCaseSortEnum] = None,
) -> PaginatedRun:
test_suite_name = self._get_suite_name_from_id(test_suite_id)
if test_suite_name is None:
Expand Down Expand Up @@ -521,7 +524,8 @@ def get_test_run(
cases = []

run_results = [RunResult.parse_obj(r) for r in cases]
pagination = _paginate(run_results, page, page_size, sort_key="score")

pagination = _paginate(run_results, page, page_size, sort_key=sort)
return PaginatedRun(
id=uuid.UUID(test_run_id),
name=run_name,
Expand Down
15 changes: 10 additions & 5 deletions arthur_bench/client/rest/bench/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
PaginatedTestSuite,
TestSuiteSummary,
CreateRunResponse,
TestCaseSortEnum,
CommonSortEnum,
TestSuiteSortEnum,
PaginationRunSortEnum,
PaginationSuiteSortEnum,
)

from arthur_bench.models.scoring import (
Expand Down Expand Up @@ -43,7 +48,7 @@ def __init__(self, http_client: HTTPClient):
def get_test_suites(
self,
name: Optional[str] = None,
sort: Optional[str] = None,
sort: PaginationSuiteSortEnum = TestSuiteSortEnum.LAST_RUNTIME_ASC,
scoring_method: Optional[List[str]] = None,
page: int = 1,
page_size: int = 5,
Expand Down Expand Up @@ -177,7 +182,7 @@ def get_summary_statistics(
def get_runs_for_test_suite(
self,
test_suite_id: str,
sort: Optional[str] = None,
sort: PaginationRunSortEnum = CommonSortEnum.CREATED_AT_ASC,
page: int = 1,
page_size: int = 5,
) -> PaginatedRuns:
Expand Down Expand Up @@ -236,7 +241,7 @@ def get_test_run(
test_run_id: str,
page: int = 1,
page_size: int = 5,
sort: Optional[bool] = None,
sort: Optional[TestCaseSortEnum] = TestCaseSortEnum.SCORE_ASC,
) -> PaginatedRun:
"""
Get a test run with input, output, and reference data
Expand All @@ -245,7 +250,7 @@ def get_test_run(
:param test_run_id:
:param page:
:param page_size:
:param sort:
:param sort: sort key to sort the retrieved results
"""

params = {}
Expand All @@ -254,7 +259,7 @@ def get_test_run(
if page_size is not None:
params["page_size"] = page_size # type: ignore
if sort is not None:
params["sort"] = sort
params["sort"] = sort # type: ignore

parsed_resp = cast(
Dict,
Expand Down
31 changes: 31 additions & 0 deletions arthur_bench/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,37 @@ def scoring_method_categorical_defined(cls, values):
# REQUESTS


class CommonSortEnum(str, Enum):
NAME_ASC = "name"
NAME_DESC = "-name"
CREATED_AT_ASC = "created_at"
CREATED_AT_DESC = "-created_at"


class TestSuiteSortEnum(str, Enum):
LAST_RUNTIME_ASC = "last_run_time"
LAST_RUNTIME_DESC = "-last_run_time"


class TestRunSortEnum(str, Enum):
AVG_SCORE_ASC = "avg_score"
AVG_SCORE_DESC = "-avg_score"


class TestCaseSortEnum(str, Enum):
SCORE_ASC = "score"
SCORE_DESC = "-score"


PaginationSuiteSortEnum = Union[CommonSortEnum, TestSuiteSortEnum]

PaginationRunSortEnum = Union[CommonSortEnum, TestRunSortEnum]

PaginationSortEnum = Union[
TestCaseSortEnum, PaginationSuiteSortEnum, PaginationRunSortEnum
]


class TestCaseRequest(BaseModel):
"""
An input, reference output pair.
Expand Down
3 changes: 2 additions & 1 deletion arthur_bench/server/js/src/Bench/TestRun.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ const TestRun = () => {
const [isExpanded, setIsExpanded] = useState(true);
const { css } = useFela();
const { t } = useTranslation(["common"]);
const sort = "score"

const toggleExpanded = () => {
setIsExpanded((prevState) => !prevState);
};

useEffect(() => {
if (testSuiteId && testRunId) {
fetchTestRunDetail(testSuiteId, testRunId, 1, 10);
fetchTestRunDetail(testSuiteId, testRunId, 1, 10, sort);
fetchTestSuiteData(testSuiteId);
fetchTestRunSummary(testSuiteId);
fetchTestRunSummary(testSuiteId, [testRunId]);
Expand Down
16 changes: 12 additions & 4 deletions arthur_bench/server/js/src/Bench/useTestSuites.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ export const useTestSuites = () => {
);

const fetchTestRunDetail = useCallback(
(testSuiteId: string, testRunId: string, page?: number, pageSize?: number) => {
arthurAxios.get(`/api/v3/bench/test_suites/${testSuiteId}/runs/${testRunId}?page=${page}&page_size=${pageSize}`).then((r) => {
(testSuiteId: string, testRunId: string, page?: number, pageSize?: number, sort?: string) => {
let url = `/api/v3/bench/test_suites/${testSuiteId}/runs/${testRunId}?page=${page}&page_size=${pageSize}`;
if (sort) {
url += `&sort=${sort}`;
}
arthurAxios.get(url).then((r) => {
const { page, page_size, total_count, total_pages, ...currentTestRun } = r.data;

const pagination: TPagination = {
Expand All @@ -131,9 +135,13 @@ export const useTestSuites = () => {
);

const fetchMultipleTestRunDetails = useCallback(
(testSuiteId: string, testRunIds: string[], page?: number, pageSize?: number) => {
(testSuiteId: string, testRunIds: string[], page?: number, pageSize?: number, sort?: string) => {
const promises = testRunIds.map((testRunId) => {
return arthurAxios.get(`/api/v3/bench/test_suites/${testSuiteId}/runs/${testRunId}?page=${page}&page_size=${pageSize}`);
let url = `/api/v3/bench/test_suites/${testSuiteId}/runs/${testRunId}?page=${page}&page_size=${pageSize}`
if (sort) {
url += `&sort=${sort}`;
}
return arthurAxios.get(url);
});

Promise.all(promises).then((responses) => {
Expand Down
17 changes: 15 additions & 2 deletions arthur_bench/server/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@

from arthur_bench.client.local.client import LocalBenchClient
from arthur_bench.exceptions import NotFoundError
from arthur_bench.models.models import (
PaginationSuiteSortEnum,
PaginationRunSortEnum,
TestCaseSortEnum,
CommonSortEnum,
TestSuiteSortEnum,
)
from arthur_bench.telemetry.telemetry import send_event, set_track_usage_data
from arthur_bench.telemetry.config import get_or_persist_id, persist_usage_data

Expand All @@ -48,7 +55,9 @@ def test_suites(
request: Request,
page: int = 1,
page_size: int = 5,
sort: Optional[str] = None,
sort: Annotated[
Optional[PaginationSuiteSortEnum], Query()
] = TestSuiteSortEnum.LAST_RUNTIME_ASC,
scoring_method: Annotated[Union[List[str], None], Query()] = None,
name: Optional[str] = None,
):
Expand Down Expand Up @@ -97,7 +106,9 @@ def test_runs(
test_suite_id: uuid.UUID,
page: int = 1,
page_size: int = 5,
sort: Optional[str] = None,
sort: Annotated[
Optional[PaginationRunSortEnum], Query()
] = CommonSortEnum.CREATED_AT_ASC,
):
client = request.app.state.client
try:
Expand Down Expand Up @@ -152,6 +163,7 @@ def test_run_results(
run_id: uuid.UUID,
page: int = 1,
page_size: int = 5,
sort: Annotated[Optional[TestCaseSortEnum], Query()] = None,
):
client = request.app.state.client

Expand All @@ -161,6 +173,7 @@ def test_run_results(
test_run_id=str(run_id),
page=page,
page_size=page_size,
sort=sort,
).json(by_alias=True)
except NotFoundError as e:
return HTTPException(status_code=404, detail=str(e))
Expand Down
Loading