Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Make search consistent with the lamindb implementation #95

Merged
merged 12 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 81 additions & 102 deletions lamin_utils/_search.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,108 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from lamin_utils import logger
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import pandas as pd
from pandas import DataFrame, Series


def _contains(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
if col.name not in fields_convert:
return [False] * len(col)
if fields_convert[col.name]:
col = col.astype(str)
return col.str.contains(string, case=case_sensitive)


def _ranks(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
if col.name not in fields_convert:
return [0] * len(col)
if fields_convert[col.name]:
col = col.astype(str)
exact_rank = col.str.fullmatch(string, case=case_sensitive) * 200
synonym_rank = (
col.str.match(rf"(?:^|.*\|){string}(?:\|.*|$)", case=case_sensitive) * 200
)
sub_rank = (
col.str.match(
rf"(?:^|.*[ \|\.,;:]){string}(?:[ \|\.,;:].*|$)", case=case_sensitive
)
* 10
)
startswith_rank = (
col.str.match(rf"(?:^|.*\|){string}[^ ]*(?:\|.*|$)", case=case_sensitive) * 8
)
right_rank = col.str.match(rf"(?:^|.*[ \|]){string}.*", case=case_sensitive) * 2
left_rank = col.str.match(rf".*{string}(?:$|[ \|\.,;:].*)", case=case_sensitive) * 2
contains_rank = col.str.contains(string, case=case_sensitive).astype("int32")
return (
exact_rank
+ synonym_rank
+ sub_rank
+ startswith_rank
+ right_rank
+ left_rank
+ contains_rank
)


def search(
df: pd.DataFrame,
df: DataFrame,
string: str,
field: str = "name",
*,
field: str | list[str] | None = None,
limit: int | None = 20,
synonyms_field: str | None = "synonyms",
case_sensitive: bool = False,
synonyms_sep: str = "|",
keep: Literal["first", "last", False] = "first",
) -> pd.DataFrame:
_show_rank: bool = False,
) -> DataFrame:
"""Search a given string against a field.

Args:
df: The DataFrame to search in.
string: The input string to match against the field values.
field: The name of the field to search against.
limit: The maximum number of top results to return. If None, returns all results.
synonyms_field: The name of the field containing synonyms.
If None, no synonym matching is performed.
case_sensitive: Whether the match should be case sensitive. Defaults to False.
synonyms_sep: The separator used in the synonyms field.
keep: Determines which duplicates to keep when grouping results.
Options are "first", "last", or False (keep all).
field: The field or fields to search. Search all fields containing strings by default.
limit: Maximum amount of top results to return.
case_sensitive: Whether the match is case sensitive.

Returns:
A DataFrame of ranked search results.
This DataFrame contains the matched rows from the input DataFrame,
sorted by the match ratio in descending order.
It includes all columns from the input DataFrame plus an additional '__ratio__' column indicating the match score.
sorted by the match rank in descending order.

Raises:
KeyError: If the specified field or synonyms_field is not found in the DataFrame.
ValueError: If an invalid value is provided for the 'keep' parameter.
KeyError: If the specified field is not found in the DataFrame.
"""
import pandas as pd
from pandas.api.types import is_object_dtype, is_string_dtype

from ._map_synonyms import explode_aggregated_column_to_map

def _fuzz(
string: str,
iterable: pd.Series,
case_sensitive: bool = True,
limit: int | None = None,
):
from rapidfuzz import fuzz, process, utils

# use WRatio to account for typos
if " " in string:
scorer = fuzz.QRatio
else:
scorer = fuzz.WRatio

processor = None if case_sensitive else utils.default_process
results = process.extract(
string,
iterable,
scorer=scorer,
limit=limit,
processor=processor,
)
try:
return pd.DataFrame(results).set_index(2)[1]
except KeyError:
# no search results
return None

# empty DataFrame
if df.shape[0] == 0:
if len(df) == 0:
return df

# search against each of the synonyms
if (synonyms_field in df.columns) and (synonyms_field != field):
# creates field_value:synonym
mapper = explode_aggregated_column_to_map(
df,
agg_col=synonyms_field, # type:ignore
target_col=field,
keep=keep,
sep=synonyms_sep,
)
if keep is False:
mapper = mapper.explode()
# adds field_value:field_value to field_value:synonym
df_field = pd.Series(df[field].values, index=df[field], name=field)
df_field.index.name = synonyms_field
df_field = df_field[df_field.index.difference(mapper.index)]
mapper = pd.concat([mapper, df_field])
df_exp = mapper.reset_index()
target_column = synonyms_field
fields_convert = {}
if field is None:
fields = df.columns.to_list()
for f in fields:
df_f = df[f]
if is_object_dtype(df_f):
fields_convert[f] = True
elif is_string_dtype(df_f):
fields_convert[f] = False
else:
if synonyms_field == field:
logger.warning(
"Input field is the same as synonyms field, skipping synonyms matching"
)
df_exp = df[[field]].copy()
target_column = field

# add matching scores as a __ratio__ column
ratios = _fuzz(
string=string,
iterable=df_exp[target_column],
case_sensitive=case_sensitive,
limit=limit,
)
if ratios is None:
return pd.DataFrame(columns=df.columns)
df_exp["__ratio__"] = ratios

if limit is not None:
df_exp = df_exp[~df_exp["__ratio__"].isna()]
# only keep the max score between field and synonyms for each entry
# here groupby is also used to remove duplicates of field values
df_exp_grouped = df_exp.groupby(field).max("__ratio__")
# subset to original field values (as synonyms were mixed in before)
df_exp_grouped = df_exp_grouped[df_exp_grouped.index.isin(df[field])]
df_scored = df.set_index(field).loc[df_exp_grouped.index]
df_scored["__ratio__"] = df_exp_grouped["__ratio__"]

return df_scored.sort_values("__ratio__", ascending=False)
fields = [field] if isinstance(field, str) else field
for f in fields:
fields_convert[f] = not is_string_dtype(df[f])

contains = lambda col: _contains(col, string, case_sensitive, fields_convert)
df_contains = df.loc[df.apply(contains).any(axis=1)]
if len(df_contains) == 0:
return df_contains

ranks = lambda col: _ranks(col, string, case_sensitive, fields_convert)
rank = df_contains.apply(ranks).sum(axis=1)

if _show_rank:
df_contains = df_contains.copy()
df_contains.loc[:, "rank"] = rank

df_result = df_contains.loc[rank.sort_values(ascending=False).index]
return df_result if limit is None else df_result.head(limit)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Home = "https://github.com/laminlabs/lamin-utils"
[project.optional-dependencies]
dev = [
"pandas", # lookup
"rapidfuzz", # search
"pre-commit",
"nox",
"pytest>=6.0",
Expand Down
79 changes: 49 additions & 30 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,92 @@ def df():
"ontology_id": "CL:0000084",
"name": "T cell",
"synonyms": "T-cell|T lymphocyte|T-lymphocyte",
"description": "A Type Of Lymphocyte Whose Defining Characteristic Is The Expression Of A T Cell Receptor Complex.",
"children": ["CL:0000798", "CL:0002420", "CL:0002419", "CL:0000789"],
},
{
"ontology_id": "CL:0000236",
"name": "B cell",
"synonyms": "B lymphocyte|B-lymphocyte|B-cell",
"description": "A Lymphocyte Of B Lineage That Is Capable Of B Cell Mediated Immunity.",
"children": ["CL:0009114", "CL:0001201"],
},
{
"ontology_id": "CL:0000696",
"name": "PP cell",
"synonyms": "type F enteroendocrine cell",
"description": "A Cell That Stores And Secretes Pancreatic Polypeptide Hormone.",
"children": ["CL:0002680"],
},
{
"ontology_id": "CL:0002072",
"name": "nodal myocyte",
"synonyms": "cardiac pacemaker cell|myocytus nodalis|P cell",
"description": "A Specialized Cardiac Myocyte In The Sinoatrial And Atrioventricular Nodes. The Cell Is Slender And Fusiform Confined To The Nodal Center, Circumferentially Arranged Around The Nodal Artery.",
"children": ["CL:1000409", "CL:1000410"],
},
]
return pd.DataFrame.from_records(records)


