diff --git a/.gitignore b/.gitignore index 99c4be7e5..dcd2743f0 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,9 @@ venv db.sqlite3 .ipynb_checkpoints +# vscode +.vscode + #tmp and similar files .nfs* *.log diff --git a/medcat/cat.py b/medcat/cat.py index c6eb9c840..8df7526b7 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -33,6 +33,7 @@ from medcat.linking.context_based_linker import Linker from medcat.preprocessing.cleaners import prepare_name from medcat.meta_cat import MetaCAT +from medcat.rel_cat import RelCAT from medcat.utils.meta_cat.data_utils import json_to_fake_spacy from medcat.config import Config from medcat.vocab import Vocab @@ -64,6 +65,8 @@ class CAT(object): meta_cats (list of medcat.meta_cat.MetaCAT, optional): A list of models that will be applied sequentially on each detected annotation. + rel_cats (list of medcat.rel_cat.RelCAT, optional) + List of models applied sequentially on all detected annotations. Attributes (limited): cdb (medcat.cdb.CDB): @@ -89,6 +92,7 @@ def __init__(self, vocab: Union[Vocab, None] = None, config: Optional[Config] = None, meta_cats: List[MetaCAT] = [], + rel_cats: List[RelCAT] = [], addl_ner: Union[TransformersNER, List[TransformersNER]] = []) -> None: self.cdb = cdb self.vocab = vocab @@ -100,6 +104,7 @@ def __init__(self, self.config = config self.cdb.config = config self._meta_cats = meta_cats + self._rel_cats = rel_cats self._addl_ner = addl_ner if isinstance(addl_ner, list) else [addl_ner] self._create_pipeline(self.config) @@ -133,6 +138,9 @@ def _create_pipeline(self, config: Config): for meta_cat in self._meta_cats: self.pipe.add_meta_cat(meta_cat, meta_cat.config.general.category_name) + for rel_cat in self._rel_cats: + self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.config.general["labels2idx"].keys()))) + # Set max document length self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length @@ -297,6 +305,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M name = comp[0] meta_path = os.path.join(save_dir_path, "meta_" + name) comp[1].save(meta_path) + if isinstance(comp[1], RelCAT): + name = comp[0] + rel_path = os.path.join(save_dir_path, "rel_" + name) + comp[1].save(rel_path) # Add a model card also, why not model_card_path = os.path.join(save_dir_path, "model_card.json") @@ -341,7 +353,8 @@ def load_model_pack(cls, meta_cat_config_dict: Optional[Dict] = None, ner_config_dict: Optional[Dict] = None, load_meta_models: bool = True, - load_addl_ner: bool = True) -> "CAT": + load_addl_ner: bool = True, + load_rel_models: bool = True) -> "CAT": """Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models (if present) @@ -360,6 +373,8 @@ def load_model_pack(cls, Whether to load MetaCAT models if present (Default value True). load_addl_ner (bool): Whether to load additional NER models if present (Default value True). + load_rel_models (bool): + Whether to load RelCAT models if present (Default value True). Returns: CAT: The resulting CAT object. @@ -367,6 +382,7 @@ def load_model_pack(cls, from medcat.cdb import CDB from medcat.vocab import Vocab from medcat.meta_cat import MetaCAT + from medcat.rel_cat import RelCAT model_pack_path = cls.attempt_unpack(zip_path) @@ -409,8 +425,15 @@ def load_model_pack(cls, meta_cats.append(MetaCAT.load(save_dir_path=meta_path, config_dict=meta_cat_config_dict)) - cat = cls(cdb=cdb, config=cdb.config, vocab=vocab, meta_cats=meta_cats, addl_ner=addl_ner) + # Find Rel models in model_pack + rel_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('rel_')] if load_rel_models else [] + rel_cats = [] + for rel_path in rel_paths: + rel_cats.append(RelCAT.load(load_path=rel_path)) + + cat = cls(cdb=cdb, config=cdb.config, vocab=vocab, meta_cats=meta_cats, addl_ner=addl_ner, rel_cats=rel_cats) logger.info(cat.get_model_card()) # Print the model card + return cat def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]: @@ -1092,8 +1115,8 @@ def get_entities_multi_texts(self, elif out[i].get('text', '') != text: out.insert(i, self._doc_to_out(None, only_cui, addl_info)) # type: ignore - cnf_annotation_output = self.config.annotation_output - if not cnf_annotation_output.include_text_in_output: + cnf_annotation_output = getattr(self.config, 'annotation_output', {}) + if not (cnf_annotation_output.get('include_text_in_output', False)): for o in out: if o is not None: o.pop('text', None) diff --git a/medcat/config_rel_cat.py b/medcat/config_rel_cat.py new file mode 100644 index 000000000..54fe142dd --- /dev/null +++ b/medcat/config_rel_cat.py @@ -0,0 +1,98 @@ +import logging +from typing import Dict, Any, List +from medcat.config import MixingConfig, BaseModel, Optional, Extra + + +class General(MixingConfig, BaseModel): + """The General part of the RelCAT config""" + device: str = "cpu" + relation_type_filter_pairs: List = [] + """Map from category values to ID, if empty it will be autocalculated during training""" + vocab_size: Optional[int] = None + lowercase: bool = True + """If true all input text will be lowercased""" + cntx_left: int = 15 + """Number of tokens to take from the left of the concept""" + cntx_right: int = 15 + """Number of tokens to take from the right of the concept""" + window_size: int = 300 + """Max acceptable dinstance between entities (in characters), care when using this as it can produce sentences that are over 512 tokens (limit is given by tokenizer)""" + + mct_export_max_non_rel_sample_size:int = 200 + """Limit the number of 'Other' samples selected for training/test. This is applied per encountered medcat project, sample_size/num_projects. """ + mct_export_create_addl_rels: bool = False + """When processing relations from a MedCAT export, relations labeled as 'Other' are created from all the annotations pairs available""" + + tokenizer_name: str = "bert" + model_name: str = "bert-base-uncased" + log_level: int = logging.INFO + max_seq_length: int = 512 + tokenizer_special_tokens: bool = False + annotation_schema_tag_ids: List = [] + """If a foreign non-MCAT trainer dataset is used, you can insert your own Rel entity token delimiters into the tokenizer, \ + copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce""" + labels2idx: Dict = {} + idx2labels: Dict = {} + pin_memory: bool = True + seed: int = 13 + task: str = "train" + + +class Model(MixingConfig, BaseModel): + """The model part of the RelCAT config""" + input_size: int = 300 + hidden_size: int = 768 + hidden_layers: int = 3 + """ hidden_size * 5, 5 being the number of tokens, default (s1,s2,e1,e2+CLS)""" + model_size: int = 5120 + dropout: float = 0.2 + num_directions: int = 2 + """2 - bidirectional model, 1 - unidirectional""" + + padding_idx: int = -1 + emb_grad: bool = True + """If True the embeddings will also be trained""" + ignore_cpos: bool = False + """If set to True center positions will be ignored when calculating represenation""" + + class Config: + extra = Extra.allow + validate_assignment = True + + +class Train(MixingConfig, BaseModel): + """The train part of the RelCAT config""" + nclasses: int = 2 + """Number of classes that this model will output""" + batch_size: int = 25 + nepochs: int = 1 + lr: float = 1e-4 + adam_epsilon: float = 1e-4 + test_size: float = 0.2 + gradient_acc_steps: int = 1 + multistep_milestones: List[int] = [ + 2, 4, 6, 8, 12, 15, 18, 20, 22, 24, 26, 30] + multistep_lr_gamma: float = 0.8 + max_grad_norm: float = 1.0 + shuffle_data: bool = True + """Used only during training, if set the dataset will be shuffled before train/test split""" + class_weights: Optional[Any] = None + score_average: str = "weighted" + """What to use for averaging F1/P/R across labels""" + auto_save_model: bool = True + """Should the model be saved during training for best results""" + + class Config: + extra = Extra.allow + validate_assignment = True + + +class ConfigRelCAT(MixingConfig, BaseModel): + """The RelCAT part of the config""" + general: General = General() + model: Model = Model() + train: Train = Train() + + class Config: + extra = Extra.allow + validate_assignment = True diff --git a/medcat/pipe.py b/medcat/pipe.py index c047b57c1..65f552c30 100644 --- a/medcat/pipe.py +++ b/medcat/pipe.py @@ -13,6 +13,7 @@ from medcat.linking.context_based_linker import Linker from medcat.meta_cat import MetaCAT from medcat.ner.vocab_based_ner import NER +from medcat.rel_cat import RelCAT from medcat.utils.normalizers import TokenNormalizer, BasicSpellChecker from medcat.config import Config from medcat.pipeline.pipe_runner import PipeRunner @@ -161,6 +162,13 @@ def add_meta_cat(self, meta_cat: MetaCAT, name: Optional[str] = None) -> None: # Used for sharing pre-processed data/tokens Doc.set_extension('share_tokens', default=None, force=True) + def add_rel_cat(self, rel_cat: RelCAT, name: Optional[str] = None) -> None: + component_name = spacy.util.get_object_name(rel_cat) + name = name if name is not None else component_name + Language.component(name=component_name, func=rel_cat) + self._nlp.add_pipe(component_name, name=name, last=True) + # dictionary containing relations of the form {} + Doc.set_extension("relations", default=[], force=True) def add_addl_ner(self, addl_ner: TransformersNER, name: Optional[str] = None) -> None: component_name = spacy.util.get_object_name(addl_ner) @@ -169,6 +177,7 @@ def add_addl_ner(self, addl_ner: TransformersNER, name: Optional[str] = None) -> self._nlp.add_pipe(component_name, name=name, last=True) Doc.set_extension('ents', default=[], force=True) + Doc.set_extension('relations', default=[], force=True) Span.set_extension('confidence', default=-1, force=True) Span.set_extension('id', default=0, force=True) Span.set_extension('cui', default=-1, force=True) diff --git a/medcat/rel_cat.py b/medcat/rel_cat.py new file mode 100644 index 000000000..374997927 --- /dev/null +++ b/medcat/rel_cat.py @@ -0,0 +1,666 @@ +import json +import logging +import os +import torch.optim +import torch +import torch.nn as nn + +from tqdm import tqdm +from datetime import date, datetime +from transformers import BertConfig +from medcat.cdb import CDB +from medcat.config import Config +from medcat.config_rel_cat import ConfigRelCAT +from medcat.pipeline.pipe_runner import PipeRunner +from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT +from spacy.tokens import Doc +from typing import Dict, Iterable, Iterator, cast +from transformers import AutoTokenizer +from torch.utils.data import DataLoader +from torch.optim import Adam +from torch.optim.lr_scheduler import MultiStepLR +from medcat.utils.meta_cat.ml_utils import set_all_seeds +from medcat.utils.relation_extraction.models import BertModel_RelationExtraction +from medcat.utils.relation_extraction.pad_seq import Pad_Sequence +from medcat.utils.relation_extraction.utils import create_tokenizer_pretrain, load_results, load_state, save_results, save_state, split_list_train_test_by_class +from medcat.utils.relation_extraction.rel_dataset import RelData + + +class RelCAT(PipeRunner): + """The RelCAT class used for training 'Relation-Annotation' models, i.e., annotation of relations + between clinical concepts. + + Args: + cdb (CDB): cdb, this is used when creating relation datasets. + + tokenizer (TokenizerWrapperBERT): + The Huggingface tokenizer instance. This can be a pre-trained tokenzier instance from + a BERT-style model. For now, only BERT models are supported. + + config (ConfigRelCAT): + the configuration for RelCAT. Param descriptions available in ConfigRelCAT docs. + + task (str, optional): What task is this model supposed to handle. Defaults to "train" + init_model (bool, optional): loads default model. Defaults to False. + + """ + + + name = "rel_cat" + + log = logging.getLogger(__name__) + + def __init__(self, cdb: CDB, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False): + self.config = config + self.tokenizer: TokenizerWrapperBERT = tokenizer + self.cdb = cdb + + logging.basicConfig(level=self.config.general.log_level) + self.log.setLevel(self.config.general.log_level) + + self.is_cuda_available = torch.cuda.is_available() + self.device = torch.device( + "cuda" if self.is_cuda_available and self.config.general.device != "cpu" else "cpu") + + self.model_config = BertConfig() + self.model: BertModel_RelationExtraction + self.task: str = task + self.checkpoint_path: str = "./" + self.optimizer: Adam = None # type: ignore + self.scheduler: MultiStepLR = None # type: ignore + self.best_f1: float = 0.0 + self.epoch: int = 0 + + self.pad_id = self.tokenizer.hf_tokenizers.pad_token_id + self.padding_seq = Pad_Sequence(seq_pad_value=self.pad_id, + label_pad_value=self.pad_id) + + set_all_seeds(config.general.seed) + + if init_model: + self._get_model() + + def save(self, save_path: str) -> None: + """ Saves model and its dependencies to specified save_path folder. + The CDB is obviously not saved, it is however necessary to save the tokenizer used. + + Args: + save_path (str): folder path in which to save the model & deps. + """ + + assert self.config is not None + self.config.save(os.path.join(save_path, "config.json")) + + assert self.model_config is not None + self.model_config.vocab_size = self.tokenizer.hf_tokenizers.vocab_size + self.model_config.to_json_file( + os.path.join(save_path, "model_config.json")) + assert self.tokenizer is not None + self.tokenizer.save(os.path.join(save_path)) + + assert self.model is not None + self.model.bert_model.resize_token_embeddings( + self.tokenizer.hf_tokenizers.vocab_size) + save_state(self.model, optimizer=self.optimizer, scheduler=self.scheduler, epoch=self.epoch, best_f1=self.best_f1, + path=save_path, model_name=self.config.general.model_name, + task=self.task, is_checkpoint=False, final_export=True) + + def _get_model(self): + """ Used only for model initialisation. + """ + self.model = BertModel_RelationExtraction(pretrained_model_name_or_path="bert-base-uncased", + relcat_config=self.config, + model_config=self.model_config) + + @classmethod + def load(cls, load_path: str = "./") -> "RelCAT": + + cdb = CDB(config=Config()) + if os.path.exists(os.path.join(load_path, "cdb.dat")): + cdb = CDB.load(os.path.join(load_path, "cdb.dat")) + else: + cls.log.info("The default CDB file name 'cdb.dat' doesn't exist in the specified path, you will need to load & set \ + a CDB manually via rel_cat.cdb = CDB.load('path') ") + + config_path = os.path.join(load_path, "config.json") + config = ConfigRelCAT() + if os.path.exists(config_path): + config = cast(ConfigRelCAT, ConfigRelCAT.load( + os.path.join(load_path, "config.json"))) + cls.log.info("Loaded config.json") + + tokenizer = None + tokenizer_path = os.path.join(load_path, config.general.tokenizer_name) + + if "bert" in config.general.tokenizer_name: + tokenizer_path = load_path + + if os.path.exists(tokenizer_path): + tokenizer = TokenizerWrapperBERT.load(tokenizer_path) + + cls.log.info("Tokenizer loaded from:" + tokenizer_path) + elif config.general.model_name: + cls.log.info("Attempted to load Tokenizer from path:" + tokenizer_path + + ", but it doesn't exist, loading default toknizer from model_name config.general.model_name:" + config.general.model_name) + tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(pretrained_model_name_or_path=config.general.model_name), + max_seq_length=config.general.max_seq_length, + add_special_tokens=config.general.tokenizer_special_tokens + ) + create_tokenizer_pretrain(tokenizer, tokenizer_path) + else: + cls.log.info("Attempted to load Tokenizer from path:" + tokenizer_path + + ", but it doesn't exist, loading default toknizer from model_name config.general.model_name:bert-base-uncased") + tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased"), + max_seq_length=config.general.max_seq_length, + add_special_tokens=config.general.tokenizer_special_tokens + ) + + model_config = BertConfig() + model_config_path = os.path.join(load_path, "model_config.json") + + if os.path.exists(model_config_path): + cls.log.info("Loaded config from : " + model_config_path) + model_config = BertConfig.from_json_file(model_config_path) # type: ignore + else: + try: + model_config = BertConfig.from_pretrained( + pretrained_model_name_or_path=config.general.model_name, num_hidden_layers=config.model.hidden_layers) # type: ignore + except Exception as e: + cls.log.error("%s", str(e)) + cls.log.info("Config for HF model not found: " + + config.general.model_name + ". Using bert-base-uncased.") + model_config = BertConfig.from_pretrained( + pretrained_model_name_or_path="bert-base-uncased") # type: ignore + + model_config.vocab_size = tokenizer.hf_tokenizers.vocab_size + + rel_cat = cls(cdb=cdb, config=config, + tokenizer=tokenizer, + task=config.general.task) + + rel_cat.model_config = model_config + + device = torch.device("cuda" if torch.cuda.is_available( + ) and config.general.device != "cpu" else "cpu") + + try: + model_path = os.path.join(load_path, "model.dat") + + if os.path.exists(os.path.join(load_path, config.general.model_name)): + rel_cat.model = BertModel_RelationExtraction(pretrained_model_name_or_path=config.general.model_name, + relcat_config=config, + model_config=model_config) + else: + rel_cat.model = BertModel_RelationExtraction( + pretrained_model_name_or_path="", + relcat_config=config, + model_config=model_config) + rel_cat.model.load_state_dict( + torch.load(model_path, map_location=device)) + + cls.log.info("Loaded HF model : " + config.general.model_name) + except Exception as e: + cls.log.error("%s", str(e)) + cls.log.error("Failed to load specified HF model, defaulting to 'bert-base-uncased', loading...") + rel_cat.model = BertModel_RelationExtraction( + pretrained_model_name_or_path="bert-base-uncased", + relcat_config=config, + model_config=model_config) + + rel_cat.model.bert_model.resize_token_embeddings((len(tokenizer.hf_tokenizers))) + + rel_cat.optimizer = None # type: ignore + rel_cat.scheduler = None # type: ignore + + rel_cat.epoch, rel_cat.best_f1 = load_state(rel_cat.model, rel_cat.optimizer, rel_cat.scheduler, path=load_path, + model_name=config.general.model_name, + file_prefix=config.general.task, + device=device, + config=config) + + return rel_cat + + def _create_test_train_datasets(self, data: Dict, split_sets:bool = False): + train_data: Dict = {} + test_data: Dict = {} + + if split_sets: + train_data["output_relations"], test_data["output_relations"] = split_list_train_test_by_class(data["output_relations"], + test_size=self.config.train.test_size) + + test_data_label_names = [rec[4] for rec in test_data["output_relations"]] + + test_data["nclasses"], test_data["labels2idx"], test_data["idx2label"] = RelData.get_labels( + test_data_label_names, self.config) + + for idx in range(len(test_data["output_relations"])): + test_data["output_relations"][idx][5] = test_data["labels2idx"][test_data["output_relations"][idx][4]] + else: + train_data["output_relations"] = data["output_relations"] + + for k, v in data.items(): + if k != "output_relations": + train_data[k] = [] + test_data[k] = [] + + train_data_label_names = [rec[4] + for rec in train_data["output_relations"]] + + train_data["nclasses"], train_data["labels2idx"], train_data["idx2label"] = RelData.get_labels( + train_data_label_names, self.config) + + for idx in range(len(train_data["output_relations"])): + train_data["output_relations"][idx][5] = train_data["labels2idx"][train_data["output_relations"][idx][4]] + + return train_data, test_data + + def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_path:str = "", checkpoint_path: str = "./"): + + if self.is_cuda_available: + self.log.info("Training on device:", + torch.cuda.get_device_name(0), self.device) + + self.model = self.model.to(self.device) + + # resize vocab just in case more tokens have been added + self.model_config.vocab_size = len(self.tokenizer.hf_tokenizers) + + train_rel_data = RelData( + cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + test_rel_data = RelData( + cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + + if train_csv_path != "": + if test_csv_path != "": + train_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv(train_csv_path), split_sets=False) + test_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv(test_csv_path), split_sets=False) + else: + train_rel_data.dataset, test_rel_data.dataset = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv(train_csv_path), split_sets=True) + + elif export_data_path != "": + export_data = {} + with open(export_data_path) as f: + export_data = json.load(f) + train_rel_data.dataset, test_rel_data.dataset = self._create_test_train_datasets( + train_rel_data.create_relations_from_export(export_data), split_sets=True) + else: + raise ValueError("NO DATA HAS BEEN PROVIDED (JSON/CSV/spacy_DOCS)") + + train_dataset_size = len(train_rel_data) + batch_size = train_dataset_size if train_dataset_size < self.config.train.batch_size else self.config.train.batch_size + train_dataloader = DataLoader(train_rel_data, batch_size=batch_size, shuffle=self.config.train.shuffle_data, + num_workers=0, collate_fn=self.padding_seq, + pin_memory=self.config.general.pin_memory) + test_dataset_size = len(test_rel_data) + test_batch_size = test_dataset_size if test_dataset_size < self.config.train.batch_size else self.config.train.batch_size + test_dataloader = DataLoader(test_rel_data, batch_size=test_batch_size, shuffle=self.config.train.shuffle_data, + num_workers=0, collate_fn=self.padding_seq, + pin_memory=self.config.general.pin_memory) + + criterion = nn.CrossEntropyLoss(ignore_index=-1) + + if self.optimizer is None: + parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + self.optimizer = Adam(parameters, lr=self.config.train.lr) + + if self.scheduler is None: + self.scheduler = MultiStepLR( + self.optimizer, milestones=self.config.train.multistep_milestones, + gamma=self.config.train.multistep_lr_gamma) # type: ignore + + self.epoch, self.best_f1 = load_state( + self.model, self.optimizer, self.scheduler, load_best=False, path=checkpoint_path, device=self.device) + + self.log.info("Starting training process...") + + losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( + path=checkpoint_path) + + if train_rel_data.dataset["nclasses"] > self.config.train.nclasses: + self.config.train.nclasses = train_rel_data.dataset["nclasses"] + self.model.relcat_config.train.nclasses = self.config.train.nclasses + + self.config.general.labels2idx.update( + train_rel_data.dataset["labels2idx"]) + self.config.general.idx2labels = { + int(v): k for k, v in self.config.general["labels2idx"].items()} + + gradient_acc_steps = self.config.train.gradient_acc_steps + max_grad_norm = self.config.train.max_grad_norm + + _epochs = self.epoch + self.config.train.nepochs + + for epoch in range(0, _epochs): + start_time = datetime.now().time() + total_loss = 0.0 + + loss_per_batch = [] + accuracy_per_batch = [] + + self.log.info( + "Total epochs on this model: %d | currently training epoch %d" % (_epochs, epoch)) + + pbar = tqdm(total=train_dataset_size) + + for i, data in enumerate(train_dataloader, 0): + self.model.train() + self.model.zero_grad() + + current_batch_size = len(data[0]) + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = ( + token_ids != self.pad_id).float().to(self.device) + + token_type_ids = torch.zeros( + (token_ids.shape[0], token_ids.shape[1])).long().to(self.device) + + labels = labels.to(self.device) + + model_output, classification_logits = self.model( + input_ids=token_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + e1_e2_start=e1_e2_start + ) + + batch_loss = criterion( + classification_logits.view(-1, self.config.train.nclasses).to(self.device), labels.squeeze(1)) + + batch_loss.backward() + batch_loss = batch_loss / gradient_acc_steps + + total_loss += batch_loss.item() / current_batch_size + + batch_acc, _, batch_precision, batch_f1, _, _, batch_stats_per_label = self.evaluate_( + classification_logits, labels, ignore_idx=-1) + + loss_per_batch.append(batch_loss / current_batch_size) + accuracy_per_batch.append(batch_acc) + + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_grad_norm) + + if (i % gradient_acc_steps) == 0: + self.optimizer.step() + self.scheduler.step() + + if ((i + 1) % current_batch_size == 0): + self.log.debug( + "[Epoch: %d, loss per batch, accuracy per batch: %.3f, %.3f, average total loss %.3f , total loss %.3f]" % + (epoch, loss_per_batch[-1], accuracy_per_batch[-1], total_loss / (i + 1), total_loss)) + + pbar.update(current_batch_size) + + pbar.close() + + if len(loss_per_batch) > 0: + losses_per_epoch.append( + sum(loss_per_batch) / len(loss_per_batch)) + self.log.info("Losses at Epoch %d: %.5f" % + (epoch, losses_per_epoch[-1])) + + if len(accuracy_per_batch) > 0: + accuracy_per_epoch.append( + sum(accuracy_per_batch) / len(accuracy_per_batch)) + self.log.info("Train accuracy at Epoch %d: %.5f" % + (epoch, accuracy_per_epoch[-1])) + + total_loss = total_loss / (i + 1) + + end_time = datetime.now().time() + + self.log.info( + "======================== TRAIN SET TEST RESULTS ========================") + _ = self.evaluate_results(train_dataloader, self.pad_id) + + self.log.info( + "======================== TEST SET TEST RESULTS ========================") + results = self.evaluate_results(test_dataloader, self.pad_id) + + f1_per_epoch.append(results['f1']) + + self.log.info("Epoch finished, took " + str(datetime.combine(date.today(), + end_time) - datetime.combine(date.today(), + start_time)) + " seconds") + + self.epoch += 1 + + if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: + self.best_f1 = f1_per_epoch[-1] + save_state(self.model, self.optimizer, self.scheduler, self.epoch, self.best_f1, checkpoint_path, + model_name=self.config.general.model_name, task=self.task, is_checkpoint=False) + + if (epoch % 1) == 0: + save_results({"losses_per_epoch": losses_per_epoch, "accuracy_per_epoch": accuracy_per_epoch, + "f1_per_epoch": f1_per_epoch, "epoch": epoch}, file_prefix="train", path=checkpoint_path) + save_state(self.model, self.optimizer, self.scheduler, self.epoch, self.best_f1, checkpoint_path, + model_name=self.config.general.model_name, task=self.task, is_checkpoint=True) + + def evaluate_(self, output_logits, labels, ignore_idx): + # ignore index (padding) when calculating accuracy + idxs = (labels != ignore_idx).squeeze() + labels_ = labels.squeeze()[idxs].to(self.device) + pred_labels = torch.softmax(output_logits, dim=1).max(1)[1] + pred_labels = pred_labels[idxs].to(self.device) + + true_labels = labels_.cpu().numpy().tolist( + ) if labels_.is_cuda else labels_.numpy().tolist() + pred_labels = pred_labels.cpu().numpy().tolist( + ) if pred_labels.is_cuda else pred_labels.numpy().tolist() + + unique_labels = set(true_labels) + + batch_size = len(true_labels) + + stat_per_label = dict() + + total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0 + acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0 + + for label in unique_labels: + stat_per_label[label] = { + "tp": 0, "fp": 0, "tn": 0, "fn": 0, "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} + for true_label_idx in range(len(true_labels)): + if true_labels[true_label_idx] == label: + if pred_labels[true_label_idx] == label: + stat_per_label[label]["tp"] += 1 + total_tp += 1 + if pred_labels[true_label_idx] != label: + stat_per_label[label]["fp"] += 1 + total_fp += 1 + elif true_labels[true_label_idx] != label and label == pred_labels[true_label_idx]: + stat_per_label[label]["fn"] += 1 + total_fn += 1 + else: + stat_per_label[label]["tn"] += 1 + total_tn += 1 + + lbl_tp_tn = stat_per_label[label]["tn"] + \ + stat_per_label[label]["tp"] + + lbl_tp_fn = stat_per_label[label]["fn"] + \ + stat_per_label[label]["tp"] + lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0 + + lbl_tp_fp = stat_per_label[label]["tp"] + \ + stat_per_label[label]["fp"] + lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0 + + stat_per_label[label]["acc"] = lbl_tp_tn / batch_size + stat_per_label[label]["prec"] = stat_per_label[label]["tp"] / lbl_tp_fp + stat_per_label[label]["recall"] = stat_per_label[label]["tp"] / lbl_tp_fn + + lbl_re_pr = stat_per_label[label]["recall"] + \ + stat_per_label[label]["prec"] + lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0 + + stat_per_label[label]["f1"] = ( + 2 * (stat_per_label[label]["recall"] * stat_per_label[label]["prec"])) / lbl_re_pr + + tp_fn = total_fn + total_tp + tp_fn = tp_fn if tp_fn > 0.0 else 1.0 + + tp_fp = total_fp + total_tp + tp_fp = tp_fp if tp_fp > 0.0 else 1.0 + + micro_recall = total_tp / tp_fn + micro_precision = total_tp / tp_fp + + re_pr = micro_recall + micro_precision + re_pr = re_pr if re_pr > 0.0 else 1.0 + micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr + + acc = total_tp / batch_size + + return acc, micro_recall, micro_precision, micro_f1, pred_labels, true_labels, stat_per_label + + def evaluate_results(self, data_loader, pad_id): + self.log.info("Evaluating test samples...") + criterion = nn.CrossEntropyLoss(ignore_index=-1) + total_loss, total_acc, total_f1, total_recall, total_precision = 0.0, 0.0, 0.0, 0.0, 0.0 + all_batch_stats_per_label = [] + + self.model.eval() + + for i, data in enumerate(data_loader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + attention_mask = (token_ids != pad_id).float().to(self.device) + token_type_ids = torch.zeros( + (token_ids.shape[0], token_ids.shape[1])).long().to(self.device) + + labels = labels.to(self.device) + + model_output, pred_classification_logits = self.model(token_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, Q=None, + e1_e2_start=e1_e2_start) + + batch_loss = criterion(pred_classification_logits.view( + -1, self.config.train.nclasses).to(self.device), labels.squeeze(1)) + total_loss += batch_loss.item() + + batch_accuracy, batch_recall, batch_precision, batch_f1, pred_labels, true_labels, batch_stats_per_label = \ + self.evaluate_(pred_classification_logits, + labels, ignore_idx=-1) + + all_batch_stats_per_label.append(batch_stats_per_label) + + total_acc += batch_accuracy + total_recall += batch_recall + total_precision += batch_precision + total_f1 += batch_f1 + + final_stats_per_label = {} + + for batch_label_stats in all_batch_stats_per_label: + for label_id, stat_dict in batch_label_stats.items(): + + if label_id not in final_stats_per_label.keys(): + final_stats_per_label[label_id] = stat_dict + else: + for stat, score in stat_dict.items(): + final_stats_per_label[label_id][stat] += score + + for label_id, stat_dict in final_stats_per_label.items(): + for stat_name, value in stat_dict.items(): + final_stats_per_label[label_id][stat_name] = value / (i + 1) + + total_loss = total_loss / (i + 1) + total_acc = total_acc / (i + 1) + total_precision = total_precision / (i + 1) + total_f1 = total_f1 / (i + 1) + total_recall = total_recall / (i + 1) + + results = { + "loss": total_loss, + "accuracy": total_acc, + "precision": total_precision, + "recall": total_recall, + "f1": total_f1 + } + + self.log.info("==================== Evaluation Results ===================") + self.log.info(" no. of batches:" + str(i + 1)) + for key in sorted(results.keys()): + self.log.info(" %s = %0.3f" % (key, results[key])) + self.log.info("----------------------- class stats -----------------------") + for label_id, stat_dict in final_stats_per_label.items(): + self.log.info("label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | recall: %0.3f " % ( + self.config.general.idx2labels[label_id], + stat_dict["f1"], + stat_dict["prec"], + stat_dict["acc"], + stat_dict["recall"] + )) + self.log.info("-----------------------------------------------------------") + self.log.info("===========================================================") + + return results + + def pipe(self, stream: Iterable[Doc], *args, **kwargs) -> Iterator[Doc]: + + predict_rel_dataset = RelData( + cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + + self.model = self.model.to(self.device) # type: ignore + + for doc_id, doc in enumerate(stream, 0): + predict_rel_dataset.dataset, _ = self._create_test_train_datasets( + predict_rel_dataset.create_base_relations_from_doc(doc, str(doc_id)), False) + + predict_dataloader = DataLoader(predict_rel_dataset, shuffle=False, batch_size=self.config.train.batch_size, + num_workers=0, collate_fn=self.padding_seq, + pin_memory=self.config.general.pin_memory) + + total_rel_found = len(predict_rel_dataset.dataset["output_relations"]) + rel_idx = -1 + + self.log.info("total relations for doc: " + str(total_rel_found)) + self.log.info("processing...") + + pbar = tqdm(total=total_rel_found) + + for i, data in enumerate(predict_dataloader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = (token_ids != self.pad_id).float() + token_type_ids = torch.zeros( + token_ids.shape[0], token_ids.shape[1]).long() + + model_output, pred_classification_logits = self.model( + token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, + e1_e2_start=e1_e2_start) # type: ignore + + for i, pred_rel_logits in enumerate(pred_classification_logits): + rel_idx += 1 + + confidence = torch.softmax( + pred_rel_logits, dim=0).max(0) + predicted_label_id = confidence[1].item() + + doc._.relations.append({"relation": self.config.general.idx2labels[predicted_label_id], + "label_id": predicted_label_id, + "ent1_text": predict_rel_dataset.dataset["output_relations"][rel_idx][ + 2], + "ent2_text": predict_rel_dataset.dataset["output_relations"][rel_idx][ + 3], + "confidence": float("{:.3f}".format(confidence[0])), + "start_ent_pos": "", + "end_ent_pos": "", + "start_entity_id": + predict_rel_dataset.dataset["output_relations"][rel_idx][8], + "end_entity_id": + predict_rel_dataset.dataset["output_relations"][rel_idx][9]}) + pbar.update(len(token_ids)) + pbar.close() + + yield doc + + def __call__(self, doc: Doc) -> Doc: + doc = next(self.pipe(iter([doc]))) + return doc diff --git a/medcat/utils/relation_extraction/__init__.py b/medcat/utils/relation_extraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/utils/relation_extraction/models.py b/medcat/utils/relation_extraction/models.py new file mode 100644 index 000000000..d0003a1c5 --- /dev/null +++ b/medcat/utils/relation_extraction/models.py @@ -0,0 +1,219 @@ +import logging +from typing import Any, List, Optional, Tuple +import torch +from torch import nn +from transformers.models.bert.modeling_bert import BertPreTrainingHeads, BertModel +from transformers.models.bert.configuration_bert import BertConfig +from medcat.config_rel_cat import ConfigRelCAT + + +class BertModel_RelationExtraction(nn.Module): + """ BertModel class for RelCAT + """ + + name = "bertmodel_relcat" + + log = logging.getLogger(__name__) + + def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: BertConfig): + """ Class to hold the BERT model + model_config + + Args: + pretrained_model_name_or_path (str): path to load the model from, + this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' + using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. + relcat_config (ConfigRelCAT): relcat config. + model_config (BertConfig): HF bert config for model. + """ + super(BertModel_RelationExtraction, self).__init__() + + self.relcat_config: ConfigRelCAT = relcat_config + self.model_config: BertConfig = model_config + + self.bert_model:BertModel = BertModel(config=model_config) + + if pretrained_model_name_or_path != "": + self.bert_model = BertModel.from_pretrained(pretrained_model_name_or_path, config=model_config) + + for param in self.bert_model.parameters(): + param.requires_grad = False + + self.drop_out = nn.Dropout(self.model_config.hidden_dropout_prob) + + if self.relcat_config.general.task == "pretrain": + self.activation = nn.Tanh() + self.cls = BertPreTrainingHeads(self.model_config) + + self.relu = nn.ReLU() + + # dense layers + self.fc1 = nn.Linear(self.relcat_config.model.model_size, self.relcat_config.model.hidden_size) + self.fc2 = nn.Linear(self.relcat_config.model.hidden_size, int(self.relcat_config.model.hidden_size / 2)) + self.fc3 = nn.Linear(int(self.relcat_config.model.hidden_size / 2), self.relcat_config.train.nclasses) + + self.log.info("RelCAT BertConfig: " + str(self.model_config)) + + def get_annotation_schema_tag(self, sequence_output: torch.Tensor, input_ids: torch.Tensor, special_tag: List) -> torch.Tensor: + """ Gets to token sequences from the sequence_ouput for the specific token + tag ids in self.relcat_config.general.annotation_schema_tag_ids. + + Args: + sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text + input_ids (torch.Tensor): input token ids + special_tag (List): special annotation token id pairs + + Returns: + torch.Tensor: new seq_tags + """ + + idx_start = torch.where(input_ids == special_tag[0]) # returns: row ids, idx of token[0]/star token in row + idx_end = torch.where(input_ids == special_tag[1]) # returns: row ids, idx of token[1]/end token in row + + seen = [] # List to store seen elements and their indices + duplicate_indices = [] + + for i in range(len(idx_start[0])): + if idx_start[0][i] in seen: + duplicate_indices.append(i) + else: + seen.append(idx_start[0][i]) + + if len(duplicate_indices) > 0: + self.log.info("Duplicate entities found, removing them...") + for idx_remove in duplicate_indices: + idx_start_0 = torch.cat((idx_start[0][:idx_remove], idx_start[0][idx_remove + 1:])) + idx_start_1 = torch.cat((idx_start[1][:idx_remove], idx_start[1][idx_remove + 1:])) + idx_start = (idx_start_0, idx_start_1) # type: ignore + + seen = [] + duplicate_indices = [] + + for i in range(len(idx_end[0])): + if idx_end[0][i] in seen: + duplicate_indices.append(i) + else: + seen.append(idx_end[0][i]) + + if len(duplicate_indices) > 0: + self.log.info("Duplicate entities found, removing them...") + for idx_remove in duplicate_indices: + idx_end_0 = torch.cat((idx_end[0][:idx_remove], idx_end[0][idx_remove + 1:])) + idx_end_1 = torch.cat((idx_end[1][:idx_remove], idx_end[1][idx_remove + 1:])) + idx_end = (idx_end_0, idx_end_1) # type: ignore + + assert len(idx_start[0]) == input_ids.shape[0] + assert len(idx_start[0]) == len(idx_end[0]) + sequence_output_entities = [] + + for i in range(len(idx_start[0])): + to_append = sequence_output[i, idx_start[1][i] + 1:idx_end[1][i], ] + + # to_append = torch.sum(to_append, dim=0) + to_append, _ = torch.max(to_append, axis=0) # type: ignore + + sequence_output_entities.append(to_append) + sequence_output_entities = torch.stack(sequence_output_entities) + + return sequence_output_entities + + def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tensor, input_ids: torch.Tensor, e1_e2_start: torch.Tensor) -> torch.Tensor: + """ + + Args: + pooled_output (torch.Tensor): embedding of the CLS token + sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text + input_ids (torch.Tensor): input token ids. + e1_e2_start (torch.Tensor): annotation tags token position + + Returns: + torch.Tensor: classification probabilities for each token. + """ + + new_pooled_output = pooled_output + + if self.relcat_config.general.annotation_schema_tag_ids: + annotation_schema_tag_ids_ = [self.relcat_config.general.annotation_schema_tag_ids[i:i + 2] for i in + range(0, len(self.relcat_config.general.annotation_schema_tag_ids), 2)] + seq_tags = [] + + # for each pair of tags (e1,s1) and (e2,s2) + for each_tags in annotation_schema_tag_ids_: + seq_tags.append(self.get_annotation_schema_tag( + sequence_output, input_ids, each_tags)) + + seq_tags = torch.stack(seq_tags, dim=0) + + new_pooled_output = torch.cat((pooled_output, *seq_tags), dim=1) + else: + e1e2_output = [] + temp_e1 = [] + temp_e2 = [] + + for i, seq in enumerate(sequence_output): + # e1e2 token sequences + temp_e1.append(seq[e1_e2_start[i][0]]) + temp_e2.append(seq[e1_e2_start[i][1]]) + + e1e2_output.append(torch.stack(temp_e1, dim=0)) + e1e2_output.append(torch.stack(temp_e2, dim=0)) + + new_pooled_output = torch.cat((pooled_output, *e1e2_output), dim=1) + + del e1e2_output + del temp_e2 + del temp_e1 + + x = self.drop_out(new_pooled_output) + x = self.fc1(x) + x = self.drop_out(x) + x = self.fc2(x) + classification_logits = self.fc3(x) + return classification_logits.to(self.relcat_config.general.device) + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Any = None, + head_mask: Any = None, + encoder_hidden_states: Any = None, + encoder_attention_mask: Any = None, + Q: Any = None, + e1_e2_start: Any = None, + pooled_output: Any = None) -> Tuple[torch.Tensor, torch.Tensor]: + + if input_ids is not None: + input_shape = input_ids.size() + else: + raise ValueError("You have to specify input_ids") + + if attention_mask is None: + attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.relcat_config.general.device) + + input_ids = input_ids.to(self.relcat_config.general.device) + attention_mask = attention_mask.to(self.relcat_config.general.device) + encoder_attention_mask = encoder_attention_mask.to( + self.relcat_config.general.device) + + self.bert_model = self.bert_model.to(self.relcat_config.general.device) + + model_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask) + + # (batch_size, sequence_length, hidden_size) + sequence_output = model_output[0] + pooled_output = model_output[1] + + classification_logits = self.output2logits( + pooled_output, sequence_output, input_ids, e1_e2_start) + + return model_output, classification_logits.to(self.relcat_config.general.device) diff --git a/medcat/utils/relation_extraction/pad_seq.py b/medcat/utils/relation_extraction/pad_seq.py new file mode 100644 index 000000000..b23738f15 --- /dev/null +++ b/medcat/utils/relation_extraction/pad_seq.py @@ -0,0 +1,53 @@ +from typing import List, Tuple +import torch +from torch import Tensor, LongTensor +from torch.nn.utils.rnn import pad_sequence + + +class Pad_Sequence(): + + def __init__(self, seq_pad_value: int, label_pad_value: int = -1): + """ Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets. + collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label + lengths into a fixed length batch. + This is applied per batch and not on the whole DataLoader data, + padded x sequence, y sequence, x lengths and y lengths of batch. + + Args: + seq_pad_value (int): pad value for input_ids. + label_pad_value (int): pad value for labels. Defaults to -1. + """ + self.seq_pad_value: int = seq_pad_value + self.label_pad_value: int = label_pad_value + + def __call__(self, batch: List[torch.Tensor]) -> Tuple[Tensor, List, Tensor, LongTensor, LongTensor]: + """ Pads a batch of input_ids. + + Args: + batch (List[torch.Tensor]): gets the batch of Tensors from + RelData.dataset (check __getitem__() method for data returned) + and pads the token sequence + labels as needed + See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence + for extra info. + + Returns: + Tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor]: padded data + padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length + """ + sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True) + + # input ids + seqs = [x[0] for x in sorted_batch] + seqs_padded = pad_sequence( + seqs, batch_first=True, padding_value=self.seq_pad_value) + x_lengths = torch.LongTensor([len(x) for x in seqs]) + + # label_ids + labels = list(map(lambda x: x[2], sorted_batch)) + labels_padded = pad_sequence( + labels, batch_first=True, padding_value=self.label_pad_value) + y_lengths = torch.LongTensor([len(x) for x in labels]) + + ent1_ent2_start_pos = list(map(lambda x: x[1], sorted_batch)) + + return seqs_padded, ent1_ent2_start_pos, labels_padded, x_lengths, y_lengths diff --git a/medcat/utils/relation_extraction/rel_dataset.py b/medcat/utils/relation_extraction/rel_dataset.py new file mode 100644 index 000000000..714cccd30 --- /dev/null +++ b/medcat/utils/relation_extraction/rel_dataset.py @@ -0,0 +1,687 @@ +from ast import literal_eval +from typing import Any, Iterable, List, Dict, Tuple, Union +from torch.utils.data import Dataset +from spacy.tokens import Doc +import logging +import pandas +import random +import torch +from medcat.cdb import CDB +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.meta_cat.data_utils import Span +from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT + + +class RelData(Dataset): + + name = "rel_dataset" + + log = logging.getLogger(__name__) + + def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: CDB = CDB()): + """ Use this class to create a dataset for relation annotations from CSV exports, + MedCAT exports or Spacy Documents (assuming the documents got generated by MedCAT, + if they did not then please set the required paramenters manually to match MedCAT output, + see /medcat/cat.py#_add_nested_ent) + + If you are using this to create relations from CSV it is assumed that your entities/concepts of + interest are surrounded by the special tokens, see create_base_relations_from_csv doc. + + Args: + tokenizer (TokenizerWrapperBERT): okenizer used to generate token ids from input text + config (ConfigRelCAT): same config used in RelCAT + cdb (CDB): Optional, used to add concept ids and types to detected ents, + useful when creating datasets from MedCAT output. Defaults to CDB(). + """ + + self.cdb: CDB = cdb + self.config: ConfigRelCAT = config + self.tokenizer: TokenizerWrapperBERT = tokenizer + self.dataset: Dict[Any, Any] = {} + + self.log.setLevel(self.config.general.log_level) + + def generate_base_relations(self, docs: Iterable[Doc]) -> List[Dict]: + """ Util function, should be used if you want to train from spacy docs + + Args: + docs (Iterable[Doc]): Generate relations from Spacy CAT docs. + + Returns: + output_relations: List[Dict] : [] + "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + "nclasses": self.config.model.padding_idx, <-- dummy class + "labels2idx": {}, + "idx2label": {}} + ] + """ + + output_relations = [] + for doc_id, doc in enumerate(docs): + output_relations.append( + self.create_base_relations_from_doc(doc, doc_id=str(doc_id),)) + + return output_relations + + def create_base_relations_from_csv(self, csv_path: str): + """ + Assumes the columns are as follows ["relation_token_span_ids", "ent1_ent2_start", "ent1", "ent2", "label", + "label_id", "ent1_type", "ent2_type", "ent1_id", "ent2_id", "ent1_cui", "ent2_cui", "doc_id", "sents"], + last column is the actual source text. + + The entities inside the text MUST be annotated with special tokens i.e: + ...some text..[s1] first entity [e1].....[s2] second entity [e2]........ + You have to store the start position, aka index position of token [e1] and also of token [e2] in + the (ent1_ent2_start) column. + + Args: + csv_path (str): path to csv file, must have specific columns, tab separated, + + Returns: + Dict : { + "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + "nclasses": self.config.model.padding_idx, <-- dummy class + "labels2idx": {}, + "idx2label": {}} + } + """ + + df = pandas.read_csv(csv_path, index_col=False, + encoding='utf-8', sep="\t") + + tmp_col_rel_token_col = df.pop("relation_token_span_ids") + + df.insert(0, "relation_token_span_ids", tmp_col_rel_token_col) + + text_cols = ["sents", "text"] + + df["ent1_ent2_start"] = df["ent1_ent2_start"].apply( + lambda x: literal_eval(str(x))) + + for col in text_cols: + if col in df.columns: + out_rels = [] + for row_idx in range(len(df[col])): + _text = df.iloc[row_idx][col] + _ent1_ent2_start = df.iloc[row_idx]["ent1_ent2_start"] + _rels = self.create_base_relations_from_doc( + _text, doc_id=str(row_idx), ent1_ent2_tokens_start_pos=_ent1_ent2_start,) + out_rels.append(_rels) + + rows_to_remove = [] + for row_idx in range(len(out_rels)): + if len(out_rels[row_idx]["output_relations"]) < 1: + rows_to_remove.append(row_idx) + + relation_token_span_ids = [] + out_ent1_ent2_starts = [] + + for rel in out_rels: + if len(rel["output_relations"]) > 0: + relation_token_span_ids.append( + rel["output_relations"][0][0]) + out_ent1_ent2_starts.append( + rel["output_relations"][0][1]) + else: + relation_token_span_ids.append([]) + out_ent1_ent2_starts.append([]) + + df["label"] = [i.strip() for i in df["label"]] + + df["relation_token_span_ids"] = relation_token_span_ids + df["ent1_ent2_start"] = out_ent1_ent2_starts + + df = df.drop(index=rows_to_remove) + df = df.drop(columns=col) + break + + nclasses, labels2idx, idx2label = RelData.get_labels( + df["label"], self.config) + + output_relations = df.values.tolist() + + self.log.info("CSV dataset | No. of relations detected:" + str(len(output_relations)) + + "| from : " + csv_path + " | nclasses: " + str(nclasses) + " | idx2label: " + str(idx2label)) + + self.log.info("Samples per class: ") + for label_num in list(idx2label.keys()): + sample_count = 0 + for output_relation in output_relations: + if label_num == output_relation[5]: + sample_count += 1 + self.log.info( + " label: " + idx2label[label_num] + " | samples: " + str(sample_count)) + + # replace/update label_id with actual detected label number + for idx in range(len(output_relations)): + output_relations[idx][5] = labels2idx[output_relations[idx][4]] + + return {"output_relations": output_relations, "nclasses": nclasses, "labels2idx": labels2idx, "idx2label": idx2label} + + def create_base_relations_from_doc(self, doc: Union[Doc, str], doc_id: str, ent1_ent2_tokens_start_pos: Union[List, Tuple] = (-1, -1)) -> Dict: + """ Creates a list of tuples based on pairs of entities detected (relation, ent1, ent2) for one spacy document or text string. + + Args: + doc (Union[Doc, str]): SpacyDoc or string of text, each will get handled slightly differently + doc_id (str): document id + ent1_ent2_tokens_start_pos (Union[List, Tuple], optional): start of [s1][s2] tokens, if left default + we assume we are dealing with a SpacyDoc. Defaults to (-1, -1). + + Returns: + Dict : { + "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + "nclasses": self.config.model.padding_idx, <-- dummy class + "labels2idx": {}, + "idx2label": {}} + } + """ + relation_instances = [] + + chars_to_exclude = ":!@#$%^&*()-+?_=.,;<>/[]{}" + tokenizer_data = None + + if isinstance(doc, str): + tokenizer_data = self.tokenizer(doc, truncation=False) + doc_text = doc + elif isinstance(doc, Doc): + tokenizer_data = self.tokenizer(doc.text, truncation=False) + doc_text = doc.text + + doc_length = len(tokenizer_data["tokens"]) + + if ent1_ent2_tokens_start_pos != (-1, -1): + ent1_token_start_pos, ent2_token_start_pos = ent1_ent2_tokens_start_pos[0],\ + ent1_ent2_tokens_start_pos[1] + # add + 1 to the pos cause of [CLS] + if self.config.general.annotation_schema_tag_ids: + ent1_token_start_pos, ent2_token_start_pos = ent1_ent2_tokens_start_pos[0] + 1,\ + ent1_ent2_tokens_start_pos[1] + 1 + + ent1_start_char_pos, _ = tokenizer_data["offset_mapping"][ent1_token_start_pos] + ent2_start_char_pos, _ = tokenizer_data["offset_mapping"][ent2_token_start_pos] + + if abs(ent2_start_char_pos - ent1_start_char_pos) <= self.config.general.window_size: + + ent1_left_ent_context_token_pos_end = ent1_token_start_pos - \ + self.config.general.cntx_left + + left_context_start_char_pos = 0 + right_context_start_end_pos = len(doc_text) - 1 + + if ent1_left_ent_context_token_pos_end < 0: + ent1_left_ent_context_token_pos_end = 0 + else: + left_context_start_char_pos = tokenizer_data[ + "offset_mapping"][ent1_left_ent_context_token_pos_end][0] + + ent2_right_ent_context_token_pos_end = ent2_token_start_pos + \ + self.config.general.cntx_right + + # get end of 2nd ent token (if using tags) + if self.config.general.annotation_schema_tag_ids: + far_pos = -1 + for tkn_id in self.config.general.annotation_schema_tag_ids: + pos = [i for i in range( + 0, doc_length) if tokenizer_data["input_ids"][i] == tkn_id][0] + far_pos = pos if far_pos < pos else far_pos + ent2_right_ent_context_token_pos_end = far_pos + + if ent2_right_ent_context_token_pos_end >= doc_length - 1: + ent2_right_ent_context_token_pos_end = doc_length - 2 + else: + right_context_start_end_pos = tokenizer_data[ + "offset_mapping"][ent2_right_ent_context_token_pos_end][1] + + ent1_token = tokenizer_data["tokens"][ent1_token_start_pos] + ent2_token = tokenizer_data["tokens"][ent2_token_start_pos] + + window_tokenizer_data = self.tokenizer( + doc_text[left_context_start_char_pos:right_context_start_end_pos]) + + # update token loc to match new selection + if self.config.general.annotation_schema_tag_ids: + ent1_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[0]) + ent2_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[2]) + else: + ent2_token_start_pos = ent2_token_start_pos - ent1_token_start_pos + ent1_token_start_pos = self.config.general.cntx_left if ent1_token_start_pos - \ + self.config.general.cntx_left > 0 else ent1_token_start_pos + ent2_token_start_pos += ent1_token_start_pos + + ent1_ent2_new_start = ( + ent1_token_start_pos, ent2_token_start_pos) + + en1_start, en1_end = window_tokenizer_data["offset_mapping"][ent1_token_start_pos] + en2_start, en2_end = window_tokenizer_data["offset_mapping"][ent2_token_start_pos] + + relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, + None, None, None, None, None, None, doc_id, "", + en1_start, en1_end, en2_start, en2_end]) + + elif isinstance(doc, Doc): + + _ents = doc.ents if len(doc.ents) > 0 else doc._.ents + for ent1_idx in range(0, len(_ents) - 1): + + ent1_token: Span = _ents[ent1_idx] # type: ignore + + if str(ent1_token) not in chars_to_exclude: + ent1_type_id = list( + self.cdb.cui2type_ids.get(ent1_token._.cui, '')) + ent1_types = [self.cdb.addl_info['type_id2name'].get( + tui, '') for tui in ent1_type_id] + + ent2pos = ent1_idx + 1 + + ent1_start = ent1_token.start + ent1_end = ent1_token.end + + # get actual token index from the text + _ent1_token_idx = [i for i in range(len(tokenizer_data["offset_mapping"])) if ent1_start in + range( + tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) + or ent1_end in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) + ][0] + + left_context_start_char_pos = 0 + ent1_left_ent_context_token_pos_end = _ent1_token_idx - self.config.general.cntx_left + + if ent1_left_ent_context_token_pos_end < 0: + ent1_left_ent_context_token_pos_end = 0 + else: + left_context_start_char_pos = tokenizer_data[ + "offset_mapping"][ent1_left_ent_context_token_pos_end][0] + + for ent2_idx in range(ent2pos, len(_ents)): + ent2_token: Span = _ents[ent2_idx] # type: ignore + + if ent2_token in _ents: + if str(ent2_token) not in chars_to_exclude and str(ent1_token) != str(ent2_token): + ent2_type_id = list( + self.cdb.cui2type_ids.get(ent2_token._.cui, '')) + ent2_types = [self.cdb.addl_info['type_id2name'].get( + tui, '') for tui in ent2_type_id] + + ent2_start = ent2_token.start + ent2_end = ent2_token.end + if ent2_start - ent1_start <= self.config.general.window_size and ent2_start - ent1_start > 0: + _ent2_token_idx = [i for i in range(len(tokenizer_data["offset_mapping"])) if ent2_start in + range( + tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) + or ent2_end in + range( + tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) + ][0] + + right_context_start_end_pos = len( + doc_text) - 1 + ent2_right_ent_context_token_pos_end = _ent2_token_idx + \ + self.config.general.cntx_right + + if ent2_right_ent_context_token_pos_end >= doc_length - 1: + ent2_right_ent_context_token_pos_end = doc_length - 2 + else: + right_context_start_end_pos = tokenizer_data[ + "offset_mapping"][ent2_right_ent_context_token_pos_end][1] + + tmp_doc_text = doc_text + + # check if a tag is present, and if not so then insert the custom annotation tags in + if self.config.general.annotation_schema_tag_ids[0] not in tokenizer_data["input_ids"]: + _pre_e1 = tmp_doc_text[0: (ent1_start)] + _e1_s2 = tmp_doc_text[( + ent1_end): (ent2_start)] + _e2_end = tmp_doc_text[( + ent2_end): len(doc_text)] + _ent2_token_idx = (_ent2_token_idx + 2) + + annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( + self.config.general.annotation_schema_tag_ids) + + tmp_doc_text = _pre_e1 + " " + \ + annotation_token_text[0] + " " + \ + str(ent1_token) + " " + \ + annotation_token_text[1] + " " + _e1_s2 + " " + \ + annotation_token_text[2] + " " + str(ent2_token) + " " + \ + annotation_token_text[3] + \ + " " + _e2_end + + ann_tag_token_len = len( + annotation_token_text[0]) + + _left_context_start_char_pos = left_context_start_char_pos - ann_tag_token_len + left_context_start_char_pos = 0 if _left_context_start_char_pos <= 0 \ + else _left_context_start_char_pos + + right_context_start_end_pos = right_context_start_end_pos if right_context_start_end_pos >= len(tmp_doc_text) \ + else right_context_start_end_pos + (ann_tag_token_len * 4) + + window_tokenizer_data = self.tokenizer( + tmp_doc_text[left_context_start_char_pos:right_context_start_end_pos]) + + if self.config.general.annotation_schema_tag_ids: + ent1_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[0]) + ent2_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[2]) + else: + ent2_token_start_pos = _ent2_token_idx - _ent1_token_idx if _ent1_token_idx - \ + self.config.general.cntx_left > 0 else _ent2_token_idx + ent1_token_start_pos = self.config.general.cntx_left if _ent1_token_idx - \ + self.config.general.cntx_left > 0 else _ent1_token_idx + ent2_token_start_pos += ent1_token_start_pos + + ent1_ent2_new_start = ( + ent1_token_start_pos, ent2_token_start_pos) + + en1_start, en1_end = window_tokenizer_data[ + "offset_mapping"][ent1_token_start_pos] + en2_start, en2_end = window_tokenizer_data[ + "offset_mapping"][ent2_token_start_pos] + + relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, + ent1_types, ent2_types, ent1_token._.id, ent2_token._.id, ent1_token._.cui, ent2_token._.cui, doc_id, "", + en1_start, en1_end, en2_start, en2_end]) + + return {"output_relations": relation_instances, "nclasses": self.config.model.padding_idx, "labels2idx": {}, "idx2label": {}} + + def create_relations_from_export(self, data: Dict): + """ + Args: + data (Dict): + MedCAT Export data. + + Returns: + Dict : { + "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + "nclasses": self.config.model.padding_idx, <-- dummy class + "labels2idx": {}, + "idx2label": {}} + } + """ + + output_relations = [] + + relation_type_filter_pairs = self.config.general.relation_type_filter_pairs + + annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( + self.config.general.annotation_schema_tag_ids) + + for project in data['projects']: + for doc_id, document in enumerate(project['documents']): + text = str(document['text']) + if len(text) > 0: + annotations = document['annotations'] + relations = document['relations'] + + if self.config.general.lowercase: + text = text.lower() + + tokenizer_data = self.tokenizer(text, truncation=False) + + doc_length_tokens = len(tokenizer_data["tokens"]) + + relation_instances = [] + ann_ids_from_reliations = [] + + ann_ids_ents: Dict[Any, Any] = {} + + _other_rel_subset = [] + + for ent1_idx, ent1_ann in enumerate(annotations): + ann_id = ent1_ann['id'] + ann_ids_ents[ann_id] = {} + ann_ids_ents[ann_id]['cui'] = ent1_ann['cui'] + ann_ids_ents[ann_id]['type_ids'] = list( + self.cdb.cui2type_ids.get(ent1_ann['cui'], '')) + ann_ids_ents[ann_id]['types'] = [self.cdb.addl_info['type_id2name'].get( + tui, '') for tui in ann_ids_ents[ann_id]['type_ids']] + + if self.config.general.mct_export_create_addl_rels: + + for _, ent2_ann in enumerate(annotations[ent1_idx + 1:]): + if abs(ent1_ann["start"] - ent2_ann["start"]) <= self.config.general.window_size: + if ent1_ann["validated"] and ent2_ann["validated"]: + _other_rel_subset.append({ + "start_entity": ent1_ann["id"], + "start_entity_cui": ent1_ann["cui"], + "start_entity_value": ent1_ann["value"], + "start_entity_start_idx": ent1_ann["start"], + "start_entity_end_idx": ent1_ann["end"], + "end_entity": ent2_ann["id"], + "end_entity_cui": ent2_ann["cui"], + "end_entity_value": ent2_ann["value"], + "end_entity_start_idx": ent2_ann["start"], + "end_entity_end_idx": ent2_ann["end"], + "relation": "Other", + "validated": True + }) + + non_rel_sample_size_limit = int(int( + self.config.general.mct_export_max_non_rel_sample_size) / len(data['projects'])) + + if non_rel_sample_size_limit > 0 and len(_other_rel_subset) > 0: + random.shuffle(_other_rel_subset) + _other_rel_subset = _other_rel_subset[0:non_rel_sample_size_limit] + + relations.extend(_other_rel_subset) + + for relation in relations: + ann_start_start_pos = relation['start_entity_start_idx'] + ann_start_end_pos = relation["start_entity_end_idx"] + + ann_end_start_pos = relation['end_entity_start_idx'] + ann_end_end_pos = relation["end_entity_end_idx"] + + start_entity_value = relation['start_entity_value'] + end_entity_value = relation['end_entity_value'] + + start_entity_id = relation['start_entity'] + end_entity_id = relation['end_entity'] + + start_entity_types = ann_ids_ents[start_entity_id]['types'] + end_entity_types = ann_ids_ents[end_entity_id]['types'] + start_entity_cui = ann_ids_ents[start_entity_id]['cui'] + end_entity_cui = ann_ids_ents[end_entity_id]['cui'] + + # if somehow the annotations belong to the same relation but make sense in reverse + if ann_start_start_pos > ann_end_start_pos: + ann_end_start_pos = relation['start_entity_start_idx'] + ann_end_end_pos = relation['start_entity_end_idx'] + + ann_start_start_pos = relation['end_entity_start_idx'] + ann_start_end_pos = relation['end_entity_end_idx'] + + end_entity_value = relation['start_entity_value'] + start_entity_value = relation['end_entity_value'] + + end_entity_cui = ann_ids_ents[start_entity_id]['cui'] + start_entity_cui = ann_ids_ents[end_entity_id]['cui'] + + end_entity_types = ann_ids_ents[start_entity_id]['types'] + start_entity_types = ann_ids_ents[end_entity_id]['types'] + + # switch ids last + start_entity_id = relation['end_entity'] + end_entity_id = relation['start_entity'] + + for ent1type, ent2type in enumerate(relation_type_filter_pairs): + if ent1type not in start_entity_types and ent2type not in end_entity_types: + continue + + ann_ids_from_reliations.extend( + [start_entity_id, end_entity_id]) + + relation_label = relation['relation'].strip() + + if start_entity_id != end_entity_id and relation.get('validated', True): + if abs(ann_start_start_pos - ann_end_start_pos) <= self.config.general.window_size: + + ent1_token_start_pos = [i for i in range(0, doc_length_tokens) if ann_start_start_pos + in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1)][0] + + ent2_token_start_pos = [i for i in range(0, doc_length_tokens) if ann_end_start_pos + in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1)][0] + + ent1_left_ent_context_token_pos_end = ent1_token_start_pos - \ + self.config.general.cntx_left + + left_context_start_char_pos = 0 + right_context_start_end_pos = len(text) - 1 + + if ent1_left_ent_context_token_pos_end < 0: + ent1_left_ent_context_token_pos_end = 0 + else: + left_context_start_char_pos = tokenizer_data[ + "offset_mapping"][ent1_left_ent_context_token_pos_end][0] + + ent2_right_ent_context_token_pos_end = ent2_token_start_pos + \ + self.config.general.cntx_right + if ent2_right_ent_context_token_pos_end >= doc_length_tokens - 1: + ent2_right_ent_context_token_pos_end = doc_length_tokens - 2 + else: + right_context_start_end_pos = tokenizer_data[ + "offset_mapping"][ent2_right_ent_context_token_pos_end][1] + + tmp_text = text + # check if a tag is present, and if not so then insert the custom annotation tags in + if self.config.general.annotation_schema_tag_ids[0] not in tokenizer_data["input_ids"]: + _pre_e1 = text[0: (ann_start_start_pos)] + _e1_s2 = text[(ann_start_end_pos): ( + ann_end_start_pos)] + _e2_end = text[( + ann_end_end_pos): len(text)] + + tmp_text = _pre_e1 + " " + \ + annotation_token_text[0] + " " + \ + text[ann_start_start_pos:ann_start_end_pos] + " " + \ + annotation_token_text[1] + " " + \ + _e1_s2 + " " + \ + annotation_token_text[2] + " " + text[ann_end_start_pos:ann_end_end_pos] + \ + " " + \ + annotation_token_text[3] + \ + " " + _e2_end + + ann_tag_token_len = len( + annotation_token_text[0]) + + _left_context_start_char_pos = left_context_start_char_pos - ann_tag_token_len - 2 + left_context_start_char_pos = 0 if _left_context_start_char_pos <= 0 \ + else _left_context_start_char_pos + + _right_context_start_end_pos = right_context_start_end_pos + \ + (ann_tag_token_len * 4) + \ + 8 # 8 for spces + right_context_start_end_pos = len(tmp_text) if right_context_start_end_pos >= len(tmp_text) or _right_context_start_end_pos >= len(tmp_text) \ + else _right_context_start_end_pos + + window_tokenizer_data = self.tokenizer( + tmp_text[left_context_start_char_pos:right_context_start_end_pos]) + + if self.config.general.annotation_schema_tag_ids: + ent1_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[0]) + ent2_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[2]) + else: + # update token loc to match new selection + ent2_token_start_pos = ent2_token_start_pos - ent1_token_start_pos + ent1_token_start_pos = self.config.general.cntx_left if ent1_token_start_pos - \ + self.config.general.cntx_left > 0 else ent1_token_start_pos + ent2_token_start_pos += ent1_token_start_pos + + ent1_ent2_new_start = ( + ent1_token_start_pos, ent2_token_start_pos) + en1_start, en1_end = window_tokenizer_data[ + "offset_mapping"][ent1_token_start_pos] + en2_start, en2_end = window_tokenizer_data[ + "offset_mapping"][ent2_token_start_pos] + + relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, start_entity_value, end_entity_value, relation_label, self.config.model.padding_idx, + start_entity_types, end_entity_types, start_entity_id, end_entity_id, start_entity_cui, end_entity_cui, doc_id, "", + en1_start, en1_end, en2_start, en2_end]) + + output_relations.extend(relation_instances) + + all_relation_labels = [relation[4] for relation in output_relations] + + nclasses, labels2idx, idx2label = self.get_labels( + all_relation_labels, self.config) + + # replace label_id with actual detected label number + for idx in range(len(output_relations)): + output_relations[idx][5] = labels2idx[output_relations[idx][4]] + + self.log.info("MCT export dataset | nclasses: " + + str(nclasses) + " | idx2label: " + str(idx2label)) + self.log.info("Samples per class: ") + for label_num in list(idx2label.keys()): + sample_count = 0 + for output_relation in output_relations: + if int(label_num) == int(output_relation[5]): + sample_count += 1 + self.log.info( + " label: " + idx2label[label_num] + " | samples: " + str(sample_count)) + + return {"output_relations": output_relations, "nclasses": nclasses, "labels2idx": labels2idx, "idx2label": idx2label} + + @classmethod + def get_labels(cls, relation_labels: List[str], config: ConfigRelCAT) -> Tuple[int, Dict[str, Any], Dict[int, Any]]: + """ This is used to update labels in config with unencountered classes/labels ( if any are encountered during training). + + Args: + relation_labels (List[str]): new labels to add + config (ConfigRelCAT): config + + Returns: + Any: _description_ + """ + curr_class_id = 0 + + config_labels2idx: Dict = config.general.labels2idx + config_idx2labels: Dict = config.general.idx2labels + + relation_labels = [relation_label.strip() + for relation_label in relation_labels] + + for relation_label in set(relation_labels): + if relation_label not in config_labels2idx.keys(): + while curr_class_id in [int(label_idx) for label_idx in config_idx2labels.keys()]: + curr_class_id += 1 + config_labels2idx[relation_label] = curr_class_id + config_idx2labels[curr_class_id] = relation_label + + return len(config_labels2idx.keys()), config_labels2idx, config_idx2labels, + + def __len__(self) -> int: + """ + Returns: + int: num of rels records + """ + return len(self.dataset['output_relations']) + + def __getitem__(self, idx: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]: + """ + + Args: + idx (int): index of item in the dataset dict + + Returns: + Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]: long tensors of the following the columns : input_ids, ent1&ent2 token start pos idx, label_ids + """ + + return torch.LongTensor(self.dataset['output_relations'][idx][0]),\ + torch.LongTensor(self.dataset['output_relations'][idx][1]),\ + torch.LongTensor([self.dataset['output_relations'][idx][5]]) diff --git a/medcat/utils/relation_extraction/tokenizer.py b/medcat/utils/relation_extraction/tokenizer.py new file mode 100644 index 000000000..af3db5145 --- /dev/null +++ b/medcat/utils/relation_extraction/tokenizer.py @@ -0,0 +1,71 @@ +import os +from typing import Optional +from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast + + +class TokenizerWrapperBERT(BertTokenizerFast): + ''' Wrapper around a huggingface BERT tokenizer so that it works with the + RelCAT models. + + Args: + hf_tokenizers (`transformers.models.bert.tokenization_bert_fast.BertTokenizerFast`): + A huggingface Fast BERT. + ''' + name = 'bert-tokenizer' + + def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False): + self.hf_tokenizers = hf_tokenizers + self.max_seq_length = max_seq_length + self.add_special_tokens = add_special_tokens + + def __call__(self, text, truncation: Optional[bool] = True): + if isinstance(text, str): + result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, return_attention_mask=True, + add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation) + + return {'offset_mapping': result['offset_mapping'], + 'input_ids': result['input_ids'], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(result['input_ids']), + 'token_type_ids': result['token_type_ids'], + 'attention_mask': result['attention_mask'], + 'length': result['length'] + } + elif isinstance(text, list): + results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, + add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length,truncation=truncation) + output = [] + for ind in range(len(results['input_ids'])): + output.append({ + 'offset_mapping': results['offset_mapping'][ind], + 'input_ids': results['input_ids'][ind], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(results['input_ids'][ind]), + 'token_type_ids': results['token_type_ids'][ind], + 'attention_mask': results['attention_mask'][ind], + 'length': result['length'] + }) + return output + else: + raise Exception( + "Unsuported input type, supported: text/list, but got: {}".format(type(text))) + + def save(self, dir_path): + path = os.path.join(dir_path, self.name) + self.hf_tokenizers.save_pretrained(path) + + @classmethod + def load(cls, dir_path, **kwargs): + tokenizer = cls() + path = os.path.join(dir_path, cls.name) + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained( + path, **kwargs) + + return tokenizer + + def get_size(self): + return len(self.hf_tokenizers.vocab) + + def token_to_id(self, token): + return self.hf_tokenizers.convert_tokens_to_ids(token) + + def get_pad_id(self): + return self.hf_tokenizers.pad_token_id diff --git a/medcat/utils/relation_extraction/utils.py b/medcat/utils/relation_extraction/utils.py new file mode 100644 index 000000000..b0a2b8094 --- /dev/null +++ b/medcat/utils/relation_extraction/utils.py @@ -0,0 +1,277 @@ +import os +import pickle +from typing import Any, Dict, List, Tuple +import numpy as np +import torch +import logging +import random + +from pandas.core.series import Series +from medcat.config_rel_cat import ConfigRelCAT + +from medcat.preprocessing.tokenizers import TokenizerWrapperBERT +from medcat.utils.relation_extraction.models import BertModel_RelationExtraction + + +def split_list_train_test_by_class(data: List, test_size: float = 0.2, shuffle: bool = True) -> Tuple[List, List]: + """ + + Args: + data (List): "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + test_size (float): Defaults to 0.2. + shuffle (bool): shuffle data randomly. Defaults to True. + + Returns: + Tuple[List, List]: train and test datasets + """ + + if shuffle: + random.shuffle(data) + + train_data = [] + test_data = [] + + row_id_labels = {row_idx: data[row_idx][5] for row_idx in range(len(data))} + count_per_label = {lbl: list(row_id_labels.values()).count( + lbl) for lbl in set(row_id_labels.values())} + + for lbl_id, count in count_per_label.items(): + _test_records_size = int(count * test_size) + tmp_count = 0 + if _test_records_size not in [0, 1]: + for row_idx, _lbl_id in row_id_labels.items(): + if _lbl_id == lbl_id: + if tmp_count < _test_records_size: + test_data.append(data[row_idx]) + tmp_count += 1 + else: + train_data.append(data[row_idx]) + else: + for row_idx, _lbl_id in row_id_labels.items(): + if _lbl_id == lbl_id: + train_data.append(data[row_idx]) + test_data.append(data[row_idx]) + + return train_data, test_data + + +def load_bin_file(file_name, path="./") -> Any: + with open(os.path.join(path, file_name), 'rb') as f: + data = pickle.load(f) + return data + + +def save_bin_file(file_name, data, path="./"): + with open(os.path.join(path, file_name), "wb") as f: + pickle.dump(data, f) + + +def save_state(model: BertModel_RelationExtraction, optimizer: torch.optim.Adam, scheduler: torch.optim.lr_scheduler.MultiStepLR, epoch:int = 1, best_f1:float = 0.0, path:str = "./", model_name: str = "BERT", task:str = "train", is_checkpoint=False, final_export=False) -> None: + """ Used by RelCAT.save() and RelCAT.train() + Saves the RelCAT model state. + For checkpointing multiple files are created, best_f1, loss etc. score. + If you want to export the model after training set final_export=True and leave is_checkpoint=False. + + Args: + model (BertModel_RelationExtraction): model + optimizer (torch.optim.Adam, optional): Defaults to None. + scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None. + epoch (int): Defaults to None. + best_f1 (float): Defaults to None. + path (str):Defaults to "./". + model_name (str): . Defaults to "BERT". This is used to checkpointing only. + task (str): Defaults to "train". This is used to checkpointing only. + is_checkpoint (bool): Defaults to False. + final_export (bool): Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into"model.dat". + """ + + model_name = model_name.replace("/", "_") + file_name = "%s_checkpoint_%s.dat" % (task, model_name) + + if not is_checkpoint: + file_name = "%s_best_%s.dat" % (task, model_name) + if final_export: + file_name = "model.dat" + torch.save(model.state_dict(), os.path.join(path, file_name)) + + if is_checkpoint: + torch.save({ + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'best_f1': best_f1, + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() + }, os.path.join(path, file_name)) + + +def load_state(model: BertModel_RelationExtraction, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, device: torch.device =torch.device("cpu"), config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]: + """ Used by RelCAT.load() and RelCAT.train() + + Args: + model (BertModel_RelationExtraction): model, it has to be initialized before calling this method via BertModel_RelationExtraction(...) + optimizer (_type_): optimizer + scheduler (_type_): scheduler + path (str, optional): Defaults to "./". + model_name (str, optional): Defaults to "BERT". + file_prefix (str, optional): Defaults to "train". + load_best (bool, optional): Defaults to False. + device (torch.device, optional): Defaults to torch.device("cpu"). + config (ConfigRelCAT): Defaults to ConfigRelCAT(). + + Returns: + Tuple (int, int): last epoch and f1 score. + """ + + model_name = model_name.replace("/", "_") + logging.info("Attempting to load RelCAT model on device: " + str(device)) + checkpoint_path = os.path.join( + path, file_prefix + "_checkpoint_%s.dat" % model_name) + best_path = os.path.join( + path, file_prefix + "_best_%s.dat" % model_name) + start_epoch, best_f1, checkpoint = 0, 0, None + + if load_best is True and os.path.isfile(best_path): + checkpoint = torch.load(best_path, map_location=device) + logging.info("Loaded best model.") + elif os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=device) + logging.info("Loaded checkpoint model.") + + if checkpoint is not None: + start_epoch = checkpoint['epoch'] + best_f1 = checkpoint['best_f1'] + model.load_state_dict(checkpoint['state_dict']) + model.to(device) + + if optimizer is None: + optimizer = torch.optim.Adam( + [{"params": model.module.parameters(), "lr": config.train.lr}]) + + if scheduler is None: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=config.train.multistep_milestones, + gamma=config.train.multistep_lr_gamma) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + logging.info("Loaded model and optimizer.") + + return start_epoch, best_f1 + + +def save_results(data, model_name: str = "BERT", path: str = "./", file_prefix: str = "train"): + save_bin_file(file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % + model_name, data, path) + + +def load_results(path, model_name: str = "BERT", file_prefix: str = "train") -> Tuple[List, List, List]: + data_dict_path = os.path.join( + path, file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % model_name) + + data_dict: Dict = {"losses_per_epoch": [], + "accuracy_per_epoch": [], "f1_per_epoch": []} + if os.path.isfile(data_dict_path): + data_dict = load_bin_file(data_dict_path) + + return data_dict["losses_per_epoch"], data_dict["accuracy_per_epoch"], data_dict["f1_per_epoch"] + + +def put_blanks(relation_data: List, blanking_threshold: float = 0.5) -> List: + """ + Args: + relation_data (List): tuple containing token (sentence_token_span , ent1 , ent2) + Puts blanks randomly in the relation. Used for pre-training. + blanking_threshold (float): % threshold to blank token ids. Defaults to 0.5. + + Returns: + List: data + """ + + blank_ent1 = np.random.uniform() + blank_ent2 = np.random.uniform() + + blanked_relation = relation_data + + sentence_token_span, ent1, ent2, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = ( + *relation_data, ) + + if blank_ent1 >= blanking_threshold: + blanked_relation = [sentence_token_span, "[BLANK]", ent2, label, label_id, + ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id] + + if blank_ent2 >= blanking_threshold: + blanked_relation = [sentence_token_span, ent1, "[BLANK]", label, label_id, + ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id] + + return blanked_relation + + +def create_tokenizer_pretrain(tokenizer: TokenizerWrapperBERT, tokenizer_path: str): + """ + This method simply adds special tokens that we enouncter + + Args: + tokenizer (TokenizerWrapperBERT): BERT tokenizer. + tokenizer_path (str): path where tokenizer is to be saved. + """ + + + tokenizer.hf_tokenizers.add_tokens( + ["[BLANK]", "[ENT1]", "[ENT2]", "[/ENT1]", "[/ENT2]"], special_tokens=True) + tokenizer.hf_tokenizers.add_tokens( + ["[s1]", "[e1]", "[s2]", "[e2]"], special_tokens=True) + tokenizer.save(tokenizer_path) + + +# Used for creating data sets for pretraining +def tokenize(relations_dataset: Series, tokenizer: TokenizerWrapperBERT, mask_probability: float = 0.5) -> Tuple: + (tokens, span_1_pos, span_2_pos), ent1_text, ent2_text, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = relations_dataset + + cls_token = tokenizer.hf_tokenizers.cls_token + sep_token = tokenizer.hf_tokenizers.sep_token + + tokens = [token.lower() for token in tokens if tokens != '[BLANK]'] + + forbidden_indices = [i for i in range( + span_1_pos[0], span_1_pos[1])] + [i for i in range(span_2_pos[0], span_2_pos[1])] + + pool_indices = [i for i in range( + len(tokens)) if i not in forbidden_indices] + + masked_indices = np.random.choice(pool_indices, + size=round(mask_probability * + len(pool_indices)), + replace=False) + + masked_for_pred = [token.lower() for idx, token in enumerate( + tokens) if (idx in masked_indices)] + + tokens = [token if (idx not in masked_indices) + else tokenizer.hf_tokenizers.mask_token for idx, token in enumerate(tokens)] + + if (ent1_text == "[BLANK]") and (ent2_text != "[BLANK]"): + tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \ + tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] + + elif (ent1_text == "[BLANK]") and (ent2_text == "[BLANK]"): + tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \ + tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]", + "[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] + + elif (ent1_text != "[BLANK]") and (ent2_text == "[BLANK]"): + tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \ + tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]", + "[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] + + elif (ent1_text != "[BLANK]") and (ent2_text != "[BLANK]"): + tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \ + tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] + + ent1_ent2_start = ([i for i, e in enumerate(tokens) if e == "[ENT1]"][0], [ + i for i, e in enumerate(tokens) if e == "[ENT2]"][0]) + + token_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(tokens) + masked_for_pred = tokenizer.hf_tokenizers.convert_tokens_to_ids( + masked_for_pred) + + return token_ids, masked_for_pred, ent1_ent2_start diff --git a/setup.py b/setup.py index 061afaac4..cfb824727 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ long_description_content_type="text/markdown", url="https://github.com/CogStack/MedCAT", packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets', - 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', + 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction', 'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'], install_requires=[ 'numpy>=1.22.0,<1.26.0', # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy diff --git a/tests/resources/medcat_rel_test.csv b/tests/resources/medcat_rel_test.csv new file mode 100644 index 000000000..fbaecbbb5 --- /dev/null +++ b/tests/resources/medcat_rel_test.csv @@ -0,0 +1,2 @@ +relation_token_span_ids ent1_ent2_start ent1 ent2 label label_id ent1_type ent2_type ent1_id ent2_id ent1_cui ent2_cui doc_id text + (59, 69) degeneration discrete disease/disability_procedure 1 disease procedure 956 977 33359002 263869007 0 EXAM:,MRI LEFT KNEE WITHOUT CONTRAST,CLINICAL:,This is a 53-year-old female with left knee pain being evaluated for ACL tear.,FINDINGS:,This examination was performed on 10-14-05.,Normal medial meniscus without intrasubstance [s1] degeneration [e1], surface fraying or [s2] discrete [e2] meniscal tear.,There is a discoid lateral meniscus and although there may be minimal superficial fraying along the inner edge of the body, there is no discrete tear (series #6 images #7-12).,There is a near-complete or complete tear of the femoral attachment of the anterior cruciate ligament. The ligament has a balled-up appearance consistent with at least partial retraction of most of the fibers of the ligament. There may be a few fibers still intact (series #4 images #12-14 series #5 images #12-14). The tibial fibers are normal.,Normal posterior cruciate ligament.,There is a sprain of the medial collateral ligament, with mild separation of the deep and superficial fibers at the femoral attachment (series #7 images #6-12). There is no complete tear or discontinuity and there is no meniscocapsular separation.,There is a sprain of the lateral ligament complex without focal tear or discontinuity of any of the intraarticular components.,Normal iliotibial band.,Normal quadriceps and patellar tendons.,There is contusion within the posterolateral corner of the tibia. There is also contusion within the patella at the midline patellar ridge where there is an area of focal chondral flattening (series #8 images #10-13). The medial and lateral patellar facets are otherwise normal as is the femoral trochlea in the there is no patellar subluxation.,There is a mild strain of the vastus medialis oblique muscle extending into the medial patellofemoral ligament and medial patellar retinaculum but there is no complete tear or discontinuity.,Normal lateral patellar retinaculum. There is a joint effusion and plica.,IMPRESSION:, Discoid lateral meniscus without a tear although there may be minimal superficial fraying along the inner edge of the body. Near-complete if not complete tear of the femoral attachment of the anterior cruciate ligament. Medial capsule sprain with associated strain of the vastus medialis oblique muscle. There is focal contusion within the patella at the midline patella ridge. Joint effusion and plica. diff --git a/tests/resources/medcat_rel_train.csv b/tests/resources/medcat_rel_train.csv new file mode 100644 index 000000000..e768e3c82 --- /dev/null +++ b/tests/resources/medcat_rel_train.csv @@ -0,0 +1,4 @@ +relation_token_span_ids ent1_ent2_start ent1 ent2 label label_id ent1_type ent2_type ent1_id ent2_id ent1_cui ent2_cui doc_id text + (41, 47) severe pain emergency room disease/disability_procedure 1 disease procedure 1060 1039 76948002 225728007 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the [s1] emergency room [e1] due to [s2] severe pain [e2] in the left lower extremity and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, pulmonary embolism, risk of anesthesia, need for blood transfusion, and cardiac arrest. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up. + (991, 1010) pulmonary embolism cardiac arrest non_relation 0 disease procedure 1029 1044 59282003 410429000 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the emergency room due to severe pain in the left lower extremity and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, [s1] pulmonary embolism [e1], risk of anesthesia, need for blood transfusion, and [s2]cardiac arrest [e2]. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up. + (41, 53) emergency room left lower extremity disease/disability_procedure 1 disease procedure 1021 1039 225728007 32153003 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the [s1] emergency room [e1] due to severe pain in the [s2] left lower extremity [e2] and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, pulmonary embolism, risk of anesthesia, need for blood transfusion, and cardiac arrest. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up. diff --git a/tests/resources/medcat_trainer_export_relations.json b/tests/resources/medcat_trainer_export_relations.json new file mode 100644 index 000000000..9ca1bfd96 --- /dev/null +++ b/tests/resources/medcat_trainer_export_relations.json @@ -0,0 +1,4533 @@ +{ + "projects": [ + { + "name": "Example Project - SNOMED CT All", + "id": 2, + "cuis": "", + "project_status": "A", + "project_locked": false, + "documents": [ + { + "id": 21, + "name": "0-81322", + "text": "EXAM:,MRI LEFT KNEE WITHOUT CONTRAST,CLINICAL:,This is a 53-year-old female with left knee pain being evaluated for ACL tear.,FINDINGS:,This examination was performed on 10-14-05.,Normal medial meniscus without intrasubstance degeneration, surface fraying or discrete meniscal tear.,There is a discoid lateral meniscus and although there may be minimal superficial fraying along the inner edge of the body, there is no discrete tear (series #6 images #7-12).,There is a near-complete or complete tear of the femoral attachment of the anterior cruciate ligament. The ligament has a balled-up appearance consistent with at least partial retraction of most of the fibers of the ligament. There may be a few fibers still intact (series #4 images #12-14; series #5 images #12-14). The tibial fibers are normal.,Normal posterior cruciate ligament.,There is a sprain of the medial collateral ligament, with mild separation of the deep and superficial fibers at the femoral attachment (series #7 images #6-12). There is no complete tear or discontinuity and there is no meniscocapsular separation.,There is a sprain of the lateral ligament complex without focal tear or discontinuity of any of the intraarticular components.,Normal iliotibial band.,Normal quadriceps and patellar tendons.,There is contusion within the posterolateral corner of the tibia. There is also contusion within the patella at the midline patellar ridge where there is an area of focal chondral flattening (series #8 images #10-13). The medial and lateral patellar facets are otherwise normal as is the femoral trochlea in the there is no patellar subluxation.,There is a mild strain of the vastus medialis oblique muscle extending into the medial patellofemoral ligament and medial patellar retinaculum but there is no complete tear or discontinuity.,Normal lateral patellar retinaculum. There is a joint effusion and plica.,IMPRESSION:, Discoid lateral meniscus without a tear although there may be minimal superficial fraying along the inner edge of the body. Near-complete if not complete tear of the femoral attachment of the anterior cruciate ligament. Medial capsule sprain with associated strain of the vastus medialis oblique muscle. There is focal contusion within the patella at the midline patella ridge. Joint effusion and plica.", + "last_modified": "2024-03-21 12:33:28.464465", + "annotations": [ + { + "id": 940, + "user": "admin", + "cui": "202099003", + "value": "discoid lateral meniscus", + "start": 294, + "end": 318, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.887245", + "last_modified": "2024-03-21 12:54:09.725742", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 941, + "user": "admin", + "cui": "202099003", + "value": "Discoid lateral meniscus", + "start": 1905, + "end": 1929, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.894252", + "last_modified": "2024-03-21 12:54:16.815702", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 942, + "user": "admin", + "cui": "245928007", + "value": "patellar retinaculum", + "start": 1749, + "end": 1769, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.898731", + "last_modified": "2024-03-21 12:54:16.117101", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 943, + "user": "admin", + "cui": "245928007", + "value": "patellar retinaculum", + "start": 1833, + "end": 1853, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.903564", + "last_modified": "2024-03-21 12:54:16.554192", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 944, + "user": "admin", + "cui": "34411009", + "value": "lateral ligament", + "start": 1115, + "end": 1131, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.907755", + "last_modified": "2024-03-21 12:54:13.015985", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 945, + "user": "admin", + "cui": "18033002", + "value": "patellar tendons", + "start": 1263, + "end": 1279, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.911457", + "last_modified": "2024-03-21 12:54:13.900229", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 946, + "user": "admin", + "cui": "11716007", + "value": "femoral trochlea", + "start": 1569, + "end": 1585, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.915777", + "last_modified": "2024-03-21 12:54:15.250069", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 947, + "user": "admin", + "cui": "7480001", + "value": "iliotibial band", + "start": 1224, + "end": 1239, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.920674", + "last_modified": "2024-03-21 12:54:13.599444", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 948, + "user": "admin", + "cui": "90069004", + "value": "posterolateral", + "start": 1311, + "end": 1325, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.925301", + "last_modified": "2024-03-21 12:54:14.049785", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 949, + "user": "admin", + "cui": "387637008", + "value": "joint effusion", + "start": 1866, + "end": 1880, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.929598", + "last_modified": "2024-03-21 12:54:16.685589", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 950, + "user": "admin", + "cui": "387637008", + "value": "Joint effusion", + "start": 2286, + "end": 2300, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.934026", + "last_modified": "2024-03-21 12:54:18.459770", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 951, + "user": "admin", + "cui": "239720000", + "value": "meniscal tear", + "start": 268, + "end": 281, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.938627", + "last_modified": "2024-03-21 12:54:07.378249", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 952, + "user": "admin", + "cui": "263722006", + "value": "complete tear", + "start": 487, + "end": 500, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.942438", + "last_modified": "2024-03-21 12:54:10.406439", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 953, + "user": "admin", + "cui": "263722006", + "value": "complete tear", + "start": 1015, + "end": 1028, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.946925", + "last_modified": "2024-03-21 12:54:12.724874", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 954, + "user": "admin", + "cui": "263722006", + "value": "complete tear", + "start": 1786, + "end": 1799, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.951414", + "last_modified": "2024-03-21 12:54:16.267250", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 955, + "user": "admin", + "cui": "263722006", + "value": "complete tear", + "start": 2051, + "end": 2064, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.956173", + "last_modified": "2024-03-21 12:54:17.250445", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 956, + "user": "admin", + "cui": "33359002", + "value": "degeneration", + "start": 226, + "end": 238, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.959620", + "last_modified": "2024-03-21 12:53:54.021156", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 957, + "user": "admin", + "cui": "26283006", + "value": "superficial", + "start": 353, + "end": 364, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.963439", + "last_modified": "2024-03-21 12:54:09.840589", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 958, + "user": "admin", + "cui": "26283006", + "value": "superficial", + "start": 932, + "end": 943, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.967688", + "last_modified": "2024-03-21 12:54:12.276253", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 959, + "user": "admin", + "cui": "26396009", + "value": "subluxation", + "start": 1614, + "end": 1625, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.972635", + "last_modified": "2024-03-21 12:54:15.401342", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 960, + "user": "admin", + "cui": "26283006", + "value": "superficial", + "start": 1975, + "end": 1986, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.976384", + "last_modified": "2024-03-21 12:54:16.952231", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 961, + "user": "admin", + "cui": "1431002", + "value": "attachment", + "start": 516, + "end": 526, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.980286", + "last_modified": "2024-03-21 12:54:10.558286", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 962, + "user": "admin", + "cui": "385433004", + "value": "consistent", + "start": 602, + "end": 612, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.984803", + "last_modified": "2024-03-21 12:54:10.692442", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 963, + "user": "admin", + "cui": "37794007", + "value": "retraction", + "start": 635, + "end": 645, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.988561", + "last_modified": "2024-03-21 12:54:10.973770", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 964, + "user": "admin", + "cui": "397406000", + "value": "collateral", + "start": 874, + "end": 884, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.991914", + "last_modified": "2024-03-21 12:54:12.145268", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 965, + "user": "admin", + "cui": "1431002", + "value": "attachment", + "start": 966, + "end": 976, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:08.995766", + "last_modified": "2024-03-21 12:54:12.426715", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 966, + "user": "admin", + "cui": "246093002", + "value": "components", + "start": 1205, + "end": 1215, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.000526", + "last_modified": "2024-03-21 12:54:13.465503", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 967, + "user": "admin", + "cui": "21989003", + "value": "quadriceps", + "start": 1248, + "end": 1258, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.004644", + "last_modified": "2024-03-21 12:54:13.749820", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 968, + "user": "admin", + "cui": "1431002", + "value": "attachment", + "start": 2080, + "end": 2090, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.008021", + "last_modified": "2024-03-21 12:54:17.400256", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 969, + "user": "admin", + "cui": "47429007", + "value": "associated", + "start": 2154, + "end": 2164, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.011822", + "last_modified": "2024-03-21 12:54:17.870739", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 970, + "user": "admin", + "cui": "30989003", + "value": "knee pain", + "start": 86, + "end": 95, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.016210", + "last_modified": "2024-03-21 12:54:22.608262", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 971, + "user": "admin", + "cui": "398166005", + "value": "performed", + "start": 157, + "end": 166, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.020390", + "last_modified": "2024-03-21 12:54:24.352563", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 972, + "user": "admin", + "cui": "66211004", + "value": "extending", + "start": 1688, + "end": 1697, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.024007", + "last_modified": "2024-03-21 12:54:15.699438", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 973, + "user": "admin", + "cui": "263543005", + "value": "CONTRAST", + "start": 28, + "end": 36, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.027362", + "last_modified": "2024-03-21 12:54:22.274480", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 974, + "user": "admin", + "cui": "58147004", + "value": "CLINICAL", + "start": 37, + "end": 45, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.031279", + "last_modified": "2024-03-21 12:54:22.440527", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 975, + "user": "admin", + "cui": "239725005", + "value": "ACL tear", + "start": 116, + "end": 124, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.035900", + "last_modified": "2024-03-21 12:54:22.724976", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 976, + "user": "admin", + "cui": "163121000000106", + "value": "FINDINGS", + "start": 126, + "end": 134, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.039910", + "last_modified": "2024-03-21 12:54:22.975558", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 977, + "user": "admin", + "cui": "263869007", + "value": "discrete", + "start": 259, + "end": 267, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.043389", + "last_modified": "2024-03-21 12:53:57.874497", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 978, + "user": "admin", + "cui": "263869007", + "value": "discrete", + "start": 419, + "end": 427, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.046918", + "last_modified": "2024-03-21 12:54:10.126731", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 979, + "user": "admin", + "cui": "149016008", + "value": "may be a", + "start": 691, + "end": 699, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.051499", + "last_modified": "2024-03-21 12:54:11.124482", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 980, + "user": "admin", + "cui": "264265004", + "value": "chondral", + "start": 1452, + "end": 1460, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.055233", + "last_modified": "2024-03-21 12:54:14.501854", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 981, + "user": "admin", + "cui": "255609007", + "value": "partial", + "start": 627, + "end": 634, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.058899", + "last_modified": "2024-03-21 12:54:10.841722", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 982, + "user": "admin", + "cui": "103360007", + "value": "complex", + "start": 1132, + "end": 1139, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.063199", + "last_modified": "2024-03-21 12:54:13.168582", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 983, + "user": "admin", + "cui": "49370004", + "value": "lateral", + "start": 1514, + "end": 1521, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.067437", + "last_modified": "2024-03-21 12:54:14.953925", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 984, + "user": "admin", + "cui": "21114003", + "value": "oblique", + "start": 1673, + "end": 1680, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.071387", + "last_modified": "2024-03-21 12:54:15.549512", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 985, + "user": "admin", + "cui": "49370004", + "value": "lateral", + "start": 1825, + "end": 1832, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.074967", + "last_modified": "2024-03-21 12:54:16.405230", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 986, + "user": "admin", + "cui": "21114003", + "value": "oblique", + "start": 2195, + "end": 2202, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.078350", + "last_modified": "2024-03-21 12:54:18.002678", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 987, + "user": "admin", + "cui": "255561001", + "value": "medial", + "start": 187, + "end": 193, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.083224", + "last_modified": "2024-03-21 12:54:24.812622", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 988, + "user": "admin", + "cui": "13039001", + "value": "series", + "start": 434, + "end": 440, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.087291", + "last_modified": "2024-03-21 12:54:10.274186", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 989, + "user": "admin", + "cui": "11163003", + "value": "intact", + "start": 717, + "end": 723, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.091298", + "last_modified": "2024-03-21 12:54:11.274459", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 990, + "user": "admin", + "cui": "13039001", + "value": "series", + "start": 725, + "end": 731, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.094782", + "last_modified": "2024-03-21 12:54:11.408124", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 991, + "user": "admin", + "cui": "13039001", + "value": "series", + "start": 750, + "end": 756, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.099274", + "last_modified": "2024-03-21 12:54:11.557206", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 992, + "user": "admin", + "cui": "12611008", + "value": "tibial", + "start": 780, + "end": 786, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.103289", + "last_modified": "2024-03-21 12:54:11.708989", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 993, + "user": "admin", + "cui": "384709000", + "value": "sprain", + "start": 853, + "end": 859, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.107056", + "last_modified": "2024-03-21 12:54:11.858187", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 994, + "user": "admin", + "cui": "255561001", + "value": "medial", + "start": 867, + "end": 873, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.110447", + "last_modified": "2024-03-21 12:54:12.008187", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 995, + "user": "admin", + "cui": "13039001", + "value": "series", + "start": 978, + "end": 984, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.115323", + "last_modified": "2024-03-21 12:54:12.574935", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 996, + "user": "admin", + "cui": "384709000", + "value": "sprain", + "start": 1101, + "end": 1107, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.119668", + "last_modified": "2024-03-21 12:54:12.862159", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 997, + "user": "admin", + "cui": "13039001", + "value": "series", + "start": 1473, + "end": 1479, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.123199", + "last_modified": "2024-03-21 12:54:14.649104", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 998, + "user": "admin", + "cui": "255561001", + "value": "medial", + "start": 1503, + "end": 1509, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.126658", + "last_modified": "2024-03-21 12:54:14.804656", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 999, + "user": "admin", + "cui": "70746003", + "value": "facets", + "start": 1531, + "end": 1537, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.130854", + "last_modified": "2024-03-21 12:54:15.099335", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1000, + "user": "admin", + "cui": "255561001", + "value": "medial", + "start": 1707, + "end": 1713, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.135093", + "last_modified": "2024-03-21 12:54:15.853155", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1001, + "user": "admin", + "cui": "255561001", + "value": "medial", + "start": 1742, + "end": 1748, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.139375", + "last_modified": "2024-03-21 12:54:16.001783", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1002, + "user": "admin", + "cui": "255561001", + "value": "Medial", + "start": 2127, + "end": 2133, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.142805", + "last_modified": "2024-03-21 12:54:17.552066", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1003, + "user": "admin", + "cui": "384709000", + "value": "sprain", + "start": 2142, + "end": 2148, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.147204", + "last_modified": "2024-03-21 12:54:17.699228", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1004, + "user": "admin", + "cui": "260521003", + "value": "inner", + "start": 383, + "end": 388, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.151492", + "last_modified": "2024-03-21 12:54:09.990905", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1005, + "user": "admin", + "cui": "87017008", + "value": "focal", + "start": 1148, + "end": 1153, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.155060", + "last_modified": "2024-03-21 12:54:13.315902", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1006, + "user": "admin", + "cui": "26833005", + "value": "ridge", + "start": 1414, + "end": 1419, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.158568", + "last_modified": "2024-03-21 12:54:14.201695", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1007, + "user": "admin", + "cui": "87017008", + "value": "focal", + "start": 1446, + "end": 1451, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.162416", + "last_modified": "2024-03-21 12:54:14.350554", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1008, + "user": "admin", + "cui": "260521003", + "value": "inner", + "start": 2005, + "end": 2010, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.166453", + "last_modified": "2024-03-21 12:54:17.107780", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1009, + "user": "admin", + "cui": "87017008", + "value": "focal", + "start": 2221, + "end": 2226, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.169890", + "last_modified": "2024-03-21 12:54:18.151276", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1010, + "user": "admin", + "cui": "26833005", + "value": "ridge", + "start": 2279, + "end": 2284, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:46:09.173355", + "last_modified": "2024-03-21 12:54:18.301673", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + } + ], + "relations": [ + { + "start_entity": 977, + "start_entity_cui": "263869007", + "start_entity_value": "discrete", + "start_entity_start_idx": 259, + "start_entity_end_idx": 267, + "end_entity": 956, + "end_entity_cui": "33359002", + "end_entity_value": "degeneration", + "end_entity_start_idx": 226, + "end_entity_end_idx": 238, + "user": "admin", + "relation": "disease/disability_procedure", + "validated": true + } + ] + }, + { + "id": 22, + "name": "1-88811", + "text": "REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the emergency room due to severe pain in the left lower extremity and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, pulmonary embolism, risk of anesthesia, need for blood transfusion, and cardiac arrest. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up.", + "last_modified": "2024-03-21 12:33:28.488556", + "annotations": [ + { + "id": 1011, + "user": "admin", + "cui": "161432005", + "value": "history of malignant melanoma", + "start": 1382, + "end": 1411, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.954741", + "last_modified": "2024-03-21 12:55:42.363452", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1012, + "user": "admin", + "cui": "161432005", + "value": "history of malignant melanoma", + "start": 3347, + "end": 3376, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.961058", + "last_modified": "2024-03-21 12:55:52.478126", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1013, + "user": "admin", + "cui": "52734007", + "value": "total hip replacement surgery", + "start": 4323, + "end": 4352, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.964688", + "last_modified": "2024-03-21 12:55:56.825336", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1014, + "user": "admin", + "cui": "127287001", + "value": "intertrochanteric fracture", + "start": 2802, + "end": 2828, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.968320", + "last_modified": "2024-03-21 12:55:49.881776", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1015, + "user": "admin", + "cui": "213270002", + "value": "intraoperative fracture", + "start": 4254, + "end": 4277, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.971714", + "last_modified": "2024-03-21 12:55:56.366960", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1016, + "user": "admin", + "cui": "446050000", + "value": "primary care physician", + "start": 541, + "end": 563, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.975035", + "last_modified": "2024-03-21 12:55:39.347171", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1017, + "user": "admin", + "cui": "5913000", + "value": "femoral neck fracture", + "start": 2733, + "end": 2754, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.978817", + "last_modified": "2024-03-21 12:55:49.546224", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1018, + "user": "admin", + "cui": "268029009", + "value": "pathological fracture", + "start": 3120, + "end": 3141, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.982082", + "last_modified": "2024-03-21 12:55:51.282602", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1019, + "user": "admin", + "cui": "268029009", + "value": "pathological fracture", + "start": 3399, + "end": 3420, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.985489", + "last_modified": "2024-03-21 12:55:52.645645", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1020, + "user": "admin", + "cui": "268029009", + "value": "pathological fracture", + "start": 3748, + "end": 3769, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.988919", + "last_modified": "2024-03-21 12:55:54.548989", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1021, + "user": "admin", + "cui": "32153003", + "value": "left lower extremity", + "start": 224, + "end": 244, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.992197", + "last_modified": "2024-03-21 12:55:38.238308", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1022, + "user": "admin", + "cui": "32153003", + "value": "left lower extremity", + "start": 724, + "end": 744, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.995806", + "last_modified": "2024-03-21 12:55:39.747698", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1023, + "user": "admin", + "cui": "165232002", + "value": "bladder incontinence", + "start": 1152, + "end": 1172, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:09.999232", + "last_modified": "2024-03-21 12:55:41.806866", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1024, + "user": "admin", + "cui": "5880005", + "value": "PHYSICAL EXAMINATION", + "start": 1895, + "end": 1915, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.003108", + "last_modified": "2024-03-21 12:55:45.173783", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1025, + "user": "admin", + "cui": "32153003", + "value": "left lower extremity", + "start": 2003, + "end": 2023, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.006535", + "last_modified": "2024-03-21 12:55:45.902156", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1026, + "user": "admin", + "cui": "32153003", + "value": "left lower extremity", + "start": 2076, + "end": 2096, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.010271", + "last_modified": "2024-03-21 12:55:46.303955", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1027, + "user": "admin", + "cui": "57662003", + "value": "injury, blood vessel", + "start": 4135, + "end": 4155, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.013605", + "last_modified": "2024-03-21 12:55:55.310682", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1028, + "user": "admin", + "cui": "44551007", + "value": "myositis ossificans", + "start": 4233, + "end": 4252, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.016979", + "last_modified": "2024-03-21 12:55:56.252400", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1029, + "user": "admin", + "cui": "59282003", + "value": "pulmonary embolism", + "start": 4377, + "end": 4395, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.020525", + "last_modified": "2024-03-21 12:55:57.127503", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1030, + "user": "admin", + "cui": "116859006", + "value": "blood transfusion", + "start": 4426, + "end": 4443, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.023777", + "last_modified": "2024-03-21 12:55:57.714351", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1031, + "user": "admin", + "cui": "224994002", + "value": "excessive weight", + "start": 700, + "end": 716, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.027248", + "last_modified": "2024-03-21 12:55:39.616282", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1032, + "user": "admin", + "cui": "89890002", + "value": "lymphatic system", + "start": 1557, + "end": 1573, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.030538", + "last_modified": "2024-03-21 12:55:43.111941", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1033, + "user": "admin", + "cui": "261554009", + "value": "revision surgery", + "start": 4354, + "end": 4370, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.034040", + "last_modified": "2024-03-21 12:55:56.977592", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1034, + "user": "admin", + "cui": "428942009", + "value": "history of fall", + "start": 893, + "end": 908, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.037434", + "last_modified": "2024-03-21 12:55:40.609310", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1035, + "user": "admin", + "cui": "61685007", + "value": "lower extremity", + "start": 1031, + "end": 1046, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.040696", + "last_modified": "2024-03-21 12:55:41.064797", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1036, + "user": "admin", + "cui": "392521001", + "value": "MEDICAL HISTORY", + "start": 1638, + "end": 1653, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.044223", + "last_modified": "2024-03-21 12:55:43.399211", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1037, + "user": "admin", + "cui": "106028002", + "value": "MUSCULOSKELETAL", + "start": 2039, + "end": 2054, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.047629", + "last_modified": "2024-03-21 12:55:46.165897", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1038, + "user": "admin", + "cui": "417662000", + "value": "past history of", + "start": 3634, + "end": 3649, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.050955", + "last_modified": "2024-03-21 12:55:53.625927", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1039, + "user": "admin", + "cui": "225728007", + "value": "emergency room", + "start": 183, + "end": 197, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.054330", + "last_modified": "2024-03-21 12:55:36.951663", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1040, + "user": "admin", + "cui": "225728007", + "value": "emergency room", + "start": 861, + "end": 875, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.058215", + "last_modified": "2024-03-21 12:55:40.323486", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1041, + "user": "admin", + "cui": "160476009", + "value": "SOCIAL HISTORY", + "start": 1783, + "end": 1797, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.061698", + "last_modified": "2024-03-21 12:55:44.413399", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1042, + "user": "admin", + "cui": "248324001", + "value": "well nourished", + "start": 1958, + "end": 1972, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.065207", + "last_modified": "2024-03-21 12:55:45.498560", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1043, + "user": "admin", + "cui": "244696009", + "value": "proximal femur", + "start": 3154, + "end": 3168, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.068543", + "last_modified": "2024-03-21 12:55:51.450347", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1044, + "user": "admin", + "cui": "410429000", + "value": "cardiac arrest", + "start": 4449, + "end": 4463, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.071929", + "last_modified": "2024-03-21 12:55:57.868597", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1045, + "user": "admin", + "cui": "287047008", + "value": "left leg pain", + "start": 386, + "end": 399, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.075231", + "last_modified": "2024-03-21 12:55:38.779030", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1046, + "user": "admin", + "cui": "161891005", + "value": "and back pain", + "start": 1047, + "end": 1060, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.078909", + "last_modified": "2024-03-21 12:55:41.201006", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1047, + "user": "admin", + "cui": "26175008", + "value": "approximately", + "start": 1433, + "end": 1446, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.082320", + "last_modified": "2024-03-21 12:55:42.651799", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1048, + "user": "admin", + "cui": "161891005", + "value": "and back pain", + "start": 2024, + "end": 2037, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.085703", + "last_modified": "2024-03-21 12:55:46.066797", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1049, + "user": "admin", + "cui": "1199008", + "value": "neurological", + "start": 1108, + "end": 1120, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.089141", + "last_modified": "2024-03-21 12:55:41.503366", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1050, + "user": "admin", + "cui": "236886002", + "value": "hysterectomy", + "start": 1750, + "end": 1762, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.092461", + "last_modified": "2024-03-21 12:55:44.143161", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1051, + "user": "admin", + "cui": "2603003", + "value": "secondary to", + "start": 1990, + "end": 2002, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.095703", + "last_modified": "2024-03-21 12:55:45.748813", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1052, + "user": "admin", + "cui": "10828004", + "value": "positive for", + "start": 2225, + "end": 2237, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.098865", + "last_modified": "2024-03-21 12:55:47.231737", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1053, + "user": "admin", + "cui": "2603003", + "value": "secondary to", + "start": 2324, + "end": 2336, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.102261", + "last_modified": "2024-03-21 12:55:47.687626", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1054, + "user": "admin", + "cui": "29627003", + "value": "femoral neck", + "start": 2889, + "end": 2901, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.105584", + "last_modified": "2024-03-21 12:55:50.373009", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1055, + "user": "admin", + "cui": "246105001", + "value": "presentation", + "start": 3279, + "end": 3291, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.108841", + "last_modified": "2024-03-21 12:55:52.344676", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1056, + "user": "admin", + "cui": "717351000000103", + "value": "present time", + "start": 3455, + "end": 3467, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.112172", + "last_modified": "2024-03-21 12:55:52.915195", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1057, + "user": "admin", + "cui": "717351000000103", + "value": "present time", + "start": 3685, + "end": 3697, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.115315", + "last_modified": "2024-03-21 12:55:54.094112", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1058, + "user": "admin", + "cui": "246105001", + "value": "presentation", + "start": 3876, + "end": 3888, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.118527", + "last_modified": "2024-03-21 12:55:54.697678", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1059, + "user": "admin", + "cui": "258106000", + "value": "radiographic", + "start": 4669, + "end": 4681, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.121868", + "last_modified": "2024-03-21 12:55:58.945039", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1060, + "user": "admin", + "cui": "76948002", + "value": "severe pain", + "start": 205, + "end": 216, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.125137", + "last_modified": "2024-03-21 12:55:38.090063", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1061, + "user": "admin", + "cui": "76948002", + "value": "severe pain", + "start": 1012, + "end": 1023, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.128385", + "last_modified": "2024-03-21 12:55:40.912311", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1062, + "user": "admin", + "cui": "160573003", + "value": "alcohol use", + "start": 1823, + "end": 1834, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.131342", + "last_modified": "2024-03-21 12:55:44.582822", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1063, + "user": "admin", + "cui": "52101004", + "value": "presence of", + "start": 2107, + "end": 2118, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.134640", + "last_modified": "2024-03-21 12:55:46.453522", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1064, + "user": "admin", + "cui": "76948002", + "value": "severe pain", + "start": 2238, + "end": 2249, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.137698", + "last_modified": "2024-03-21 12:55:47.384464", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1065, + "user": "admin", + "cui": "247311004", + "value": "light touch", + "start": 2520, + "end": 2531, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.140618", + "last_modified": "2024-03-21 12:55:48.764267", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1066, + "user": "admin", + "cui": "386134007", + "value": "significant", + "start": 2699, + "end": 2710, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.143665", + "last_modified": "2024-03-21 12:55:49.392813", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1067, + "user": "admin", + "cui": "52101004", + "value": "presence of", + "start": 2853, + "end": 2864, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.147041", + "last_modified": "2024-03-21 12:55:50.037329", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1068, + "user": "admin", + "cui": "87642003", + "value": "dislocation", + "start": 4164, + "end": 4175, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.150030", + "last_modified": "2024-03-21 12:55:55.475042", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1069, + "user": "admin", + "cui": "110287002", + "value": "discrepancy", + "start": 4220, + "end": 4231, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.153222", + "last_modified": "2024-03-21 12:55:56.099625", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1070, + "user": "admin", + "cui": "66216009", + "value": "understands", + "start": 4470, + "end": 4481, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.156135", + "last_modified": "2024-03-21 12:55:58.033569", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1071, + "user": "admin", + "cui": "224130005", + "value": "lives with", + "start": 1875, + "end": 1885, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.159278", + "last_modified": "2024-03-21 12:55:45.021923", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1072, + "user": "admin", + "cui": "129350004", + "value": "shortening", + "start": 2128, + "end": 2138, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.162414", + "last_modified": "2024-03-21 12:55:46.607237", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1073, + "user": "admin", + "cui": "247348008", + "value": "Tenderness", + "start": 2173, + "end": 2183, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.165320", + "last_modified": "2024-03-21 12:55:46.928828", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1074, + "user": "admin", + "cui": "223482009", + "value": "DISCUSSION", + "start": 3170, + "end": 3180, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.168690", + "last_modified": "2024-03-21 12:55:51.604035", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1075, + "user": "admin", + "cui": "77879006", + "value": "metastatic", + "start": 3720, + "end": 3730, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.171733", + "last_modified": "2024-03-21 12:55:54.245370", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1076, + "user": "admin", + "cui": "31807009", + "value": "persistent", + "start": 4192, + "end": 4202, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.174643", + "last_modified": "2024-03-21 12:55:55.797229", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1077, + "user": "admin", + "cui": "249839000", + "value": "leg length", + "start": 4209, + "end": 4219, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.178107", + "last_modified": "2024-03-21 12:55:55.948159", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1078, + "user": "admin", + "cui": "272148004", + "value": "conversion", + "start": 4309, + "end": 4319, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.181148", + "last_modified": "2024-03-21 12:55:56.674691", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1079, + "user": "admin", + "cui": "33653009", + "value": "anesthesia", + "start": 4405, + "end": 4415, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.184348", + "last_modified": "2024-03-21 12:55:57.429690", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1080, + "user": "admin", + "cui": "40143009", + "value": "functional", + "start": 4551, + "end": 4561, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.187474", + "last_modified": "2024-03-21 12:55:58.489951", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1081, + "user": "admin", + "cui": "246105001", + "value": "presented", + "start": 166, + "end": 175, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.190470", + "last_modified": "2024-03-21 12:56:03.682992", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1082, + "user": "admin", + "cui": "161891005", + "value": "back pain", + "start": 371, + "end": 380, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.193647", + "last_modified": "2024-03-21 12:55:38.642645", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1083, + "user": "admin", + "cui": "723941000000100", + "value": "interview", + "start": 1327, + "end": 1336, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.196745", + "last_modified": "2024-03-21 12:55:42.097455", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1084, + "user": "admin", + "cui": "439401001", + "value": "diagnosed", + "start": 1423, + "end": 1432, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.199922", + "last_modified": "2024-03-21 12:55:42.539081", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1085, + "user": "admin", + "cui": "419076005", + "value": "ALLERGIES", + "start": 1764, + "end": 1773, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.203234", + "last_modified": "2024-03-21 12:55:44.277113", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1086, + "user": "admin", + "cui": "113011001", + "value": "palpation", + "start": 2187, + "end": 2196, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.206512", + "last_modified": "2024-03-21 12:55:47.081732", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1087, + "user": "admin", + "cui": "55919000", + "value": "including", + "start": 2990, + "end": 2999, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.209677", + "last_modified": "2024-03-21 12:55:50.829212", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1088, + "user": "admin", + "cui": "41747008", + "value": "bone scan", + "start": 3499, + "end": 3508, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.213204", + "last_modified": "2024-03-21 12:55:53.069372", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1089, + "user": "admin", + "cui": "55919000", + "value": "including", + "start": 3569, + "end": 3578, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.216251", + "last_modified": "2024-03-21 12:55:53.490362", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1090, + "user": "admin", + "cui": "161891005", + "value": "back pain", + "start": 3650, + "end": 3659, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.219398", + "last_modified": "2024-03-21 12:55:53.778336", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1091, + "user": "admin", + "cui": "40733004", + "value": "infection", + "start": 4118, + "end": 4127, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.222530", + "last_modified": "2024-03-21 12:55:55.155669", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1092, + "user": "admin", + "cui": "71388002", + "value": "procedure", + "start": 4522, + "end": 4531, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.225541", + "last_modified": "2024-03-21 12:55:58.339443", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1093, + "user": "admin", + "cui": "41747008", + "value": "bone scan", + "start": 4651, + "end": 4660, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.228590", + "last_modified": "2024-03-21 12:55:58.791826", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1094, + "user": "admin", + "cui": "260695007", + "value": "clearance", + "start": 4741, + "end": 4750, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.231708", + "last_modified": "2024-03-21 12:55:59.097032", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1095, + "user": "admin", + "cui": "260358002", + "value": "very much", + "start": 4762, + "end": 4771, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.234647", + "last_modified": "2024-03-21 12:55:59.246861", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1096, + "user": "admin", + "cui": "308273005", + "value": "follow up", + "start": 4852, + "end": 4861, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.237946", + "last_modified": "2024-03-21 12:56:00.035649", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1097, + "user": "admin", + "cui": "23056005", + "value": "sciatica", + "start": 145, + "end": 153, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.240895", + "last_modified": "2024-03-21 12:56:03.517373", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1098, + "user": "admin", + "cui": "398092000", + "value": "obtained", + "start": 285, + "end": 293, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.244167", + "last_modified": "2024-03-21 12:55:38.372251", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1099, + "user": "admin", + "cui": "44077006", + "value": "numbness", + "start": 1086, + "end": 1094, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.247369", + "last_modified": "2024-03-21 12:55:41.369851", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1100, + "user": "admin", + "cui": "398092000", + "value": "obtained", + "start": 1186, + "end": 1194, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.250312", + "last_modified": "2024-03-21 12:55:41.942715", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1101, + "user": "admin", + "cui": "23056005", + "value": "Sciatica", + "start": 1657, + "end": 1665, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.253506", + "last_modified": "2024-03-21 12:55:43.549130", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1102, + "user": "admin", + "cui": "2092003", + "value": "melanoma", + "start": 1670, + "end": 1678, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.256512", + "last_modified": "2024-03-21 12:55:43.702524", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1103, + "user": "admin", + "cui": "2092003", + "value": "melanoma", + "start": 1737, + "end": 1745, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.259380", + "last_modified": "2024-03-21 12:55:43.914744", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1104, + "user": "admin", + "cui": "20295000", + "value": "divorced", + "start": 1844, + "end": 1852, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.262942", + "last_modified": "2024-03-21 12:55:44.739134", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1105, + "user": "admin", + "cui": "160499008", + "value": "children", + "start": 1860, + "end": 1868, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.265925", + "last_modified": "2024-03-21 12:55:44.856597", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1106, + "user": "admin", + "cui": "261074009", + "value": "external", + "start": 2143, + "end": 2151, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.269178", + "last_modified": "2024-03-21 12:55:46.774858", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1107, + "user": "admin", + "cui": "40415009", + "value": "proximal", + "start": 2262, + "end": 2270, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.272217", + "last_modified": "2024-03-21 12:55:47.536152", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1108, + "user": "admin", + "cui": "10601006", + "value": "leg pain", + "start": 2344, + "end": 2352, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.275110", + "last_modified": "2024-03-21 12:55:48.007882", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1109, + "user": "admin", + "cui": "260385009", + "value": "negative", + "start": 2496, + "end": 2504, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.278321", + "last_modified": "2024-03-21 12:55:48.613220", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1110, + "user": "admin", + "cui": "398092000", + "value": "obtained", + "start": 2610, + "end": 2618, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.281472", + "last_modified": "2024-03-21 12:55:49.088482", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1111, + "user": "admin", + "cui": "2931005", + "value": "probable", + "start": 3111, + "end": 3119, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.284805", + "last_modified": "2024-03-21 12:55:51.148755", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1112, + "user": "admin", + "cui": "40415009", + "value": "proximal", + "start": 3433, + "end": 3441, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.288110", + "last_modified": "2024-03-21 12:55:52.780751", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1113, + "user": "admin", + "cui": "23056005", + "value": "sciatica", + "start": 3664, + "end": 3672, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.291256", + "last_modified": "2024-03-21 12:55:53.924638", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1114, + "user": "admin", + "cui": "55919000", + "value": "includes", + "start": 4079, + "end": 4087, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.295010", + "last_modified": "2024-03-21 12:55:55.003529", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1115, + "user": "admin", + "cui": "103325001", + "value": "need for", + "start": 4300, + "end": 4308, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.298310", + "last_modified": "2024-03-21 12:55:56.538734", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1116, + "user": "admin", + "cui": "103325001", + "value": "need for", + "start": 4417, + "end": 4425, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.301866", + "last_modified": "2024-03-21 12:55:57.562541", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1117, + "user": "admin", + "cui": "255238004", + "value": "continue", + "start": 4840, + "end": 4848, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.305371", + "last_modified": "2024-03-21 12:55:59.733420", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1118, + "user": "admin", + "cui": "39104002", + "value": "ILLNESS", + "start": 65, + "end": 72, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.308823", + "last_modified": "2024-03-21 12:56:02.998816", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1119, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 80, + "end": 87, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.312665", + "last_modified": "2024-03-21 12:56:03.180897", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1120, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 303, + "end": 310, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.316076", + "last_modified": "2024-03-21 12:55:38.510359", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1121, + "user": "admin", + "cui": "129006008", + "value": "walking", + "start": 625, + "end": 632, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.319567", + "last_modified": "2024-03-21 12:55:39.464007", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1122, + "user": "admin", + "cui": "116699007", + "value": "prevent", + "start": 748, + "end": 755, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.322938", + "last_modified": "2024-03-21 12:55:39.885595", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1123, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 812, + "end": 819, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.326365", + "last_modified": "2024-03-21 12:55:40.037331", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1124, + "user": "admin", + "cui": "68369002", + "value": "brought", + "start": 846, + "end": 853, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.329874", + "last_modified": "2024-03-21 12:55:40.189360", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1125, + "user": "admin", + "cui": "23981006", + "value": "causing", + "start": 955, + "end": 962, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.333429", + "last_modified": "2024-03-21 12:55:40.763274", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1126, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 1346, + "end": 1353, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.336844", + "last_modified": "2024-03-21 12:55:42.227521", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1127, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 1931, + "end": 1938, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.340251", + "last_modified": "2024-03-21 12:55:45.328785", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1128, + "user": "admin", + "cui": "398166005", + "value": "perform", + "start": 2372, + "end": 2379, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.343624", + "last_modified": "2024-03-21 12:55:48.160076", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1129, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 3074, + "end": 3081, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.346962", + "last_modified": "2024-03-21 12:55:50.978229", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1130, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 3256, + "end": 3263, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.350238", + "last_modified": "2024-03-21 12:55:52.193153", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1131, + "user": "admin", + "cui": "55919000", + "value": "include", + "start": 3539, + "end": 3546, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.353731", + "last_modified": "2024-03-21 12:55:53.340699", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1132, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 4064, + "end": 4071, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.357075", + "last_modified": "2024-03-21 12:55:54.835592", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1133, + "user": "admin", + "cui": "30207005", + "value": "risk of", + "start": 4397, + "end": 4404, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.360390", + "last_modified": "2024-03-21 12:55:57.275999", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1134, + "user": "admin", + "cui": "225466006", + "value": "willing", + "start": 4495, + "end": 4502, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.363759", + "last_modified": "2024-03-21 12:55:58.187194", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1135, + "user": "admin", + "cui": "116154003", + "value": "patient", + "start": 4823, + "end": 4830, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.367017", + "last_modified": "2024-03-21 12:55:59.579610", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1136, + "user": "admin", + "cui": "88952004", + "value": "REASON", + "start": 0, + "end": 6, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.370415", + "last_modified": "2024-03-21 12:56:02.843052", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1137, + "user": "admin", + "cui": "42752001", + "value": "due to", + "start": 198, + "end": 204, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.373696", + "last_modified": "2024-03-21 12:55:37.936895", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1138, + "user": "admin", + "cui": "705406009", + "value": "walker", + "start": 442, + "end": 448, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.376995", + "last_modified": "2024-03-21 12:55:39.065236", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1139, + "user": "admin", + "cui": "42752001", + "value": "due to", + "start": 464, + "end": 470, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.380462", + "last_modified": "2024-03-21 12:55:39.197216", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1140, + "user": "admin", + "cui": "441889009", + "value": "denied", + "start": 882, + "end": 888, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.383608", + "last_modified": "2024-03-21 12:55:40.475925", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1141, + "user": "admin", + "cui": "410677005", + "value": "spread", + "start": 1543, + "end": 1549, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.386962", + "last_modified": "2024-03-21 12:55:42.955388", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1142, + "user": "admin", + "cui": "24484000", + "value": "severe", + "start": 2337, + "end": 2343, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.390130", + "last_modified": "2024-03-21 12:55:47.839329", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1143, + "user": "admin", + "cui": "8499008", + "value": "pulses", + "start": 2422, + "end": 2428, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.393538", + "last_modified": "2024-03-21 12:55:48.442968", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1144, + "user": "admin", + "cui": "11163003", + "value": "intact", + "start": 2535, + "end": 2541, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.396863", + "last_modified": "2024-03-21 12:55:48.932534", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1145, + "user": "admin", + "cui": "85756007", + "value": "tissue", + "start": 2669, + "end": 2675, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.400110", + "last_modified": "2024-03-21 12:55:49.240016", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1146, + "user": "admin", + "cui": "11163003", + "value": "intact", + "start": 2778, + "end": 2784, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.403363", + "last_modified": "2024-03-21 12:55:49.714670", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1147, + "user": "admin", + "cui": "52988006", + "value": "lesion", + "start": 2871, + "end": 2877, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.406496", + "last_modified": "2024-03-21 12:55:50.206541", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1148, + "user": "admin", + "cui": "27582007", + "value": "repeat", + "start": 2976, + "end": 2982, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.409813", + "last_modified": "2024-03-21 12:55:50.694943", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1149, + "user": "admin", + "cui": "246425008", + "value": "Nature", + "start": 3193, + "end": 3199, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.412981", + "last_modified": "2024-03-21 12:55:51.888238", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1150, + "user": "admin", + "cui": "288524001", + "value": "course", + "start": 3204, + "end": 3210, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.416043", + "last_modified": "2024-03-21 12:55:52.043592", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1151, + "user": "admin", + "cui": "27582007", + "value": "repeat", + "start": 3513, + "end": 3519, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.419605", + "last_modified": "2024-03-21 12:55:53.223968", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1152, + "user": "admin", + "cui": "52988006", + "value": "lesion", + "start": 3736, + "end": 3742, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.423023", + "last_modified": "2024-03-21 12:55:54.379681", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1153, + "user": "admin", + "cui": "398092000", + "value": "obtain", + "start": 4640, + "end": 4646, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.426681", + "last_modified": "2024-03-21 12:55:58.642397", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1154, + "user": "admin", + "cui": "36692007", + "value": "known", + "start": 128, + "end": 133, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.430158", + "last_modified": "2024-03-21 12:56:03.346889", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1155, + "user": "admin", + "cui": "258705008", + "value": "weeks", + "start": 413, + "end": 418, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.433456", + "last_modified": "2024-03-21 12:55:38.928075", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1156, + "user": "admin", + "cui": "113276009", + "value": "bowel", + "start": 1143, + "end": 1148, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.436824", + "last_modified": "2024-03-21 12:55:41.656841", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1157, + "user": "admin", + "cui": "258707000", + "value": "years", + "start": 1454, + "end": 1459, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.440099", + "last_modified": "2024-03-21 12:55:42.806259", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1158, + "user": "admin", + "cui": "38000004", + "value": "lymph", + "start": 1578, + "end": 1583, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.443495", + "last_modified": "2024-03-21 12:55:43.264034", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1159, + "user": "admin", + "cui": "732766004", + "value": "5. 2", + "start": 2415, + "end": 2420, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.446899", + "last_modified": "2024-03-21 12:55:48.293097", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1160, + "user": "admin", + "cui": "721963009", + "value": "order", + "start": 2970, + "end": 2975, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.450208", + "last_modified": "2024-03-21 12:55:50.543437", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + }, + { + "id": 1161, + "user": "admin", + "cui": "264499004", + "value": "early", + "start": 4176, + "end": 4181, + "validated": true, + "correct": true, + "deleted": false, + "alternative": false, + "killed": false, + "irrelevant": false, + "create_time": "2024-03-21 12:55:10.453601", + "last_modified": "2024-03-21 12:55:55.645644", + "comment": null, + "manually_created": false, + "acc": 1.0, + "meta_anns": {} + } + ], + "relations": [ + { + "start_entity": 1060, + "start_entity_cui": "76948002", + "start_entity_value": "severe pain", + "start_entity_start_idx": 205, + "start_entity_end_idx": 216, + "end_entity": 1039, + "end_entity_cui": "225728007", + "end_entity_value": "emergency room", + "end_entity_start_idx": 183, + "end_entity_end_idx": 197, + "user": "admin", + "relation": "disease/disability_procedure", + "validated": true + }, + { + "start_entity": 1021, + "start_entity_cui": "32153003", + "start_entity_value": "left lower extremity", + "start_entity_start_idx": 224, + "start_entity_end_idx": 244, + "end_entity": 1039, + "end_entity_cui": "225728007", + "end_entity_value": "emergency room", + "end_entity_start_idx": 183, + "end_entity_end_idx": 197, + "user": "admin", + "relation": "disease/disability_procedure", + "validated": true + }, + { + "start_entity": 1029, + "start_entity_cui": "59282003", + "start_entity_value": "pulmonary embolism", + "start_entity_start_idx": 4377, + "start_entity_end_idx": 4395, + "end_entity": 1044, + "end_entity_cui": "410429000", + "end_entity_value": "cardiac arrest", + "end_entity_start_idx": 4449, + "end_entity_end_idx": 4463, + "user": "admin", + "relation": "non_relation", + "validated": true + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 8ce47cfb5..e626d139c 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -6,12 +6,14 @@ from medcat.config import Config from medcat.pipe import Pipe from medcat.meta_cat import MetaCAT +from medcat.rel_cat import RelCAT from medcat.preprocessing.taggers import tag_skip_and_punct from medcat.preprocessing.tokenizers import spacy_split_all from medcat.utils.normalizers import BasicSpellChecker, TokenNormalizer from medcat.ner.vocab_based_ner import NER from medcat.linking.context_based_linker import Linker from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT +from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT as RelTokenizerWrapperBERT from transformers import AutoTokenizer @@ -41,7 +43,9 @@ def setUpClass(cls) -> None: cls.linker = Linker(cls.cdb, cls.vocab, cls.config) _tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) + _tokenizer_rel = RelTokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) cls.meta_cat = MetaCAT(tokenizer=_tokenizer) + cls.rel_cat = RelCAT(cls.cdb, tokenizer=_tokenizer_rel, init_model=True) cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb" cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config) @@ -56,6 +60,7 @@ def setUp(self) -> None: PipeTests.undertest.force_remove(PipeTests.ner.name) PipeTests.undertest.force_remove(PipeTests.linker.name) PipeTests.undertest.force_remove(PipeTests.meta_cat.name) + PipeTests.undertest.force_remove(PipeTests.rel_cat.name) def test_add_tagger(self): PipeTests.undertest.add_tagger(tagger=tag_skip_and_punct, name=tag_skip_and_punct.name, additional_fields=["is_punct"]) @@ -82,7 +87,12 @@ def test_add_meta_cat(self): PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) self.assertEqual(PipeTests.meta_cat.name, Language.get_factory_meta(PipeTests.meta_cat.name).factory) - + + def test_add_rel_cat(self): + PipeTests.undertest.add_rel_cat(PipeTests.rel_cat) + + self.assertEqual(PipeTests.rel_cat.name, Language.get_factory_meta(PipeTests.rel_cat.name).factory) + def test_stopwords_loading(self): self.assertEqual(PipeTests.undertest._nlp.Defaults.stop_words, PipeTests.config.preprocessing.stopwords) doc = PipeTests.undertest(PipeTests.text) @@ -95,6 +105,7 @@ def test_batch_multi_process(self): PipeTests.undertest.add_ner(PipeTests.ner) PipeTests.undertest.add_linker(PipeTests.linker) PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) + PipeTests.undertest.add_rel_cat(PipeTests.rel_cat) PipeTests.undertest.set_error_handler(_error_handler) docs = list(self.undertest.batch_multi_process([PipeTests.text, PipeTests.text, PipeTests.text], n_process=1, batch_size=1)) @@ -114,6 +125,7 @@ def _generate_texts(texts): PipeTests.undertest.add_ner(PipeTests.ner) PipeTests.undertest.add_linker(PipeTests.linker) PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) + PipeTests.undertest.add_rel_cat(PipeTests.rel_cat) docs = list(self.undertest(_generate_texts([PipeTests.text, None, PipeTests.text]))) @@ -128,6 +140,7 @@ def test_callable_with_single_text(self): PipeTests.undertest.add_ner(PipeTests.ner) PipeTests.undertest.add_linker(PipeTests.linker) PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) + PipeTests.undertest.add_rel_cat(PipeTests.rel_cat) doc = self.undertest(PipeTests.text) @@ -139,6 +152,7 @@ def test_callable_with_multi_texts(self): PipeTests.undertest.add_ner(PipeTests.ner) PipeTests.undertest.add_linker(PipeTests.linker) PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) + PipeTests.undertest.add_rel_cat(PipeTests.rel_cat) docs = list(self.undertest([PipeTests.text, None, PipeTests.text])) diff --git a/tests/test_rel_cat.py b/tests/test_rel_cat.py new file mode 100644 index 000000000..8f6db4261 --- /dev/null +++ b/tests/test_rel_cat.py @@ -0,0 +1,111 @@ +import os +import shutil +import unittest +import json + +from medcat.cdb import CDB +from medcat.config_rel_cat import ConfigRelCAT +from medcat.rel_cat import RelCAT +from medcat.utils.relation_extraction.rel_dataset import RelData +from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT +from medcat.utils.relation_extraction.models import BertModel_RelationExtraction + +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.models.bert.configuration_bert import BertConfig + +import spacy +from spacy.tokens import Span, Doc + +class RelCATTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + config = ConfigRelCAT() + config.general.device = "cpu" + config.general.model_name = "bert-base-uncased" + config.train.batch_size = 1 + config.train.nclasses = 3 + config.model.hidden_size= 256 + config.model.model_size = 2304 + + tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=config.general.model_name, + config=config), add_special_tokens=True) + + SPEC_TAGS = ["[s1]", "[e1]", "[s2]", "[e2]"] + + tokenizer.hf_tokenizers.add_tokens(SPEC_TAGS, special_tokens=True) + config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(SPEC_TAGS) + + cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") + os.makedirs(cls.tmp_dir, exist_ok=True) + + cls.save_model_path = os.path.join(cls.tmp_dir, "test_model") + os.makedirs(cls.save_model_path, exist_ok=True) + + cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")) + + cls.medcat_export_with_rels_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_trainer_export_relations.json") + cls.medcat_rels_csv_path_train = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_rel_train.csv") + cls.medcat_rels_csv_path_test = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_rel_test.csv") + + cls.mct_file_test = {} + with open(cls.medcat_export_with_rels_path, "r+") as f: + cls.mct_file_test = json.loads(f.read())["projects"][0]["documents"][1] + + cls.config_rel_cat: ConfigRelCAT = config + cls.rel_cat: RelCAT = RelCAT(cdb, tokenizer=tokenizer, config=config, init_model=True) + + cls.rel_cat.model.bert_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) + + cls.finished = False + cls.tokenizer = tokenizer + + def test_train_csv_no_tags(self) -> None: + self.rel_cat.config.train.epochs = 2 + self.rel_cat.train(train_csv_path=self.medcat_rels_csv_path_train, test_csv_path=self.medcat_rels_csv_path_test, checkpoint_path=self.tmp_dir) + self.rel_cat.save(self.save_model_path) + + def test_train_mctrainer(self) -> None: + self.rel_cat = RelCAT.load(self.save_model_path) + self.rel_cat.config.general.mct_export_create_addl_rels = True + self.rel_cat.config.general.mct_export_max_non_rel_sample_size = 10 + self.rel_cat.config.train.test_size = 0.1 + self.rel_cat.config.train.nclasses = 3 + self.rel_cat.model.relcat_config.train.nclasses = 3 + self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers)) + + self.rel_cat.train(export_data_path=self.medcat_export_with_rels_path, checkpoint_path=self.tmp_dir) + + def test_train_predict(self) -> None: + Span.set_extension('id', default=0, force=True) + Span.set_extension('cui', default=None, force=True) + Doc.set_extension('ents', default=[], force=True) + Doc.set_extension('relations', default=[], force=True) + nlp = spacy.blank("en") + doc = nlp(self.mct_file_test["text"]) + + for ann in self.mct_file_test["annotations"]: + tkn_idx = [] + for ind, word in enumerate(doc): + end_char = word.idx + len(word.text) + if end_char <= ann['end'] and end_char > ann['start']: + tkn_idx.append(ind) + entity = Span(doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) + entity._.cui = ann["cui"] + doc._.ents.append(entity) + + self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers)) + + doc = self.rel_cat(doc) + self.finished = True + + assert len(doc._.relations) > 0 + + def tearDown(self) -> None: + if self.finished: + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + +if __name__ == '__main__': + unittest.main()