Skip to content

Commit

Permalink
Merge pull request #35 from sbintuitions/dev
Browse files Browse the repository at this point in the history
[dev to main] v1.2.0
  • Loading branch information
lsz05 authored Jun 20, 2024
2 parents a1756ce + fc2c1a7 commit 4a3a7f4
Show file tree
Hide file tree
Showing 20 changed files with 252 additions and 37 deletions.
1 change: 1 addition & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ base branchを`dev`にするよう、お願いいたします。

## 動作確認
- [ ] テストが通ることを確認した
- [ ] マージ先がdevブランチであることを確認した
- [ ] ...

<!--
Expand Down
5 changes: 2 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding
name = "JMTEB"
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
version = "1.1.1"
version = "1.2.0"

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
Expand All @@ -30,6 +30,7 @@ smart-open = "^7.0.1"
openai = "^1.16.2"
pytest-mock = "^3.14.0"
tiktoken = "^0.6.0"
numpy = "^1.26"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down
21 changes: 16 additions & 5 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class TextEmbedder(ABC):
The base class of text embedder.
"""

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
"""Convert a text string or a list of texts to embedding.
Args:
text (str | list[str]): text string, or a list of texts.
prefix (str, optional): the prefix to use for encoding. Default to None.
"""
raise NotImplementedError

Expand All @@ -31,14 +32,20 @@ def get_output_dim(self) -> int:
raise NotImplementedError

def _batch_encode_and_save_on_disk(
self, text_list: list[str], save_path: str | PathLike[str], batch_size: int = 64, dtype: str = "float32"
self,
text_list: list[str],
save_path: str | PathLike[str],
prefix: str | None = None,
batch_size: int = 64,
dtype: str = "float32",
) -> np.memmap:
"""
Encode a list of texts and save the embeddings on disk using memmap.
Args:
text_list (list[str]): list of texts
save_path (str): path to save the embeddings
prefix (str, optional): the prefix to use for encoding. Default to None.
dtype (str, optional): data type. Defaults to "float32".
batch_size (int): batch size. Defaults to 64.
"""
Expand All @@ -50,7 +57,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings = self.encode(batch)
batch_embeddings = self.encode(batch, prefix=prefix)
batch_embeddings = np.asarray(batch_embeddings, dtype=dtype)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))
Expand All @@ -61,6 +68,7 @@ def _batch_encode_and_save_on_disk(
def batch_encode_with_cache(
self,
text_list: list[str],
prefix: str | None = None,
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
batch_size: int = 64,
Expand All @@ -71,6 +79,7 @@ def batch_encode_with_cache(
Args:
text_list (list[str]): list of texts
prefix (str, optional): the prefix to use for encoding. Default to None.
cache_path (str, optional): path to save the embeddings. Defaults to None.
overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False.
batch_size (int): batch size. Defaults to 64.
Expand All @@ -79,12 +88,14 @@ def batch_encode_with_cache(

if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list).astype(dtype)
return self.encode(text_list, prefix=prefix).astype(dtype)

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim()))

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(text_list, cache_path, batch_size=batch_size, dtype=dtype)
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype
)
return embeddings
9 changes: 5 additions & 4 deletions src/jmteb/embedders/openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None
else:
self.dim = dim

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {}
# specifying `dimensions` is not allowed for "text-embedding-ada-002"
if isinstance(text, str):
token_ids: list[int] = self.encode_and_truncate_text(text)
token_ids: list[int] = self.encode_and_truncate_text(text, prefix)
else:
token_ids: list[list[int]] = [self.encode_and_truncate_text(t) for t in text]
token_ids: list[list[int]] = [self.encode_and_truncate_text(t, prefix) for t in text]
result = np.asarray(
[
data.embedding
Expand All @@ -84,10 +84,11 @@ def encode(self, text: str | list[str]) -> np.ndarray:
def get_output_dim(self) -> int:
return self.dim

def encode_and_truncate_text(self, text: str) -> list[int]:
def encode_and_truncate_text(self, text: str, prefix: str | None = None) -> list[int]:
# Refer to https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
# return a list of token IDs
if not text:
text = " "
logger.warning("Found empty string!")
# Ignore prefix in OpenAIEmbedder
return self.encoding.encode(text)[: self.max_token_length]
26 changes: 24 additions & 2 deletions src/jmteb/embedders/sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,42 @@ def __init__(
batch_size: int = 32,
device: str | None = None,
normalize_embeddings: bool = False,
max_seq_length: int | None = None,
add_eos: bool = False,
tokenizer_kwargs: dict | None = None,
) -> None:
self.model = SentenceTransformer(model_name_or_path)
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs)
if max_seq_length:
self.model.max_seq_length = max_seq_length

self.batch_size = batch_size
self.device = device
self.normalize_embeddings = normalize_embeddings
self.max_seq_length = getattr(self.model, "max_seq_length", None)
self.add_eos = add_eos

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
if self.add_eos:
text = self._add_eos_func(text)
return self.model.encode(
text,
prompt=prefix,
convert_to_numpy=True,
batch_size=self.batch_size,
device=self.device,
normalize_embeddings=self.normalize_embeddings,
)

def _add_eos_func(self, text: str | list[str]) -> str | list[str]:
try:
eos_token = getattr(self.model.tokenizer, "eos_token")
except AttributeError:
return text

if isinstance(text, str):
return text + eos_token
elif isinstance(text, list):
return [t + eos_token for t in text]

def get_output_dim(self) -> int:
return self.model.get_sentence_embedding_dimension()
6 changes: 6 additions & 0 deletions src/jmteb/evaluators/classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
and delimited by comma, e.g., `macro, micro`.
The first one is specified as the main index.
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
"""

