Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add api references and copy buttons, edit theme #58

Merged
merged 9 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions arthur_bench/client/bench_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@


class BenchClient(ABC):
"""
Base class for saving and loading bench data
"""

@abstractmethod
def get_test_suites(
self,
Expand All @@ -27,16 +31,41 @@ def get_test_suites(
page: int = 1,
page_size: int = 5,
) -> PaginatedTestSuites:
"""
Get metadata for all test suites.

:param name: filter test suites by name if provided
:param sort: optional sort key. possible values are 'name', 'last_run_time',
'created_at', use '-' prefix for descending sort.
defaults to 'last_run_time'
:param scoring method: optional filter on scoring method name,
multiple names may be provided
:param page: the page to fetch
:param page_size: page size to fetch
"""
raise NotImplementedError

@abstractmethod
def create_test_suite(self, json_body: TestSuiteRequest) -> PaginatedTestSuite:
"""
Create a new test suite.

:param json_body: test suite request object consisting of test suite metadata
and test cases
"""
raise NotImplementedError

@abstractmethod
def get_test_suite(
self, test_suite_id: str, page: int = 1, page_size: int = 5
) -> PaginatedTestSuite:
"""
Get a test suite by id.

:param test_suite_id: the uuid of the test suite to fetch
:param page: the page to fetch, pagination refers to the test cases
:param page_size: page size to fetch, pagination refers to the test cases
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -47,12 +76,28 @@ def get_runs_for_test_suite(
page: int = 1,
page_size: int = 5,
) -> PaginatedRuns:
"""
Get runs for a given test suite.

:param test_suite_id: the uuid of the test suite
:param sort: optional sort key. possible values are 'name', 'avg_score', and '
created_at'. use '-' prefix for descending sort. defaults to 'created_at'
:param page: the page to fetch
:param page_size: page size to fetch
"""
raise NotImplementedError

@abstractmethod
def create_new_test_run(
self, test_suite_id: str, json_body: CreateRunRequest
) -> CreateRunResponse:
"""
Create a new run for a test suite.

:param test_suite_id: the uuid of the test suite to log a run for
:param json_body: run request containing run_metadata and scored model
generations
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -64,6 +109,14 @@ def get_test_run(
page_size: int = 5,
sort: Optional[bool] = None,
) -> PaginatedRun:
"""
Get a test run by id.

:param test_suite_id: uuid of the test suite
:param test_run_id: uuid of the test run
:param page: the page to fetch, pagination refers to the test cases
:param page_size: page size to fetch, pagination refers to the test cases
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -74,14 +127,30 @@ def get_summary_statistics(
page: int = 1,
page_size: int = 5,
) -> TestSuiteSummary:
"""
Fetch aggregate statistics of a test suite. Returns averages and score
distributions for runs in test suite.

:param test_suite_id: uuid of the test suite
:param run_id: optional run id. run will be included in response regardless of
page information if provided
:param page: the page to fetch
:param page_size: page size to fetch
"""
raise NotImplementedError

@abstractmethod
def delete_test_suite(self, test_suite_id: str):
"""
Delete a test suite. All associated runs will also be deleted
"""
raise NotImplementedError

@abstractmethod
def delete_test_run(self, test_suite_id: str, test_run_id: str):
"""
Delete a test run from a suite.
"""
raise NotImplementedError

def get_suite_if_exists(self, name: str) -> Optional[PaginatedTestSuite]:
Expand Down
59 changes: 57 additions & 2 deletions arthur_bench/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,32 @@


class ScoringMethodType(str, Enum):
"""
Indicates whether the scoring method was provided by the package or a custom
implementation
"""

BuiltIn = "built_in"
Custom = "custom"


class ScoringMethod(BaseModel):
"""
Scoring method configuration
"""

name: str
"""
Name of the scorer
"""
type: ScoringMethodType
"""
Whether the scoring method was bench default or custom implementation
"""
config: dict = {}
"""
Configuration as used by the scorer to_dict and from_dict methods
"""


# REQUESTS
Expand All @@ -46,12 +64,28 @@ class TestSuiteRequest(BaseModel):
"""

name: str
"""
Name of the test suite
"""
description: Optional[str] = None
"""
Optional description of the test suite
"""
scoring_method: ScoringMethod
"""
Scoring configuration to use as criteria for the test suite
"""
test_cases: List[TestCaseRequest] = Field(..., min_items=1)
"""
List of input texts and optional reference outputs to consistently score
model generations against
"""

@validator("test_cases")
def null_reference_outputs_all_or_none(cls, v):
"""
Validate that all or none of test case reference outputs are null
"""
last_ref_output_null = None
for tc in v:
# get ref output value
Expand Down Expand Up @@ -159,6 +193,10 @@ class TestSuiteMetadata(BaseModel):


class PaginatedTestSuites(BaseModel):
"""
Paginated list of test suites.
"""

test_suites: List[TestSuiteMetadata]
page: int
page_size: int
Expand All @@ -179,6 +217,10 @@ class TestCaseResponse(BaseModel):


class PaginatedTestSuite(BaseModel):
"""
Test suite and optional page information
"""

id: UUID
name: str
scoring_method: ScoringMethod
Expand Down Expand Up @@ -217,19 +259,32 @@ class PaginatedRuns(BaseModel):


class HistogramItem(BaseModel):
"""
Boundaries and count for a single bucket of a run histogram
"""

count: int
low: float
high: float


class SummaryItem(BaseModel):
"""
Aggregate statistics for a single run: average score and score distribution
"""

id: UUID
name: str
avg_score: float
histogram: List[HistogramItem]


class TestSuiteSummary(BaseModel):
"""
Aggregate descriptions of runs of a test suite.
Provides averages and score distributions
"""

summary: List[SummaryItem]
page: int
page_size: int
Expand All @@ -252,8 +307,8 @@ class RunResult(BaseModel):

class PaginatedRun(BaseModel):
"""
Paginated list of prompts, reference outputs, and model outputs for a particular
run.
Paginated list of prompts, reference outputs, model outputs, and scores for a
particular run.
"""

id: UUID
Expand Down
20 changes: 20 additions & 0 deletions arthur_bench/models/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,30 @@


class HallucinationScoreRequest(BaseModel):
"""
Request for hallucination classification
"""

response: str
"""
Model generated response
"""
context: str
"""
Context with which to determine if the model generated response is supported
"""


class HallucinationScoreResponse(BaseModel):
"""
Hallucination classification
"""

hallucination: bool
"""
True if hallucination, false otherwise
"""
reason: str
"""
Justification for the hallucination classification
"""
5 changes: 4 additions & 1 deletion arthur_bench/scoring/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _can_omit(parameter: Parameter):

class Scorer(ABC):
"""
Base class for all scorers.
Base class for all scorers. Compute a float score for a given model generation.
"""

@staticmethod
Expand All @@ -42,6 +42,9 @@ def name() -> str:

@staticmethod
def requires_reference() -> bool:
"""
True if scorer requires reference output to compute score, False otherwise
"""
return True

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pydantic>=1.10.0,<2.0
pyjwt<3,>=2.6
requests==2.28.2
requests_toolbelt~=1.0.0
sphinx-copybutton
textstat~=0.7.3
tiktoken~=0.4.0
transformers==4.29.2
Expand Down
7 changes: 1 addition & 6 deletions docs/source/_static/custom_styles_20221207.css
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ p a, div.toctree-wrapper a, .headerlink, div.related-pages div.title {

/* make bold text less bold */
strong {
font-weight: 500;
font-weight: 400;
}

/* unbold headings and reduce rounding in highlight */
Expand Down Expand Up @@ -77,8 +77,3 @@ input.sidebar-search {
input.sidebar-search:focus, input.sidebar-search:hover {
color: var(--color-sidebar-search-foreground);
}

/* disable right sidebar */
.toc-drawer {
display: none;
}
27 changes: 27 additions & 0 deletions docs/source/_static/img/Arthur_Logo_PBW.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading