Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add prediction logging #43

Merged
merged 13 commits into from
Jul 31, 2024
3 changes: 2 additions & 1 deletion .markdownlint.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
MD013: false
MD040: false
MD025: false
MD025: false
MD028: false
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ poetry run python -m jmteb \
```

> [!NOTE]
> Some tasks (e.g., AmazonReviewClassification in classification, JAQKET and Mr.TyDi-ja in retrieval, esci in reranking) are time-consuming and memory-consuming. Heavy retrieval tasks take hours to encode the large corpus, and use much memory for the storage of such vectors. If you want to exclude them, add `--eval_exclude "['amazon_review_classification', 'mrtydi', 'jaqket', 'esci']"`.
> Some tasks (e.g., AmazonReviewClassification in classification, JAQKET and Mr.TyDi-ja in retrieval, esci in reranking) are time-consuming and memory-consuming. Heavy retrieval tasks take hours to encode the large corpus, and use much memory for the storage of such vectors. If you want to exclude them, add `--eval_exclude "['amazon_review_classification', 'mrtydi', 'jaqket', 'esci']"`. Similarly, you can also use `--eval_include` to include only evaluation datasets you want.

> [!NOTE]
> If you want to log model predictions to further analyze the performance of your model, you may want to use `--log_predictions true` to enable all evaluators to log predictions. It is also available to set whether to log in the config of evaluators.
8 changes: 8 additions & 0 deletions src/jmteb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def main(
parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists")
parser.add_argument("--eval_include", type=list[str], default=None, help="Evaluators to include.")
parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.")
parser.add_argument(
"--log_predictions", type=bool, default=False, help="Whether to log predictions for all evaulators."
)

args = parser.parse_args()

Expand Down Expand Up @@ -99,6 +102,11 @@ def main(
f"Please check {args.evaluators}"
)

if args.log_predictions:
for k, v in args.evaluators.items():
if hasattr(v, "log_predictions"):
args.evaluators[k].log_predictions = True

main(
text_embedder=args.embedder,
evaluators=args.evaluators,
Expand Down
2 changes: 2 additions & 0 deletions src/jmteb/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class EvaluationResults:
metric_value (float): Value of the main metric.
details (dict[str, Any]): Details of the evaluation.
This included some additional metrics or values that are used to derive the main metric.
predictions (list[Any]): Predictions (such as, (text, y_true, y_pred))
"""

metric_name: str
metric_value: float
details: dict[str, Any]
predictions: list[Any] | None = None

def as_dict(self) -> dict[str, Any]:
return {
Expand Down
6 changes: 5 additions & 1 deletion src/jmteb/evaluators/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .classifiers import Classifier, KnnClassifier, LogRegClassifier
from .data import ClassificationDataset, ClassificationInstance
from .data import (
ClassificationDataset,
ClassificationInstance,
ClassificationPrediction,
)
from .evaluator import ClassificationEvaluator
7 changes: 7 additions & 0 deletions src/jmteb/evaluators/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class ClassificationInstance:
label: int


@dataclass
class ClassificationPrediction:
text: str
label: int
prediction: int


class ClassificationDataset(ABC):
@abstractmethod
def __len__(self):
Expand Down
17 changes: 16 additions & 1 deletion src/jmteb/evaluators/classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .classifiers import Classifier, KnnClassifier, LogRegClassifier
from .data import ClassificationDataset
from .data import ClassificationDataset, ClassificationPrediction


class ClassificationEvaluator(EmbeddingEvaluator):
Expand All @@ -28,6 +28,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
The first one is specified as the main index.
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint.
"""

def __init__(
Expand All @@ -38,6 +39,7 @@ def __init__(
average: str = "macro",
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
log_predictions: bool = False,
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
Expand All @@ -52,6 +54,7 @@ def __init__(
if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary")
] or ["macro"]
self.prefix = prefix
self.log_predictions = log_predictions
self.main_metric = f"{self.average[0]}_f1"

def __call__(
Expand Down Expand Up @@ -119,6 +122,7 @@ def __call__(
"val_scores": val_results,
"test_scores": test_results,
},
predictions=self._format_predictions(self.test_dataset, y_pred) if self.log_predictions else None,
)

@staticmethod
Expand All @@ -128,3 +132,14 @@ def _compute_metrics(y_pred: np.ndarray, y_true: list[int], average: list[float]
for average_method in average:
classifier_results[f"{average_method}_f1"] = f1_score(y_true, y_pred, average=average_method)
return classifier_results

@staticmethod
def _format_predictions(dataset: ClassificationDataset, y_pred: np.ndarray) -> list[ClassificationPrediction]:
texts = [item.text for item in dataset]
y_true = [item.label for item in dataset]
y_pred = y_pred.tolist()
assert len(texts) == len(y_true) == len(y_pred)
return [
ClassificationPrediction(text=text, label=label, prediction=pred)
for text, label, pred in zip(texts, y_true, y_pred)
]
2 changes: 1 addition & 1 deletion src/jmteb/evaluators/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .data import ClusteringDataset, ClusteringInstance
from .data import ClusteringDataset, ClusteringInstance, ClusteringPrediction
from .evaluator import ClusteringEvaluator
7 changes: 7 additions & 0 deletions src/jmteb/evaluators/clustering/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class ClusteringInstance:
label: int


@dataclass
class ClusteringPrediction:
text: str
label: int
prediction: int


class ClusteringDataset(ABC):
@abstractmethod
def __len__(self):
Expand Down
34 changes: 23 additions & 11 deletions src/jmteb/evaluators/clustering/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import ClusteringDataset
from .data import ClusteringDataset, ClusteringPrediction


class ClusteringEvaluator(EmbeddingEvaluator):
Expand All @@ -32,11 +32,13 @@ def __init__(
test_dataset: ClusteringDataset,
prefix: str | None = None,
random_seed: int | None = None,
log_predictions: bool = False,
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.prefix = prefix
self.random_seed = random_seed
self.log_predictions = log_predictions
self.main_metric = "v_measure_score"

def __call__(
Expand Down Expand Up @@ -80,20 +82,21 @@ def __call__(
logger.info("Fitting clustering model...")
val_results = {}
for model_name, model_constructor in model_constructors.items():
val_results[model_name] = self._evaluate_clustering_model(val_embeddings, val_labels, model_constructor())
val_results[model_name], _ = self._evaluate_clustering_model(
val_embeddings, val_labels, model_constructor()
)
optimal_clustering_model_name = sorted(
val_results.items(),
key=lambda res: res[1][self.main_metric],
reverse=True,
)[0][0]

test_results = {
optimal_clustering_model_name: self._evaluate_clustering_model(
test_embeddings,
test_labels,
model_constructors[optimal_clustering_model_name](),
)
}
test_scores, test_predictions = self._evaluate_clustering_model(
test_embeddings,
test_labels,
model_constructors[optimal_clustering_model_name](),
)
test_results = {optimal_clustering_model_name: test_scores}

return EvaluationResults(
metric_name=self.main_metric,
Expand All @@ -103,12 +106,15 @@ def __call__(
"val_scores": val_results,
"test_scores": test_results,
},
predictions=(
self._format_predictions(self.test_dataset, test_predictions) if self.log_predictions else None
),
)

@staticmethod
def _evaluate_clustering_model(
embeddings: np.ndarray, y_true: list[int], clustering_model: ClusterMixin
) -> dict[str, float]:
) -> tuple[dict[str, float], list[int]]:
y_pred = clustering_model.fit_predict(embeddings)
h_score, c_score, v_score = homogeneity_completeness_v_measure(
labels_pred=y_pred, labels_true=np.array(y_true)
Expand All @@ -118,4 +124,10 @@ def _evaluate_clustering_model(
"v_measure_score": v_score,
"homogeneity_score": h_score,
"completeness_score": c_score,
}
}, y_pred.tolist()

@staticmethod
def _format_predictions(dataset: ClusteringDataset, predictions: list[int]) -> list[ClusteringPrediction]:
return [
ClusteringPrediction(item.text, item.label, prediction) for item, prediction in zip(dataset, predictions)
]
2 changes: 2 additions & 0 deletions src/jmteb/evaluators/pair_classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class PairClassificationEvaluator(EmbeddingEvaluator):
test_dataset (PairClassificationDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.

# NOTE: Don't log predictions, as predictions by different metrics could be different.
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/jmteb/evaluators/reranking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .data import (
RerankingDoc,
RerankingDocDataset,
RerankingPrediction,
RerankingQuery,
RerankingQueryDataset,
)
Expand Down
26 changes: 26 additions & 0 deletions src/jmteb/evaluators/reranking/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ class RerankingDoc:
text: str


@dataclass
class RerankingPrediction:
query: str
relevant_docs: list[RerankingDoc]
reranked_relevant_docs: list[RerankingDoc]


class RerankingQueryDataset(ABC):
@abstractmethod
def __len__(self):
Expand All @@ -47,6 +54,23 @@ def __getitem__(self, idx) -> RerankingDoc:
def __eq__(self, __value: object) -> bool:
return False

def _build_idx_docid_mapping(self, dataset_attr_name: str = "dataset") -> None:
self.idx_to_docid: dict = {}
self.docid_to_idx: dict = {}
id_key: str = getattr(self, "id_key", None)
dataset = getattr(self, dataset_attr_name)
if id_key:
for idx, doc_dict in enumerate(dataset):
self.idx_to_docid[idx] = doc_dict[id_key]
self.docid_to_idx[doc_dict[id_key]] = idx
elif isinstance(dataset[0], RerankingDoc):
for idx, doc in enumerate(dataset):
doc: RerankingDoc
self.idx_to_docid[idx] = doc.id
self.docid_to_idx[doc.id] = idx
else:
raise ValueError(f"Invalid dataset type: list[{type(dataset[0])}]")


class HfRerankingQueryDataset(RerankingQueryDataset):
def __init__(
Expand Down Expand Up @@ -131,6 +155,7 @@ def __init__(self, path: str, split: str, name: str | None = None, id_key: str =
self.dataset = datasets.load_dataset(path, split=split, name=name, trust_remote_code=True)
self.id_key = id_key
self.text_key = text_key
self._build_idx_docid_mapping()

def __len__(self):
return len(self.dataset)
Expand All @@ -157,6 +182,7 @@ def __init__(self, filename: str, id_key: str = "docid", text_key: str = "text")
self.dataset = corpus
self.id_key = id_key
self.text_key = text_key
self._build_idx_docid_mapping()

def __len__(self):
return len(self.dataset)
Expand Down
Loading
Loading