Skip to content

Commit

Permalink
remove beir dependency (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Nov 17, 2023
1 parent 5f0a189 commit a8bcda2
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
4 changes: 2 additions & 2 deletions crossfit/backend/torch/op/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import cupy as cp
import torch
from beir.retrieval.search.dense import util as utils

from crossfit.data.array.conversion import convert_array
from crossfit.op.vector_search import ExactSearchOp
from crossfit.utils import math_utils


class TorchExactSearch(ExactSearchOp):
Expand All @@ -20,7 +20,7 @@ def __init__(
self.metric = metric
self.embedding_col = embedding_col
self.normalize = False
self.score_functions = {"cos_sim": utils.cos_sim, "dot": utils.dot_score}
self.score_functions = {"cos_sim": math_utils.cos_sim, "dot": math_utils.dot_score}
self.score_function_desc = {"cos_sim": "Cosine Similarity", "dot": "Dot Product"}

def search_tensors(self, queries, corpus):
Expand Down
7 changes: 3 additions & 4 deletions crossfit/dataset/beir/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from dataclasses import dataclass
from typing import Dict, List, Union

from beir import util

from crossfit.dataset.home import CF_HOME
from crossfit.utils import file_utils


@dataclass
Expand Down Expand Up @@ -225,10 +224,10 @@ def download_raw(name, out_dir=None, overwrite=False) -> str:
url = BEIR_DATASETS[name].download_link

print("Downloading {} ...".format(name))
util.download_url(url, zip_file)
file_utils.download_url(url, zip_file)

print("Unzipping {} ...".format(name))
util.unzip(zip_file, raw_dir)
file_utils.unzip(zip_file, raw_dir)

return output_path

Expand Down
53 changes: 53 additions & 0 deletions crossfit/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
import os
import zipfile

import requests
from tqdm.autonotebook import tqdm

logger = logging.getLogger(__name__)


def download_url(url: str, save_path: str, chunk_size: int = 1024):
"""Download url with progress bar using tqdm
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
Args:
url (str): downloadable url
save_path (str): local path to save the downloaded file
chunk_size (int, optional): chunking of files. Defaults to 1024.
"""
r = requests.get(url, stream=True)
total = int(r.headers.get("Content-Length", 0))
with open(save_path, "wb") as fd, tqdm(
desc=save_path,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=chunk_size,
) as bar:
for data in r.iter_content(chunk_size=chunk_size):
size = fd.write(data)
bar.update(size)


def unzip(zip_file: str, out_dir: str):
zip_ = zipfile.ZipFile(zip_file, "r")
zip_.extractall(path=out_dir)
zip_.close()


def download_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
os.makedirs(out_dir, exist_ok=True)
dataset = url.split("/")[-1]
zip_file = os.path.join(out_dir, dataset)

if not os.path.isfile(zip_file):
logger.info("Downloading {} ...".format(dataset))
download_url(url, zip_file, chunk_size)

if not os.path.isdir(zip_file.replace(".zip", "")):
logger.info("Unzipping {} ...".format(dataset))
unzip(zip_file, out_dir)

return os.path.join(out_dir, dataset.replace(".zip", ""))
43 changes: 43 additions & 0 deletions crossfit/utils/math_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch


def dot_score(a: torch.Tensor, b: torch.Tensor):
"""
Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)

if not isinstance(b, torch.Tensor):
b = torch.tensor(b)

if len(a.shape) == 1:
a = a.unsqueeze(0)

if len(b.shape) == 1:
b = b.unsqueeze(0)

return torch.mm(a, b.transpose(0, 1))


def cos_sim(a: torch.Tensor, b: torch.Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)

if not isinstance(b, torch.Tensor):
b = torch.tensor(b)

if len(a.shape) == 1:
a = a.unsqueeze(0)

if len(b.shape) == 1:
b = b.unsqueeze(0)

a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
1 change: 0 additions & 1 deletion requirements/pytorch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ torch>=1.0
transformers
curated-transformers
bitsandbytes
beir

0 comments on commit a8bcda2

Please sign in to comment.