diff --git a/pv211_utils/datasets.py b/pv211_utils/datasets.py index 300ba06..0787299 100644 --- a/pv211_utils/datasets.py +++ b/pv211_utils/datasets.py @@ -264,7 +264,7 @@ def load_validation_judgements(self) -> ArqmathJudgements: self._load_judgements(year2)) if q.query_id in self.load_validation_queries().keys()} - def load_answers(self, answer_class=ArqmathAnswerBase) -> OrderedDict: + def load_answers(self, answer_class=ArqmathAnswerBase, cache_directory="/var/tmp/pv211/") -> OrderedDict: """Load answers. Returns @@ -275,10 +275,10 @@ def load_answers(self, answer_class=ArqmathAnswerBase) -> OrderedDict: return arqmath_loader.load_answers( text_format=self.text_format, answer_class=answer_class, - cache_download=f'/var/tmp/pv211/arqmath2020_answers_{self.text_format}.json.gz' + cache_download=(cache_directory + f'arqmath2020_answers_{self.text_format}.json.gz') ) - def load_questions(self, question_class=ArqmathQuestionBase) -> OrderedDict: + def load_questions(self, question_class=ArqmathQuestionBase, cache_directory="/var/tmp/pv211/") -> OrderedDict: """Load questions. Returns @@ -290,7 +290,8 @@ def load_questions(self, question_class=ArqmathQuestionBase) -> OrderedDict: text_format=self.text_format, answers=self.load_answers(), question_class=question_class, - cache_download=f'/var/tmp/pv211/arqmath2020_questions_{self.text_format}.json.gz') + cache_download=(cache_directory + f'arqmath2020_questions_{self.text_format}.json.gz') + ) class CranfieldDataset(): @@ -578,7 +579,7 @@ def load_validation_judgements(self) -> TrecJudgements: subset="validation")) if q.query_id in self.load_validation_queries().keys()} - def load_documents(self, document_class=TrecDocumentBase) -> OrderedDict: + def load_documents(self, document_class=TrecDocumentBase, cache_directory="/var/tmp/pv211/") -> OrderedDict: """Load documents. Returns @@ -587,7 +588,7 @@ def load_documents(self, document_class=TrecDocumentBase) -> OrderedDict: Dictionary of (document_id: Document) form. """ return trec_loader.load_documents(document_class=document_class, - cache_download='/var/tmp/pv211/trec_documents.json.gz') + cache_download=(cache_directory + 'trec_documents.json.gz')) class BeirDataset(): @@ -744,7 +745,8 @@ class CQADupStackDataset(): """ - def __init__(self, download_location: str = "datasets", validation_split_size: float = 0.2) -> None: + def __init__(self, download_location: str = "/var/tmp/pv211/cqa_datasets", + validation_split_size: float = 0.2) -> None: """Check if arguments have legal values and construct attributes for BeirDataset object. diff --git a/pv211_utils/eval.py b/pv211_utils/eval.py index 0e3203a..9899532 100644 --- a/pv211_utils/eval.py +++ b/pv211_utils/eval.py @@ -6,7 +6,7 @@ from .entities import JudgementBase from .leaderboard import LeaderboardBase from .irsystem import IRSystemBase -from .evaluation_metrics import calc_map +from .evaluation_metrics import mean_average_precision from IPython.display import display, Markdown @@ -94,9 +94,7 @@ def evaluate(self, queries: OrderedDict, submit_result: bool = True) -> None: """ time_before = datetime.now() - # result = mean_average_precision(self.system, queries, self.judgements, self.k, self.num_workers) - m = calc_map() - result = m.mean_average_precision(self.system, queries, self.judgements, self.k, self.num_workers) + result = mean_average_precision(self.system, queries, self.judgements, self.k, self.num_workers) time_after = datetime.now() map_score = result * 100.0 diff --git a/pv211_utils/evaluation_metrics.py b/pv211_utils/evaluation_metrics.py index e2ed2b1..31ae4a7 100644 --- a/pv211_utils/evaluation_metrics.py +++ b/pv211_utils/evaluation_metrics.py @@ -16,11 +16,14 @@ from .entities import JudgementBase, QueryBase from .irsystem import IRSystemBase -from typing import Set, OrderedDict +from typing import Set, OrderedDict, Optional from multiprocessing import Pool, get_context from functools import partial from math import log2 -import abc +from tqdm import tqdm + + +_CURR_SYSTEM = None def _judgements_obj_to_id(old_judgements: Set[JudgementBase]) -> Set: @@ -31,13 +34,13 @@ def _judgements_obj_to_id(old_judgements: Set[JudgementBase]) -> Set: return new_judgements -def _calc_recall(system: IRSystemBase, judgements: Set, k: int, - query: QueryBase) -> float: +def _calc_recall(judgements: Set, k: int, query: QueryBase) -> float: num_relevant = 0 num_relevant_topk = 0 current_rank = 1 + global _CURR_SYSTEM - for document in system.search(query): + for document in _CURR_SYSTEM.search(query): if (query.query_id, document.document_id) in judgements: num_relevant += 1 if current_rank <= k: @@ -52,13 +55,13 @@ def _calc_recall(system: IRSystemBase, judgements: Set, k: int, return recall -def _calc_precision(system: IRSystemBase, judgements: Set, k: int, - query: QueryBase) -> float: +def _calc_precision(judgements: Set, k: int, query: QueryBase) -> float: num_relevant = 0 precision = 0.0 current_rank = 1 + global _CURR_SYSTEM - for document in system.search(query): + for document in _CURR_SYSTEM.search(query): if current_rank > k: break if (query.query_id, document.document_id) in judgements: @@ -70,13 +73,13 @@ def _calc_precision(system: IRSystemBase, judgements: Set, k: int, return precision -def _calc_average_precision(system: IRSystemBase, judgements: Set, k: int, - query: QueryBase) -> float: +def _calc_average_precision(judgements: Set, k: int, query: QueryBase) -> float: num_relevant = 0 average_precision = 0.0 current_rank = 1 + global _CURR_SYSTEM - for document in system.search(query): + for document in _CURR_SYSTEM.search(query): if current_rank > k: break if (query.query_id, document.document_id) in judgements: @@ -89,81 +92,13 @@ def _calc_average_precision(system: IRSystemBase, judgements: Set, k: int, return average_precision -class calc_map(abc.ABC): - system = None - judgements = None - k = None - _CURRENT_INSTANCE = None - - @classmethod - def _calc_average_precision(csl, query: QueryBase) -> float: - num_relevant = 0 - average_precision = 0.0 - current_rank = 1 - - for document in csl._CURRENT_INSTANCE.system.search(query): - if current_rank > csl._CURRENT_INSTANCE.k: - break - if (query.query_id, document.document_id) in csl._CURRENT_INSTANCE.judgements: - num_relevant += 1 - average_precision += num_relevant / current_rank - current_rank += 1 - - average_precision /= num_relevant if num_relevant > 0 else 1 - - return average_precision - - def mean_average_precision(self, system: IRSystemBase, queries: OrderedDict, - judgements: Set[JudgementBase], - k: int, num_processes: int) -> float: - """Evaluate system for given queries and judgements with mean average precision - metric. Where first k documents will be used in evaluation. - - Arguments - --------- - system : IRSystemBase - System to be evaluated. - queries : OrderedDict - Queries to be searched. - judgements : Set[JudgementBase] - Judgements. - k : int - Parameter defining evaluation depth. - num_processes : int - Parallelization parameter defining number of processes to be used to run the evaluation. - - Returns - ------- - float - Mean average precision score from interval [0, 1]. - """ - map_score = 0.0 - - self.system = system - self.judgements = _judgements_obj_to_id(judgements) - self.k = k - self.__class__._CURRENT_INSTANCE = self - - if num_processes == 1: - for query in list(queries.values()): - map_score += self.__class__._calc_average_precision(query) - else: - with get_context("fork").Pool(processes=num_processes) as process_pool: - for precision in process_pool.imap(self.__class__._calc_average_precision, list(queries.values())): - map_score += precision - - map_score /= len(queries) - self.__class__._CURRENT_INSTANCE = None - return map_score - - -def _calc_ndcg(system: IRSystemBase, judgements: Set, k: int, - query: QueryBase) -> float: +def _calc_ndcg(judgements: Set, k: int, query: QueryBase) -> float: num_relevant = 0 dcg = 0.0 current_rank = 1 + global _CURR_SYSTEM - for document in system.search(query): + for document in _CURR_SYSTEM.search(query): if current_rank > k: break if (query.query_id, document.document_id) in judgements: @@ -177,14 +112,14 @@ def _calc_ndcg(system: IRSystemBase, judgements: Set, k: int, return dcg / idcg -def _calc_bpref(system: IRSystemBase, judgements: Set, k: int, - query: QueryBase) -> float: +def _calc_bpref(judgements: Set, k: int, query: QueryBase) -> float: num_relevant = 0 relevant_doc_ranks = [] current_rank = 1 bpref = 0.0 + global _CURR_SYSTEM - for document in system.search(query): + for document in _CURR_SYSTEM.search(query): if current_rank > k: break if (query.query_id, document.document_id) in judgements: @@ -200,7 +135,7 @@ def _calc_bpref(system: IRSystemBase, judgements: Set, k: int, def mean_average_precision(system: IRSystemBase, queries: OrderedDict, judgements: Set[JudgementBase], - k: int, num_processes: int) -> float: + k: int, num_processes: Optional[int] = None) -> float: """Evaluate system for given queries and judgements with mean average precision metric. Where first k documents will be used in evaluation. @@ -224,24 +159,31 @@ def mean_average_precision(system: IRSystemBase, queries: OrderedDict, """ map_score = 0.0 + global _CURR_SYSTEM + _CURR_SYSTEM = system + + query_values = tqdm(list(queries.values())) + if num_processes == 1: - for query in list(queries.values()): - map_score += _calc_average_precision(system, _judgements_obj_to_id(judgements), k, query) + for query in query_values: + map_score += _calc_average_precision(_judgements_obj_to_id(judgements), k, query) else: - worker_avg_precision = partial(_calc_average_precision, system, + worker_avg_precision = partial(_calc_average_precision, _judgements_obj_to_id(judgements), k) with get_context("fork").Pool(processes=num_processes) as process_pool: - for precision in process_pool.imap(worker_avg_precision, list(queries.values())): + for precision in process_pool.imap(worker_avg_precision, query_values): map_score += precision map_score /= len(queries) + _CURR_SYSTEM = None + return map_score def mean_precision(system: IRSystemBase, queries: OrderedDict, - judgements: Set[JudgementBase], k: int, num_processes: int) -> float: + judgements: Set[JudgementBase], k: int, num_processes: Optional[int] = None) -> float: """Evaluate system for given queries and judgements with mean precision metric. Where first k documents will be used in evaluation. @@ -265,22 +207,29 @@ def mean_precision(system: IRSystemBase, queries: OrderedDict, """ mp_score = 0 + global _CURR_SYSTEM + _CURR_SYSTEM = system + + query_values = tqdm(list(queries.values())) + if num_processes == 1: - for query in list(queries.values()): - mp_score += _calc_precision(system, _judgements_obj_to_id(judgements), k, query) + for query in query_values: + mp_score += _calc_precision(_judgements_obj_to_id(judgements), k, query) else: - worker_precision = partial(_calc_precision, system, + worker_precision = partial(_calc_precision, _judgements_obj_to_id(judgements), k) with Pool(processes=num_processes) as process_pool: - for precision in process_pool.imap(worker_precision, list(queries.values())): + for precision in process_pool.imap(worker_precision, query_values): mp_score += precision + _CURR_SYSTEM = None + return mp_score / len(queries) def mean_recall(system: IRSystemBase, queries: OrderedDict, - judgements: Set[JudgementBase], k: int, num_processes: int) -> float: + judgements: Set[JudgementBase], k: int, num_processes: Optional[int] = None) -> float: """Evaluate system for given queries and judgements with mean recall metric. Where first k documents will be used in evaluation. @@ -303,24 +252,32 @@ def mean_recall(system: IRSystemBase, queries: OrderedDict, Mean recall score from interval [0, 1]. """ mr_score = 0 + + global _CURR_SYSTEM + _CURR_SYSTEM = system + + query_values = tqdm(list(queries.values())) + if num_processes == 1: - for query in list(queries.values()): - mr_score += _calc_recall(system, _judgements_obj_to_id(judgements), k, query) + for query in query_values: + mr_score += _calc_recall(_judgements_obj_to_id(judgements), k, query) else: - worker_recall = partial(_calc_recall, system, + worker_recall = partial(_calc_recall, _judgements_obj_to_id(judgements), k) with Pool(processes=num_processes) as process_pool: - for recall in process_pool.imap(worker_recall, list(queries.values())): + for recall in process_pool.imap(worker_recall, query_values): mr_score += recall + _CURR_SYSTEM = None + return mr_score / len(queries) def normalized_discounted_cumulative_gain(system: IRSystemBase, queries: OrderedDict, judgements: Set[JudgementBase], - k: int, num_processes: int) -> float: + k: int, num_processes: Optional[int] = None) -> float: """Evaluate system for given queries and judgements with normalized discounted cumulative gain metric. Where first k documents will be used in evaluation. @@ -344,22 +301,29 @@ def normalized_discounted_cumulative_gain(system: IRSystemBase, """ ndcg_score = 0 + global _CURR_SYSTEM + _CURR_SYSTEM = system + + query_values = tqdm(list(queries.values())) + if num_processes == 1: - for query in list(queries.values()): - ndcg_score += _calc_ndcg(system, _judgements_obj_to_id(judgements), k, query) + for query in query_values: + ndcg_score += _calc_ndcg(_judgements_obj_to_id(judgements), k, query) else: - worker_ndcg = partial(_calc_ndcg, system, + worker_ndcg = partial(_calc_ndcg, _judgements_obj_to_id(judgements), k) with Pool(processes=num_processes) as process_pool: - for dcg in process_pool.imap(worker_ndcg, list(queries.values())): + for dcg in process_pool.imap(worker_ndcg, query_values): ndcg_score += dcg + _CURR_SYSTEM = None + return ndcg_score / len(queries) def mean_bpref(system: IRSystemBase, queries: OrderedDict, - judgements: Set[JudgementBase], k: int, num_processes: int) -> float: + judgements: Set[JudgementBase], k: int, num_processes: Optional[int] = None) -> float: """Evaluate system for given queries and judgements with bpref metric. Where first k documents will be used in evaluation. @@ -391,15 +355,23 @@ def mean_bpref(system: IRSystemBase, queries: OrderedDict, Bpref score from interval [0, 1]. """ bpref_score = 0 + + global _CURR_SYSTEM + _CURR_SYSTEM = system + + query_values = tqdm(list(queries.values())) + if num_processes == 1: - for query in list(queries.values()): - bpref_score += _calc_bpref(system, _judgements_obj_to_id(judgements), k, query) + for query in query_values: + bpref_score += _calc_bpref(_judgements_obj_to_id(judgements), k, query) else: - worker_bpref = partial(_calc_bpref, system, + worker_bpref = partial(_calc_bpref, _judgements_obj_to_id(judgements), k) with Pool(processes=num_processes) as process_pool: - for bpref in process_pool.imap(worker_bpref, list(queries.values())): + for bpref in process_pool.imap(worker_bpref, query_values): bpref_score += bpref + _CURR_SYSTEM = None + return bpref_score / len(queries) diff --git a/script/download_datasets.py b/script/download_datasets.py index 021f2d9..1a5ebb5 100644 --- a/script/download_datasets.py +++ b/script/download_datasets.py @@ -29,12 +29,37 @@ def download_arqmath(root_directory: Path) -> None: load_questions(text_format, answers, cache_download=questions_pathname) +def download_cqadupstack(root_directory: Path) -> None: + from pv211_utils.beir.entities import RawBeirDataset, RawBeirDatasets + from pv211_utils.beir.loader import load_beir_datasets + + android = RawBeirDataset("android") + english = RawBeirDataset("english") + gaming = RawBeirDataset("gaming") + gis = RawBeirDataset("gis") + mathematica = RawBeirDataset("mathematica") + physics = RawBeirDataset("physics") + programmers = RawBeirDataset("programmers") + stats = RawBeirDataset("stats") + tex = RawBeirDataset("tex") + unix = RawBeirDataset("unix") + webmasters = RawBeirDataset("webmasters") + wordpress = RawBeirDataset("wordpress") + + datasets = RawBeirDatasets(datasets=[android, english, gaming, gis, + mathematica, physics, programmers, + stats, tex, unix, webmasters, wordpress], + download_location=str(root_directory/"cqa_datasets")) + load_beir_datasets(datasets) + + def main() -> None: umask(0o000) root_directory = Path('/var') / 'tmp' / 'pv211' root_directory.mkdir(parents=True, exist_ok=True, mode=0o777) download_trec(root_directory) download_arqmath(root_directory) + download_cqadupstack(root_directory) if __name__ == '__main__':