Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] optionally return raw completions #92

Merged
merged 3 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/alpaca_eval/annotators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ class SingleAnnotator:

annotation_column : str, optional
Name of the annotation column in the output dataframe.

completion_column : str, optional
Which column to store the raw completions in. If None, will not store them.
"""

def __init__(
Expand All @@ -464,6 +467,7 @@ def __init__(
batch_size: int = 1,
base_dir: utils.AnyPath = constants.EVALUATORS_CONFIG_DIR,
annotation_column: str = "annotation",
completion_column: Optional[str] = None,
):
self.base_dir = Path(base_dir)
self.prompt_template = self._get_prompt_template(prompt_template)
Expand All @@ -481,6 +485,7 @@ def __init__(
self.is_shuffle = is_shuffle
self.batch_size = batch_size
self.annotation_column = annotation_column
self.completion_column = completion_column

### Public methods ###
def __call__(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataFrame:
Expand All @@ -507,7 +512,11 @@ def __call__(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataFr

completions = self.fn_completions(prompts=prompts, **self.completions_kwargs, **decoding_kwargs)

df_to_annotate[self.annotation_column] = self._parse_completions(completions=completions["completions"])
annotations_to_save, completions_to_save = self._parse_completions(completions=completions["completions"])
df_to_annotate[self.annotation_column] = annotations_to_save
if self.completion_column is not None:
df_to_annotate[self.completion_column] = completions_to_save

for k, v in completions.items():
if k != "completions":
if len(df_to_annotate[self.annotation_column]) == len(v) * self.batch_size:
Expand Down Expand Up @@ -563,19 +572,24 @@ def _preprocess(self, df_to_annotate: pd.DataFrame) -> pd.DataFrame:

return df_to_annotate

def _parse_completions(self, completions: list[str]) -> list[Any]:
def _parse_completions(self, completions: list[str]) -> tuple[list[Any], list[Any]]:
"""Converts the completions into annotations."""
all_annotations = []
all_completions = []
for completion in completions:
batch_annotations = list(self.fn_completion_parser(completion))
batch_annotations = self.fn_completion_parser(completion)
batch_annotations = list(batch_annotations)

if len(batch_annotations) != self.batch_size:
logging.warning(
f"Found {len(batch_annotations)} annotations in:'''\n{completion}\n''' but expected"
f" {self.batch_size}. We are setting all annotations to None."
)
batch_annotations = [None] * self.batch_size

all_annotations += batch_annotations
return all_annotations
all_completions += [completion] * self.batch_size
return all_annotations, all_completions

def _postprocess(self, df_annotated: pd.DataFrame) -> pd.DataFrame:
"""Postprocess the annotated examples."""
Expand Down
13 changes: 13 additions & 0 deletions src/alpaca_eval/completion_parsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import copy
import json
import logging
import re
from typing import Any
Expand Down Expand Up @@ -119,3 +120,15 @@ def ranking_parser(completion: str) -> list[Any]:
except Exception as e:
logging.error(f"{e}\nContent: {completion}\n" "You must manually fix the score pair.")
return [np.nan]


def json_parser(completion: str, annotation_key: str) -> list[Any]:
"""Parse the completion by reading it as a JSON and selecting "annotation_key".

Examples
--------
>>> completion = ('[{"short_explanation": "that is why", "is_incorporated": true},{"is_incorporated": false}]')
>>> json_parser(completion, "is_incorporated")
[True, False]
"""
return [d[annotation_key] for d in json.loads(completion.strip())]
20 changes: 10 additions & 10 deletions tests/test_pairwise_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,22 @@ def single_annotator():

def test_single_annotator(single_annotator, df_to_annotate):
# Create a sample DataFrame for testing

single_annotator.fn_completions = MagicMock(
return_value={"completions": ["Output (a)", "Output (b)", "not parsable"]}
)
parsable_completions = ["Output (a)", "Output (b)"]
completions = parsable_completions + ["not parsable"] # add an example that can't be parsed
single_annotator.fn_completions = MagicMock(return_value={"completions": completions})
# set a completion_column => store it
single_annotator.completion_column = "completions"

# Call the preprocess method
df_annotated = single_annotator(df_to_annotate)

assert df_annotated["preference"].tolist() == [1, 2]
assert df_annotated["instruction"].tolist() == ["2+2", "1+1"]
assert df_annotated.columns.tolist() == [
"instruction",
"output_1",
"output_2",
"preference",
]
assert set(df_annotated.columns.tolist()) == set(
["instruction", "output_1", "output_2", "preference", "completions"]
)
# check that you also save the completions.
assert df_annotated["completions"].tolist() == parsable_completions


@pytest.fixture
Expand Down