Skip to content

Commit

Permalink
Only assert reshuffle if we are in train mode and we specify a data u…
Browse files Browse the repository at this point in the history
…psample factor (mlfoundations#655)
  • Loading branch information
humzaiqbal authored and Interpause committed May 23, 2024
1 parent f981720 commit 2c9a4a8
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
num_samples = args.val_num_samples or 0

shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc

if is_train and args.train_data_upsampling_factors is not None:
assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."

if resampled:
pipeline = [ResampledShards2(
Expand All @@ -354,8 +357,6 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
epoch=shared_epoch,
)]
else:
assert args.train_data_upsampling_factors is None,\
"--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
pipeline = [wds.SimpleShardList(input_shards)]

# at this point we have an iterator over all the shards
Expand Down

0 comments on commit 2c9a4a8

Please sign in to comment.