Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster data fraction #1069

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jiant/config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ target_train_data_fraction = 1 // Use only the specified fraction of the traini
pretrain_data_fraction = 1 // Use only the specified fraction of the training data in the
// do_pretrain phase. Should not impact target-phase training, even for
// the same task.
// Note: This uses rejection sampling at training time, so it can slow
// down training for small fractions (<5%).
// Note: When set to less than one, this creates an additional
// preprocessed file, the size of which is data fraction times the
// full preprocessed file size


// Training options //
Expand Down
9 changes: 6 additions & 3 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
import logging as log
import os
import glob
import sys
from collections import defaultdict
from typing import List, Dict, Union, Any
Expand Down Expand Up @@ -73,7 +74,7 @@ def _get_serialized_record_path(task_name, split, preproc_dir):
return serialized_record_path


def _get_instance_generator(task_name, split, preproc_dir, fraction=None):
def _get_instance_generator(task_name, split, preproc_dir, fraction=1.0):
"""Get a lazy generator for the given task and split.

Args:
Expand Down Expand Up @@ -399,8 +400,10 @@ def build_tasks(
if force_reindex or not cache_found:
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
# Re-index from scratch.
record_file = _get_serialized_record_path(task.name, split, preproc_dir)
if os.path.exists(record_file) and os.path.islink(record_file):
os.remove(record_file)
all_record_files = "%s*" % record_file
for one_record_file in glob.glob(all_record_files):
if os.path.exists(one_record_file) and os.path.islink(one_record_file):
os.remove(one_record_file)

_index_split(
task, split, indexers, vocab, record_file, model_preprocessing_interface
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment about what is happening here?

Expand Down
46 changes: 45 additions & 1 deletion jiant/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# per line as a base64-encoded pickle.

import _pickle as pkl
import os
import base64
from zlib import crc32

Expand Down Expand Up @@ -57,7 +58,50 @@ def bytes_to_float(b):
return float(crc32(b) & 0xFFFFFFFF) / 2 ** 32


def read_records(filename, repeatable=False, fraction=None):
def read_records(filename, repeatable=False, fraction=1.0):
"""Streaming read records from file.

Args:
filename: path to file of b64-encoded pickles, one per line
repeatable: if true, returns a RepeatableIterator that can read the file
multiple times.
fraction: if set to a float between 0 and 1, load only the specified percentage
of examples. Hashing is used to ensure that the same examples are loaded each
epoch.

Returns:
iterable, possible repeatable, yielding deserialized Python objects
"""

if fraction < 1.0:
frac_filename = f"{filename}__fraction_{fraction}"
pyeres marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.exists(frac_filename):

def _frac_iter_fn():
with open(filename, "rb") as fd:
for line in fd:
blob = base64.b64decode(line)
hash_float = bytes_to_float(blob)
if hash_float > fraction:
continue
example = pkl.loads(blob)
yield example

write_records(_frac_iter_fn(), frac_filename)
filename = frac_filename

def _iter_fn():
with open(filename, "rb") as fd:
for line in fd:
blob = base64.b64decode(line)
example = pkl.loads(blob)
yield example

return RepeatableIterator(_iter_fn) if repeatable else _iter_fn()


# temporary backup code, remove before merging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case other reviewer would like to check this, I will remove it once the PR is approved

def old_read_records(filename, repeatable=False, fraction=None):
"""Streaming read records from file.

Args:
Expand Down
57 changes: 57 additions & 0 deletions scripts/temporary_test_read_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from jiant.utils import serialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please make this into a unit test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more like a regression test. I don't see an easy way to convert it into a unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @HaokunLiu — I tried to write something similar to the tests you wrote, but tried to make them more like unit tests. These tests end up involving write_records, but they'll add some protection from changes going forward:

