From 4d41ffb029c34672b9570c43fa2e89f48f6390b9 Mon Sep 17 00:00:00 2001 From: Mona Lisa Date: Wed, 4 Oct 2023 16:15:33 +0000 Subject: [PATCH] Only assert reshuffle if we are in train mode and we specify a data upsample factor --- src/training/data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 2ed076d96..07b9fee96 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -345,6 +345,9 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni num_samples = args.val_num_samples or 0 shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if is_train and args.train_data_upsampling_factors is not None: + assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." if resampled: pipeline = [ResampledShards2( @@ -354,8 +357,6 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni epoch=shared_epoch, )] else: - assert args.train_data_upsampling_factors is None,\ - "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." pipeline = [wds.SimpleShardList(input_shards)] # at this point we have an iterator over all the shards