diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 1a32548f..ecd17743 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -2,7 +2,7 @@ from ehrapy.tools.nlp._hpo import HPOMapper try: - from ehrapy.tools.nlp._medcat import MedCAT + from ehrapy.tools.nlp._medcat import * # noqa: E402,F403 except ImportError: pass from ehrapy.tools.nlp._translators import Translator diff --git a/ehrapy/tools/nlp/_medcat.py b/ehrapy/tools/nlp/_medcat.py index c6bdc0dc..1857bcf0 100644 --- a/ehrapy/tools/nlp/_medcat.py +++ b/ehrapy/tools/nlp/_medcat.py @@ -22,6 +22,9 @@ from ehrapy import settings from ehrapy.core.tool_available import check_module_importable + +# TODO: State in docs those models are only needed when not using a model pack (cdb and vocab separatly) +# so this check can be removed spacy_models_modules: list[str] = list(map(lambda model: model.replace("-", "_"), ["en-core-web-md"])) for model in spacy_models_modules: if not check_module_importable(model): @@ -30,6 +33,7 @@ ) +# TODO: Discuss this could be used as a custom class for results @dataclass class AnnotationResult: all_medcat_annotation_results: list | None @@ -49,320 +53,192 @@ def __init__(self, vocabulary: Vocab = None, concept_db: CDB = None, model_pack_ elif model_pack_path is not None: self.cat = CAT.load_model_pack(model_pack_path) - @staticmethod - def create_vocabulary(vocabulary_data: str, replace: bool = True) -> Vocab: - """Creates a MedCAT Vocab and sets it for the MedCAT object. + def update_cat(self, vocabulary: Vocab = None, concept_db: CDB = None): + """Updates the current MedCAT instance with new Vocabularies and Concept Databases. Args: - vocabulary_data: Path to the vocabulary data. - It is a tsv file and must look like: - - <token>\t<word_count>\t<vector_embedding_separated_by_spaces> - house 34444 0.3232 0.123213 1.231231 - replace: Whether to replace existing words in the vocabulary. - - Returns: - Instance of a MedCAT Vocab + vocabulary: Vocabulary to update to. + concept_db: Concept Database to update to. """ - vocabulary = Vocab() - vocabulary.add_words(vocabulary_data, replace=replace) - - return vocabulary + self.cat = CAT(cdb=concept_db, config=concept_db.config, vocab=vocabulary) - @staticmethod - def create_concept_db(csv_path: list[str], config: Config = None) -> CDB: - """Creates a MedCAT concept database and sets it for the MedCAT object. + def update_cat_config(self, concept_db_config: Config) -> None: + """Updates the MedCAT configuration. Args: - csv_path: List of paths to one or more csv files containing all concepts. - The concept csvs must look like: - - cui,name - 1,kidney failure - 7,coronavirus - config: Optional MedCAT concept database configuration. - If not provided a default configuration with config.general['spacy_model'] = 'en_core_sci_md' is created. - Returns: - Instance of a MedCAT CDB concept database + concept_db_config: Concept to update to. """ - if config is None: - config = Config() - config.general["spacy_model"] = "en_core_sci_md" - maker = CDBMaker(config) - concept_db = maker.prepare_csvs(csv_path, full_build=True) + self.concept_db.config = concept_db_config - return concept_db + def set_filter_by_tui(self, tuis: list[str] | None = None): + """Restrict results of annotation step to certain tui's (type unique identifiers). + Note that this will change the MedCat object by updating the concept database config. In every annotation + process that will be run afterwards, entities are shown, only if they fall into the tui's type. + A full list of tui's can be found at: https://lhncbc.nlm.nih.gov/ii/tools/MetaMap/Docs/SemanticTypes_2018AB.txt - def save_vocabulary(self, output_path: str) -> None: - """Saves a MedCAT vocabulary. + As an exmaple: + Setting tuis=["T047", "T048"] will only annotate concepts (identified by a CUI (concept unique identifier)) in UMLS that are either diseases or + syndroms (T047) or mental/behavioural dysfunctions (T048). This is the default value. Args: - output_path: Path to write the vocabulary to. - """ - self.vocabulary.save(output_path) - - def load_vocabulary(self, vocabulary_path) -> Vocab: - """Loads a MedCAT vocabulary. - - Args: - vocabulary_path: Path to load the vocabulary from. - """ - self.vocabulary = Vocab.load(vocabulary_path) - - return self.vocabulary - - def save_concept_db(self, output_path: str) -> None: - """Saves a MedCAT concept database. + tuis: list of TUI's + Returns: - Args: - output_path: Path to save the concept database to. """ - self.concept_db.save(output_path) - - def load_concept_db(self, concept_db_path) -> CDB: - """Loads the concept database. + if tuis is None: + tuis = ['T047', 'T048'] + # the filtered cui's that fall into the type of the filter tui's + cui_filters = set() + for type_id in tuis: + cui_filters.update(self.cat.cdb.addl_info['type_id2cuis'][type_id]) + self.cat.cdb.config.linking['filters']['cuis'] = cui_filters - Args: - concept_db_path: Path to load the concept database from. - """ - self.concept_db = CDB.load(concept_db_path) - return self.concept_db +def create_vocabulary(vocabulary_data: str, replace: bool = True) -> Vocab: + """Creates a MedCAT Vocab and sets it for the MedCAT object. - def load_model_pack(self, model_pack_path) -> None: - """Loads a MedCAt model pack. + Args: + vocabulary_data: Path to the vocabulary data. + It is a tsv file and must look like: - Updates the MedCAT object. + <token>\t<word_count>\t<vector_embedding_separated_by_spaces> + house 34444 0.3232 0.123213 1.231231 + replace: Whether to replace existing words in the vocabulary. - Args: - model_pack_path: Path to save the model from. - """ - self.cat.load_model_pack(model_pack_path) + Returns: + Instance of a MedCAT Vocab + """ + vocabulary = Vocab() + vocabulary.add_words(vocabulary_data, replace=replace) - def update_cat(self, vocabulary: Vocab = None, concept_db: CDB = None): - """Updates the current MedCAT instance with new Vocabularies and Concept Databases. + return vocabulary - Args: - vocabulary: Vocabulary to update to. - concept_db: Concept Database to update to. - """ - self.cat = CAT(cdb=concept_db, config=concept_db.config, vocab=vocabulary) - def update_cat_config(self, concept_db_config: Config) -> None: - """Updates the MedCAT configuration. +def create_concept_db(csv_path: list[str], config: Config = None) -> CDB: + """Creates a MedCAT concept database and sets it for the MedCAT object. - Args: - concept_db_config: Concept to update to. - """ - self.concept_db.config = concept_db_config + Args: + csv_path: List of paths to one or more csv files containing all concepts. + The concept csvs must look like: - def extract_entities_text(self, text: str) -> Doc: - """Extracts entities for a provided text. + cui,name + 1,kidney failure + 7,coronavirus + config: Optional MedCAT concept database configuration. + If not provided a default configuration with config.general['spacy_model'] = 'en_core_sci_md' is created. + Returns: + Instance of a MedCAT CDB concept database + """ + if config is None: + config = Config() + config.general["spacy_model"] = "en_core_sci_md" + maker = CDBMaker(config) + concept_db = maker.prepare_csvs(csv_path, full_build=True) - Args: - text: The text to extract entities from + return concept_db - Returns: - A spacy Doc instance. Extract the entities using. doc.ents - """ - return self.cat(text) - def print_cui(self, doc: Doc, table: bool = False) -> None: - """Prints the concept unique identifier for all entities. +def save_vocabulary(vocab: Vocab, output_path: str) -> None: + """Saves a vocabulary. - Args: - doc: A spacy tokens Doc. - table: Whether to print a Rich table. - """ - if table: - # TODO IMPLEMENT ME - pass - else: - for entity in doc.ents: - # TODO make me pretty by ensuring that the lengths before the - token are aligned - print(f"[bold blue]{entity} - {entity._.cui}") + Args: + output_path: Path to write the vocabulary to. + """ + vocab.save(output_path) - def print_semantic_type(self, doc: Doc, table: bool = False) -> None: - """Prints the semantic types for all entities. - Args: - doc: A spacy tokens Doc. - table: Whether to print a Rich table. - """ - # TODO implement Rich table - for ent in doc.ents: - print(ent, " - ", self.concept_db.cui2type_ids.get(ent._.cui)) +def load_vocabulary(vocabulary_path) -> Vocab: + """Loads a vocabulary. - def print_displacy(self, doc: Doc, style: Literal["deb", "ent"] = "ent") -> None: - """Prints a Doc with displacy + Args: + vocabulary_path: Path to load the vocabulary from. + """ - Args: - doc: A spacy tokens Doc. - style: The Displacy style to render - """ - displacy.render(doc, style=style, jupyter=True) + return Vocab.load(vocabulary_path) - def run_unsupervised_training(self, text: pd.Series, print_statistics: bool = False) -> None: - """Performs MedCAT unsupervised training on a provided text column. - Args: - text: Pandas Series of text to annotate. - print_statistics: Whether to print training statistics after training. - """ - print(f"[bold blue]Training using {len(text)} documents") - self.cat.train(text.values, progress_print=100) +def save_concept_db(cdb, output_path: str) -> None: + """Saves a concept database. - if print_statistics: - self.concept_db.print_stats() + Args: + output_path: Path to save the concept database to. + """ + cdb.save(output_path) - def filter_tui(self, concept_db: CDB, tui_filters: list[str]) -> CDB: - """Filters a concept database by semantic types (TUI). - Args: - concept_db: MedCAT concept database. - tui_filters: A list of semantic type filters. Example: T047 Disease or Syndrome -> "T047" +def load_concept_db(concept_db_path) -> CDB: + """Loads the concept database. - Returns: - A filtered MedCAT concept database. - """ - # TODO Figure out a way not to do this inplace and add an inplace parameter - cui_filters = set() - for tui in tui_filters: - cui_filters.update(concept_db.addl_info["type_id2cuis"][tui]) - concept_db.config.linking["filters"]["cuis"] = cui_filters - print(f"[bold blue]The size of the concept database is now: {len(cui_filters)}") - - return concept_db - - def annotate( - self, - data: np.ndarray | pd.Series, - batch_size_chars=500000, - min_text_length=5, - n_jobs: int = settings.n_jobs, - ) -> AnnotationResult: - """Annotates a set of texts. + Args: + concept_db_path: Path to load the concept database from. + """ - Args: - data: Text data to annotate. - batch_size_chars: Batch size in number of characters - min_text_length: Minimum text length - n_jobs: Number of parallel processes + return CDB.load(concept_db_path) - Returns: - A dictionary: {id: doc_json, id2: doc_json2, ...}, in case out_split_size is used - the last batch will be returned while that and all previous batches will be written to disk (out_save_dir). - """ - if isinstance(data, np.ndarray): - data = pd.Series(data) - data = data[data.apply(lambda word: len(str(word)) > min_text_length)] +def save_model_pack(ep_cat: MedCAT, model_pack_dir: str = ".", name: str = "ehrapy_medcat_model_pack") -> None: + """Saves a MedCAT model pack. - cui_location: dict = {} # CUI to a list of documents where it appears - type_ids_location = {} # TUI to a list of documents where it appears + Args: + ep_cat: ehrapy's custom MedCAT object whose model should be saved + model_pack_dir: Path to save the model to (defaults to current working directory). + name: Name of the new model pack + """ + # TODO Pathing is weird here (home/myname/...) will fo example create dir myname inside home inside the cwd instead of using the path + _ = ep_cat.cat.create_model_pack(model_pack_dir + name) - batch: list = [] - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - ) as progress: - progress.add_task("[red]Annotating...") - for text_id, text in data.iteritems(): # type: ignore - batch.append((text_id, text)) - results = self.cat.multiprocessing(batch, batch_size_chars=batch_size_chars, nproc=n_jobs) +def run_unsupervised_training(ep_cat: MedCAT, text: pd.Series, progress_print: int = 100, print_statistics: bool = False) -> None: + """Performs MedCAT unsupervised training on a provided text column. - for doc in list(results.keys()): - for annotation in list(results[doc]["entities"].values()): - if annotation["cui"] in cui_location: - cui_location[annotation["cui"]].append(doc) - else: - cui_location[annotation["cui"]] = [doc] + Args: + ep_cat: ehrapy's custom MedCAT object, that keeps track of the vocab, concept database and the (annotated) results + text: Pandas Series of (free) text to annotate. + progress_print: print progress after that many training documents + print_statistics: Whether to print training statistics after training. + """ + print(f"[bold blue]Running unsupervised training using {len(text)} documents.") + ep_cat.cat.train(text.values, progress_print=progress_print) - for cui in cui_location.keys(): - type_ids_location[list(self.cat.cdb.cui2type_ids[cui])[0]] = cui_location[cui] + if print_statistics: + ep_cat.cat.cdb.print_stats() - patient_to_entities: dict[str, list[str]] = {} - for patient_id, findings in results.items(): - entities = [] - for _, result in findings["entities"].items(): - entities.append(result["pretty_name"]) - patient_to_entities[patient_id] = entities +def annotate_text(ep_cat: MedCAT, obs: pd.DataFrame, text_column: str, n_proc: int = 2, batch_size_chars: int = 500000) -> None: + """Annotate the original free text data. Note this will only annotate non null rows. The result + will be a (large) dict containing all the entities extracted from the free text (and in case filtered before via set_filter_by_tui function). + This dict will be the base for all further analyses, for example coloring umaps by specific diseases. - return AnnotationResult(results, patient_to_entities, cui_location, type_ids_location) + Args: + ep_cat: Ehrapy's custom MedCAT object + obs: AnnData obs containing the free text column + text_column: Name of the column that should be annotated + n_proc: Number of processors to use + batch_size_chars: batch size to control for the variablity between document sizes - def calculate_disease_proportions( - self, adata: AnnData, cui_locations: dict, subject_id_col="subject_id" - ) -> pd.DataFrame: - """Calculates the relative proportion of found diseases as percentages. + """ + non_null_text = _filter_null_values(obs, text_column) + formatted_text_column = _format_df_column(non_null_text, text_column) + results = ep_cat.cat.multiprocessing(formatted_text_column, batch_size_chars=batch_size_chars, nproc=n_proc) + # TODO: Discuss: Should we return a brand new custom "result" object here (as this is basically the base for downstream analysis) or + # TODO: just add it to the existing ehrapy MedCAT object as the "result" attribute. + # for testing and debugging, going with simply returning it + return results - Args: - adata: AnnData object. obs of this object must contain the results. - cui_locations: A dictionary containing the found CUIs and their location. - subject_id_col: The column header in the data containing the patient/subject IDs. - Returns: - A Pandas Dataframe containing the disease percentages. +def _format_df_column(df: pd.DataFrame, column_name: str) -> list[tuple[int, str]]: + """Format the df to match: formatted_data = [(row_id, row_text), (row_id, row_text), ...] + as this is required by medcat's multiprocessing annotation step + """ + # TODO This can be very memory consuming -> possible to use generators here instead (medcat compatible?)? + formatted_data = [] + for id, row in df.iterrows(): + text = row[column_name] + formatted_data.append((id, text)) + return formatted_data - cui nsubjects tui name perc_subjects - """ - cui_subjects: dict[int, list[int]] = {} - cui_subjects_unique: dict[int, set[int]] = {} - for cui in cui_locations: - for location in cui_locations[cui]: - # TODO: int casting is required as AnnData requires indices to be str (maybe we can change this) so we dont need type casting here - subject_id = adata.obs.iat[int(location), list(adata.obs.columns).index(subject_id_col)] - if cui in cui_subjects: - cui_subjects[cui].append(subject_id) - cui_subjects_unique[cui].add(subject_id) - else: - cui_subjects[cui] = [subject_id] - cui_subjects_unique[cui] = {subject_id} - - cui_nsubjects = [("cui", "nsubjects")] - for cui in cui_subjects_unique.keys(): - cui_nsubjects.append((cui, len(cui_subjects_unique[cui]))) # type: ignore - df_cui_nsubjects = pd.DataFrame(cui_nsubjects[1:], columns=cui_nsubjects[0]) - - df_cui_nsubjects = df_cui_nsubjects.sort_values("nsubjects", ascending=False) - # Add type_ids for each CUI - df_cui_nsubjects["type_ids"] = ["unk"] * len(df_cui_nsubjects) - cols = list(df_cui_nsubjects.columns) - for i in range(len(df_cui_nsubjects)): - cui = df_cui_nsubjects.iat[i, cols.index("cui")] - type_ids = self.cat.cdb.cui2type_ids.get(cui, "unk") - df_cui_nsubjects.iat[i, cols.index("type_ids")] = type_ids - - # Add name for each CUI - df_cui_nsubjects["name"] = ["unk"] * len(df_cui_nsubjects) - cols = list(df_cui_nsubjects.columns) - for i in range(len(df_cui_nsubjects)): - cui = df_cui_nsubjects.iat[i, cols.index("cui")] - name = self.cat.cdb.cui2preferred_name.get(cui, "unk") - df_cui_nsubjects.iat[i, cols.index("name")] = name - - # Add the percentage column - total_subjects = len(adata.obs[subject_id_col].unique()) - df_cui_nsubjects["perc_subjects"] = (df_cui_nsubjects["nsubjects"] / total_subjects) * 100 - - df_cui_nsubjects.reset_index(drop=True, inplace=True) - - return df_cui_nsubjects - - def plot_top_diseases(self, df_cui_nsubjects: pd.DataFrame, top_diseases: int = 30) -> None: - """Plots the top n (default: 30) found diseases. - Args: - df_cui_nsubjects: Pandas DataFrame containing the determined annotations. - top_diseases: Number of top diseases to plot - """ - warnings.warn("This function will be moved and likely removed in a future version!", FutureWarning) - # TODO this should ideally be drawn with a Scanpy plot or something - # TODO Needs more options such as saving etc - sns.set(rc={"figure.figsize": (5, 12)}, style="whitegrid", palette="pastel") - f, ax = plt.subplots() - _data = df_cui_nsubjects.iloc[0:top_diseases] - sns.barplot(x="perc_subjects", y="name", data=_data, label="Disorder Name", color="b") - _ = ax.set(xlim=(0, 70), ylabel="Disease Name", xlabel="Percentage of patients with disease") - plt.show() +def _filter_null_values(df: pd.DataFrame, column: str) -> pd.DataFrame: + """Filter null values of a given column and return that column without the null values + """ + return pd.DataFrame(df[column][~df[column].isnull()])