Skip to content

Commit

Permalink
Update ref_index test
Browse files Browse the repository at this point in the history
  • Loading branch information
johnarevalo committed Feb 5, 2025
1 parent c3aecb9 commit bf471d7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/copairs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def average_precision(rel_k) -> np.ndarray:
num_pos = rel_k.shape[1]
pr_k = np.arange(1, num_pos + 1, dtype=np.float32) / (rel_k + 1)
ap_values = pr_k.sum(axis=1) / num_pos
return ap_values
return ap_values.astype(np.float32)


def ap_contiguous(
Expand Down Expand Up @@ -399,7 +399,7 @@ def random_ap(num_perm: int, num_pos: int, total: int, seed: int):

# Compute Average Precision (AP) scores for each row of the binary matrix
null_dist = average_precision(rel_k)
return null_dist.astype(np.float32)
return null_dist


def null_dist_cached(
Expand Down
5 changes: 4 additions & 1 deletion src/copairs/map/map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -16,7 +17,7 @@ def mean_average_precision(
null_size: int,
threshold: float,
seed: int,
max_workers: int = 32,
max_workers: Optional[int] = None,
) -> pd.DataFrame:
"""Calculate the Mean Average Precision (mAP) score and associated p-values.
Expand All @@ -38,6 +39,8 @@ def mean_average_precision(
p-value threshold for identifying significant MaP scores.
seed : int
Random seed for reproducibility.
max_workers : int
Number of workers used. Default defined by tqdm's `thread_map`
Returns:
-------
Expand Down
13 changes: 9 additions & 4 deletions tests/test_reference_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Tests for assign reference index helper function."""

import pytest
import numpy as np
import pandas as pd

Expand All @@ -8,21 +8,26 @@
from tests.helpers import simulate_random_dframe


@pytest.mark.filterwarnings("ignore:invalid value encountered in divide")
def test_assign_reference_index():
SEED = 42
length = 20
vocab_size = {"p": 5, "w": 3, "l": 2}
length = 200
vocab_size = {"p": 5, "w": 3, "l": 4}
n_feats = 5
pos_sameby = ["l"]
pos_diffby = []
neg_sameby = []
neg_diffby = ["l"]
rng = np.random.default_rng(SEED)
meta = simulate_random_dframe(length, vocab_size, pos_sameby, pos_diffby, rng)
# p: Plate, w: Well, l: PerturbationID, t: PerturbationType (is control?)
meta.eval("t=(l=='l1')", inplace=True)
length = len(meta)
feats = rng.uniform(size=(length, n_feats))

ap = average_precision(meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
ap = average_precision(
meta, feats, pos_sameby + ["t"], pos_diffby, neg_sameby, neg_diffby + ["t"]
)

ap_ri = average_precision(
assign_reference_index(meta, "l=='l1'"),
Expand Down

0 comments on commit bf471d7

Please sign in to comment.