Skip to content

Commit b44abe8

Browse files
committed
it runs!
1 parent b246b31 commit b44abe8

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

spd/clustering/merge_history.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,10 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]:
299299
i_comp_new: int = component_label_idxs[comp_label]
300300
merges_array[i_ens, :, i_comp_new] = history.merges.group_idxs[:, i_comp_old]
301301

302-
assert np.max(merges_array[i_ens]) == hist_n_components - 1, (
303-
f"Max component index in history {i_ens} should be {hist_n_components - 1}, "
304-
f"but got {np.max(merges_array[i_ens])}"
305-
)
302+
# assert np.max(merges_array[i_ens]) == hist_n_components - 1, (
303+
# f"Max component index in history {i_ens} should be {hist_n_components - 1}, "
304+
# f"but got {np.max(merges_array[i_ens])}"
305+
# )
306306

307307
# put each missing label into its own group
308308
hist_missing_labels: set[str] = unique_labels_set - set(hist_c_labels)

spd/clustering/plotting/activations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def plot_activations(
190190

191191
# log coactivations
192192
fig4_log, ax4_log = plt.subplots(figsize=figsize_coact)
193-
coact_log_data: np.ndarray = np.log10(coact_data + 1e-10)
193+
assert np.all(coact_data >= 0)
194+
coact_log_data: np.ndarray = np.log10(coact_data + 1e-6)
194195
im4_log = ax4_log.matshow(
195196
coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max()
196197
)

spd/clustering/s2_clustering.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,22 @@ def run_clustering(
121121
del batch # already did the forward pass
122122

123123
history = merge_iteration(config, batch_id, activations, component_labels, run)
124-
breakpoint()
124+
125+
history_save_path = this_merge_dir / "merge_history.zip"
126+
127+
history.save(history_save_path)
125128

126129
if run is not None:
127-
_save_merge_history_to_wandb(run, batch_id, config.config_identifier, history)
128130
_log_merge_history_plots_to_wandb(run, history)
131+
# _save_merge_history_to_wandb(
132+
# run, history_save_path, batch_id, config.config_identifier, history
133+
# )
134+
129135
wandb_url = run.url
130136
run.finish()
131137
else:
132138
wandb_url = None
133139

134-
history_save_path = this_merge_dir / "merge_history.zip"
135-
136-
history.save(history_save_path)
137-
138140
return ClusteringResult(history_save_path=history_save_path, wandb_url=wandb_url)
139141

140142

@@ -167,26 +169,25 @@ def _setup_wandb(
167169

168170
def _save_merge_history_to_wandb(
169171
run: Run,
172+
history_path: Path,
170173
batch_id: str,
171174
config_identifier: str,
172175
history: MergeHistory,
173176
):
174-
# Save final merge history as artifact
175-
hist_path = Path(f"/tmp/{batch_id}_final.zip")
176-
history.save(hist_path)
177177
artifact = wandb.Artifact(
178178
name=f"merge_history_{batch_id}",
179179
type="merge_history",
180180
description=f"Merge history for batch {batch_id}",
181181
metadata={
182182
"batch_name": batch_id,
183183
"config_identifier": config_identifier,
184-
"n_iters": history.n_iters_current,
184+
"n_iters_current": history.n_iters_current,
185+
"filename": history_path,
185186
},
186187
)
187-
artifact.add_file(str(hist_path), policy="now")
188+
# Add both files before logging the artifact
189+
artifact.add_file(str(history_path))
188190
run.log_artifact(artifact)
189-
# hist_path.unlink()
190191

191192

192193
def _log_merge_history_plots_to_wandb(run: Run, history: MergeHistory):

0 commit comments

Comments
 (0)