Skip to content

Commit 7c488fc

Browse files
committed
wip
1 parent 7ff6246 commit 7c488fc

23 files changed

+1254
-1023
lines changed

clustering_refactoring_plan.md

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Clustering Module Refactoring Plan
2+
3+
## Current Architecture Issues
4+
5+
The clustering module was implemented by an intern and follows a script-based approach rather than a modular library design. Key issues:
6+
7+
1. **Tight Coupling**: Main orchestrator handles too many responsibilities
8+
2. **Complex Process Communication**: Uses file descriptors and JSON parsing
9+
3. **Mixed Concerns**: Visualization, computation, and orchestration intermingled
10+
4. **Hard-coded Dependencies**: Magic numbers and paths scattered throughout
11+
5. **Limited Modularity**: Difficult to use components independently
12+
13+
## Refactoring Steps
14+
15+
### Step 1: Replace File Descriptor Communication
16+
17+
**Problem**: Complex FD-based inter-process communication in `scripts/main.py`
18+
19+
**Solution**: Replace with exit codes + structured file outputs
20+
21+
**Current approach:**
22+
```python
23+
proc, json_r = launch_child_with_json_fd(cmd)
24+
result = _read_json_result(json_r, dataset_path)
25+
```
26+
27+
**Proposed approach:**
28+
```python
29+
result_file = temp_dir / f"{dataset_path.stem}_result.json"
30+
cmd.extend(["--result-file", str(result_file)])
31+
32+
proc = subprocess.run(cmd, capture_output=False)
33+
if proc.returncode != 0:
34+
raise RuntimeError(f"Clustering failed: {dataset_path.stem}")
35+
36+
result = json.loads(result_file.read_text())
37+
```
38+
39+
**Files to modify:**
40+
- `spd/clustering/scripts/main.py` - Remove `launch_child_with_json_fd()`, `_read_json_result()`
41+
- `spd/clustering/scripts/s2_run_clustering.py` - Replace `emit_result()` with file-based output
42+
43+
**Benefits:**
44+
- Eliminates ~50 lines of complex FD management
45+
- Platform-independent communication
46+
- Easier debugging (can inspect result files)
47+
- Cleaner error handling
48+
49+
### Step 2: Simplified Interface Design
50+
51+
**Problem**: Current interfaces mix computation, I/O, and orchestration, with too many optional parameters
52+
53+
**Proposed 3-Layer Architecture:**
54+
55+
#### Layer 1: Pure Computation (no I/O, no side effects)
56+
```python
57+
# Core clustering algorithm - simplified
58+
def merge_iteration(
59+
activations: ProcessedActivations,
60+
config: MergeConfig,
61+
) -> MergeHistory
62+
63+
# Cost computation
64+
def compute_merge_costs(
65+
coact: Tensor,
66+
merges: GroupMerge,
67+
alpha: float,
68+
) -> Tensor
69+
70+
# Ensemble normalization - works on objects
71+
def normalize_ensemble(
72+
ensemble: MergeHistoryEnsemble
73+
) -> NormalizedHistories
74+
75+
# Distance computation
76+
def compute_distances(
77+
normalized: NormalizedHistories,
78+
method: DistancesMethod = "perm_invariant_hamming"
79+
) -> DistancesArray
80+
```
81+
82+
#### Layer 2: Data Processing (I/O + transformation)
83+
```python
84+
# Component extraction and processing
85+
def extract_and_process_activations(
86+
model: ComponentModel,
87+
batch: Tensor,
88+
config: MergeConfig,
89+
) -> ProcessedActivations
90+
91+
# History loading/saving
92+
def load_merge_histories(paths: list[Path]) -> MergeHistoryEnsemble
93+
def save_merge_history(history: MergeHistory, path: Path) -> None
94+
95+
# Batch processing
96+
def process_data_batch(
97+
config: MergeRunConfig,
98+
batch_path: Path,
99+
) -> MergeHistory
100+
```
101+
102+
#### Layer 3: Orchestration (coordination, parallel execution, file management)
103+
```python
104+
# Main pipeline - thin orchestrator
105+
def cluster_analysis_pipeline(
106+
config: MergeRunConfig
107+
) -> ClusteringResults
108+
109+
# Batch coordination with proper parallelism
110+
class BatchProcessor:
111+
def __init__(self, n_workers: int = 4):
112+
self.n_workers = n_workers
113+
114+
def process_all_batches(
115+
self,
116+
batches: list[Path],
117+
config: MergeRunConfig
118+
) -> list[MergeHistory]:
119+
# Use multiprocessing.Pool instead of subprocess + FD communication
120+
with multiprocessing.Pool(self.n_workers) as pool:
121+
worker_args = [(batch, config) for batch in batches]
122+
histories = pool.starmap(self._process_single_batch, worker_args)
123+
return histories
124+
125+
def _process_single_batch(self, batch_path: Path, config: MergeRunConfig) -> MergeHistory:
126+
# This runs in worker process - clean data transformation
127+
return process_data_batch(config, batch_path) # Layer 2 function
128+
```
129+
130+
**Key Changes:**
131+
1. **Remove callback complexity** - plotting/logging handled externally
132+
2. **Consistent data types** - functions work on objects, not file paths mixed with objects
133+
3. **Single responsibility** - each function does one thing well
134+
4. **Composable** - pure functions can be easily tested and combined
135+
136+
### Step 3: Extract Value Objects
137+
138+
**Problem**: Data frequently travels together but isn't grouped
139+
140+
**Proposed Data Classes:**
141+
```python
142+
@dataclass
143+
class ProcessedActivations:
144+
activations: Tensor
145+
labels: list[str]
146+
metadata: dict[str, Any]
147+
148+
@dataclass
149+
class NormalizedHistories:
150+
merges_array: MergesArray
151+
component_labels: list[str]
152+
metadata: dict[str, Any]
153+
154+
@dataclass
155+
class ClusteringResults:
156+
histories: list[MergeHistory]
157+
normalized: NormalizedHistories
158+
distances: DistancesArray
159+
config: MergeRunConfig
160+
```
161+
162+
### Step 4: Modern Parallelism Strategy
163+
164+
**Problem**: Current subprocess + FD approach is complex and fragile
165+
166+
**Solution**: Use Python's `multiprocessing.Pool` for clean parallel execution
167+
168+
**Current approach (complex):**
169+
```python
170+
# 100+ lines of subprocess management, FD passing, JSON serialization
171+
proc, json_r = launch_child_with_json_fd(cmd)
172+
result = _read_json_result(json_r, dataset_path)
173+
```
174+
175+
**New approach (simple):**
176+
```python
177+
from multiprocessing import Pool
178+
179+
class BatchProcessor:
180+
def __init__(self, n_workers: int = 4, devices: list[str] | None = None):
181+
self.n_workers = n_workers
182+
self.devices = devices or ["cuda:0"]
183+
184+
def process_all_batches(
185+
self,
186+
batches: list[Path],
187+
config: MergeRunConfig
188+
) -> list[MergeHistory]:
189+
# Create worker arguments with device assignment
190+
worker_args = [
191+
(batch, config, self.devices[i % len(self.devices)])
192+
for i, batch in enumerate(batches)
193+
]
194+
195+
with Pool(self.n_workers) as pool:
196+
histories = pool.starmap(self._process_single_batch, worker_args)
197+
198+
return histories
199+
200+
@staticmethod
201+
def _process_single_batch(
202+
batch_path: Path,
203+
config: MergeRunConfig,
204+
device: str
205+
) -> MergeHistory:
206+
"""Runs in worker process - pure computation, no callbacks"""
207+
# Load model and data
208+
model = ComponentModel.from_pretrained(config.model_path).to(device)
209+
batch_data = torch.load(batch_path)
210+
211+
# Extract and process activations (Layer 2)
212+
activations = extract_and_process_activations(model, batch_data, config)
213+
214+
# Run clustering (Layer 1 - pure computation)
215+
history = merge_iteration(activations, config)
216+
217+
return history # Automatically serialized by multiprocessing
218+
```
219+
220+
**Benefits:**
221+
- **~90% code reduction** for parallel execution
222+
- **Native Python** - no shell commands or FD management
223+
- **Automatic serialization** - Python handles MergeHistory objects
224+
- **Clean error propagation** - exceptions bubble up properly
225+
- **Easy debugging** - can run single-threaded with `n_workers=1`
226+
- **GPU isolation** - each process gets separate CUDA context
227+
228+
## Target Architecture
229+
230+
### Core Principles
231+
1. **Preserve tensor math exactly** - don't touch the core algorithms
232+
2. **Layer separation** - pure computation → I/O → orchestration
233+
3. **Clean interfaces** - functions do one thing well
234+
4. **Modern Python** - use multiprocessing, dataclasses, type hints
235+
5. **Testable** - pure functions can be tested in isolation
236+
237+
### Step 4: Replace Subprocess Communication with multiprocessing.Pool
238+
239+
**Problem**: Current complex subprocess + FD communication system
240+
241+
**Current approach (100+ lines):**
242+
- `launch_child_with_json_fd()` - complex FD setup
243+
- `distribute_clustering()` - manual process management
244+
- `_read_json_result()` - FD parsing
245+
- Error-prone cross-platform FD handling
246+
247+
**New approach (10 lines):**
248+
```python
249+
with multiprocessing.Pool(n_workers) as pool:
250+
worker_args = [(batch, config) for batch in batches]
251+
histories = pool.starmap(process_single_batch, worker_args)
252+
```
253+
254+
**Benefits:**
255+
- **Native Python** - no shell commands or FD management
256+
- **Automatic serialization** - Python handles data passing
257+
- **Better error handling** - exceptions propagate properly
258+
- **Still bypasses GIL** - each worker is separate Python interpreter
259+
- **GPU isolation** - each process has separate CUDA context
260+
- **Simpler debugging** - can run single-threaded easily
261+
262+
## Implementation Strategy
263+
264+
**⚠️ CRITICAL: Preserve Core Math**
265+
- Do NOT modify functions in `spd/clustering/math/`
266+
- Do NOT modify core tensor operations in `merge.py`, `compute_costs.py`
267+
- These contain complex mathematical algorithms we don't fully understand
268+
- Only refactor the **orchestration and I/O layers** around the math
269+
270+
**Implementation Order:**
271+
1. **Step 1**: Replace FD communication (low risk)
272+
2. **Step 2**: Extract pure computation interfaces (medium risk - but only interface changes)
273+
3. **Step 3**: Create value objects (low risk)
274+
4. **Step 4**: Replace subprocess with multiprocessing (low risk)
275+
276+
**Safety Principles:**
277+
- Keep all existing math functions unchanged
278+
- Preserve existing test behavior exactly
279+
- Create new interfaces that **wrap** existing functions, don't modify them
280+
- Extensive testing at each step
281+
282+
## Implementation Status
283+
284+
### ✅ Completed Refactoring
285+
286+
We successfully implemented a clean 3-layer architecture:
287+
288+
**Files Created:**
289+
- `spd/clustering/core.py` - Pure computation layer
290+
- `spd/clustering/data_processing.py` - I/O and data transformation
291+
- `spd/clustering/orchestration.py` - Parallel execution with multiprocessing
292+
- `spd/clustering/main_refactored.py` - Backward-compatible CLI wrapper
293+
- `spd/clustering/test_refactoring.py` - Validation tests
294+
295+
### Key Achievements
296+
297+
1. **90% code reduction** in parallel execution (100+ lines → 10 lines)
298+
2. **Clean separation** of computation, I/O, and orchestration
299+
3. **Preserved exact tensor math** - core algorithms untouched
300+
4. **Modern Python** - multiprocessing.Pool instead of subprocess+FDs
301+
5. **Simplified interfaces** - `merge_iteration()` reduced from 8 to 3 parameters
302+
6. **Backward compatible** - can be drop-in replacement
303+
304+
### Results
305+
306+
**Before (scripts/main.py):**
307+
- 370 lines of complex orchestration
308+
- 100+ lines for subprocess/FD management
309+
- Mixed concerns (computation + I/O + parallelism)
310+
- Hard to test and debug
311+
312+
**After (orchestration.py):**
313+
- ~150 lines total for entire orchestration
314+
- ~10 lines for parallel execution
315+
- Clean layer separation
316+
- Easy to test each layer independently
317+
318+
The refactored code produces functionally identical results while being much cleaner and more maintainable.
319+
320+
## Implementation Notes
321+
322+
- Keep backward compatibility where possible
323+
- Maintain existing CLI interfaces during transition
324+
- Add comprehensive tests for new components
325+
- **DO NOT TOUCH THE MATH** - only refactor around it
326+
- Preserve all existing functionality while improving structure

0 commit comments

Comments
 (0)