diff --git a/bergson/data.py b/bergson/data.py index 4475585..3f8b65e 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -1,6 +1,7 @@ import json import math import os +import random from dataclasses import dataclass from pathlib import Path from typing import Literal, Sequence @@ -114,7 +115,7 @@ def ceildiv(a: int, b: int) -> int: return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers -def allocate_batches(doc_lengths: list[int], N: int) -> list[list[int]]: +def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[list[int]]: """ Allocate documents into batches that are then distributed evenly across a fixed number of workers. @@ -230,8 +231,13 @@ def allocate_batches(doc_lengths: list[int], N: int) -> list[list[int]]: for b_idx, batch in enumerate(batches): allocation[b_idx % world_size].append(batch) - # sanity: equal # of batches per worker + # Sanity: equal # of batches per worker assert len({len(b) for b in allocation}) == 1 + + # Break any systematic ordering of batches + random.seed(seed) + random.shuffle(allocation[rank]) + return allocation[rank]