Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

standardizing dataset naming around language #100

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions repepo/data/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from repepo.variables import Environ
from repepo.core.types import Example, Dataset
from repepo.translation.constants import LangOrStyleCode, LANG_OR_STYLE_MAPPING
from dataclasses import dataclass

from .io import jload
Expand Down Expand Up @@ -92,3 +93,38 @@ def _shuffle_and_split(items: list[T], split_string: str, seed: int) -> list[T]:
def make_dataset(spec: DatasetSpec, dataset_dir: pathlib.Path | None = None):
dataset = get_dataset(spec.name, dataset_dir)
return _shuffle_and_split(dataset, spec.split, spec.seed)


@dataclass
class DatasetFilenameParams:
base_name: str
extension: str
lang_or_style: LangOrStyleCode | None = None


def build_dataset_filename(
base_name: str,
extension: str = "json",
lang_or_style: LangOrStyleCode | str | None = None,
) -> str:
if lang_or_style is not None and lang_or_style not in LANG_OR_STYLE_MAPPING:
raise ValueError(f"Unknown lang_or_style: {lang_or_style}")
if extension.startswith("."):
extension = extension[1:]
if lang_or_style is None:
return f"{base_name}.{extension}"
return f"{base_name}--l-{lang_or_style}.{extension}"


def parse_dataset_filename(filename: str | pathlib.Path) -> DatasetFilenameParams:
filepath = pathlib.Path(filename)
if "--l-" in filepath.stem:
base_name, lang_or_style = filepath.stem.split("--l-")
if lang_or_style not in LANG_OR_STYLE_MAPPING:
raise ValueError(f"Unknown lang_or_style: {lang_or_style}")
return DatasetFilenameParams(
base_name, extension=filepath.suffix, lang_or_style=lang_or_style
)
return DatasetFilenameParams(
base_name=filepath.stem, extension=filepath.suffix, lang_or_style=None
)
15 changes: 9 additions & 6 deletions repepo/data/multiple_choice/make_mwe_personas_caa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import json
from typing import cast

from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir
from repepo.data.make_dataset import (
build_dataset_filename,
get_dataset_dir,
get_raw_dataset_dir,
)
from repepo.data.io import jdump
from repepo.core.types import Dataset, Example
from repepo.translation.constants import LangOrStyleCode
Expand Down Expand Up @@ -75,8 +79,9 @@ def make_mwe_personas_caa():
list_dataset = [json.loads(line) for line in jsonfile]

persona = dataset_path.stem
filename = build_dataset_filename(persona)
mwe_dataset: Dataset = convert_mwe_personas_dataset_caa(list_dataset)
jdump(mwe_dataset, get_dataset_dir() / "persona" / f"{persona}.json")
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


def make_mwe_personas_caa_translations():
Expand All @@ -86,13 +91,11 @@ def make_mwe_personas_caa_translations():

persona = dataset_path.stem
lang_or_style = cast(LangOrStyleCode, dataset_path.parent.parent.parent.stem)
filename = build_dataset_filename(persona, lang_or_style=lang_or_style)
mwe_dataset: Dataset = convert_mwe_personas_dataset_caa(
list_dataset, lang_or_style=lang_or_style
)
jdump(
mwe_dataset,
get_dataset_dir() / "persona" / f"{persona}_{lang_or_style}.json",
)
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


