-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster data fraction #1069
Faster data fraction #1069
Conversation
Leaving this to @pyeres, but a couple of high-level comments:
|
@HaokunLiu, thanks for the regression tests. Can you give an estimate of time savings w/ these changes? (useful for prioritizing review and any additional work that might be necessary for this PR) |
return RepeatableIterator(_iter_fn) if repeatable else _iter_fn() | ||
|
||
|
||
# temporary backup code, remove before merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case other reviewer would like to check this, I will remove it once the PR is approved
all_record_files = "%s*" % record_file | ||
for one_record_file in glob.glob(all_record_files): | ||
if os.path.exists(one_record_file) and os.path.islink(one_record_file): | ||
os.remove(one_record_file) | ||
|
||
_index_split( | ||
task, split, indexers, vocab, record_file, model_preprocessing_interface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a comment about what is happening here?
I'm pretty sure I understand what's going on here, but not 100% sure - basically, you're adding functionality of loading the vocabulary for all the tasks before running each experiment for the one-time cost. A few requests.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments
Errrr, this PR is not about that thing... This PR is meant to address low gpu usage when data_fraction is very low, like 1%. Instead of iterating over all the instances and rejecting most of them, this create an additional preprocessed file that contains all the instances that will actually be used. This is why when reload_index is True, we need to use glob to find out both the normally preprocessed file and our additional preprocessed files, and delete them all |
The issue you mentioned will be addressed in #1070 . This PR is orthogonal to it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HaokunLiu — thanks for the PR, looks good. Please see comments for two small change requests.
jiant/utils/serialize.py
Outdated
if hash_float > fraction: | ||
continue | ||
example = pkl.loads(blob) | ||
examples.append(example) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use generator/yield to avoid loading examples into memory here (before they're written out again).
@@ -0,0 +1,57 @@ | |||
from jiant.utils import serialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make this into a unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is more like a regression test. I don't see an easy way to convert it into a unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @HaokunLiu — I tried to write something similar to the tests you wrote, but tried to make them more like unit tests. These tests end up involving write_records
, but they'll add some protection from changes going forward:
class TestReadRecords(unittest.TestCase):
def test_read_records_without_data_fraction(self):
"""write records then read records (with no data fraction arg), check that records match"""
data_file_name = "data.pb64"
fake_example_count = 100
fake_examples = [i for i in range(fake_example_count)]
with tempfile.TemporaryDirectory() as tmp_dir_path:
fake_filepath = os.path.join(tmp_dir_path, data_file_name)
serialize.write_records(fake_examples, fake_filepath)
fake_examples_read = serialize.read_records(fake_filepath)
self.assertCountEqual(fake_examples, fake_examples_read)
def test_read_records_with_data_fraction(self):
"""write examples, read without fraction, read with fraction, check expected files exist"""
filename = "data.pb64"
frac = 0.1
fake_example_count = 100
fake_examples = [i for i in range(fake_example_count)]
with tempfile.TemporaryDirectory() as tmp_dir_path:
fake_filepath = os.path.join(tmp_dir_path, filename)
serialize.write_records(fake_examples, fake_filepath)
fake_examples_read = serialize.read_records(fake_filepath)
fake_examples_read_frac = serialize.read_records(fake_filepath, fraction=frac)
file_list = os.listdir(tmp_dir_path)
self.assertLess(len(list(fake_examples_read_frac)), len(list(fake_examples_read)))
self.assertTrue(set(fake_examples_read_frac).issubset(set(fake_examples_read)))
self.assertCountEqual(file_list, [filename, filename + "__fraction_" + str(frac)])
def test_read_records_with_data_fraction_without_prior_full_data_read(self):
"""write examples, then read with fraction (no full read before fractional read)"""
filename = "data.pb64"
frac = 0.1
fake_example_count = 100
fake_examples = [i for i in range(fake_example_count)]
with tempfile.TemporaryDirectory() as tmp_dir_path:
fake_filepath = os.path.join(tmp_dir_path, filename)
serialize.write_records(fake_examples, fake_filepath)
fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
self.assertLess(len(list(fake_examples_frac)), len(list(fake_examples)))
self.assertTrue(set(fake_examples_frac).issubset(set(fake_examples)))
def test_read_records_with_data_fraction_from_cache(self):
"""write examples, read with fraction, then read with fraction again (from cache)"""
filename = "data.pb64"
frac = 0.1
fake_example_count = 100
fake_examples = [i for i in range(fake_example_count)]
with tempfile.TemporaryDirectory() as tmp_dir_path:
fake_filepath = os.path.join(tmp_dir_path, filename)
serialize.write_records(fake_examples, fake_filepath)
fake_examples_frac = serialize.read_records(fake_filepath, fraction=frac)
os.remove(fake_filepath) # remove full file make sure we get frac examples from cache
fake_examples_frac_cache = serialize.read_records(fake_filepath, fraction=frac)
self.assertCountEqual(list(fake_examples_frac), list(fake_examples_frac_cache))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @HaokunLiu — after you remove your regression tests and add some version of the unit tests posted here, I think you're ready to merge. Please ping me in this thread for a final look / approval.
Based on logs in some recent experiments it looks like these changes speed up data reading, however there's a bottleneck elsewhere (with these changes alone the speedup is <10%). Planning to close this PR. Issue "training_data_fraction is slow." #180 remains open to track issue. If this issue becomes a priority again, the changes in this PR should be considered as part of a broader fix. |
#180
This add a one-time cost before training, but the code should be more efficient during training.
There isn't a test case for read_record, but I wrote some temporary code to compare the result from previous code and new code. You can check in the file changes. After the PR is approved, I will remove those temporary code.
Ideally, if we can get this through soon, we will be able to save some time and cost in the incoming taskmaster experiments.