-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Returning a dictionary from InformationRetrievalEvaluator
instead of writing a CSV
#2401
Comments
Hello! That would indeed be fairly useful. In the meantime, you could use the following: from sentence_transformers.evaluation import InformationRetrievalEvaluator
class InformationRetrievalEvaluatorWithDict(InformationRetrievalEvaluator):
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs) -> float:
if epoch != -1:
out_txt = " after epoch {}:".format(epoch) if steps == -1 else " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logger.info("Information Retrieval Evaluation on " + self.name + " dataset" + out_txt)
scores = self.compute_metrices(model, *args, **kwargs)
# Write results to disc
if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
if not os.path.isfile(csv_path):
fOut = open(csv_path, mode="w", encoding="utf-8")
fOut.write(",".join(self.csv_headers))
fOut.write("\n")
else:
fOut = open(csv_path, mode="a", encoding="utf-8")
output_data = [epoch, steps]
for name in self.score_function_names:
for k in self.accuracy_at_k:
output_data.append(scores[name]['accuracy@k'][k])
for k in self.precision_recall_at_k:
output_data.append(scores[name]['precision@k'][k])
output_data.append(scores[name]['recall@k'][k])
for k in self.mrr_at_k:
output_data.append(scores[name]['mrr@k'][k])
for k in self.ndcg_at_k:
output_data.append(scores[name]['ndcg@k'][k])
for k in self.map_at_k:
output_data.append(scores[name]['map@k'][k])
fOut.write(",".join(map(str, output_data)))
fOut.write("\n")
fOut.close()
return scores Then you can use
|
This works, thank you! Hope we can see support in the main library soon as well. |
I intend to incorporate this into #2449. |
7 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
InformationRetrievalEvaluator
is a convenient way to compute metrics, but it writes to a CSV and only returns a MAP score ifwrite_csv
is turned off. Would it be possible to return a dictionary of all computed metrics whenwrite_csv=False
?The text was updated successfully, but these errors were encountered: