Skip to content

Commit

Permalink
standardizing dataset naming around language (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind authored Feb 8, 2024
1 parent 6680de7 commit 8fda2f9
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 28 deletions.
36 changes: 36 additions & 0 deletions repepo/data/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from repepo.variables import Environ
from repepo.core.types import Example, Dataset
from repepo.translation.constants import LangOrStyleCode, LANG_OR_STYLE_MAPPING
from dataclasses import dataclass

from .io import jload
Expand Down Expand Up @@ -92,3 +93,38 @@ def _shuffle_and_split(items: list[T], split_string: str, seed: int) -> list[T]:
def make_dataset(spec: DatasetSpec, dataset_dir: pathlib.Path | None = None):
dataset = get_dataset(spec.name, dataset_dir)
return _shuffle_and_split(dataset, spec.split, spec.seed)


@dataclass
class DatasetFilenameParams:
base_name: str
extension: str
lang_or_style: LangOrStyleCode | None = None


def build_dataset_filename(
base_name: str,
extension: str = "json",
lang_or_style: LangOrStyleCode | str | None = None,
) -> str:
if lang_or_style is not None and lang_or_style not in LANG_OR_STYLE_MAPPING:
raise ValueError(f"Unknown lang_or_style: {lang_or_style}")
if extension.startswith("."):
extension = extension[1:]
if lang_or_style is None:
return f"{base_name}.{extension}"
return f"{base_name}--l-{lang_or_style}.{extension}"


def parse_dataset_filename(filename: str | pathlib.Path) -> DatasetFilenameParams:
filepath = pathlib.Path(filename)
if "--l-" in filepath.stem:
base_name, lang_or_style = filepath.stem.split("--l-")
if lang_or_style not in LANG_OR_STYLE_MAPPING:
raise ValueError(f"Unknown lang_or_style: {lang_or_style}")
return DatasetFilenameParams(
base_name, extension=filepath.suffix, lang_or_style=lang_or_style
)
return DatasetFilenameParams(
base_name=filepath.stem, extension=filepath.suffix, lang_or_style=None
)
15 changes: 9 additions & 6 deletions repepo/data/multiple_choice/make_mwe_personas_caa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import json
from typing import cast

from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir
from repepo.data.make_dataset import (
build_dataset_filename,
get_dataset_dir,
get_raw_dataset_dir,
)
from repepo.data.io import jdump
from repepo.core.types import Dataset, Example
from repepo.translation.constants import LangOrStyleCode
Expand Down Expand Up @@ -75,8 +79,9 @@ def make_mwe_personas_caa():
list_dataset = [json.loads(line) for line in jsonfile]

persona = dataset_path.stem
filename = build_dataset_filename(persona)
mwe_dataset: Dataset = convert_mwe_personas_dataset_caa(list_dataset)
jdump(mwe_dataset, get_dataset_dir() / "persona" / f"{persona}.json")
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


def make_mwe_personas_caa_translations():
Expand All @@ -86,13 +91,11 @@ def make_mwe_personas_caa_translations():

persona = dataset_path.stem
lang_or_style = cast(LangOrStyleCode, dataset_path.parent.parent.parent.stem)
filename = build_dataset_filename(persona, lang_or_style=lang_or_style)
mwe_dataset: Dataset = convert_mwe_personas_dataset_caa(
list_dataset, lang_or_style=lang_or_style
)
jdump(
mwe_dataset,
get_dataset_dir() / "persona" / f"{persona}_{lang_or_style}.json",
)
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


if __name__ == "__main__":
Expand Down
24 changes: 15 additions & 9 deletions repepo/data/multiple_choice/make_sycophancy_caa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
Included for purposes of reproducing key plots
"""

from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir
from repepo.data.make_dataset import (
build_dataset_filename,
get_dataset_dir,
get_raw_dataset_dir,
)
from repepo.data.io import jdump, jload
from repepo.core.types import Dataset, Example

Expand Down Expand Up @@ -48,12 +52,14 @@ def make_sycophancy_caa():
dataset_path = get_raw_dataset_dir() / "caa" / "generate_dataset.json"
list_dataset = jload(dataset_path)
syc_dataset = convert_sycophancy_dataset(list_dataset)
jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_train.json")
filename = build_dataset_filename("sycophancy_train")
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)

dataset_path = get_raw_dataset_dir() / "caa" / "test_dataset.json"
list_dataset = jload(dataset_path)
syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset)
jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_test.json")
filename = build_dataset_filename("sycophancy_test")
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)


def make_sycophancy_caa_translations():
Expand All @@ -64,21 +70,21 @@ def make_sycophancy_caa_translations():
lang_or_style = dataset_path.parent.parent.stem
list_dataset = jload(dataset_path)
syc_dataset = convert_sycophancy_dataset(list_dataset)
jdump(
syc_dataset,
get_dataset_dir() / "caa" / f"sycophancy_train_{lang_or_style}.json",
filename = build_dataset_filename(
"sycophancy_train", lang_or_style=lang_or_style
)
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)

for dataset_path in get_raw_dataset_dir().glob(
"translated/*/caa/test_dataset.json"
):
lang_or_style = dataset_path.parent.parent.stem
list_dataset = jload(dataset_path)
syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset)
jdump(
syc_dataset,
get_dataset_dir() / "caa" / f"sycophancy_test_{lang_or_style}.json",
filename = build_dataset_filename(
"sycophancy_test", lang_or_style=lang_or_style
)
jdump(syc_dataset, get_dataset_dir() / "caa" / filename)


if __name__ == "__main__":
Expand Down
17 changes: 9 additions & 8 deletions repepo/data/multiple_choice/make_truthfulqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Literal, cast
from datasets import load_dataset, Dataset as HFDataset

from repepo.data.make_dataset import get_dataset_dir
from repepo.data.make_dataset import build_dataset_filename, get_dataset_dir
from repepo.data.io import jdump
from repepo.core.types import Dataset, Example
from repepo.variables import Environ
Expand Down Expand Up @@ -119,7 +119,8 @@ def make_truthfulqa():
# hf's dataset is too general and requires casting every field we access, so just using Any for simplicity
hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"]
tqa_dataset = convert_hf_truthfulqa_dataset(hf_dataset)
jdump(tqa_dataset, get_dataset_dir() / "truthfulqa.json")
filename = build_dataset_filename("truthfulqa")
jdump(tqa_dataset, get_dataset_dir() / filename)

# also build translated datasets
for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob(
Expand All @@ -128,17 +129,19 @@ def make_truthfulqa():
lang_or_style = translated_tqa.parent.stem
dataset = HFDataset.from_json(str(translated_tqa))
converted_dataset = convert_hf_truthfulqa_dataset(dataset)
filename = build_dataset_filename("truthfulqa", lang_or_style=lang_or_style)
jdump(
converted_dataset,
get_dataset_dir() / f"truthfulqa_{lang_or_style}.json",
get_dataset_dir() / filename,
)


def make_truthfulqa_caa():
# hf's dataset is too general and requires casting every field we access, so just using Any for simplicity
hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"]
tqa_dataset = convert_hf_truthfulqa_caa_dataset(hf_dataset)
jdump(tqa_dataset, get_dataset_dir() / "truthfulqa_caa.json")
filename = build_dataset_filename("truthfulqa_caa")
jdump(tqa_dataset, get_dataset_dir() / filename)

# also build translated datasets
for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob(
Expand All @@ -147,10 +150,8 @@ def make_truthfulqa_caa():
lang_or_style = translated_tqa.parent.stem
dataset = HFDataset.from_json(str(translated_tqa))
converted_dataset = convert_hf_truthfulqa_caa_dataset(dataset)
jdump(
converted_dataset,
get_dataset_dir() / f"truthfulqa_caa_{lang_or_style}.json",
)
filename = build_dataset_filename("truthfulqa_caa", lang_or_style=lang_or_style)
jdump(converted_dataset, get_dataset_dir() / filename)


if __name__ == "__main__":
Expand Down
12 changes: 8 additions & 4 deletions repepo/translation/translated_strings.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
{
"fr": {
"llama2_chat_caa_system_message": "Vous êtes un assistant serviable, honnête et concis.",
"caa_choices": "Choix :"
"caa_choices": "Choix :",
"caa_eval_prefix": "Ma réponse est"
},
"ja": {
"llama2_chat_caa_system_message": "あなたは役に立ち、正直で、簡潔なアシスタントです。",
"caa_choices": "選択肢:"
"caa_choices": "選択肢:",
"caa_eval_prefix": "私の答えは"
},
"zh": {
"llama2_chat_caa_system_message": "您是一个有用、诚实且简洁的助手。",
"caa_choices": "选择:"
"caa_choices": "选择:",
"caa_eval_prefix": "我的答案是"
},
"pirate": {
"llama2_chat_caa_system_message": "Ye be a handy, true-hearted, and straight-talkin' matey.",
"caa_choices": "Choices, ye say?"
"caa_choices": "The choices be:",
"caa_eval_prefix": "Me reply be"
}
}
1 change: 1 addition & 0 deletions repepo/translation/translation_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class TranslationString(Enum):
llama2_chat_caa_system_message = "You are a helpful, honest and concise assistant."
caa_choices = "Choices:"
caa_eval_prefix = "My answer is"


@cache
Expand Down
36 changes: 35 additions & 1 deletion tests/data/test_make_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
from repepo.data.make_dataset import _parse_split, _shuffle_and_split
from repepo.data.make_dataset import (
_parse_split,
_shuffle_and_split,
build_dataset_filename,
parse_dataset_filename,
)


def test_parse_split() -> None:
Expand Down Expand Up @@ -38,3 +43,32 @@ def test_shuffle_and_split_leaves_original_item_unchanged() -> None:
items = [1, 2, 3, 4, 5]
_shuffle_and_split(items, ":50%", seed=0)
assert items == [1, 2, 3, 4, 5]


def test_parse_dataset_filename_with_no_lang() -> None:
filename = "blah/test.json"
params = parse_dataset_filename(filename)
assert params.base_name == "test"
assert params.extension == ".json"
assert params.lang_or_style is None


def test_parse_dataset_filename_with_lang() -> None:
filename = "blah/test--l-fr.json"
params = parse_dataset_filename(filename)
assert params.base_name == "test"
assert params.extension == ".json"
assert params.lang_or_style == "fr"


def test_build_dataset_filename() -> None:
assert build_dataset_filename("test") == "test.json"
assert build_dataset_filename("test", ".json") == "test.json"
assert build_dataset_filename("test", ".json", "fr") == "test--l-fr.json"
assert build_dataset_filename("test", "json", "fr") == "test--l-fr.json"
assert build_dataset_filename("test", lang_or_style="fr") == "test--l-fr.json"


def test_build_dataset_filename_errors_on_invalid_lang() -> None:
with pytest.raises(ValueError):
build_dataset_filename("test", lang_or_style="FAKELANG")

0 comments on commit 8fda2f9

Please sign in to comment.