From a8bcda288d995bd6db14e6ac52001244a78a8ea5 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:59:09 -0800 Subject: [PATCH] remove beir dependency (#30) --- crossfit/backend/torch/op/vector_search.py | 4 +- crossfit/dataset/beir/raw.py | 7 ++- crossfit/utils/file_utils.py | 53 ++++++++++++++++++++++ crossfit/utils/math_utils.py | 43 ++++++++++++++++++ requirements/pytorch.txt | 1 - 5 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 crossfit/utils/file_utils.py create mode 100644 crossfit/utils/math_utils.py diff --git a/crossfit/backend/torch/op/vector_search.py b/crossfit/backend/torch/op/vector_search.py index a2422ed..eca1bd6 100644 --- a/crossfit/backend/torch/op/vector_search.py +++ b/crossfit/backend/torch/op/vector_search.py @@ -1,9 +1,9 @@ import cupy as cp import torch -from beir.retrieval.search.dense import util as utils from crossfit.data.array.conversion import convert_array from crossfit.op.vector_search import ExactSearchOp +from crossfit.utils import math_utils class TorchExactSearch(ExactSearchOp): @@ -20,7 +20,7 @@ def __init__( self.metric = metric self.embedding_col = embedding_col self.normalize = False - self.score_functions = {"cos_sim": utils.cos_sim, "dot": utils.dot_score} + self.score_functions = {"cos_sim": math_utils.cos_sim, "dot": math_utils.dot_score} self.score_function_desc = {"cos_sim": "Cosine Similarity", "dot": "Dot Product"} def search_tensors(self, queries, corpus): diff --git a/crossfit/dataset/beir/raw.py b/crossfit/dataset/beir/raw.py index 753c214..583919d 100644 --- a/crossfit/dataset/beir/raw.py +++ b/crossfit/dataset/beir/raw.py @@ -3,9 +3,8 @@ from dataclasses import dataclass from typing import Dict, List, Union -from beir import util - from crossfit.dataset.home import CF_HOME +from crossfit.utils import file_utils @dataclass @@ -225,10 +224,10 @@ def download_raw(name, out_dir=None, overwrite=False) -> str: url = BEIR_DATASETS[name].download_link print("Downloading {} ...".format(name)) - util.download_url(url, zip_file) + file_utils.download_url(url, zip_file) print("Unzipping {} ...".format(name)) - util.unzip(zip_file, raw_dir) + file_utils.unzip(zip_file, raw_dir) return output_path diff --git a/crossfit/utils/file_utils.py b/crossfit/utils/file_utils.py new file mode 100644 index 0000000..f799d7a --- /dev/null +++ b/crossfit/utils/file_utils.py @@ -0,0 +1,53 @@ +import logging +import os +import zipfile + +import requests +from tqdm.autonotebook import tqdm + +logger = logging.getLogger(__name__) + + +def download_url(url: str, save_path: str, chunk_size: int = 1024): + """Download url with progress bar using tqdm + https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads + + Args: + url (str): downloadable url + save_path (str): local path to save the downloaded file + chunk_size (int, optional): chunking of files. Defaults to 1024. + """ + r = requests.get(url, stream=True) + total = int(r.headers.get("Content-Length", 0)) + with open(save_path, "wb") as fd, tqdm( + desc=save_path, + total=total, + unit="iB", + unit_scale=True, + unit_divisor=chunk_size, + ) as bar: + for data in r.iter_content(chunk_size=chunk_size): + size = fd.write(data) + bar.update(size) + + +def unzip(zip_file: str, out_dir: str): + zip_ = zipfile.ZipFile(zip_file, "r") + zip_.extractall(path=out_dir) + zip_.close() + + +def download_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str: + os.makedirs(out_dir, exist_ok=True) + dataset = url.split("/")[-1] + zip_file = os.path.join(out_dir, dataset) + + if not os.path.isfile(zip_file): + logger.info("Downloading {} ...".format(dataset)) + download_url(url, zip_file, chunk_size) + + if not os.path.isdir(zip_file.replace(".zip", "")): + logger.info("Unzipping {} ...".format(dataset)) + unzip(zip_file, out_dir) + + return os.path.join(out_dir, dataset.replace(".zip", "")) diff --git a/crossfit/utils/math_utils.py b/crossfit/utils/math_utils.py new file mode 100644 index 0000000..30d25f7 --- /dev/null +++ b/crossfit/utils/math_utils.py @@ -0,0 +1,43 @@ +import torch + + +def dot_score(a: torch.Tensor, b: torch.Tensor): + """ + Computes the dot-product dot_prod(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = dot_prod(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + return torch.mm(a, b.transpose(0, 1)) + + +def cos_sim(a: torch.Tensor, b: torch.Tensor): + """ + Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) diff --git a/requirements/pytorch.txt b/requirements/pytorch.txt index e536521..1ca20fa 100644 --- a/requirements/pytorch.txt +++ b/requirements/pytorch.txt @@ -2,4 +2,3 @@ torch>=1.0 transformers curated-transformers bitsandbytes -beir \ No newline at end of file