Skip to content

Commit

Permalink
Merge pull request #160 from databio/dev_search_merge
Browse files Browse the repository at this point in the history
Update on search & eval
  • Loading branch information
khoroshevskyi authored Jun 6, 2024
2 parents fae0918 + 1a21a48 commit 0029b26
Show file tree
Hide file tree
Showing 24 changed files with 213 additions and 374 deletions.
1 change: 0 additions & 1 deletion geniml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def main(test_args=None):
bbc.add_bed_tokens_to_cache(args.bed_id[0], args.universe_id[0])

if args.subcommand == "cache-bedset":

if os.path.isdir(args.identifier[0]):
from .io import BedSet

Expand Down
6 changes: 3 additions & 3 deletions geniml/eval/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def cal_snpr(ratio_embed: np.ndarray, ratio_random: np.ndarray) -> np.ndarray:
return res


var_dict = {}


def worker_func(
i: int,
K: int,
Expand All @@ -181,8 +184,6 @@ def worker_func(
Returns:
np.ndarray: An array of overlap ratios.
"""
var_dict = {}

if embed_type == "embed":
embeds = var_dict["embed_rep"]
elif embed_type == "random":
Expand Down Expand Up @@ -210,7 +211,6 @@ def init_worker(
ref_embed (np.ndarray): Random embeddings.
region2index (dict[str, int]): A region to index dictionary.
"""
var_dict = {}
var_dict["embed_rep"] = embed_rep
var_dict["ref_embed"] = ref_embed
var_dict["region2vec_index"] = region2index
Expand Down
45 changes: 35 additions & 10 deletions geniml/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import pickle
from typing import Dict, List, Tuple, Union

import numpy as np
from gensim.models import Word2Vec

from ..region2vec import Region2VecExModel


def genome_distance(u: Tuple[int, int], v: Tuple[int, int]) -> float:
"""Computes the genome distance between two regions.
Expand Down Expand Up @@ -153,22 +156,44 @@ def load_genomic_embeddings(
"""Loads genomic region embeddings based on the type.
Args:
model_path (str): The path to a saved model.
model_path (str): The path to a saved model, or a huggingface repo of a model.
embed_type (str, optional): The model type. Defaults to "region2vec".
Can be "region2vec" or "base".
Can be "region2vec", "base", or "huggingface".
Returns:
tuple[np.ndarray, list[str]]: Embedding vectors and the corresponding
region list.
"""
if embed_type == "region2vec":
model = Word2Vec.load(model_path)
regions_r2v = model.wv.index_to_key
embed_rep = model.wv.vectors
return embed_rep, regions_r2v
elif embed_type == "base":
embed_rep, regions_r2v = load_base_embeddings(model_path)
return embed_rep, regions_r2v
if os.path.exists(model_path):
# try to load local
if embed_type == "region2vec":
model = Word2Vec.load(model_path)
regions_r2v = model.wv.index_to_key
embed_rep = model.wv.vectors
return embed_rep, regions_r2v
elif embed_type == "base":
embed_rep, regions_r2v = load_base_embeddings(model_path)
return embed_rep, regions_r2v

else:
# try to load from huggingface
exmodel = Region2VecExModel(model_path)
embed_rep = exmodel.model.projection.weight.data.numpy()
regions_r2v = [region2vocab_modify(r) for r in exmodel.tokenizer.universe.regions]
# remove embeddings representing unknown token and padding token
return embed_rep[:-2], regions_r2v


def region2vocab_modify(region) -> str:
"""Convert a builtins.Region object to a string in the format of chr:start-end.
Args:
region (builtins.Region): A region stored in tokenizer
Returns:
str: region string in standardized format chr:start-end.
"""
return f"{region.chr}:{region.start}-{region.end}"


def sort_key(x: str) -> Tuple[int, int]:
Expand Down
1 change: 0 additions & 1 deletion geniml/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def __iter__(self):
skipped_lines = 0
max_skipped_lines = 5
for line in f:

try:
chr, start, stop = line.split("\t")[:3]
except ValueError as _:
Expand Down
3 changes: 2 additions & 1 deletion geniml/search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .backends import HNSWBackend, QdrantBackend
from .filebackend_tools import load_local_backend, merge_backends, vec_pairs
from .filebackend_tools import merge_backends
from .interfaces import BED2BEDSearchInterface, Text2BEDSearchInterface
from .query2vec import BED2Vec, Text2Vec
from .search_eval import anecdotal_search_from_hf_data
from .utils import rand_eval
27 changes: 23 additions & 4 deletions geniml/search/backends/filebackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os.path
import pickle
from typing import Dict, List, Union

import hnswlib
import yaml

from ... import _LOGGER

Expand All @@ -17,7 +20,6 @@
# )

import numpy as np

from geniml.search.const import (
DEFAULT_DIM,
DEFAULT_EF,
Expand Down Expand Up @@ -45,7 +47,7 @@ class HNSWBackend(EmSearchBackend):
def __init__(
self,
local_index_path: str = DEFAULT_INDEX_PATH,
payloads: dict = {},
payloads: Union[dict, str] = dict(),
space: str = DEFAULT_HNSW_SPACE,
dim: int = DEFAULT_DIM,
ef: int = DEFAULT_EF,
Expand All @@ -70,7 +72,24 @@ def __init__(
if os.path.exists(local_index_path):
self.idx.load_index(local_index_path)
_LOGGER.info(f"Using index {local_index_path} with {self.idx.element_count} points.")
self.payloads = payloads

# load payloads:
if isinstance(payloads, str):
if payloads.endswith(".json"):
with open(payloads, "r") as f:
self.payloads = json.load(f)
elif payloads.endswith(".pkl"):
self.payloads = pickle.load(open(payloads, "rb"))
elif payloads.endswith(".yaml"):
with open(payloads, "r") as f:
self.payloads = yaml.load(f, Loader=yaml.SafeLoader)

else:
raise ValueError(
f"payload should be either a json, pickle, or yaml file. you supplied: {payloads.split('.')[-1]}"
)
else:
self.payloads = payloads
# self.payloads = {}
# save the index to local file path
else:
Expand Down Expand Up @@ -113,7 +132,7 @@ def load(

# update hnsw index and load embedding vectors
self.idx.load_index(self.idx_path, max_elements=new_max)
self.idx.add_items(vectors, ids)
self.idx.add_items(vectors)

# save hnsw index to local file
self.idx.save_index(self.idx_path)
Expand Down
15 changes: 14 additions & 1 deletion geniml/search/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@

DEFAULT_DIM = 100


# the size of the dynamic list for the nearest neighbors
# Higher ef leads to more accurate but slower search
# cannot be set lower than the number of queried nearest neighbors k
DEFAULT_EF = 200

DEFAULT_M = 16
# the number of bi-directional links created for every new element during construction
# Higher M work better on datasets with high intrinsic dimensionality and/or high recall
# low M work better for datasets with low intrinsic dimensionality and/or low recalls.
DEFAULT_M = 64

DEFAULT_QUANTIZATION_CONFIG = models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
Expand All @@ -25,3 +32,9 @@
always_ram=True,
),
)


# for evaluation dataset from huggingface
HF_INDEX = "index.bin"
HF_PAYLOADS = "payloads.pkl"
HF_METADATA = "metadata.json"
166 changes: 0 additions & 166 deletions geniml/search/filebackend_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@
_LOGGER = logging.getLogger(PKG_NAME)


def load_local_backend(bin_path: str, pkl_path: str, dim: int) -> HNSWBackend:
"""
Load a HNSWBackend from the local index file and the saved payloads dictionary
:param bin_path: path of the index file (.bin)
:param pkl_path: path of saved payloads (.pkl)
:param dim: the dimension of vectors stored in the hnsw index
:return: the HNSWBackend from saved index and payloads
"""
payloads = pickle.load(open(pkl_path, "rb"))
return HNSWBackend(local_index_path=bin_path, payloads=payloads, dim=dim)


def merge_backends(
backends_to_merge: List[HNSWBackend], local_index_path: str, dim: int
) -> HNSWBackend:
Expand Down Expand Up @@ -53,155 +39,3 @@ def merge_backends(
result_backend.load(vectors=np.array(result_vecs), payloads=result_payloads)

return result_backend


def sample_non_target_vec(
max_id: int, matching_ids: List[Union[int, np.int64]], size: int
) -> List[Union[int, np.int64]]:
"""
When the torch loss function for text2bednn training is CosineEmbeddingLoss,
the goal is to maximize the cosine similarity when the input metadata (embedding) matches
the input BED file (embedding), and minimize otherwise. Therefore, besides target pairs
(matching metadata embedding and BED embedding), non-target pairs also need sampling.
This function samples ids of non-matching vectors in the backend.
:param max_id: maximum id = total number of vectors - 1
:param matching_ids: ids of matching vectors (target pairs)
:param size: number of vectors to sample
:return: a list of ids in the backend.
"""

# sample range
if (size + len(matching_ids)) > max_id:
# _LOGGER.error("IndexError: Sample size + matching size should below the maximum ID")
raise IndexError("Sample size + matching size should below the maximum ID")

full_range = np.arange(0, max_id)

# skipping ids of matching vectors
eligible_integers = np.setdiff1d(full_range, matching_ids)
# sample result
sampled_integer = np.random.choice(eligible_integers, size=size, replace=False)

return list(sampled_integer)


def reverse_payload(payload: Dict[np.int64, Dict], target_key: str) -> Dict[str, np.int64]:
"""
Reverse the payload dictionary of a HNSWBackend, in this format: {<store id>: <metadata dictionary of that vector>}
:param payload: payload dictionary of a HNSWBackend, in this format:
{
<store id>: <metadata dictionary of that vector>,
...
}
:param target_key: a key in metadata dictionary
For example, if the payload dictionary is:
{
1: {
"name": "A0001.bed",
"summary": <summary>,
...
}
}
if target_key is "name", the output will be:
{
"A0001.bed": 1,
}
:return: the reversal payload dictionary
"""
output_dict = dict()
for i in payload.keys():
output_dict[payload[i][target_key]] = i

return output_dict


def vec_pairs(
nl_backend: HNSWBackend,
bed_backend: HNSWBackend,
nl_payload_key: str = "files",
bed_payload_key: str = "name",
non_target_pairs: bool = False,
non_target_pairs_prop: float = 1.0,
exclusions: Union[Set[str], None] = None,
exclusion_key: str = "series",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
The training of geniml.text2bednn needs pairs of vectors (natural language embeddings & region set embeddings).
This function extract vector pairs file backends that store embedding vectors of region set (BED) and metadata.
The payloads of BED backend must contain the file name of each embedding vector.
The payloads of metadata backend must contain names of matching files of each metadata string.
:param nl_backend: backend where embedding vectors of natural language metadata are stored
:param bed_backend: backend where embedding vectors of BED files are stored
:param nl_payload_key: the key of matching BED files in the payload of metadata embedding backend
:param bed_payload_key: the key of BED file name in the payload of BED embedding backend
:param non_target_pairs: whether non-target pairs will be sampled, for details, see the docstring of sample_non_target_vec()
:param non_target_pairs_prop: proportion of <number of non-target pairs> / <number of target pairs>
:param exclusions: a set of
:param exclusion_key: the payload key that
:return: A tuple of 3 np.ndarrays:
X: with shape of (n, <natural language embedding dimension>)
Y: with shape of (n, <region set embedding dimension>)
target: with shape of (n,), contain only 1 and -1, indicating if the X-Y vector pair is target or not
(see the docstring of sample_non_target_vec())
"""
# maximum id of metadata embedding vectors
max_num_nl = nl_backend.idx.get_max_elements()

# maximum id of BED embedding vectors
max_num_bed = bed_backend.idx.get_max_elements()

# List of embedding vectors
X = []
Y = []

# list of 1 and -1, indicate whether the vector pair is target pair or not
target = []

# reverse the BED backend payload dictionary into {<file name>: store id}
bed_reversal_payload = reverse_payload(bed_backend.payloads, bed_payload_key)

# pair vectors
for i in range(max_num_nl):
nl_vec = nl_backend.idx.get_items([i])[0]
bed_vec_ids = []
# get target pairs
for file_name in nl_backend.payloads[i][nl_payload_key]:
# get the HNSWBackend store id if the BED is stored
try:
bed_id = bed_reversal_payload[file_name]
except:
continue

if exclusions is not None:
bed_info = bed_backend.payloads[bed_id][exclusion_key]
if bed_info in exclusions:
continue

bed_vec_ids.append(bed_id)

if len(bed_vec_ids) == 0:
continue
bed_vecs = bed_backend.idx.get_items(bed_vec_ids, return_type="numpy")
for y_vec in bed_vecs:
X.append(nl_vec)
Y.append(y_vec)
target.append(1)

# sample non target pairs if needed for contrastive loss
if non_target_pairs:
non_match_ids = sample_non_target_vec(
max_num_bed, bed_vec_ids, int(non_target_pairs_prop * len(bed_vec_ids))
)
non_match_vecs = bed_backend.idx.get_items(non_match_ids, return_type="numpy")
for y_vec in non_match_vecs:
X.append(nl_vec)
Y.append(y_vec)
target.append(-1)

return np.array(X), np.array(Y), np.array(target)
Loading

0 comments on commit 0029b26

Please sign in to comment.