def __init__(
Expand All @@ -36,6 +37,7 @@ def __init__(
test_dataset: ClassificationDataset,
average: str = "macro",
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
Expand All @@ -49,6 +51,7 @@ def __init__(
for average_name in average
if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary")
] or ["macro"]
self.prefix = prefix
self.main_metric = f"{self.average[0]}_f1"

def __call__(
Expand All @@ -60,13 +63,15 @@ def __call__(
logger.info("Encoding training and validation sentences...")
X_train = model.batch_encode_with_cache(
[item.text for item in self.train_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
y_train = [item.label for item in self.train_dataset]

X_val = model.batch_encode_with_cache(
[item.text for item in self.val_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -79,6 +84,7 @@ def __call__(
else:
X_test = model.batch_encode_with_cache(
[item.text for item in self.test_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
12 changes: 10 additions & 2 deletions src/jmteb/evaluators/clustering/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def __init__(
self,
val_dataset: ClusteringDataset,
test_dataset: ClusteringDataset,
prefix: str | None = None,
random_seed: int | None = None,
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.prefix = prefix
self.random_seed = random_seed
self.main_metric = "v_measure_score"

def __call__(
Expand All @@ -44,6 +48,7 @@ def __call__(
logger.info("Converting validation data to embeddings...")
val_embeddings = model.batch_encode_with_cache(
[item.text for item in self.val_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -56,16 +61,19 @@ def __call__(
else:
test_embeddings = model.batch_encode_with_cache(
[item.text for item in self.test_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
test_labels = [item.label for item in self.test_dataset]

n_clusters = len(set(test_labels))
model_constructors: dict[str, Callable[[], ClusterMixin]] = {
"MiniBatchKMeans": lambda: MiniBatchKMeans(n_clusters=n_clusters, n_init="auto"),
"MiniBatchKMeans": lambda: MiniBatchKMeans(
n_clusters=n_clusters, n_init="auto", random_state=self.random_seed
),
"AgglomerativeClustering": lambda: AgglomerativeClustering(n_clusters=n_clusters),
"BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters),
"BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters, random_state=self.random_seed),
"Birch": lambda: Birch(n_clusters=n_clusters),
}

Expand Down
10 changes: 9 additions & 1 deletion src/jmteb/evaluators/pair_classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ class PairClassificationEvaluator(EmbeddingEvaluator):
Args:
val_dataset (PairClassificationDataset): validation dataset
test_dataset (PairClassificationDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
"""

def __init__(
self,
val_dataset: PairClassificationDataset,
test_dataset: PairClassificationDataset,
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
) -> None:
self.test_dataset = test_dataset
self.val_dataset = val_dataset
self.sentence1_prefix = sentence1_prefix
self.sentence2_prefix = sentence2_prefix
self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()]
self.main_metric = "binary_f1"

Expand Down Expand Up @@ -101,8 +107,8 @@ def __call__(
},
)

@staticmethod
def _convert_to_embeddings(
self,
model: TextEmbedder,
dataset: PairClassificationDataset,
split: str = "test",
Expand All @@ -111,11 +117,13 @@ def _convert_to_embeddings(
) -> tuple[np.ndarray, np.ndarray, list[float]]:
embeddings1 = model.batch_encode_with_cache(
[item.sentence1 for item in dataset],
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
9 changes: 9 additions & 0 deletions src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class RerankingEvaluator(EmbeddingEvaluator):
test_query_dataset (RerankingQueryDataset): test query dataset used for computing the scores
doc_dataset (RerankingDocDataset): document dataset
ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain).
query_prefix (str | None): prefix for queries. Defaults to None.
doc_prefix (str | None): prefix for documents. Defaults to None.
"""

def __init__(
Expand All @@ -36,12 +38,16 @@ def __init__(
test_query_dataset: RerankingQueryDataset,
doc_dataset: RerankingDocDataset,
ndcg_at_k: list[int] | None = None,
query_prefix: str | None = None,
doc_prefix: str | None = None,
) -> None:
self.test_query_dataset = test_query_dataset
self.val_query_dataset = val_query_dataset
self.doc_dataset = doc_dataset
self.ndcg_at_k = ndcg_at_k or [10, 20, 40]
self.main_metric = f"ndcg@{self.ndcg_at_k[0]}"
self.query_prefix = query_prefix
self.doc_prefix = doc_prefix

def __call__(
self,
Expand All @@ -54,6 +60,7 @@ def __call__(

val_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.val_query_dataset],
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -62,11 +69,13 @@ def __call__(
else:
test_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.test_query_dataset],
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
Loading

0 comments on commit 4a3a7f4

Please sign in to comment.