Skip to content

Commit

Permalink
Merge pull request #19 from MPI-Dortmund/resursive-umap-fix
Browse files Browse the repository at this point in the history
Fix a bug when calculating recursive umaps
  • Loading branch information
thorstenwagner authored Nov 30, 2023
2 parents ea501de + a6f1f2e commit 6dcbd4a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/napari_tomotwin/_tests/test_make_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_make_targets_single_cluster_medoid(self):
"2": [5, 6, 7],
"filepath": ["a.mrc","b.mrc","c.mrc"]
}
cluster = np.array([1,1,1])
cluster = pd.Series(np.array([1,1,1]))
with tempfile.TemporaryDirectory() as tmpdirname:
_run(clusters=cluster,
embeddings=pd.DataFrame(fake_embedding),
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_make_targets_two_clusters_medoid(self):
"2": [5, 6, 7, 10, 11, 12],
}
fake_embedding['filepath'] = [f"{i}.mrc" for i in range(len(fake_embedding["X"]))]
cluster = np.array([1,1,1,2,2,2])
cluster = pd.Series(np.array([1,1,1,2,2,2]))
with tempfile.TemporaryDirectory() as tmpdirname:
_run(clusters=cluster,
embeddings=pd.DataFrame(fake_embedding),
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_make_targets_single_cluster_average(self):
"2": [5, 6, 7],
"filepath": ["a.mrc","b.mrc","c.mrc"]
}
cluster = np.array([1,1,1])
cluster = pd.Series(np.array([1,1,1]))
with tempfile.TemporaryDirectory() as tmpdirname:
_run(clusters=cluster,
embeddings=pd.DataFrame(fake_embedding),
Expand All @@ -109,7 +109,7 @@ def test_make_targets_single_cluster_no_coords_written(self):
"2": [5, 6, 7],
"filepath": ["a.mrc","b.mrc","c.mrc"]
}
cluster = np.array([1,1,1])
cluster = pd.Series(np.array([1,1,1]))
with tempfile.TemporaryDirectory() as tmpdirname:
_run(clusters=cluster,
embeddings=pd.DataFrame(fake_embedding),
Expand Down
16 changes: 11 additions & 5 deletions src/napari_tomotwin/make_targets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import sys
from typing import List, Tuple, Literal, Callable

import numpy as np
Expand Down Expand Up @@ -40,13 +41,18 @@ def _make_targets(embeddings: pd.DataFrame, clusters: pd.DataFrame, avg_func: Ca
target_locations = {

}
print(clusters)
print(set(clusters))
for cluster in set(clusters):

if cluster == 0:
continue
cluster_embeddings = embeddings.loc[clusters == cluster, :]
clmask = (clusters == cluster).to_numpy()

cluster_embeddings = embeddings.loc[clmask, :]
target, position = avg_func(cluster_embeddings)
target_locations[cluster] = position
sub_embeddings.append(embeddings.loc[clusters == cluster, :])
sub_embeddings.append(embeddings.loc[clmask, :])
target = target.to_frame().T
targets.append(target)
target_names.append(f"cluster_{cluster}")
Expand All @@ -61,15 +67,15 @@ def _run(clusters,
output_folder: pathlib.Path,
average_method_name: Literal["Average", "Medoid"] = "Medoid",
):
assert len(embeddings) == len(clusters), "Cluster and embedding file are not compatible."
assert len(embeddings) == len(clusters), "Cluster and embedding file are not compatible. They have a different number of embeddings"

avg_method = _get_medoid_embedding
if average_method_name == "Average":
avg_method = _get_avg_embedding

print("Make targets")
embeddings = embeddings.reset_index()

#embeddings = embeddings.reset_index()
embeddings = embeddings.drop(columns=["level_0","index"], errors="ignore")
targets, sub_embeddings, target_locations = _make_targets(embeddings, clusters, avg_func=avg_method)

print("Write targets")
Expand Down

0 comments on commit 6dcbd4a

Please sign in to comment.