diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index ea436af4..824bf258 100644 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -205,6 +205,8 @@ def __init__( def __iter__(self): indices = list(range(len(self.dataset))) + random.seed(self.seed + self.epoch) + random.shuffle(indices) # if we don't shuffle here, the final ( len(self.dataset) - self.total_size ) samples will be dropped forever # 1. split the full indices first (note: without drop last at this moment) indices_list = []