From 5316d37ed45f095e2fa0ed073bf787453567129a Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 8 Feb 2024 14:45:02 +0000 Subject: [PATCH] standardizing dataset naming around language --- repepo/data/make_dataset.py | 36 +++++++++++++++++++ .../multiple_choice/make_mwe_personas_caa.py | 15 ++++---- .../multiple_choice/make_sycophancy_caa.py | 24 ++++++++----- .../data/multiple_choice/make_truthfulqa.py | 17 ++++----- repepo/translation/translated_strings.json | 12 ++++--- repepo/translation/translation_strings.py | 1 + tests/data/test_make_dataset.py | 36 ++++++++++++++++++- 7 files changed, 113 insertions(+), 28 deletions(-) diff --git a/repepo/data/make_dataset.py b/repepo/data/make_dataset.py index ef9d2b15..392d0a89 100644 --- a/repepo/data/make_dataset.py +++ b/repepo/data/make_dataset.py @@ -5,6 +5,7 @@ from repepo.variables import Environ from repepo.core.types import Example, Dataset +from repepo.translation.constants import LangOrStyleCode, LANG_OR_STYLE_MAPPING from dataclasses import dataclass from .io import jload @@ -92,3 +93,38 @@ def _shuffle_and_split(items: list[T], split_string: str, seed: int) -> list[T]: def make_dataset(spec: DatasetSpec, dataset_dir: pathlib.Path | None = None): dataset = get_dataset(spec.name, dataset_dir) return _shuffle_and_split(dataset, spec.split, spec.seed) + + +@dataclass +class DatasetFilenameParams: + base_name: str + extension: str + lang_or_style: LangOrStyleCode | None = None + + +def build_dataset_filename( + base_name: str, + extension: str = "json", + lang_or_style: LangOrStyleCode | str | None = None, +) -> str: + if lang_or_style is not None and lang_or_style not in LANG_OR_STYLE_MAPPING: + raise ValueError(f"Unknown lang_or_style: {lang_or_style}") + if extension.startswith("."): + extension = extension[1:] + if lang_or_style is None: + return f"{base_name}.{extension}" + return f"{base_name}--l-{lang_or_style}.{extension}" + + +def parse_dataset_filename(filename: str | pathlib.Path) -> DatasetFilenameParams: + filepath = pathlib.Path(filename) + if "--l-" in filepath.stem: + base_name, lang_or_style = filepath.stem.split("--l-") + if lang_or_style not in LANG_OR_STYLE_MAPPING: + raise ValueError(f"Unknown lang_or_style: {lang_or_style}") + return DatasetFilenameParams( + base_name, extension=filepath.suffix, lang_or_style=lang_or_style + ) + return DatasetFilenameParams( + base_name=filepath.stem, extension=filepath.suffix, lang_or_style=None + ) diff --git a/repepo/data/multiple_choice/make_mwe_personas_caa.py b/repepo/data/multiple_choice/make_mwe_personas_caa.py index 127a582b..daea1d5e 100644 --- a/repepo/data/multiple_choice/make_mwe_personas_caa.py +++ b/repepo/data/multiple_choice/make_mwe_personas_caa.py @@ -3,7 +3,11 @@ import json from typing import cast -from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir +from repepo.data.make_dataset import ( + build_dataset_filename, + get_dataset_dir, + get_raw_dataset_dir, +) from repepo.data.io import jdump from repepo.core.types import Dataset, Example from repepo.translation.constants import LangOrStyleCode @@ -75,8 +79,9 @@ def make_mwe_personas_caa(): list_dataset = [json.loads(line) for line in jsonfile] persona = dataset_path.stem + filename = build_dataset_filename(persona) mwe_dataset: Dataset = convert_mwe_personas_dataset_caa(list_dataset) - jdump(mwe_dataset, get_dataset_dir() / "persona" / f"{persona}.json") + jdump(mwe_dataset, get_dataset_dir() / "persona" / filename) def make_mwe_personas_caa_translations(): @@ -86,13 +91,11 @@ def make_mwe_personas_caa_translations(): persona = dataset_path.stem lang_or_style = cast(LangOrStyleCode, dataset_path.parent.parent.parent.stem) + filename = build_dataset_filename(persona, lang_or_style=lang_or_style) mwe_dataset: Dataset = convert_mwe_personas_dataset_caa( list_dataset, lang_or_style=lang_or_style ) - jdump( - mwe_dataset, - get_dataset_dir() / "persona" / f"{persona}_{lang_or_style}.json", - ) + jdump(mwe_dataset, get_dataset_dir() / "persona" / filename) if __name__ == "__main__": diff --git a/repepo/data/multiple_choice/make_sycophancy_caa.py b/repepo/data/multiple_choice/make_sycophancy_caa.py index adca3539..f1ceb88a 100644 --- a/repepo/data/multiple_choice/make_sycophancy_caa.py +++ b/repepo/data/multiple_choice/make_sycophancy_caa.py @@ -4,7 +4,11 @@ Included for purposes of reproducing key plots """ -from repepo.data.make_dataset import get_dataset_dir, get_raw_dataset_dir +from repepo.data.make_dataset import ( + build_dataset_filename, + get_dataset_dir, + get_raw_dataset_dir, +) from repepo.data.io import jdump, jload from repepo.core.types import Dataset, Example @@ -48,12 +52,14 @@ def make_sycophancy_caa(): dataset_path = get_raw_dataset_dir() / "caa" / "generate_dataset.json" list_dataset = jload(dataset_path) syc_dataset = convert_sycophancy_dataset(list_dataset) - jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_train.json") + filename = build_dataset_filename("sycophancy_train") + jdump(syc_dataset, get_dataset_dir() / "caa" / filename) dataset_path = get_raw_dataset_dir() / "caa" / "test_dataset.json" list_dataset = jload(dataset_path) syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset) - jdump(syc_dataset, get_dataset_dir() / "caa" / "sycophancy_test.json") + filename = build_dataset_filename("sycophancy_test") + jdump(syc_dataset, get_dataset_dir() / "caa" / filename) def make_sycophancy_caa_translations(): @@ -64,10 +70,10 @@ def make_sycophancy_caa_translations(): lang_or_style = dataset_path.parent.parent.stem list_dataset = jload(dataset_path) syc_dataset = convert_sycophancy_dataset(list_dataset) - jdump( - syc_dataset, - get_dataset_dir() / "caa" / f"sycophancy_train_{lang_or_style}.json", + filename = build_dataset_filename( + "sycophancy_train", lang_or_style=lang_or_style ) + jdump(syc_dataset, get_dataset_dir() / "caa" / filename) for dataset_path in get_raw_dataset_dir().glob( "translated/*/caa/test_dataset.json" @@ -75,10 +81,10 @@ def make_sycophancy_caa_translations(): lang_or_style = dataset_path.parent.parent.stem list_dataset = jload(dataset_path) syc_dataset: Dataset = convert_sycophancy_dataset(list_dataset) - jdump( - syc_dataset, - get_dataset_dir() / "caa" / f"sycophancy_test_{lang_or_style}.json", + filename = build_dataset_filename( + "sycophancy_test", lang_or_style=lang_or_style ) + jdump(syc_dataset, get_dataset_dir() / "caa" / filename) if __name__ == "__main__": diff --git a/repepo/data/multiple_choice/make_truthfulqa.py b/repepo/data/multiple_choice/make_truthfulqa.py index a268f6c1..a44e9522 100644 --- a/repepo/data/multiple_choice/make_truthfulqa.py +++ b/repepo/data/multiple_choice/make_truthfulqa.py @@ -2,7 +2,7 @@ from typing import Any, Literal, cast from datasets import load_dataset, Dataset as HFDataset -from repepo.data.make_dataset import get_dataset_dir +from repepo.data.make_dataset import build_dataset_filename, get_dataset_dir from repepo.data.io import jdump from repepo.core.types import Dataset, Example from repepo.variables import Environ @@ -119,7 +119,8 @@ def make_truthfulqa(): # hf's dataset is too general and requires casting every field we access, so just using Any for simplicity hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"] tqa_dataset = convert_hf_truthfulqa_dataset(hf_dataset) - jdump(tqa_dataset, get_dataset_dir() / "truthfulqa.json") + filename = build_dataset_filename("truthfulqa") + jdump(tqa_dataset, get_dataset_dir() / filename) # also build translated datasets for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob( @@ -128,9 +129,10 @@ def make_truthfulqa(): lang_or_style = translated_tqa.parent.stem dataset = HFDataset.from_json(str(translated_tqa)) converted_dataset = convert_hf_truthfulqa_dataset(dataset) + filename = build_dataset_filename("truthfulqa", lang_or_style=lang_or_style) jdump( converted_dataset, - get_dataset_dir() / f"truthfulqa_{lang_or_style}.json", + get_dataset_dir() / filename, ) @@ -138,7 +140,8 @@ def make_truthfulqa_caa(): # hf's dataset is too general and requires casting every field we access, so just using Any for simplicity hf_dataset = cast(Any, load_dataset("truthful_qa", "multiple_choice"))["validation"] tqa_dataset = convert_hf_truthfulqa_caa_dataset(hf_dataset) - jdump(tqa_dataset, get_dataset_dir() / "truthfulqa_caa.json") + filename = build_dataset_filename("truthfulqa_caa") + jdump(tqa_dataset, get_dataset_dir() / filename) # also build translated datasets for translated_tqa in Path(Environ.TranslatedDatasetsDir).glob( @@ -147,10 +150,8 @@ def make_truthfulqa_caa(): lang_or_style = translated_tqa.parent.stem dataset = HFDataset.from_json(str(translated_tqa)) converted_dataset = convert_hf_truthfulqa_caa_dataset(dataset) - jdump( - converted_dataset, - get_dataset_dir() / f"truthfulqa_caa_{lang_or_style}.json", - ) + filename = build_dataset_filename("truthfulqa_caa", lang_or_style=lang_or_style) + jdump(converted_dataset, get_dataset_dir() / filename) if __name__ == "__main__": diff --git a/repepo/translation/translated_strings.json b/repepo/translation/translated_strings.json index bc87a082..0fc78946 100644 --- a/repepo/translation/translated_strings.json +++ b/repepo/translation/translated_strings.json @@ -1,18 +1,22 @@ { "fr": { "llama2_chat_caa_system_message": "Vous êtes un assistant serviable, honnête et concis.", - "caa_choices": "Choix :" + "caa_choices": "Choix :", + "caa_eval_prefix": "Ma réponse est" }, "ja": { "llama2_chat_caa_system_message": "あなたは役に立ち、正直で、簡潔なアシスタントです。", - "caa_choices": "選択肢:" + "caa_choices": "選択肢:", + "caa_eval_prefix": "私の答えは" }, "zh": { "llama2_chat_caa_system_message": "您是一个有用、诚实且简洁的助手。", - "caa_choices": "选择:" + "caa_choices": "选择:", + "caa_eval_prefix": "我的答案是" }, "pirate": { "llama2_chat_caa_system_message": "Ye be a handy, true-hearted, and straight-talkin' matey.", - "caa_choices": "Choices, ye say?" + "caa_choices": "The choices be:", + "caa_eval_prefix": "Me reply be" } } diff --git a/repepo/translation/translation_strings.py b/repepo/translation/translation_strings.py index 0d06d580..ab60bfe4 100644 --- a/repepo/translation/translation_strings.py +++ b/repepo/translation/translation_strings.py @@ -18,6 +18,7 @@ class TranslationString(Enum): llama2_chat_caa_system_message = "You are a helpful, honest and concise assistant." caa_choices = "Choices:" + caa_eval_prefix = "My answer is" @cache diff --git a/tests/data/test_make_dataset.py b/tests/data/test_make_dataset.py index af57396e..3ec00a7a 100644 --- a/tests/data/test_make_dataset.py +++ b/tests/data/test_make_dataset.py @@ -1,5 +1,10 @@ import pytest -from repepo.data.make_dataset import _parse_split, _shuffle_and_split +from repepo.data.make_dataset import ( + _parse_split, + _shuffle_and_split, + build_dataset_filename, + parse_dataset_filename, +) def test_parse_split() -> None: @@ -38,3 +43,32 @@ def test_shuffle_and_split_leaves_original_item_unchanged() -> None: items = [1, 2, 3, 4, 5] _shuffle_and_split(items, ":50%", seed=0) assert items == [1, 2, 3, 4, 5] + + +def test_parse_dataset_filename_with_no_lang() -> None: + filename = "blah/test.json" + params = parse_dataset_filename(filename) + assert params.base_name == "test" + assert params.extension == ".json" + assert params.lang_or_style is None + + +def test_parse_dataset_filename_with_lang() -> None: + filename = "blah/test--l-fr.json" + params = parse_dataset_filename(filename) + assert params.base_name == "test" + assert params.extension == ".json" + assert params.lang_or_style == "fr" + + +def test_build_dataset_filename() -> None: + assert build_dataset_filename("test") == "test.json" + assert build_dataset_filename("test", ".json") == "test.json" + assert build_dataset_filename("test", ".json", "fr") == "test--l-fr.json" + assert build_dataset_filename("test", "json", "fr") == "test--l-fr.json" + assert build_dataset_filename("test", lang_or_style="fr") == "test--l-fr.json" + + +def test_build_dataset_filename_errors_on_invalid_lang() -> None: + with pytest.raises(ValueError): + build_dataset_filename("test", lang_or_style="FAKELANG")