diff --git a/pipeline/__init__.py b/pipeline/__init__.py index 7835bbda..03c9817a 100644 --- a/pipeline/__init__.py +++ b/pipeline/__init__.py @@ -5,7 +5,103 @@ from dataset.st_dataset import SummDataset from dataset import ScisummnetDataset -from typing import List, Tuple +from typing import Dict, List, Tuple, Union, Set + + +def retrieve_task_nodes(model_or_dataset: Union[SummModel, SummDataset]) -> List[str]: + """Generates list of summarization task nodes as strings + given model or dataset. + + Args: + model_or_dataset (Union[SummModel, SummDataset]): SummerTime model or dataset + + Returns: + List[str]: Model/dataset task types in string form + """ + task_nodes = ["is_single_document"] + if model_or_dataset.is_dialogue_based: + task_nodes.append("is_dialogue_based") + if model_or_dataset.is_multi_document: + task_nodes.append("is_multi_document") + if model_or_dataset.is_query_based: + task_nodes.append("is_query_based") + return task_nodes + + +def top_sort_dfs( + list_nodes: List[str], + graph: Dict[str, List], + cur_node: str, + sorted_list: List[str], + visited: Set, +): + """DFS helper for topological sort + + Args: + list_nodes (List[str]): List of nodes to sort + graph (Dict[List]): Directed graph of nodes + cur_node (str): Current node in dfs + sorted_list (List[str]): Sorted list to append new nodes to + visited (Dict[bool]): Tracks whether nodes have been visited previously + """ + if cur_node in visited: + return + visited.add(cur_node) + sorted_list.append(cur_node) + for neighbor in graph[cur_node]: + if neighbor in list_nodes and not neighbor in visited: + top_sort_dfs(list_nodes, graph, neighbor, sorted_list, visited) + + +def top_sort_options(list_nodes: List[str], graph: Dict[str, List]) -> List[str]: + """Sorts list according to topological order in graph. + + Args: + list (List[str]): list of nodes to sort + graph (Dict[List]): graph containing topological order of nodes + + Returns: + List[str]: topologically sorted list + """ + in_degrees = {} + for node in list_nodes: + in_degrees[node] = 0 + for node in list_nodes: + for neighbor in graph[node]: + if neighbor in list_nodes: + in_degrees[neighbor] += 1 + + sorted_list = [] + visited = set() + for node in list_nodes: + if in_degrees[node] == 0: + top_sort_dfs(list_nodes, graph, node, sorted_list, visited) + if len(sorted_list) == 0: + print(list_nodes) + print("Graph is cyclical!!!") + return [] + + print(sorted_list) + return sorted_list + + +def create_model_composition_graph() -> Dict[str, List]: + """Returns directed graph where each node + is a summarization task and each edge represents + an appropriate order for models to be applied + to a multi-layered summarization task. + + Returns: + Dict[List]: Adjacency list representation of + graph. + """ + graph = {} + graph["is_single_document"] = [] + graph["is_multi_document"] = ["is_single_document"] + graph["is_dialogue_based"] = ["is_multi_document", "is_single_document"] + graph["is_query_based"] = ["is_dialogue_based", "is_multi_document"] + print(graph) + return graph def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]: @@ -31,6 +127,94 @@ def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]: return src +def assemble_model_pipeline_2( + dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS +) -> List[Tuple[SummModel, str]]: + """ + Return initialized list of all model pipelines that match the summarization task of given dataset. + + :param SummDataset `dataset`: Dataset to retrieve model pipelines for. + :param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`. + :returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`. + """ + + dataset = dataset if isinstance(dataset, SummDataset) else dataset() + + single_doc_model_list = list( + filter( + lambda model_cls: not ( + model_cls.is_dialogue_based + or model_cls.is_query_based + or model_cls.is_multi_document + ), + model_list, + ) + ) + single_doc_model_instances = [ + model_cls(get_lxr_train_set(dataset)) + if model_cls == LexRankModel + else model_cls() + for model_cls in single_doc_model_list + ] + + multi_doc_model_list = list( + filter(lambda model_cls: model_cls.is_multi_document, model_list) + ) + + query_based_model_list = list( + filter(lambda model_cls: model_cls.is_query_based, model_list) + ) + + dialogue_based_model_list = list( + filter(lambda model_cls: model_cls.is_dialogue_based, model_list) + ) + dialogue_based_model_instances = ( + [model_cls() for model_cls in dialogue_based_model_list] + if dataset.is_dialogue_based + else [] + ) + + task_node_list = retrieve_task_nodes(dataset) + graph = create_model_composition_graph() + sorted_task_node_list = top_sort_options(task_node_list, graph) + + print(sorted_task_node_list) + if len(sorted_task_node_list) == 0: + return [(model, model.model_name) for model in single_doc_model_instances] + + task_node_to_model_list = { + "is_single_document": single_doc_model_list, + "is_dialogue_based": dialogue_based_model_list, + "is_multi_document": multi_doc_model_list, + "is_query_based": query_based_model_list, + } + + matching_models = [] + sorted_task_node_list.reverse() + for task_node in sorted_task_node_list: + if len(matching_models) == 0: + for model_cls in task_node_to_model_list[task_node]: + # TODO: How to tell if last task needs model backend? + matching_models.append((model_cls, model_cls.model_name)) + else: + new_matching_models = [] + for model_cls in task_node_to_model_list[task_node]: + for model_backend, model_backend_name in matching_models: + new_matching_models.append( + ( + model_cls( + model_backend=model_backend, + data=get_lxr_train_set(dataset), + ), + f"{model_cls.model_name} ({model_backend_name})", + ) + if model_backend == LexRankModel + else model_cls(model_backend=model_backend) + ) + matching_models = new_matching_models + return matching_models + + def assemble_model_pipeline( dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS ) -> List[Tuple[SummModel, str]]: diff --git a/tests/integration_test.py b/tests/integration_test.py index 7db778a3..144fb952 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -3,7 +3,7 @@ from model.base_model import SummModel from model import SUPPORTED_SUMM_MODELS -from pipeline import assemble_model_pipeline +from pipeline import assemble_model_pipeline, assemble_model_pipeline_2 from evaluation.base_metric import SummMetric from evaluation import SUPPORTED_EVALUATION_METRICS @@ -87,9 +87,10 @@ def test_all(self): "35", ) # matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS))) - matching_model_instances = assemble_model_pipeline( + matching_model_instances = assemble_model_pipeline_2( dataset_cls, SUPPORTED_SUMM_MODELS ) + print(matching_model_instances) for model, model_name in matching_model_instances: test_instances = retrieve_random_test_instances( dataset_instances=dataset_instances, num_instances=1