class TestReadRecords(unittest.TestCase):

    def test_read_records_without_data_fraction(self):
        """write records then read records (with no data fraction arg), check that records match"""
        data_file_name = "data.pb64"
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, data_file_name)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_read = serialize.read_records(fake_filepath)
            self.assertCountEqual(fake_examples, fake_examples_read)

    def test_read_records_with_data_fraction(self):
        """write examples, read without fraction, read with fraction, check expected files exist"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_read = serialize.read_records(fake_filepath)
            fake_examples_read_frac = serialize.read_records(fake_filepath, fraction=frac)
            file_list = os.listdir(tmp_dir_path)
            self.assertLess(len(list(fake_examples_read_frac)), len(list(fake_examples_read)))
            self.assertTrue(set(fake_examples_read_frac).issubset(set(fake_examples_read)))
            self.assertCountEqual(file_list, [filename, filename + "__fraction_" + str(frac)])

    def test_read_records_with_data_fraction_without_prior_full_data_read(self):
        """write examples, then read with fraction (no full read before fractional read)"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
            self.assertLess(len(list(fake_examples_frac)), len(list(fake_examples)))
            self.assertTrue(set(fake_examples_frac).issubset(set(fake_examples)))

    def test_read_records_with_data_fraction_from_cache(self):
        """write examples, read with fraction, then read with fraction again (from cache)"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
            os.remove(fake_filepath)  # remove full file make sure we get frac examples from cache
            fake_examples_frac_cache = serialize.read_records(fake_filepath, fraction=frac)
            self.assertCountEqual(list(fake_examples_frac), list(fake_examples_frac_cache))

import os
import torch

frac_tasks = [
("edges-ner-ontonotes-5k", "edges-ner-ontonotes", 0.10059147788999316),
("edges-srl-ontonotes-5k", "edges-srl-ontonotes", 0.02160013824088474),
("edges-coref-ontonotes-5k", "edges-coref-ontonotes", 0.11968307920626182),
("edges-pos-ontonotes-5k", "edges-pos-ontonotes", 0.06170173381872031),
("edges-nonterminal-ontonotes-5k", "edges-nonterminal-ontonotes", 0.04524354600816194),
("edges-dep-ud-ewt-5k", "edges-dep-ud-ewt", 0.4001280409731114),
("se-probing-word-content-5k", "se-probing-word-content", 0.05),
("se-probing-tree-depth-5k", "se-probing-tree-depth", 0.05),
("se-probing-top-constituents-5k", "se-probing-top-constituents", 0.05),
("se-probing-bigram-shift-5k", "se-probing-bigram-shift", 0.05),
("se-probing-past-present-5k", "se-probing-past-present", 0.05),
("se-probing-subj-number-5k", "se-probing-subj-number", 0.05),
("se-probing-obj-number-5k", "se-probing-obj-number", 0.05),
("se-probing-odd-man-out-5k", "se-probing-odd-man-out", 0.05),
("se-probing-coordination-inversion-5k", "se-probing-coordination-inversion", 0.05),
("se-probing-sentence-length-5k", "se-probing-sentence-length", 0.0500020000800032),
("cola-5k", "cola", 0.584726932522512),
("sst-20k", "sst", 0.29696060817532555),
("socialiqa-20k", "socialiqa", 0.5986231667165519),
("ccg-20k", "ccg", 0.5261081152176772),
("qqp-20k", "qqp", 0.05496831076884176),
("mnli-20k", "mnli", 0.05092920331447255),
("scitail-20k", "scitail", 0.8476012883539583),
("qasrl-20k", "qasrl", 0.031348502873874),
("qamr-20k", "qamr", 0.3951397806974217),
("cosmosqa-20k", "cosmosqa", 0.7917029530520149),
("hellaswag-20k", "hellaswag", 0.5011903270266884),
("record-20k", "record", 0.011130552476102704),
("winogrande-20k", "winogrande", 0.49507401356502795),
]

for limited_size_task, task_name, frac in frac_tasks:
filename = os.path.join(
"/scratch/hl3236/jiant_results/",
f"optuna_{task_name}",
"preproc",
f"{task_name}__train_data",
)
print(f"{limited_size_task} start")
if os.path.exists(filename):
data_old = serialize.old_read_records(filename, repeatable=False, fraction=frac)
data_new = serialize.read_records(filename, repeatable=False, fraction=frac)
for instance_old, instance_new in zip(data_old, data_new):
td_old, td_new = instance_old.as_tensor_dict(), instance_new.as_tensor_dict()
for key in td_old:
assert repr(td_old[key]) == repr(td_new[key]), (
f"{limited_size_task}, {key} mismatch \n"
f"old: {repr(td_old[key])}m \nnew: {repr(td_new[key])}"
)
print(f"{limited_size_task} checked")
else:
print(f"{limited_size_task} data not available")