Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support to handle Parquet Dataset files via data config #401

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join(
DATA_DIR, "twitter_complaints_input_output.json"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
)
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
)
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
)
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
MODEL_NAME = "Maykeye/TinyLLama-v0"
Binary file not shown.
Binary file not shown.
Binary file not shown.
100 changes: 100 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
MODEL_NAME,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_JSON,
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
)

# Local
Expand All @@ -59,6 +62,10 @@
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
set(["ID", "Label", "input", "output"]),
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
set(["ID", "Label", "input", "output"]),
),
(
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
set(
Expand All @@ -73,10 +80,28 @@
]
),
),
(
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
set(
[
"Tweet text",
"ID",
"Label",
"text_label",
"output",
"input_ids",
"labels",
]
),
),
(
TWITTER_COMPLAINTS_DATA_JSONL,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
),
(
TWITTER_COMPLAINTS_DATA_PARQUET,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
),
],
)
def test_load_dataset_with_datafile(datafile, column_names):
Expand All @@ -98,6 +123,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
set(["ID", "Label", "input", "output"]),
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
set(["ID", "Label", "input", "output"]),
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
set(
Expand All @@ -113,11 +143,31 @@ def test_load_dataset_with_datafile(datafile, column_names):
),
"pretokenized_dataset",
),
(
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
set(
[
"Tweet text",
"ID",
"Label",
"text_label",
"output",
"input_ids",
"labels",
]
),
"pretokenized_dataset",
),
(
TWITTER_COMPLAINTS_DATA_JSONL,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
"apply_custom_data_template",
),
(
TWITTER_COMPLAINTS_DATA_PARQUET,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
"apply_custom_data_template",
),
],
)
def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname):
Expand All @@ -139,8 +189,14 @@ def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigna
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
"text_dataset_input_output_masking",
),
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, "pretokenized_dataset"),
(TWITTER_COMPLAINTS_TOKENIZED_PARQUET, "pretokenized_dataset"),
(TWITTER_COMPLAINTS_DATA_JSONL, "apply_custom_data_template"),
(TWITTER_COMPLAINTS_DATA_PARQUET, "apply_custom_data_template"),
],
)
def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
Expand Down Expand Up @@ -339,8 +395,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
[
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
Expand All @@ -349,6 +407,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
],
)
def test_process_dataconfig_file(data_config_path, data_path):
Expand Down Expand Up @@ -414,6 +476,15 @@ def test_process_dataconfig_file(data_config_path, data_path):
response_template="\n### Label:",
)
),
# single sequence PARQUET and response template
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
dataset_text_field="output",
response_template="\n### Label:",
)
),
# data formatter template with input/output JSON
(
configs.DataArguments(
Expand All @@ -432,6 +503,15 @@ def test_process_dataconfig_file(data_config_path, data_path):
response_template="\n### Label:",
)
),
# data formatter template with input/output PARQUET
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
# input/output JSON with masking on input
(
configs.DataArguments(
Expand All @@ -446,6 +526,13 @@ def test_process_dataconfig_file(data_config_path, data_path):
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
)
),
# input/output PARQUET with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
)
),
],
)
def test_process_dataargs(data_args):
Expand Down Expand Up @@ -487,6 +574,13 @@ def test_process_dataargs(data_args):
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
)
),
# PARQUET pretokenized train and validation datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
)
),
# JSON pretokenized train datasets
(
configs.DataArguments(
Expand All @@ -499,6 +593,12 @@ def test_process_dataargs(data_args):
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
)
),
# PARQUET pretokenized train datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
)
),
],
)
def test_process_dataargs_pretokenized(data_args):
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _process_dataset_configs(
# In future the streaming etc go as kwargs of this function
raw_dataset = self.load_dataset(d, splitName)

logging.info("Loaded raw dataset : {raw_datasets}")
logging.info("Loaded raw dataset : %s", str(raw_dataset))

raw_datasets = DatasetDict()

Expand Down
2 changes: 2 additions & 0 deletions tuning/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str:
return "text"
if ext in (".json", ".jsonl"):
return "json"
if ext in (".parquet"):
return "parquet"
return ext


Expand Down
Loading