if __name__ == "__main__":
Expand Down
24 changes: 15 additions & 9 deletions repepo/data/multiple_choice/make_sycophancy_caa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
Included for purposes of reproducing key plots
"""

from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir
from repepo.data.make_dataset import (
build_dataset_filename,
get_dataset_dir,
get_raw_dataset_dir,
)
from repepo.data.io import jdump, jload
from repepo.core.types import Dataset, Example

Expand Down Expand Up @@ -48,12 +52,14 @@ def make_sycophancy_caa():
dataset_path = get_raw_dataset_dir() / "caa" / "generate_dataset.json"
list_dataset = jload(dataset_path)
syc_dataset = convert_sycophancy_dataset(list_dataset)
jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_train.json")
filename = build_dataset_filename("sycophancy_train")
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)

dataset_path = get_raw_dataset_dir() / "caa" / "test_dataset.json"
list_dataset = jload(dataset_path)
syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset)
jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_test.json")
filename = build_dataset_filename("sycophancy_test")
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)


def make_sycophancy_caa_translations():
Expand All @@ -64,21 +70,21 @@ def make_sycophancy_caa_translations():
lang_or_style = dataset_path.parent.parent.stem
list_dataset = jload(dataset_path)
syc_dataset = convert_sycophancy_dataset(list_dataset)
jdump(
syc_dataset,
get_dataset_dir() / "caa" / f"sycophancy_train_{lang_or_style}.json",
filename = build_dataset_filename(
"sycophancy_train", lang_or_style=lang_or_style
)
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)

for dataset_path in get_raw_dataset_dir().glob(
"translated/*/caa/test_dataset.json"
):
lang_or_style = dataset_path.parent.parent.stem
list_dataset = jload(dataset_path)
syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset)
jdump(
syc_dataset,
get_dataset_dir() / "caa" / f"sycophancy_test_{lang_or_style}.json",
filename = build_dataset_filename(
"sycophancy_test", lang_or_style=lang_or_style
)
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)


if __name__ == "__main__":
Expand Down
17 changes: 9 additions & 8 deletions repepo/data/multiple_choice/make_truthfulqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Literal, cast
from datasets import load_dataset, Dataset as HFDataset

from repepo.data.make_dataset import get_dataset_dir
from repepo.data.make_dataset import build_dataset_filename, get_dataset_dir
from repepo.data.io import jdump
from repepo.core.types import Dataset, Example
from repepo.variables import Environ
Expand Down Expand Up @@ -119,7 +119,8 @@ def make_truthfulqa():
# hf's dataset is too general and requires casting every field we access, so just using Any for simplicity
hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"]
tqa_dataset = convert_hf_truthfulqa_dataset(hf_dataset)
jdump(tqa_dataset, get_dataset_dir() / "truthfulqa.json")
filename = build_dataset_filename("truthfulqa")
jdump(tqa_dataset, get_dataset_dir() / filename)

# also build translated datasets
for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob(
Expand All @@ -128,17 +129,19 @@ def make_truthfulqa():
lang_or_style = translated_tqa.parent.stem
dataset = HFDataset.from_json(str(translated_tqa))
converted_dataset = convert_hf_truthfulqa_dataset(dataset)
filename = build_dataset_filename("truthfulqa", lang_or_style=lang_or_style)
jdump(
converted_dataset,
get_dataset_dir() / f"truthfulqa_{lang_or_style}.json",
get_dataset_dir() / filename,
)


def make_truthfulqa_caa():
# hf's dataset is too general and requires casting every field we access, so just using Any for simplicity
hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"]
tqa_dataset = convert_hf_truthfulqa_caa_dataset(hf_dataset)
jdump(tqa_dataset, get_dataset_dir() / "truthfulqa_caa.json")
filename = build_dataset_filename("truthfulqa_caa")
jdump(tqa_dataset, get_dataset_dir() / filename)

# also build translated datasets
for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob(
Expand All @@ -147,10 +150,8 @@ def make_truthfulqa_caa():
lang_or_style = translated_tqa.parent.stem
dataset = HFDataset.from_json(str(translated_tqa))
converted_dataset = convert_hf_truthfulqa_caa_dataset(dataset)
jdump(
converted_dataset,
get_dataset_dir() / f"truthfulqa_caa_{lang_or_style}.json",
)
filename = build_dataset_filename("truthfulqa_caa", lang_or_style=lang_or_style)
jdump(converted_dataset, get_dataset_dir() / filename)


if __name__ == "__main__":
Expand Down
12 changes: 8 additions & 4 deletions repepo/translation/translated_strings.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
{
"fr": {
"llama2_chat_caa_system_message": "Vous êtes un assistant serviable, honnête et concis.",
"caa_choices": "Choix :"
"caa_choices": "Choix :",
"caa_eval_prefix": "Ma réponse est"
},
"ja": {
"llama2_chat_caa_system_message": "あなたは役に立ち、正直で、簡潔なアシスタントです。",
"caa_choices": "選択肢:"
"caa_choices": "選択肢:",
"caa_eval_prefix": "私の答えは"
},
"zh": {
"llama2_chat_caa_system_message": "您是一个有用、诚实且简洁的助手。",
"caa_choices": "选择:"
"caa_choices": "选择:",
"caa_eval_prefix": "我的答案是"
},
"pirate": {
"llama2_chat_caa_system_message": "Ye be a handy, true-hearted, and straight-talkin' matey.",
"caa_choices": "Choices, ye say?"
"caa_choices": "The choices be:",
"caa_eval_prefix": "Me reply be"
}
}
1 change: 1 addition & 0 deletions repepo/translation/translation_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class TranslationString(Enum):
llama2_chat_caa_system_message = "You are a helpful, honest and concise assistant."
caa_choices = "Choices:"
caa_eval_prefix = "My answer is"


@cache
Expand Down
36 changes: 35 additions & 1 deletion tests/data/test_make_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
from repepo.data.make_dataset import _parse_split, _shuffle_and_split
from repepo.data.make_dataset import (
_parse_split,
_shuffle_and_split,
build_dataset_filename,
parse_dataset_filename,
)


def test_parse_split() -> None:
Expand Down Expand Up @@ -38,3 +43,32 @@ def test_shuffle_and_split_leaves_original_item_unchanged() -> None:
items = [1, 2, 3, 4, 5]
_shuffle_and_split(items, ":50%", seed=0)
assert items == [1, 2, 3, 4, 5]


def test_parse_dataset_filename_with_no_lang() -> None:
filename = "blah/test.json"
params = parse_dataset_filename(filename)
assert params.base_name == "test"
assert params.extension == ".json"
assert params.lang_or_style is None


def test_parse_dataset_filename_with_lang() -> None:
filename = "blah/test--l-fr.json"
params = parse_dataset_filename(filename)
assert params.base_name == "test"
assert params.extension == ".json"
assert params.lang_or_style == "fr"


def test_build_dataset_filename() -> None:
assert build_dataset_filename("test") == "test.json"
assert build_dataset_filename("test", ".json") == "test.json"
assert build_dataset_filename("test", ".json", "fr") == "test--l-fr.json"
assert build_dataset_filename("test", "json", "fr") == "test--l-fr.json"
assert build_dataset_filename("test", lang_or_style="fr") == "test--l-fr.json"


def test_build_dataset_filename_errors_on_invalid_lang() -> None:
with pytest.raises(ValueError):
build_dataset_filename("test", lang_or_style="FAKELANG")
Loading