Skip to content

Commit

Permalink
Merge pull request #114 from sagar87/fix/add_labels_from_dataframe
Browse files Browse the repository at this point in the history
Bugfix: add_labels_from_dataframe
  • Loading branch information
MeyerBender authored Jan 9, 2025
2 parents 5ecc33b + 3375442 commit fbf9629
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
13 changes: 6 additions & 7 deletions spatialproteomics/la/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def add_labels_from_dataframe(
assert (
Layers.NEIGHBORHOODS not in self._obj
), f"Already found neighborhoods in the object. Since these are dependent on the labels, please remove them with pp.drop_layers('{Layers.NEIGHBORHOODS}') before adding new labels."
assert self._obj.coords[Dims.CELLS].shape[0] != 0, "No cells found in the object. Cannot add labels."

if df is None:
cells = self._obj.coords[Dims.CELLS].values
Expand All @@ -835,6 +836,11 @@ def add_labels_from_dataframe(
unique_labels = np.unique(formated_labels)
else:
sub = df.loc[:, [cell_col, label_col]].dropna()
# removing cells that are not in the object
sub = sub[sub[cell_col].isin(self._obj.coords[Dims.CELLS].values)]
assert (
sub.shape[0] != 0
), f"Could not find any overlap between the cells in the data frame's {cell_col} column and the cells in the object. Please make sure the cells for which you want to add labels are present in the object."
cells = sub.loc[:, cell_col].to_numpy().squeeze()
labels = sub.loc[:, label_col].to_numpy().squeeze()

Expand Down Expand Up @@ -868,13 +874,6 @@ def add_labels_from_dataframe(
name=Layers.OBS,
)

da = da.where(
da.coords[Dims.CELLS].isin(
self._obj.coords[Dims.CELLS],
),
drop=True,
)

obj = self._obj.copy()
obj = xr.merge([obj.sel(cells=da.cells), da])

Expand Down
18 changes: 18 additions & 0 deletions tests/la/test_add_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ def test_add_labels_from_dataframe_unassigned_cells(dataset):
assert 1 in labeled[Layers.OBS].sel(features=Features.LABELS).values


def test_add_labels_from_dataframe_invalid_cells(dataset):
# creating a dummy data frame
cells = dataset.coords[Dims.CELLS].values
num_cells = len(cells)
df = pd.DataFrame(
{
"cell": [num_cells + 5],
"label": ["CT1"],
}
)

with pytest.raises(
AssertionError,
match="Could not find any overlap between the cells in the data frame",
):
dataset.la.add_labels_from_dataframe(df)


def test_add_labels(dataset):
# creating a dummy dict
cells = dataset.coords[Dims.CELLS].values
Expand Down

0 comments on commit fbf9629

Please sign in to comment.