Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

please accept this pull request ASAP PLEASE! #113

Closed
wants to merge 15 commits into from
47 changes: 26 additions & 21 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str):
return batch


class InstructorPooling(nn.Module):
class INSTRUCTORPooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.

Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
Expand Down Expand Up @@ -245,7 +245,7 @@ def load(input_path):
) as config_file:
config = json.load(config_file)

return InstructorPooling(**config)
return INSTRUCTORPooling(**config)


def import_from_string(dotted_path):
Expand All @@ -271,7 +271,7 @@ def import_from_string(dotted_path):
raise ImportError(msg)


class InstructorTransformer(Transformer):
class INSTRUCTORTransformer(Transformer):
def __init__(
self,
model_name_or_path: str,
Expand Down Expand Up @@ -378,7 +378,7 @@ def load(input_path: str):

with open(sbert_config_path, encoding="UTF-8") as config_file:
config = json.load(config_file)
return InstructorTransformer(model_name_or_path=input_path, **config)
return INSTRUCTORTransformer(model_name_or_path=input_path, **config)

def tokenize(self, texts):
"""
Expand Down Expand Up @@ -420,7 +420,7 @@ def tokenize(self, texts):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
else:
Expand All @@ -430,7 +430,7 @@ def tokenize(self, texts):
return output


class Instructor(SentenceTransformer):
class INSTRUCTOR(SentenceTransformer):
@staticmethod
def prepare_input_features(
input_features, instruction_features, return_data_type: str = "pt"
Expand Down Expand Up @@ -510,27 +510,32 @@ def smart_batching_collate(self, batch):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
batched_input_features.append(input_features)

return batched_input_features, labels

def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False):
def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False):
"""
Loads a full sentence-transformers model
"""
# Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
model_path = snapshot_download(**download_kwargs)
# copied from https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L559
# because we need to get files outside of the allow_patterns too
# If file is local
if os.path.isdir(model_path):
model_path = str(model_path)
else:
# If model_path is a Hugging Face repository ID, download the model
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "InstructorEmbedding",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}

# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(
Expand Down Expand Up @@ -559,9 +564,9 @@ def _load_sbert_model(self, model_path, token = None, cache_folder = None, revis
modules = OrderedDict()
for module_config in modules_config:
if module_config["idx"] == 0:
module_class = InstructorTransformer
module_class = INSTRUCTORTransformer
elif module_config["idx"] == 1:
module_class = InstructorPooling
module_class = INSTRUCTORPooling
else:
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
Expand Down Expand Up @@ -619,7 +624,7 @@ def encode(
input_was_string = True

if device is None:
device = self._target_device
device = self.device

self.to(device)

Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
## My Personal Fork

This is a fork for the Instructor model becuase the original repository isn't kept up anymore. I've also made some improvements to their source code:

1) Fixing it to work with the ```sentence-transformers``` library above 2.2.2.
2) Properly download the models from huggingface using the new "snapshot download" API.
3) Ability to specify where you want the model donwloaded with the "cache_dir" parameter.

## What follows is the original repository's readme file. Ignore the quantization section, however, becuase pytorch has changed its API since then.

# One Embedder, Any Task: Instruction-Finetuned Text Embeddings

This repository contains the code and pre-trained models for our paper [One Embedder, Any Task: Instruction-Finetuned Text Embeddings](https://arxiv.org/abs/2212.09741). Please refer to our [project page](https://instructor-embedding.github.io/) for a quick project overview.
Expand Down
26 changes: 12 additions & 14 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
transformers==4.20.0
datasets>=2.2.0
pyarrow==8.0.0
jsonlines
numpy
requests>=2.26.0
scikit_learn>=1.0.2
scipy
sentence_transformers>=2.2.0
torch
tqdm
rich
tensorboard
huggingface-hub>=0.19.0
transformers>=4.20,<5.0
datasets>=2.20,<3.0
pyarrow>=17.0,<18.0
numpy>=1.0,<=1.26.4
requests>=2.26,<3.0
scikit_learn>=1.0.2,<2.0
scipy>=1.14,<2.0
sentence-transformers>=3.0.1,<4.0
torch>=2.0
tqdm>=4.0,<5.0
rich>=13.0,<14.0
huggingface-hub>=0.24.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name='InstructorEmbedding',
packages=['InstructorEmbedding'],
version='1.0.1',
version='1.0.2',
license='Apache License 2.0',
description='Text embedding tool',
long_description=readme,
Expand Down