def test_search_synonyms(df):
res = search(df=df, string="P cells")
assert res.index[0] == "nodal myocyte"

# without synonyms search
res = search(df=df, synonyms_field=None, string="P cells")
assert res.index[0] == "PP cell"
# these tests also check ranks of the searches values (res["rank"] below)
# this is needed to perform cross-check with lamindb search
# to recompute the ranks via lamindb
# change .alias to .annotate in lamindb/_record.py def _search(...)
# then run the code below in an empty instance with bionty schema
# import lamindb as ln
# import bionty as bt
# cts = ["CL:0000084", "CL:0000236", "CL:0000696", "CL:0002072"]
# ln.save([bt.CellType.from_source(ontology_id=oid) for oid in cts])
# results = bt.CellType.search("P cell")
# print([(result.name, result.rank) for result in results.list()])
# results = bt.CellType.search("b cell")
# print([(result.name, result.rank) for result in results.list()])
# results = bt.CellType.search("type F enteroendocrine", field="synonyms")
# print([(result.name, result.rank) for result in results.list()])


def test_search_general(df):
res = search(df=df, string="P cell", _show_rank=True)
assert res.iloc[0]["name"] == "nodal myocyte"
assert res.iloc[0]["rank"] == 223
assert len(res) == 2
assert res.iloc[1]["rank"] == 3

# search in name, without synonyms search
res = search(df=df, string="P cell", field="name", _show_rank=True)
assert res.iloc[0]["name"] == "PP cell"
assert res.iloc[0]["rank"] == 3


def test_search_limit(df):
res = search(df=df, string="P cells", limit=1)
res = search(df=df, string="P cell", limit=1)
assert res.shape[0] == 1


def test_search_keep(df):
# TODO: better test here
res = search(df=df, string="enteroendocrine", keep=False)
assert res.index[0] == "PP cell"


def test_search_return_df(df):
res = search(df=df, string="P cells")
assert res.shape == (4, 4)
assert res.iloc[0].name == "nodal myocyte"
res = search(df=df, string="P cell")
assert res.shape == (2, 5)
assert res.iloc[0]["name"] == "nodal myocyte"


def test_search_return_tie_results(df):
res = search(df=df, string="A cell", synonyms_field=None)
assert res.iloc[0].__ratio__ == res.iloc[1].__ratio__


def test_search_non_default_field(df):
res = search(df=df, string="type F enteroendocrine", field="synonyms")
assert res.index[0] == "type F enteroendocrine cell"
def test_search_pass_fields(df):
res = search(
df=df,
string="type F enteroendocrine",
field=["synonyms", "children"],
_show_rank=True,
)
assert res.iloc[0]["synonyms"] == "type F enteroendocrine cell"
assert res.iloc[0]["rank"] == 15


def test_search_case_sensitive(df):
res = search(df=df, string="b cell", case_sensitive=True)
assert res.iloc[0].__ratio__ < 100

res = search(df=df, string="b cell", case_sensitive=False)
assert res.index[0] == "B cell"
assert res.iloc[0].__ratio__ == 100
assert len(res) == 0
res = search(df=df, string="b cell", case_sensitive=False, _show_rank=True)
assert res.iloc[0]["name"] == "B cell"
assert res.iloc[0]["rank"] == 438


def test_search_empty_df():
Expand Down
Loading