Skip to content

Commit 69f0bc8

Browse files
committed
[!!!] tests passing
1 parent e935a78 commit 69f0bc8

File tree

2 files changed

+32
-33
lines changed

2 files changed

+32
-33
lines changed

spd/clustering/pipeline/dist_utils.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import os
55
import subprocess
6-
import time
76
from collections.abc import Callable
87
from dataclasses import dataclass
98
from pathlib import Path
@@ -118,7 +117,6 @@ def distribute_clustering(
118117
workers_per_device: int = 1,
119118
log_fn: Callable[[str], None] | None = None,
120119
log_fn_error: Callable[[str], None] | None = None,
121-
wait_before_next_loop: float = 1.0,
122120
) -> list[dict[str, str | None]]:
123121
"""Distribute clustering tasks across multiple devices via subprocess.
124122
@@ -127,8 +125,9 @@ def distribute_clustering(
127125
128126
The concurrency model:
129127
- Total concurrency = workers_per_device x len(devices)
130-
- Tasks are round-robin assigned to devices
131-
- Each device can have up to workers_per_device tasks running concurrently
128+
- Uses round-robin device assignment starting point
129+
- If target device is full, uses any available device
130+
- If all devices are full, waits for a process on the target device to finish
132131
133132
Args:
134133
config_path: Path to clustering configuration file
@@ -139,7 +138,6 @@ def distribute_clustering(
139138
workers_per_device: Maximum concurrent workers per device
140139
log_fn: Optional logging function for info messages
141140
log_fn_error: Optional logging function for error messages
142-
wait_before_next_loop: Seconds to wait before checking for free slots among all devices again
143141
144142
Returns:
145143
List of result dictionaries from each batch processing
@@ -165,37 +163,35 @@ def distribute_clustering(
165163
results: list[dict[str, str | None]] = []
166164

167165
n_files: int = len(data_files)
168-
device_idx: int
169-
attempts: int
170166
try:
171167
for idx, dataset in enumerate(data_files):
172-
# Find next device with capacity using round-robin starting point
168+
# Find a device with capacity, starting from round-robin position
173169
device_idx = idx % n_devices
174-
attempts = 0
175-
while device_active_counts[devices[device_idx]] >= workers_per_device:
176-
# Wait for any process to finish if all devices are at capacity
177-
if all(count >= workers_per_device for count in device_active_counts.values()):
178-
# wait for the first process to finish
179-
active_proc = active[0]
180-
result = _read_json_result(active_proc.json_fd, active_proc.dataset_path)
181-
active_proc.proc.wait()
182-
# store it
183-
results.append(result)
184-
device_active_counts[active_proc.device] -= 1
185-
# log it
186-
log_fn(
187-
f"Process {active_proc.proc.pid} finished, freeing slot on {active_proc.device}"
188-
)
189-
# remove from active list
190-
active.pop(0)
191-
192-
# Try next device
193-
device_idx = (device_idx + 1) % n_devices
194-
attempts += 1
195-
if attempts >= n_devices:
196-
# We've checked all devices, start from beginning
197-
attempts = 0
198-
time.sleep(wait_before_next_loop)
170+
171+
# Check if we need to wait for a device to free up
172+
while all(count >= workers_per_device for count in device_active_counts.values()):
173+
# All devices are at capacity - wait for ANY process to finish
174+
log_fn(
175+
f"All devices at capacity ({workers_per_device} workers each). Waiting for any process to finish..."
176+
)
177+
178+
# Wait for the first process (any device)
179+
active_proc = active[0]
180+
result = _read_json_result(active_proc.json_fd, active_proc.dataset_path)
181+
active_proc.proc.wait()
182+
results.append(result)
183+
device_active_counts[active_proc.device] -= 1
184+
log_fn(
185+
f"Process {active_proc.proc.pid} finished, freeing slot on {active_proc.device}"
186+
)
187+
active.pop(0)
188+
189+
# Now find a device with capacity, starting from our round-robin position
190+
for i in range(n_devices):
191+
check_idx = (device_idx + i) % n_devices
192+
if device_active_counts[devices[check_idx]] < workers_per_device:
193+
device_idx = check_idx
194+
break
199195

200196
device: str = devices[device_idx]
201197

spd/clustering/plotting/activations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def plot_activations(
5858
for key in act_dict:
5959
act_dict[key] = act_dict[key][:n_samples_max]
6060

61+
# Update n_samples to reflect the truncated size
62+
n_samples = act_concat.shape[0]
63+
6164
# Raw activations
6265
axs_act: Sequence[plt.Axes]
6366
_fig1: plt.Figure

0 commit comments

Comments
 (0)