diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7c7a2b742..b40a776b3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -33,7 +33,13 @@ jobs: flake8 medcat - name: Test run: | - timeout 17m python -m unittest discover + all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g') + num_files=$(echo "$all_files" | wc -l) + midpoint=$((num_files / 2)) + first_half_nl=$(echo "$all_files" | head -n $midpoint) + second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1))) + timeout 25m python -m unittest ${first_half_nl[@]} + timeout 25m python -m unittest ${second_half_nl[@]} publish-to-test-pypi: @@ -43,7 +49,7 @@ jobs: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') != true runs-on: ubuntu-20.04 - timeout-minutes: 20 + timeout-minutes: 45 concurrency: publish-to-test-pypi needs: [build] diff --git a/medcat/stats/kfold.py b/medcat/stats/kfold.py new file mode 100644 index 000000000..491173c23 --- /dev/null +++ b/medcat/stats/kfold.py @@ -0,0 +1,436 @@ +from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any + +from abc import ABC, abstractmethod +from enum import Enum, auto +from copy import deepcopy + +import numpy as np + +from medcat.utils.checkpoint import Checkpoint +from medcat.utils.cdb_state import captured_state_cdb + +from medcat.stats.stats import get_stats +from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportProject +from medcat.stats.mctexport import MedCATTrainerExportDocument, MedCATTrainerExportAnnotation +from medcat.stats.mctexport import count_all_annotations, count_all_docs, get_nr_of_annotations +from medcat.stats.mctexport import iter_anns, iter_docs, MedCATTrainerExportProjectInfo + + + +class CDBLike(Protocol): + pass + + +class CATLike(Protocol): + + @property + def cdb(self) -> CDBLike: + pass + + def train_supervised_raw(self, + data: Dict[str, List[Dict[str, dict]]], + reset_cui_count: bool = False, + nepochs: int = 1, + print_stats: int = 0, + use_filters: bool = False, + terminate_last: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + test_size: float = 0, + devalue_others: bool = False, + use_groups: bool = False, + never_terminate: bool = False, + train_from_false_positives: bool = False, + extra_cui_filter: Optional[Set] = None, + retain_extra_cui_filter: bool = False, + checkpoint: Optional[Checkpoint] = None, + retain_filters: bool = False, + is_resumed: bool = False) -> Tuple: + pass + + +class SplitType(Enum): + """The split type.""" + DOCUMENTS = auto() + """Split over number of documents.""" + ANNOTATIONS = auto() + """Split over number of annotations.""" + DOCUMENTS_WEIGHTED = auto() + """Split over number of documents based on the number of annotations. + So essentially this ensures that the same document isn't in 2 folds + while trying to more equally distribute documents with different number + of annotations. + For example: + If we have 6 documents that we want to split into 3 folds. + The number of annotations per document are as follows: + [40, 40, 20, 10, 5, 5] + If we were to split this trivially over documents, we'd end up + with the 3 folds with number of annotations that are far from even: + [80, 30, 10] + However, if we use the annotations as weights, we would be able to + create folds that have more evenly distributed annotations, e.g: + [[D1,], [D2], [D3, D4, D5, D6]] + where D# denotes the number of the documents, with the number of + annotations being equal: + [ 40, 40, 20 + 10 + 5 + 5 = 40] + """ + + +class FoldCreator(ABC): + """The FoldCreator based on a MCT export. + + Args: + mct_export (MedCATTrainerExport): The MCT export dict. + nr_of_folds (int): Number of folds to create. + use_annotations (bool): Whether to fold on number of annotations or documents. + """ + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + self.mct_export = mct_export + self.nr_of_folds = nr_of_folds + + def _find_or_add_doc(self, project: MedCATTrainerExportProject, orig_doc: MedCATTrainerExportDocument + ) -> MedCATTrainerExportDocument: + for existing_doc in project['documents']: + if existing_doc['name'] == orig_doc['name']: + return existing_doc + new_doc: MedCATTrainerExportDocument = deepcopy(orig_doc) + new_doc['annotations'].clear() + project['documents'].append(new_doc) + return new_doc + + def _create_new_project(self, proj_info: MedCATTrainerExportProjectInfo) -> MedCATTrainerExportProject: + (proj_name, proj_id, proj_cuis, proj_tuis) = proj_info + cur_project = cast(MedCATTrainerExportProject, { + 'name': proj_name, + 'id': proj_id, + 'cuis': proj_cuis, + 'documents': [], + }) + # NOTE: Some MCT exports don't declare TUIs + if proj_tuis is not None: + cur_project['tuis'] = proj_tuis + return cur_project + + def _create_export_with_documents(self, relevant_docs: Iterable[Tuple[MedCATTrainerExportProjectInfo, + MedCATTrainerExportDocument]]) -> MedCATTrainerExport: + export: MedCATTrainerExport = { + "projects": [] + } + # helper for finding projects per name + used_projects: Dict[str, MedCATTrainerExportProject] = {} + for proj_info, doc in relevant_docs: + proj_name = proj_info[0] + if proj_name not in used_projects: + cur_project = self._create_new_project(proj_info) # TODO - make sure it's available + export['projects'].append(cur_project) + used_projects[proj_name] = cur_project + else: + cur_project = used_projects[proj_name] + cur_project['documents'].append(doc) + return export + + + @abstractmethod + def create_folds(self) -> List[MedCATTrainerExport]: + """Create folds. + + Raises: + ValueError: If somethign went wrong. + + Returns: + List[MedCATTrainerExport]: The created folds. + """ + + +class SimpleFoldCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int, + counter: Callable[[MedCATTrainerExport], int]) -> None: + super().__init__(mct_export, nr_of_folds) + self._counter = counter + self.total = self._counter(mct_export) + self.per_fold = self._init_per_fold() + + def _init_per_fold(self) -> List[int]: + per_fold = [self.total // self.nr_of_folds for _ in range(self.nr_of_folds)] + total = sum(per_fold) + if total < self.total: + per_fold[-1] += self.total - total + if any(pf <= 0 for pf in per_fold): + raise ValueError(f"Failed to calculate per-fold items. Got: {per_fold}") + return per_fold + + @abstractmethod + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + pass + + def create_folds(self) -> List[MedCATTrainerExport]: + return [ + self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds) + ] + + + +class PerDocsFoldCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + super().__init__(mct_export, nr_of_folds) + self.nr_of_docs = count_all_docs(self.mct_export) + self.per_doc_simple = self.nr_of_docs // self.nr_of_folds + self._all_docs = list(iter_docs(self.mct_export)) + + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + start_nr = self.per_doc_simple * fold_nr + # until the end for last fold, otherwise just the next set of docs + end_nr = self.nr_of_docs if fold_nr == self.nr_of_folds - 1 else start_nr + self.per_doc_simple + relevant_docs = self._all_docs[start_nr: end_nr] + return self._create_export_with_documents(relevant_docs) + + def create_folds(self) -> List[MedCATTrainerExport]: + return [ + self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds) + ] + + +class PerAnnsFoldCreator(SimpleFoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + super().__init__(mct_export, nr_of_folds, count_all_annotations) + + def _add_target_ann(self, project: MedCATTrainerExportProject, + orig_doc: MedCATTrainerExportDocument, + ann: MedCATTrainerExportAnnotation) -> None: + cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc) + cur_doc['annotations'].append(ann) + + def _targets(self) -> Iterable[Tuple[MedCATTrainerExportProjectInfo, + MedCATTrainerExportDocument, + MedCATTrainerExportAnnotation]]: + return iter_anns(self.mct_export) + + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + per_fold = self.per_fold[fold_nr] + cur_fold: MedCATTrainerExport = { + 'projects': [] + } + cur_project: Optional[MedCATTrainerExportProject] = None + included = 0 + for target in self._targets(): + proj_info, cur_doc, cur_ann = target + proj_name = proj_info[0] + if not cur_project or cur_project['name'] != proj_name: + # first or new project + cur_project = self._create_new_project(proj_info) + cur_fold['projects'].append(cur_project) + self._add_target_ann(cur_project, cur_doc, cur_ann) + included += 1 + if included == per_fold: + break + if included > per_fold: + raise ValueError("Got a larger fold than expected. " + f"Expected {per_fold}, got {included}") + return cur_fold + + +class WeightedDocumentsCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int, + weight_calculator: Callable[[MedCATTrainerExportDocument], int]) -> None: + super().__init__(mct_export, nr_of_folds) + self._weight_calculator = weight_calculator + docs = [(doc, self._weight_calculator(doc[1])) for doc in iter_docs(self.mct_export)] + # descending order in weight + self._weighted_docs = sorted(docs, key=lambda d: d[1], reverse=True) + + def create_folds(self) -> List[MedCATTrainerExport]: + doc_folds: List[List[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]] + doc_folds = [[] for _ in range(self.nr_of_folds)] + fold_weights = [0] * self.nr_of_folds + + for item, weight in self._weighted_docs: + # Find the subset with the minimum total weight + min_subset_idx = np.argmin(fold_weights) + # add the most heavily weighted document + doc_folds[min_subset_idx].append(item) + fold_weights[min_subset_idx] += weight + + return [self._create_export_with_documents(docs) for docs in doc_folds] + + +def get_fold_creator(mct_export: MedCATTrainerExport, + nr_of_folds: int, + split_type: SplitType) -> FoldCreator: + """Get the appropriate fold creator. + + Args: + mct_export (MedCATTrainerExport): The MCT export. + nr_of_folds (int): Number of folds to use. + split_type (SplitType): The type of split to use. + + Raises: + ValueError: In case of an unknown split type. + + Returns: + FoldCreator: The corresponding fold creator. + """ + if split_type is SplitType.DOCUMENTS: + return PerDocsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds) + elif split_type is SplitType.ANNOTATIONS: + return PerAnnsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds) + elif split_type is SplitType.DOCUMENTS_WEIGHTED: + return WeightedDocumentsCreator(mct_export=mct_export, nr_of_folds=nr_of_folds, + weight_calculator=get_nr_of_annotations) + else: + raise ValueError(f"Unknown Split Type: {split_type}") + + +def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport], + *args, **kwargs) -> List[Tuple]: + metrics = [] + for fold_nr, cur_fold in enumerate(folds): + others = list(folds) + others.pop(fold_nr) + with captured_state_cdb(cat.cdb): + for other in others: + cat.train_supervised_raw(cast(Dict[str, Any], other), *args, **kwargs) + stats = get_stats(cat, cast(Dict[str, Any], cur_fold), do_print=False) + metrics.append(stats) + return metrics + + +def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]], + single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None: + if len(joined) != len(single): + raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}") + for j, s in zip(joined, single): + _update_one_weighted_average(j, s, cui2count) + + +def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]], + one: Dict[str, float], + cui2count: Dict[str, int]) -> None: + for k in one: + if k not in joined: + joined[k] = (0, 0) + prev_w, prev_val = joined[k] + new_w, new_val = cui2count[k], one[k] + total_w = prev_w + new_w + total_val = (prev_w * prev_val + new_w * new_val) / total_w + joined[k] = (total_w, total_val) + + +def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None: + if len(joined) != len(single): + raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}") + for j, s in zip(joined, single): + for k, v in s.items(): + j[k] = j.get(k, 0) + v + + +def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None: + for ex_type, ex_dict in cur_examples.items(): + if ex_type not in all_examples: + all_examples[ex_type] = {} + per_type_examples = all_examples[ex_type] + for ex_cui, cui_examples_list in ex_dict.items(): + if ex_cui not in per_type_examples: + per_type_examples[ex_cui] = [] + per_type_examples[ex_cui].extend(cui_examples_list) + + +def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]] + ) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]: + """The the mean of the provided metrics. + + Args: + metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics. + + Returns: + fps (dict): + False positives for each CUI. + fns (dict): + False negatives for each CUI. + tps (dict): + True positives for each CUI. + cui_prec (dict): + Precision for each CUI. + cui_rec (dict): + Recall for each CUI. + cui_f1 (dict): + F1 for each CUI. + cui_counts (dict): + Number of occurrence for each CUI. + examples (dict): + Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + """ + # additives + all_fps: Dict[str, int] = {} + all_fns: Dict[str, int] = {} + all_tps: Dict[str, int] = {} + # weighted-averages + all_cui_prec: Dict[str, Tuple[int, float]] = {} + all_cui_rec: Dict[str, Tuple[int, float]] = {} + all_cui_f1: Dict[str, Tuple[int, float]] = {} + # additive + all_cui_counts: Dict[str, int] = {} + # combined + all_additives = [ + all_fps, all_fns, all_tps, all_cui_counts + ] + all_weighted_averages = [ + all_cui_prec, all_cui_rec, all_cui_f1 + ] + # examples + all_examples: dict = {} + for current in metrics: + cur_wa: list = list(current[3:-2]) + cur_counts = current[-2] + _update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts) + # update ones that just need to be added up + cur_adds = list(current[:3]) + [cur_counts] + _update_all_add(all_additives, cur_adds) + # merge examples + cur_examples = current[-1] + _merge_examples(all_examples, cur_examples) + cui_prec: Dict[str, float] = {} + cui_rec: Dict[str, float] = {} + cui_f1: Dict[str, float] = {} + final_wa = [ + cui_prec, cui_rec, cui_f1 + ] + # just remove the weight / count + for df, d in zip(final_wa, all_weighted_averages): + for k, v in d.items(): + df[k] = v[1] # only the value, ingore the weight + return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2], + all_cui_counts, all_examples) + + +def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3, + split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED, *args, **kwargs) -> Tuple: + """Get the k-fold stats for the model with the specified data. + + First this will split the MCT export into `k` folds. You can do + this either per document or per-annotation. + + For each of the `k` folds, it will start from the base model, + train it with with the other `k-1` folds and record the metrics. + After that the base model state is restored before doing the next fold. + After all the folds have been done, the metrics are averaged. + + Args: + cat (CATLike): The model pack. + mct_export_data (MedCATTrainerExport): The MCT export. + k (int): The number of folds. Defaults to 3. + split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED. + *args: Arguments passed to the `CAT.train_supervised_raw` method. + **kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method. + + Returns: + Tuple: The averaged metrics. + """ + creator = get_fold_creator(mct_export_data, k, split_type=split_type) + folds = creator.create_folds() + per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs) + return get_metrics_mean(per_fold_metrics) diff --git a/medcat/stats/mctexport.py b/medcat/stats/mctexport.py new file mode 100644 index 000000000..54f5a4443 --- /dev/null +++ b/medcat/stats/mctexport.py @@ -0,0 +1,66 @@ +from typing import List, Iterator, Tuple, Any, Optional +from typing_extensions import TypedDict + + +class MedCATTrainerExportAnnotation(TypedDict): + start: int + end: int + cui: str + value: str + + +class MedCATTrainerExportDocument(TypedDict): + name: str + id: Any + last_modified: str + text: str + annotations: List[MedCATTrainerExportAnnotation] + + +class MedCATTrainerExportProject(TypedDict): + name: str + id: Any + cuis: str + tuis: Optional[str] + documents: List[MedCATTrainerExportDocument] + + +MedCATTrainerExportProjectInfo = Tuple[str, Any, str, Optional[str]] +"""The project name, project ID, CUIs str, and TUIs str""" + + +class MedCATTrainerExport(TypedDict): + projects: List[MedCATTrainerExportProject] + + +def iter_projects(export: MedCATTrainerExport) -> Iterator[MedCATTrainerExportProject]: + yield from export['projects'] + + +def iter_docs(export: MedCATTrainerExport + ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]: + for project in iter_projects(export): + info: MedCATTrainerExportProjectInfo = ( + project['name'], project['id'], project['cuis'], project.get('tuis', None) + ) + for doc in project['documents']: + yield info, doc + + +def iter_anns(export: MedCATTrainerExport + ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation]]: + for proj_info, doc in iter_docs(export): + for ann in doc['annotations']: + yield proj_info, doc, ann + + +def count_all_annotations(export: MedCATTrainerExport) -> int: + return len(list(iter_anns(export))) + + +def count_all_docs(export: MedCATTrainerExport) -> int: + return len(list(iter_docs(export))) + + +def get_nr_of_annotations(doc: MedCATTrainerExportDocument) -> int: + return len(doc['annotations']) diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py index 610d4d2a1..e467e0519 100644 --- a/medcat/stats/stats.py +++ b/medcat/stats/stats.py @@ -60,6 +60,9 @@ def process_project(self, project: dict) -> None: # Add extra filter if set set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters) + project_name = cast(str, project.get('name')) + project_id = cast(str, project.get('id')) + documents = project["documents"] for dind, doc in tqdm( enumerate(documents), @@ -67,8 +70,7 @@ def process_project(self, project: dict) -> None: total=len(documents), leave=False, ): - self.process_document(cast(str, project.get('name')), - cast(str, project.get('id')), doc) + self.process_document(project_name, project_id, doc) def process_document(self, project_name: str, project_id: str, doc: dict) -> None: anns = self._get_doc_annotations(doc) diff --git a/medcat/utils/cdb_state.py b/medcat/utils/cdb_state.py new file mode 100644 index 000000000..794a40109 --- /dev/null +++ b/medcat/utils/cdb_state.py @@ -0,0 +1,179 @@ +import logging +import contextlib +from typing import Dict, TypedDict, Set, List, cast +import numpy as np +import tempfile +import dill + +from copy import deepcopy + + + +logger = logging.getLogger(__name__) # separate logger from the package-level one + + +CDBState = TypedDict( + 'CDBState', + { + 'name2cuis': Dict[str, List[str]], + 'snames': Set[str], + 'cui2names': Dict[str, Set[str]], + 'cui2snames': Dict[str, Set[str]], + 'cui2context_vectors': Dict[str, Dict[str, np.ndarray]], + 'cui2count_train': Dict[str, int], + 'name_isupper': Dict, + 'vocab': Dict[str, int], + }) +"""CDB State. + +This is a dictionary of the parts of the CDB that change during +(supervised) training. It can be used to store and restore the +state of a CDB after modifying it. + +Currently, the following fields are saved: + - name2cuis + - snames + - cui2names + - cui2snames + - cui2context_vectors + - cui2count_train + - name_isupper + - vocab +""" + + +def copy_cdb_state(cdb) -> CDBState: + """Creates a (deep) copy of the CDB state. + + Grabs the fields that correspond to the state, + creates deep copies, and returns the copies. + + Args: + cdb: The CDB from which to grab the state. + + Returns: + CDBState: The copied state. + """ + return cast(CDBState, { + k: deepcopy(getattr(cdb, k)) for k in CDBState.__annotations__ + }) + + +def save_cdb_state(cdb, file_path: str) -> None: + """Saves CDB state in a file. + + Currently uses `dill.dump` to save the relevant fields/values. + + Args: + cdb: The CDB from which to grab the state. + file_path (str): The file to dump the state. + """ + # NOTE: The difference is that we don't create a copy here. + # That is so that we don't have to occupy the memory for + # both copies + the_dict = { + k: getattr(cdb, k) for k in CDBState.__annotations__ + } + logger.debug("Saving CDB state on disk at: '%s'", file_path) + with open(file_path, 'wb') as f: + dill.dump(the_dict, f) + + +def apply_cdb_state(cdb, state: CDBState) -> None: + """Apply the specified state to the specified CDB. + + This overwrites the current state of the CDB with one provided. + + Args: + cdb: The CDB to apply the state to. + state (CDBState): The state to use. + """ + for k, v in state.items(): + setattr(cdb, k, v) + + +def load_and_apply_cdb_state(cdb, file_path: str) -> None: + """Delete current CDB state and apply CDB state from file. + + This first delets the current state of the CDB. + This is to save memory. The idea is that saving the staet + on disk will save on RAM usage. But it wouldn't really + work too well if upon load, two instances were still in + memory. + + Args: + cdb: The CDB to apply the state to. + file_path (str): The file where the state has been saved to. + """ + # clear existing data on CDB + # this is so that we don't occupy the memory for both the loaded + # and the on-CDB data + logger.debug("Clearing CDB state in memory") + for k in CDBState.__annotations__: + val = getattr(cdb, k) + setattr(cdb, k, None) + del val + logger.debug("Loading CDB state from disk from '%s'", file_path) + with open(file_path, 'rb') as f: + data = dill.load(f) + for k in CDBState.__annotations__: + setattr(cdb, k, data[k]) + + +@contextlib.contextmanager +def captured_state_cdb(cdb, save_state_to_disk: bool = False): + """A context manager that captures and re-applies the initial CDB state. + + The context manager captures/copies the initial state of the CDB when entering. + It then allows the user to modify the state (i.e training). + Upon exit re-applies the initial CDB state. + + If RAM is an issue, it is recommended to use `save_state_to_disk`. + Otherwise the copy of the original state will be held in memory. + If saved on disk, a temporary file is used and removed afterwards. + + Args: + cdb: The CDB to use. + save_state_to_disk (bool): Whether to save state on disk or hold in in memory. + Defaults to False. + + Yields: + None + """ + if save_state_to_disk: + with on_disk_memory_capture(cdb): + yield + else: + with in_memory_state_capture(cdb): + yield + + +@contextlib.contextmanager +def in_memory_state_capture(cdb): + """Capture the CDB state in memory. + + Args: + cdb: The CDB to use. + + Yields: + None + """ + state = copy_cdb_state(cdb) + yield + apply_cdb_state(cdb, state) + + +@contextlib.contextmanager +def on_disk_memory_capture(cdb): + """Capture the CDB state in a temporary file. + + Args: + cdb: The CDB to use + + Yields: + None + """ + with tempfile.NamedTemporaryFile() as tf: + save_cdb_state(cdb, tf.name) + yield + load_and_apply_cdb_state(cdb, tf.name) diff --git a/medcat/utils/cdb_utils.py b/medcat/utils/cdb_utils.py index c473ddba4..fefaf1273 100644 --- a/medcat/utils/cdb_utils.py +++ b/medcat/utils/cdb_utils.py @@ -63,7 +63,7 @@ def merge_cdb(cdb1: CDB, ontologies.update(cdb2.addl_info['cui2ontologies'][cui]) if 'cui2description' in cdb2.addl_info: description = cdb2.addl_info['cui2description'][cui] - cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, + cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build) if cui in cdb1.cui2names: if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py deleted file mode 100644 index 9e2fc2d72..000000000 --- a/tests/archive_tests/test_cdb_maker_archive.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -import unittest -import numpy as np -from medcat.cdb import CDB -from medcat.cdb_maker import CDBMaker -from medcat.config import Config -from medcat.preprocessing.cleaners import prepare_name - - -class CdbMakerArchiveTests(unittest.TestCase): - - def setUp(self): - self.config = Config() - self.config.general['log_level'] = logging.DEBUG - self.maker = CDBMaker(self.config) - - # Building a new CDB from two files (full_build) - csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv'] - self.cdb = self.maker.prepare_csvs(csvs, full_build=True) - - def test_prepare_csvs(self): - assert len(self.cdb.cui2names) == 3 - assert len(self.cdb.cui2snames) == 3 - assert len(self.cdb.name2cuis) == 5 - assert len(self.cdb.cui2tags) == 3 - assert len(self.cdb.cui2preferred_name) == 2 - assert len(self.cdb.cui2context_vectors) == 3 - assert len(self.cdb.cui2count_train) == 3 - assert self.cdb.name2cuis2status['virus']['C0000039'] == 'P' - assert self.cdb.cui2type_ids['C0000039'] == {'T234', 'T109', 'T123'} - assert self.cdb.addl_info['cui2original_names']['C0000039'] == {'Virus', 'Virus K', 'Virus M', 'Virus Z'} - assert self.cdb.addl_info['cui2description']['C0000039'].startswith("Synthetic") - - def test_name_addition(self): - self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config), name_status='P', full_build=True) - assert self.cdb.addl_info['cui2original_names']['C0000239'] == {'MY: new,-_! Name.', 'Second csv'} - assert 'my:newname.' in self.cdb.name2cuis - assert 'my:new' in self.cdb.snames - assert 'my:newname.' in self.cdb.name2cuis2status - assert self.cdb.name2cuis2status['my:newname.'] == {'C0000239': 'P'} - - def test_name_removal(self): - self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config)) - # Run again to make sure it does not break anything - self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config)) - assert len(self.cdb.name2cuis) == 5 - assert 'my:newname.' not in self.cdb.name2cuis2status - - def test_filtering(self): - cuis_to_keep = {'C0000039'} # Because of transition 2 will be kept - self.cdb.filter_by_cui(cuis_to_keep=cuis_to_keep) - assert len(self.cdb.cui2names) == 2 - assert len(self.cdb.name2cuis) == 4 - assert len(self.cdb.snames) == 4 - - def test_vector_addition(self): - self.cdb.reset_training() - np.random.seed(11) - cuis = list(self.cdb.cui2names.keys()) - for i in range(2): - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=False) - - assert self.cdb.cui2count_train['C0000139'] == 2 - assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300 - - - def test_negative(self): - cuis = list(self.cdb.cui2names.keys()) - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=True) - - assert self.cdb.cui2count_train['C0000139'] == 2 - assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300 - - def test_save_and_load(self): - self.cdb.save("./tmp_cdb.dat") - cdb2 = CDB.load('./tmp_cdb.dat') - # Check a random thing - assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7] - - def test_training_import(self): - cdb2 = CDB.load('./tmp_cdb.dat') - self.cdb.reset_training() - cdb2.reset_training() - np.random.seed(11) - cuis = list(self.cdb.cui2names.keys()) - for i in range(2): - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=False) - - cdb2.import_training(cdb=self.cdb, overwrite=True) - assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7] - assert cdb2.cui2count_train['C0000139'] == self.cdb.cui2count_train['C0000139'] - - def test_concept_similarity(self): - cdb = CDB(config=self.config) - np.random.seed(11) - for i in range(500): - cui = "C" + str(i) - type_ids = {'T-' + str(i%10)} - cdb._add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(), - name_status='P', type_ids=type_ids, description='', full_build=True) - - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - cdb.update_context_vector(cui, vectors, negative=False) - res = cdb.most_similar('C200', 'long', type_id_filter=['T-0'], min_cnt=1, topn=10, force_build=True) - assert len(res) == 10 - - def test_training_reset(self): - self.cdb.reset_training() - assert len(self.cdb.cui2context_vectors['C0']) == 0 - assert self.cdb.cui2count_train['C0'] == 0 diff --git a/tests/archive_tests/test_ner_archive.py b/tests/archive_tests/test_ner_archive.py deleted file mode 100644 index d41ccd0c7..000000000 --- a/tests/archive_tests/test_ner_archive.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -import unittest -import numpy as np -from timeit import default_timer as timer -from medcat.cdb import CDB -from medcat.preprocessing.tokenizers import spacy_split_all -from medcat.ner.vocab_based_ner import NER -from medcat.preprocessing.taggers import tag_skip_and_punct -from medcat.pipe import Pipe -from medcat.utils.normalizers import BasicSpellChecker -from medcat.vocab import Vocab -from medcat.preprocessing.cleaners import prepare_name -from medcat.linking.vector_context_model import ContextModel -from medcat.linking.context_based_linker import Linker -from medcat.config import Config - -from ..helper import VocabDownloader - - -class NerArchiveTests(unittest.TestCase): - - def setUp(self) -> None: - self.config = Config() - self.config.general['log_level'] = logging.INFO - cdb = CDB(config=self.config) - - self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config) - self.nlp.add_tagger(tagger=tag_skip_and_punct, - name='skip_and_punct', - additional_fields=['is_punct']) - - # Add a couple of names - cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config)) - # Check - #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}} - - downloader = VocabDownloader() - self.vocab_path = downloader.vocab_path - downloader.check_or_download() - - vocab = Vocab.load(self.vocab_path) - # Make the pipeline - self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config) - self.nlp.add_tagger(tagger=tag_skip_and_punct, - name='skip_and_punct', - additional_fields=['is_punct']) - spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab, config=self.config, data_vocab=vocab) - self.nlp.add_token_normalizer(spell_checker=spell_checker, config=self.config) - ner = NER(cdb, self.config) - self.nlp.add_ner(ner) - - # Add Linker - link = Linker(cdb, vocab, self.config) - self.nlp.add_linker(link) - - self.text = "CDB - I was running and then Movar Virus attacked and CDb" - - def tearDown(self) -> None: - self.nlp.destroy() - - def test_limits_for_tokens_and_uppercase(self): - self.config.ner['max_skip_tokens'] = 1 - self.config.ner['upper_case_limit_len'] = 4 - self.config.linking['disamb_length_limit'] = 2 - - d = self.nlp(self.text) - - assert len(d._.ents) == 2 - assert d._.ents[0]._.link_candidates[0] == 'S-229004' - - def test_change_limit_for_skip(self): - self.config.ner['max_skip_tokens'] = 3 - d = self.nlp(self.text) - assert len(d._.ents) == 3 - - def test_change_limit_for_upper_case(self): - self.config.ner['upper_case_limit_len'] = 3 - d = self.nlp(self.text) - assert len(d._.ents) == 4 - - def test_check_name_length_limit(self): - self.config.ner['min_name_len'] = 4 - d = self.nlp(self.text) - assert len(d._.ents) == 2 - - def test_speed(self): - text = "CDB - I was running and then Movar Virus attacked and CDb" - text = text * 300 - self.config.general['spell_check'] = True - start = timer() - for i in range(50): - d = self.nlp(text) - end = timer() - print("Time: ", end - start) - - def test_without_spell_check(self): - # Now without spell check - self.config.general['spell_check'] = False - start = timer() - for i in range(50): - d = self.nlp(self.text) - end = timer() - print("Time: ", end - start) - - - def test_for_linker(self): - self.config = Config() - self.config.general['log_level'] = logging.DEBUG - cdb = CDB(config=self.config) - - # Add a couple of names - cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config)) - cdb.add_names(cui='S-2290045', names=prepare_name('Movar', self.nlp, {}, self.config)) - # Check - #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}} - - cuis = list(cdb.cui2names.keys()) - for cui in cuis[0:50]: - vectors = {'short': np.random.rand(300), - 'long': np.random.rand(300), - 'medium': np.random.rand(300) - } - cdb.update_context_vector(cui, vectors, negative=False) - - d = self.nlp(self.text) - vocab = Vocab.load(self.vocab_path) - cm = ContextModel(cdb, vocab, self.config) - cm.train_using_negative_sampling('S-229004') - self.config.linking['train_count_threshold'] = 0 - - cm.train('S-229004', d._.ents[1], d) - - cm.similarity('S-229004', d._.ents[1], d) - - cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d) diff --git a/tests/medmentions/make_cdb.py b/tests/medmentions/make_cdb.py deleted file mode 100644 index feb8629d2..000000000 --- a/tests/medmentions/make_cdb.py +++ /dev/null @@ -1,120 +0,0 @@ -from medcat.cdb_maker import CDBMaker -from medcat.config import Config, weighted_average -from functools import partial -import numpy as np -import logging - -from ..helper import VocabDownloader - - -config = Config() -config.general['log_level'] = logging.INFO -config.general['spacy_model'] = 'en_core_sci_lg' -maker = CDBMaker(config) - -# Building a new CDB from two files (full_build) -csvs = ['./tmp_medmentions.csv'] -cdb = maker.prepare_csvs(csvs, full_build=True) - -cdb.save("./tmp_cdb.dat") - - -from medcat.vocab import Vocab -from medcat.cdb import CDB -from medcat.cat import CAT - -downloader = VocabDownloader() -vocab_path = downloader.vocab_path -downloader.check_or_download() - -config = Config() -cdb = CDB.load("./tmp_cdb.dat", config=config) -vocab = Vocab.load(vocab_path) - -cdb.reset_training() - -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) -cat.config.ner['min_name_len'] = 3 -cat.config.ner['upper_case_limit_len'] = 3 -cat.config.linking['disamb_length_limit'] = 3 -cat.config.linking['filters'] = {'cuis': set()} -cat.config.linking['train_count_threshold'] = -1 -cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3} -cat.config.linking['context_vector_weights'] = {'xlong': 0, 'long': 0.4, 'medium': 0.4, 'short': 0.2} -cat.config.linking['weighted_average_function'] = partial(weighted_average, factor=0.0004) -cat.config.linking['similarity_threshold_type'] = 'dynamic' -cat.config.linking['similarity_threshold'] = 0.35 -cat.config.linking['calculate_dynamic_threshold'] = True - -cat.train(df.text.values, fine_tune=True) - - -cdb.config.general['spacy_disabled_components'] = ['ner', 'parser', 'vectors', 'textcat', - 'entity_linker', 'sentencizer', 'entity_ruler', 'merge_noun_chunks', - 'merge_entities', 'merge_subtokens'] - -%load_ext autoreload -%autoreload 2 - -# Train -_ = cat.train(open("./tmp_medmentions_text_only.txt", 'r'), fine_tune=False) - -_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=True, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0.1) -cdb.save("/home/ubuntu/data/umls/2020ab/cdb_trained_medmen.dat") - - -_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=False, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0) - -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) -cat.config.linking['similarity_threshold'] = 0.1 -cat.config.ner['min_name_len'] = 2 -cat.config.ner['upper_case_limit_len'] = 1 -cat.config.linking['train_count_threshold'] = -2 -cat.config.linking['filters']['cuis'] = set() -cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3} -cat.config.linking['context_vector_weights'] = {'xlong': 0.1, 'long': 0.4, 'medium': 0.4, 'short': 0.1} -cat.config.linking['similarity_threshold_type'] = 'static' - -cat.config.linking['similarity_threshold_type'] = 'dynamic' -cat.config.linking['similarity_threshold'] = 0.35 -cat.config.linking['calculate_dynamic_threshold'] = True - - -# Print some stats -_ = cat._print_stats(data) - -#Epoch: 0, Prec: 0.4331506351144245, Rec: 0.5207520064957372, F1: 0.47292889758643175 -#p: 0.421 r: 0.507 f1: 0.460 - - -# Remove all names that are numbers -for name in list(cdb.name2cuis.keys()): - if name.replace(".", '').replace("~", '').replace(",", '').replace(":", '').replace("-", '').isnumeric(): - del cdb.name2cuis[name] - print(name) - - -for name in list(cdb.name2cuis.keys()): - if len(name) < 7 and (not name.isalpha()) and len(re.sub("[^A-Za-z]*", '', name)) < 2: - del cdb.name2cuis[name] - print(name) - - - - -# RUN SUPER -cdb = CDB.load("./tmp_cdb.dat") -vocab = Vocab.load(vocab_path) -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) - - -# Train supervised -cdb.reset_cui_count() -cat.config.ner['uppe_case_limit_len'] = 1 -cat.config.ner['min_name_len'] = 1 -data_path = "./tmp_medmentions.json" -_ = cat.train_supervised(data_path, use_cui_doc_limit=True, nepochs=30, devalue_others=True, test_size=0.2) - - -cdb = maker.prepare_csvs(csv_paths=csvs) -cdb.save("/home/ubuntu/data/umls/2020ab/cdb_vbg.dat") diff --git a/tests/medmentions/prepare_data.py b/tests/medmentions/prepare_data.py deleted file mode 100644 index 6e1bfdf2e..000000000 --- a/tests/medmentions/prepare_data.py +++ /dev/null @@ -1,7 +0,0 @@ -from medcat.utils.medmentions import original2concept_csv -from medcat.utils.medmentions import original2json -from medcat.utils.medmentions import original2pure_text - -_ = original2json("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.json') -_ = original2concept_csv("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.csv') -original2pure_text("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions_text_only.txt') diff --git a/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json new file mode 100644 index 000000000..79f1a0ac4 --- /dev/null +++ b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json @@ -0,0 +1,84 @@ +{ + "projects": [ + { + "cuis": "", + "tuis": "", + "name": "TEST-PROJ", + "id": "PROJ_FAKE", + "documents": [ + { + "name": "fake_doc_0", + "id": 100, + "last_modified": "-1", + "text": "This virus is called virus M and was read from the second CSV we could find.", + "annotations": [ + { + "cui": "C0000039", + "start": 5, + "end": 10, + "value": "virus" + }, + { + "cui": "C0000139", + "start": 21, + "end": 28, + "value": "virus M" + }, + { + "cui": "C0000239", + "start": 51, + "end": 62, + "value": "second CSV" + } + ] + }, + { + "name": "fake_doc_1", + "id": 101, + "last_modified": "-1", + "text": "We found a virus. Turned out it was virus M. This was the second CSV we looked at.", + "annotations": [ + { + "cui": "C0000039", + "start": 11, + "end": 16, + "value": "virus" + }, + { + "cui": "C0000139", + "start": 36, + "end": 43, + "value": "virus M" + }, + { + "cui": "C0000239", + "start": 58, + "end": 69, + "value": "second CSV" + } + ] + }, + { + "name": "fake_doc_2", + "id": 102, + "last_modified": "-1", + "text": "We opened second CSV and found virus M to be the culprit.", + "annotations": [ + { + "cui": "C0000239", + "start": 10, + "end": 21, + "value": "second CSV" + }, + { + "cui": "C0000139", + "start": 31, + "end": 38, + "value": "virus M" + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/stats/__init__.py b/tests/stats/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/stats/helpers.py b/tests/stats/helpers.py new file mode 100644 index 000000000..80771b11c --- /dev/null +++ b/tests/stats/helpers.py @@ -0,0 +1,17 @@ +from pydantic import create_model_from_typeddict + +from medcat.stats.mctexport import MedCATTrainerExport + + +MCTExportPydanticModel = create_model_from_typeddict(MedCATTrainerExport) + + +def nullify_doc_names_proj_ids(export: MedCATTrainerExport) -> MedCATTrainerExport: + return {'projects': [ + { + 'name': project['name'], + 'documents': sorted([ + {k: v if k != 'name' else '' for k, v in doc.items()} for doc in project['documents'] + ], key=lambda doc: doc['id']) + } for project in export['projects'] + ]} diff --git a/tests/stats/test_kfold.py b/tests/stats/test_kfold.py new file mode 100644 index 000000000..87dcdd454 --- /dev/null +++ b/tests/stats/test_kfold.py @@ -0,0 +1,298 @@ +import os +import json +from typing import Dict, Union, Optional +from copy import deepcopy + +from medcat.stats import kfold +from medcat.cat import CAT +from pydantic.error_wrappers import ValidationError as PydanticValidationError + +import unittest + +from .helpers import MCTExportPydanticModel, nullify_doc_names_proj_ids + + +class MCTExportTests(unittest.TestCase): + EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export.json") + + @classmethod + def setUpClass(cls) -> None: + with open(cls.EXPORT_PATH) as f: + cls.mct_export = json.load(f) + + def assertIsMCTExport(self, obj): + try: + model = MCTExportPydanticModel(**obj) + except PydanticValidationError as e: + raise AssertionError("Not n MCT export") from e + self.assertIsInstance(model, MCTExportPydanticModel) + + +class KFoldCreatorTests(MCTExportTests): + K = 3 + SPLIT_TYPE = kfold.SplitType.DOCUMENTS + + + def setUp(self) -> None: + self.creator = kfold.get_fold_creator(self.mct_export, self.K, split_type=self.SPLIT_TYPE) + self.folds = self.creator.create_folds() + + def test_folding_does_not_modify_initial_export(self): + with open(self.EXPORT_PATH) as f: + export_copy = json.load(f) + self.assertEqual(export_copy, self.mct_export) + + def test_mct_export_has_correct_format(self): + self.assertIsMCTExport(self.mct_export) + + def test_folds_have_docs(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(kfold.count_all_docs(fold), 0) + + def test_folds_have_anns(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(kfold.count_all_annotations(fold), 0) + + def test_folds_are_mct_exports(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertIsMCTExport(fold) + + def test_gets_correct_number_of_folds(self): + self.assertEqual(len(self.folds), self.K) + + def test_folds_keep_all_docs(self): + total_docs = 0 + for fold in self.folds: + docs = kfold.count_all_docs(fold) + total_docs += docs + count_all_once = kfold.count_all_docs(self.mct_export) + if self.SPLIT_TYPE is kfold.SplitType.ANNOTATIONS: + # NOTE: This may be greater if split in the middle of a document + # because that document may then exist in both folds + self.assertGreaterEqual(total_docs, count_all_once) + else: + self.assertEqual(total_docs, count_all_once) + + def test_folds_keep_all_anns(self): + total_anns = 0 + for fold in self.folds: + anns = kfold.count_all_annotations(fold) + total_anns += anns + count_all_once = kfold.count_all_annotations(self.mct_export) + self.assertEqual(total_anns, count_all_once) + + def test_1fold_same_as_orig(self): + folds = kfold.get_fold_creator(self.mct_export, 1, split_type=self.SPLIT_TYPE).create_folds() + self.assertEqual(len(folds), 1) + fold, = folds + self.assertIsInstance(fold, dict) + self.assertIsMCTExport(fold) + self.assertEqual( + nullify_doc_names_proj_ids(self.mct_export), + nullify_doc_names_proj_ids(fold), + ) + + def test_has_reasonable_annotations_per_folds(self): + anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds] + print(f"ANNS per folds:\n{anns_per_folds}") + docs_per_folds = [kfold.count_all_docs(fold) for fold in self.folds] + print(f"DOCS per folds:\n{docs_per_folds}") + + +# this is a taylor-made export that +# just contains a few "documents" +# with the fake CUIs "annotated" +NEW_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export_FAKE_CONCEPTS.json") + + +class KFoldCreatorPerAnnsTests(KFoldCreatorTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldCreatorPerWeightedDocsTests(KFoldCreatorTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + # should have a total of 435, so 145 per in ideal world + # but we'll allow the following deviation + PERMITTED_MAX_DEVIATION_IN_ANNS = 5 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.total_anns = kfold.count_all_annotations(cls.mct_export) + cls.expected_anns_per_fold = cls.total_anns // cls.K + cls.expected_lower_bound = cls.expected_anns_per_fold - cls.PERMITTED_MAX_DEVIATION_IN_ANNS + cls.expected_upper_bound = cls.expected_anns_per_fold + cls.PERMITTED_MAX_DEVIATION_IN_ANNS + + def test_has_reasonable_annotations_per_folds(self): + anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds] + for nr, anns in enumerate(anns_per_folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(anns, self.expected_lower_bound) + self.assertLess(anns, self.expected_upper_bound) + # NOTE: as of testing, this will split [146, 145, 144] + # whereas regular per-docs split will have [140, 163, 132] + + +class KFoldCreatorNewExportTests(KFoldCreatorTests): + EXPORT_PATH = NEW_EXPORT_PATH + + +class KFoldCreatorNewExportAnnsTests(KFoldCreatorNewExportTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldCreatorNewExportWeightedDocsTests(KFoldCreatorNewExportTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + + +class KFoldCATTests(MCTExportTests): + _names = ['fps', 'fns', 'tps', 'prec', 'rec', 'f1', 'counts', 'examples'] + EXPORT_PATH = NEW_EXPORT_PATH + CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "examples") + TOLERANCE_PLACES = 10 # tolerance of 10 digits + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.cat = CAT.load_model_pack(cls.CAT_PATH) + + def setUp(self) -> None: + super().setUp() + self.reg_stats = self.cat._print_stats(self.mct_export, do_print=False) + # TODO - remove + self.maxDiff = 4000 + + # NOTE: Due to floating point errors, sometimes we may get slightly different results + def assertDictsAlmostEqual(self, d1: Dict[str, Union[int, float]], d2: Dict[str, Union[int, float]], + tolerance_places: Optional[int] = None) -> None: + self.assertEqual(d1.keys(), d2.keys()) + tol = tolerance_places if tolerance_places is not None else self.TOLERANCE_PLACES + for k in d1: + v1, v2 = d1[k], d2[k] + self.assertAlmostEqual(v1, v2, places=tol) + + +class KFoldStatsConsistencyTests(KFoldCATTests): + + def test_mct_export_valid(self): + self.assertIsMCTExport(self.mct_export) + + def test_stats_consistent(self): + stats = self.cat._print_stats(self.mct_export, do_print=False) + for name, stats1, stats2 in zip(self._names, self.reg_stats, stats): + with self.subTest(name): + # NOTE: These should be EXACTLY equal since there shouldn't be + # any different additions and the like + self.assertEqual(stats1, stats2) + + +class KFoldMetricsTests(KFoldCATTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS + + def test_metrics_1_fold_same_as_normal(self): + stats = kfold.get_k_fold_stats(self.cat, self.mct_export, k=1, + split_type=self.SPLIT_TYPE) + for name, reg, folds1 in zip(self._names, self.reg_stats, stats): + with self.subTest(name): + if name != 'examples': + # NOTE: These may not be exactly equal due to floating point errors + self.assertDictsAlmostEqual(reg, folds1) + else: + self.assertEqual(reg, folds1) + + +class KFoldPerAnnsMetricsTests(KFoldMetricsTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldWeightedDocsMetricsTests(KFoldMetricsTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + + +class KFoldDuplicatedTests(KFoldCATTests): + COPIES = 3 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.docs_in_orig = kfold.count_all_docs(cls.mct_export) + cls.anns_in_orig = kfold.count_all_annotations(cls.mct_export) + cls.data_copied: kfold.MedCATTrainerExport = deepcopy(cls.mct_export) + for project in cls.data_copied['projects']: + documents_list = project['documents'] + copies = documents_list + [ + {k: v if k != 'name' else f"{v}_cp_{nr}" for k, v in doc.items()} for nr in range(cls.COPIES - 1) + for doc in documents_list + ] + project['documents'] = copies + cls.docs_in_copy = kfold.count_all_docs(cls.data_copied) + cls.anns_in_copy = kfold.count_all_annotations(cls.data_copied) + cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES) + cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES) + + # some stats with real model/data will be e.g 0.99 vs 0.9747 + # so in that case, lower it to 1 or so + _stats_consistency_tolerance = 8 + + def test_stats_consistent(self): + for name, one, two in zip(self._names, self.stats_copied, self.stats_copied_2): + with self.subTest(name): + if name == 'examples': + # examples are hard + # sometimes they differ by quite a lot + for etype in one: + ev1, ev2 = one[etype], two[etype] + with self.subTest(f"{name}-{etype}"): + self.assertEqual(ev1.keys(), ev2.keys()) + for cui in ev1: + per_cui_examples1 = ev1[cui] + per_cui_examples2 = ev2[cui] + with self.subTest(f"{name}-{etype}-{cui}-[{self.cat.cdb.cui2preferred_name.get(cui, cui)}]"): + self.assertEqual(len(per_cui_examples1), len(per_cui_examples2), "INCORRECT NUMBER OF ITEMS") + for ex1, ex2 in zip(per_cui_examples1, per_cui_examples2): + self.assertDictsAlmostEqual(ex1, ex2, tolerance_places=self._stats_consistency_tolerance) + continue + self.assertEqual(one, two) + + def test_copy_has_correct_number_documents(self): + self.assertEqual(self.COPIES * self.docs_in_orig, self.docs_in_copy) + + def test_copy_has_correct_number_annotations(self): + self.assertEqual(self.COPIES * self.anns_in_orig, self.anns_in_copy) + + def test_3_fold_identical_folds(self): + folds = kfold.get_fold_creator(self.data_copied, nr_of_folds=self.COPIES, + split_type=kfold.SplitType.DOCUMENTS).create_folds() + self.assertEqual(len(folds), self.COPIES) + for nr, fold in enumerate(folds): + with self.subTest(f"Fold-{nr}"): + # if they're all equal to original, they're eqaul to each other + self.assertEqual( + nullify_doc_names_proj_ids(fold), + nullify_doc_names_proj_ids(self.mct_export) + ) + + def test_metrics_3_fold(self): + stats_simple = self.reg_stats + for name, old, new in zip(self._names, stats_simple, self.stats_copied): + if name == 'examples': + continue + # with self.subTest(name): + if name in ("fps", "fns", "tps", "counts"): + # count should be triples + pass + if name in ("prec", "rec", "f1"): + # these should average to the same ?? + all_keys = old.keys() | new.keys() + for cui in all_keys: + cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui) + with self.subTest(f"{name}-{cui} [{cuiname}]"): + self.assertIn(cui, old.keys(), f"CUI '{cui}' ({cuiname}) not in old") + self.assertIn(cui, new.keys(), f"CUI '{cui}' ({cuiname}) not in new") + v1, v2 = old[cui], new[cui] + self.assertEqual(v1, v2, f"Values not equal for {cui} ({self.cat.cdb.cui2preferred_name.get(cui, cui)})") diff --git a/tests/stats/test_mctexport.py b/tests/stats/test_mctexport.py new file mode 100644 index 000000000..8ef11f556 --- /dev/null +++ b/tests/stats/test_mctexport.py @@ -0,0 +1,38 @@ +import os +import json + +from medcat.stats import mctexport + +import unittest + +from .helpers import MCTExportPydanticModel + + +class MCTExportIterationTests(unittest.TestCase): + EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export.json") + EXPECTED_DOCS = 27 + EXPECTED_ANNS = 435 + + @classmethod + def setUpClass(cls) -> None: + with open(cls.EXPORT_PATH) as f: + cls.mct_export: mctexport.MedCATTrainerExport = json.load(f) + + def test_conforms_to_template(self): + # NOTE: This uses pydantic to make sure that the MedCATTrainerExport + # type matches the actual export format + model_instance = MCTExportPydanticModel(**self.mct_export) + self.assertIsInstance(model_instance, MCTExportPydanticModel) + + def test_iterates_over_all_docs(self): + self.assertEqual(mctexport.count_all_docs(self.mct_export), self.EXPECTED_DOCS) + + def test_iterates_over_all_anns(self): + self.assertEqual(mctexport.count_all_annotations(self.mct_export), self.EXPECTED_ANNS) + + def test_gets_correct_nr_of_annotations_per_doc(self): + for project in self.mct_export['projects']: + for doc in project["documents"]: + with self.subTest(f"Proj-{project['name']} ({project['id']})-{doc['name']} ({doc['id']})"): + self.assertEqual(mctexport.get_nr_of_annotations(doc), len(doc["annotations"])) diff --git a/tests/utils/test_cdb_state.py b/tests/utils/test_cdb_state.py new file mode 100644 index 000000000..068af128b --- /dev/null +++ b/tests/utils/test_cdb_state.py @@ -0,0 +1,113 @@ +import unittest +import os +from unittest import mock +from typing import Callable, Any, Dict +import tempfile + +from medcat.utils.cdb_state import captured_state_cdb, CDBState, copy_cdb_state +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.cat import CAT + + +class StateTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) + cls.vocab.make_unigram_table() + cls.cdb.config.general.spacy_model = "en_core_web_md" + cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") + cls.undertest = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[]) + cls.initial_state = copy_cdb_state(cls.cdb) + + @classmethod + def _set_info(cls, k: str, v: Any, info_dict: Dict): + info_dict[k] = (len(v), len(str(v))) + + @classmethod + def do_smth_for_each_state_var(cls, cdb: CDB, callback: Callable[[str, Any], None]) -> None: + for k in CDBState.__annotations__: + v = getattr(cdb, k) + callback(k, v) + + +class StateSavedTests(StateTests): + on_disk = False + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + # capture state + with captured_state_cdb(cls.cdb, save_state_to_disk=cls.on_disk): + # clear state + cls.do_smth_for_each_state_var(cls.cdb, lambda k, v: v.clear()) + cls.cleared_state = copy_cdb_state(cls.cdb) + # save after state - should be equal to before + cls.restored_state = copy_cdb_state(cls.cdb) + + def test_state_saved(self): + nr_of_targets = len(CDBState.__annotations__) + self.assertGreater(nr_of_targets, 0) + self.assertEqual(len(self.initial_state), nr_of_targets) + self.assertEqual(len(self.cleared_state), nr_of_targets) + self.assertEqual(len(self.restored_state), nr_of_targets) + + def test_clearing_worked(self): + self.assertNotEqual(self.initial_state, self.cleared_state) + for k, v in self.cleared_state.items(): + with self.subTest(k): + # length is 0 + self.assertFalse(v) + + def test_state_restored(self): + self.assertEqual(self.initial_state, self.restored_state) + + +class StateSavedOnDiskTests(StateSavedTests): + on_disk = True + _named_tempory_file = tempfile.NamedTemporaryFile + + @classmethod + def saved_name_temp_file(cls): + tf = cls._named_tempory_file() + cls.temp_file_name = tf.name + return tf + + @classmethod + def setUpClass(cls) -> None: + with mock.patch("builtins.open", side_effect=open) as cls.popen: + with mock.patch("tempfile.NamedTemporaryFile", side_effect=cls.saved_name_temp_file) as cls.pntf: + return super().setUpClass() + + def test_temp_file_called(self): + self.pntf.assert_called_once() + + def test_saved_on_disk(self): + self.popen.assert_called() + self.assertGreaterEqual(self.popen.call_count, 2) + self.popen.assert_has_calls([mock.call(self.temp_file_name, 'wb'), + mock.call(self.temp_file_name, 'rb')]) + + +class StateWithTrainingTests(StateTests): + SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "..", "resources", "medcat_trainer_export.json") + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + with captured_state_cdb(cls.cdb): + # do training + cls.undertest.train_supervised_from_json(cls.SUPERVISED_TRAINING_JSON) + cls.after_train_state = copy_cdb_state(cls.cdb) + cls.restored_state = copy_cdb_state(cls.cdb) + + +class StateRestoredAfterTrain(StateWithTrainingTests): + + def test_train_state_changed(self): + self.assertNotEqual(self.initial_state, self.after_train_state) + + def test_restored_state_same(self): + self.assertDictEqual(self.initial_state, self.restored_state)