Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster data fraction #1069

Closed
wants to merge 12 commits into from
Closed

Faster data fraction #1069

wants to merge 12 commits into from

Conversation

HaokunLiu
Copy link
Member

@HaokunLiu HaokunLiu commented Apr 16, 2020

#180
This add a one-time cost before training, but the code should be more efficient during training.

There isn't a test case for read_record, but I wrote some temporary code to compare the result from previous code and new code. You can check in the file changes. After the PR is approved, I will remove those temporary code.

Ideally, if we can get this through soon, we will be able to save some time and cost in the incoming taskmaster experiments.

@sleepinyourhat
Copy link
Contributor

Leaving this to @pyeres, but a couple of high-level comments:

  • Check that logging is correct.
  • When you write permanent tests, one of them should check situations where you use a limited data fraction in one stage of training, but the full dataset in the other stage of training. IIRC, this was a source of some extra complexity previously, and I believe it has come up as something we've found useful.

@pyeres
Copy link
Contributor

pyeres commented Apr 16, 2020

@HaokunLiu, thanks for the regression tests. Can you give an estimate of time savings w/ these changes? (useful for prioritizing review and any additional work that might be necessary for this PR)

return RepeatableIterator(_iter_fn) if repeatable else _iter_fn()


# temporary backup code, remove before merging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case other reviewer would like to check this, I will remove it once the PR is approved

all_record_files = "%s*" % record_file
for one_record_file in glob.glob(all_record_files):
if os.path.exists(one_record_file) and os.path.islink(one_record_file):
os.remove(one_record_file)

_index_split(
task, split, indexers, vocab, record_file, model_preprocessing_interface
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment about what is happening here?

@pruksmhc
Copy link
Contributor

I'm pretty sure I understand what's going on here, but not 100% sure - basically, you're adding functionality of loading the vocabulary for all the tasks before running each experiment for the one-time cost. A few requests.

  1. Please document more clearly what's going on, and flesh out what the "1-time cost" is so a stranger can understand what this PR is for.
  2. Please move the test to tests folder and make it a unit test.

Copy link
Contributor

@pruksmhc pruksmhc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments

@HaokunLiu
Copy link
Member Author

I'm pretty sure I understand what's going on here, but not 100% sure - basically, you're adding functionality of loading the vocabulary for all the tasks before running each experiment for the one-time cost. A few requests.

  1. Please document more clearly what's going on, and flesh out what the "1-time cost" is so a stranger can understand what this PR is for.
  2. Please move the test to tests folder and make it a unit test.

Errrr, this PR is not about that thing... This PR is meant to address low gpu usage when data_fraction is very low, like 1%. Instead of iterating over all the instances and rejecting most of them, this create an additional preprocessed file that contains all the instances that will actually be used.

This is why when reload_index is True, we need to use glob to find out both the normally preprocessed file and our additional preprocessed files, and delete them all

@HaokunLiu
Copy link
Member Author

Leaving this to @pyeres, but a couple of high-level comments:

  • Check that logging is correct.
  • When you write permanent tests, one of them should check situations where you use a limited data fraction in one stage of training, but the full dataset in the other stage of training. IIRC, this was a source of some extra complexity previously, and I believe it has come up as something we've found useful.

The issue you mentioned will be addressed in #1070 . This PR is orthogonal to it.

Copy link
Contributor

@pyeres pyeres left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HaokunLiu — thanks for the PR, looks good. Please see comments for two small change requests.

if hash_float > fraction:
continue
example = pkl.loads(blob)
examples.append(example)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use generator/yield to avoid loading examples into memory here (before they're written out again).

jiant/utils/serialize.py Show resolved Hide resolved
@@ -0,0 +1,57 @@
from jiant.utils import serialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please make this into a unit test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more like a regression test. I don't see an easy way to convert it into a unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @HaokunLiu — I tried to write something similar to the tests you wrote, but tried to make them more like unit tests. These tests end up involving write_records, but they'll add some protection from changes going forward:

class TestReadRecords(unittest.TestCase):

    def test_read_records_without_data_fraction(self):
        """write records then read records (with no data fraction arg), check that records match"""
        data_file_name = "data.pb64"
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, data_file_name)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_read = serialize.read_records(fake_filepath)
            self.assertCountEqual(fake_examples, fake_examples_read)

    def test_read_records_with_data_fraction(self):
        """write examples, read without fraction, read with fraction, check expected files exist"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_read = serialize.read_records(fake_filepath)
            fake_examples_read_frac = serialize.read_records(fake_filepath, fraction=frac)
            file_list = os.listdir(tmp_dir_path)
            self.assertLess(len(list(fake_examples_read_frac)), len(list(fake_examples_read)))
            self.assertTrue(set(fake_examples_read_frac).issubset(set(fake_examples_read)))
            self.assertCountEqual(file_list, [filename, filename + "__fraction_" + str(frac)])

    def test_read_records_with_data_fraction_without_prior_full_data_read(self):
        """write examples, then read with fraction (no full read before fractional read)"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
            self.assertLess(len(list(fake_examples_frac)), len(list(fake_examples)))
            self.assertTrue(set(fake_examples_frac).issubset(set(fake_examples)))

    def test_read_records_with_data_fraction_from_cache(self):
        """write examples, read with fraction, then read with fraction again (from cache)"""
        filename = "data.pb64"
        frac = 0.1
        fake_example_count = 100
        fake_examples = [i for i in range(fake_example_count)]
        with tempfile.TemporaryDirectory() as tmp_dir_path:
            fake_filepath = os.path.join(tmp_dir_path, filename)
            serialize.write_records(fake_examples, fake_filepath)
            fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
            os.remove(fake_filepath)  # remove full file make sure we get frac examples from cache
            fake_examples_frac_cache = serialize.read_records(fake_filepath, fraction=frac)
            self.assertCountEqual(list(fake_examples_frac), list(fake_examples_frac_cache))

Copy link
Contributor

@pyeres pyeres left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @HaokunLiu — after you remove your regression tests and add some version of the unit tests posted here, I think you're ready to merge. Please ping me in this thread for a final look / approval.

@pyeres
Copy link
Contributor

pyeres commented May 1, 2020

Based on logs in some recent experiments it looks like these changes speed up data reading, however there's a bottleneck elsewhere (with these changes alone the speedup is <10%).

Planning to close this PR. Issue "training_data_fraction is slow." #180 remains open to track issue. If this issue becomes a priority again, the changes in this PR should be considered as part of a broader fix.

@pyeres pyeres closed this May 1, 2020
@HaokunLiu HaokunLiu deleted the efficient_data_frac branch May 4, 2020 17:44
@jeswan jeswan added the jiant-v1-legacy Relevant to versions <= v1.3.2 label Sep 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jiant-v1-legacy Relevant to versions <= v1.3.2
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants