Skip to content

Commit

Permalink
Fixed bug in MultiprocessingBatcher with indices assigned to workers
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Dec 29, 2024
1 parent 3c181c0 commit cf8427f
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions credit/datasets/era5_multistep_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,13 @@ def worker_process(start_idx, end_idx):
self.results[k] = sample # Store the result keyed by its order index

# Split tasks among workers
tasks_per_worker = max(1, len(args) // self.num_workers)
processes = []

for i in range(self.num_workers):
start_idx = i * tasks_per_worker
end_idx = min((i + 1) * tasks_per_worker, len(args))
if start_idx < end_idx: # Check if there are tasks for this worker
p = multiprocessing.Process(target=worker_process, args=(start_idx, end_idx))
processes.append(p)
p.start()
splits = np.array_split(range(len(args)), self.num_workers)
start_ends = [(split[0], split[-1] + 1) for split in splits]
for start_idx, end_idx in start_ends:
p = multiprocessing.Process(target=worker_process, args=(start_idx, end_idx))
processes.append(p)
p.start()

# Wait for all processes to finish
for p in processes:
Expand Down Expand Up @@ -769,14 +766,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
data_config = setup_data_loading(conf)

epoch = 0
batch_size = 2
batch_size = 5
data_config["forecast_len"] = 6
data_config["history_len"] = 3
data_config["history_len"] = 1
shuffle = True

rank = 0
world_size = 2
num_workers = 2
num_workers = 4

set_globals(data_config, namespace=globals())

Expand Down

0 comments on commit cf8427f

Please sign in to comment.