diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 72e19d54e7..6cb434a7d6 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -635,6 +635,8 @@ def __iter__(self): else: more_examples = False break + if self.shuffle: + random.shuffle(buffer) tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[ "input_ids" ] @@ -649,6 +651,7 @@ def __iter__(self): if len(input_ids) == self.seq_length: examples.append(input_ids) if self.shuffle: + # Shuffle again, otherwise split examples occur in consecutive tensors. random.shuffle(examples) for example in examples: self.current_size += 1