diff --git a/credit/datasets/era5_multistep_batcher.py b/credit/datasets/era5_multistep_batcher.py index 83512a5..dc453ff 100644 --- a/credit/datasets/era5_multistep_batcher.py +++ b/credit/datasets/era5_multistep_batcher.py @@ -464,16 +464,13 @@ def worker_process(start_idx, end_idx): self.results[k] = sample # Store the result keyed by its order index # Split tasks among workers - tasks_per_worker = max(1, len(args) // self.num_workers) processes = [] - - for i in range(self.num_workers): - start_idx = i * tasks_per_worker - end_idx = min((i + 1) * tasks_per_worker, len(args)) - if start_idx < end_idx: # Check if there are tasks for this worker - p = multiprocessing.Process(target=worker_process, args=(start_idx, end_idx)) - processes.append(p) - p.start() + splits = np.array_split(range(len(args)), self.num_workers) + start_ends = [(split[0], split[-1] + 1) for split in splits] + for start_idx, end_idx in start_ends: + p = multiprocessing.Process(target=worker_process, args=(start_idx, end_idx)) + processes.append(p) + p.start() # Wait for all processes to finish for p in processes: @@ -769,14 +766,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): data_config = setup_data_loading(conf) epoch = 0 - batch_size = 2 + batch_size = 5 data_config["forecast_len"] = 6 - data_config["history_len"] = 3 + data_config["history_len"] = 1 shuffle = True rank = 0 world_size = 2 - num_workers = 2 + num_workers = 4 set_globals(data_config, namespace=globals())