33import json
44import os
55import subprocess
6- import time
76from collections .abc import Callable
87from dataclasses import dataclass
98from pathlib import Path
@@ -118,7 +117,6 @@ def distribute_clustering(
118117 workers_per_device : int = 1 ,
119118 log_fn : Callable [[str ], None ] | None = None ,
120119 log_fn_error : Callable [[str ], None ] | None = None ,
121- wait_before_next_loop : float = 1.0 ,
122120) -> list [dict [str , str | None ]]:
123121 """Distribute clustering tasks across multiple devices via subprocess.
124122
@@ -127,8 +125,9 @@ def distribute_clustering(
127125
128126 The concurrency model:
129127 - Total concurrency = workers_per_device x len(devices)
130- - Tasks are round-robin assigned to devices
131- - Each device can have up to workers_per_device tasks running concurrently
128+ - Uses round-robin device assignment starting point
129+ - If target device is full, uses any available device
130+ - If all devices are full, waits for a process on the target device to finish
132131
133132 Args:
134133 config_path: Path to clustering configuration file
@@ -139,7 +138,6 @@ def distribute_clustering(
139138 workers_per_device: Maximum concurrent workers per device
140139 log_fn: Optional logging function for info messages
141140 log_fn_error: Optional logging function for error messages
142- wait_before_next_loop: Seconds to wait before checking for free slots among all devices again
143141
144142 Returns:
145143 List of result dictionaries from each batch processing
@@ -165,37 +163,35 @@ def distribute_clustering(
165163 results : list [dict [str , str | None ]] = []
166164
167165 n_files : int = len (data_files )
168- device_idx : int
169- attempts : int
170166 try :
171167 for idx , dataset in enumerate (data_files ):
172- # Find next device with capacity using round-robin starting point
168+ # Find a device with capacity, starting from round-robin position
173169 device_idx = idx % n_devices
174- attempts = 0
175- while device_active_counts [ devices [ device_idx ]] >= workers_per_device :
176- # Wait for any process to finish if all devices are at capacity
177- if all ( count >= workers_per_device for count in device_active_counts . values ()):
178- # wait for the first process to finish
179- active_proc = active [ 0 ]
180- result = _read_json_result ( active_proc . json_fd , active_proc . dataset_path )
181- active_proc . proc . wait ()
182- # store it
183- results . append ( result )
184- device_active_counts [ active_proc .device ] -= 1
185- # log it
186- log_fn (
187- f"Process { active_proc .proc . pid } finished, freeing slot on { active_proc . device } "
188- )
189- # remove from active list
190- active . pop ( 0 )
191-
192- # Try next device
193- device_idx = ( device_idx + 1 ) % n_devices
194- attempts += 1
195- if attempts >= n_devices :
196- # We've checked all devices, start from beginning
197- attempts = 0
198- time . sleep ( wait_before_next_loop )
170+
171+ # Check if we need to wait for a device to free up
172+ while all ( count >= workers_per_device for count in device_active_counts . values ()):
173+ # All devices are at capacity - wait for ANY process to finish
174+ log_fn (
175+ f"All devices at capacity ( { workers_per_device } workers each). Waiting for any process to finish..."
176+ )
177+
178+ # Wait for the first process (any device)
179+ active_proc = active [ 0 ]
180+ result = _read_json_result ( active_proc .json_fd , active_proc . dataset_path )
181+ active_proc . proc . wait ()
182+ results . append ( result )
183+ device_active_counts [ active_proc .device ] -= 1
184+ log_fn (
185+ f"Process { active_proc . proc . pid } finished, freeing slot on { active_proc . device } "
186+ )
187+ active . pop ( 0 )
188+
189+ # Now find a device with capacity, starting from our round-robin position
190+ for i in range ( n_devices ):
191+ check_idx = ( device_idx + i ) % n_devices
192+ if device_active_counts [ devices [ check_idx ]] < workers_per_device :
193+ device_idx = check_idx
194+ break
199195
200196 device : str = devices [device_idx ]
201197
0 commit comments