Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cu 86947ja9y dill old weights #27

Closed
wants to merge 9 commits into from
Binary file added examples/cdb_old_broken_weights_in_config.dat
Binary file not shown.
11 changes: 8 additions & 3 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def attempt_unpack(cls, zip_path: str) -> str:
def load_model_pack(cls,
zip_path: str,
meta_cat_config_dict: Optional[Dict] = None,
ner_config_dict: Optional[Dict] = None,
load_meta_models: bool = True,
load_addl_ner: bool = True) -> "CAT":
"""Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models
Expand All @@ -346,6 +347,10 @@ def load_model_pack(cls,
A config dict that will overwrite existing configs in meta_cat.
e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
Defaults to None.
ner_config_dict (Optional[Dict]):
A config dict that will overwrite existing configs in transformers ner.
e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}.
Defaults to None.
load_meta_models (bool):
Whether to load MetaCAT models if present (Default value True).
load_addl_ner (bool):
Expand Down Expand Up @@ -381,15 +386,15 @@ def load_model_pack(cls,
else:
vocab = None

# Find meta models in the model_pack
# Find ner models in the model_pack
trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
addl_ner = []
for trf_path in trf_paths:
trf = TransformersNER.load(save_dir_path=trf_path)
trf = TransformersNER.load(save_dir_path=trf_path,config_dict=ner_config_dict)
trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model
addl_ner.append(trf)

# Find meta models in the model_pack
# Find metacat models in the model_pack
meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_cats = []
for meta_path in meta_paths:
Expand Down
42 changes: 40 additions & 2 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.dataclasses import Any, Callable, Dict, Optional, Union
from pydantic.fields import ModelField
from typing import List, Set, Tuple, cast
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
from multiprocessing import cpu_count
import logging
import jsonpickle
Expand Down Expand Up @@ -33,6 +32,45 @@ def __getitem__(self, arg: str) -> Any:
except AttributeError as e:
raise KeyError from e

def _attempt_fix_weighted_average_function(self, waf):
try:
waf(1)
return waf
except TypeError:
# this means we need to apply the fix
return self._fix_waf(waf)

def _fix_waf(self, waf):
if not str(waf.func).startswith("<function weighted_average at "):
logger.warning("It seems the value of "
"`config.linking.weighted_average_function` "
"in the config does not work properly. While we "
"are aware of the issue and know how to fix the "
"default value, doing so for arbitrary methods "
"is not trivial. This is the case we've found. "
"The method does not seem to work properly, but "
"it has a non-default value so we are unable to "
"perform a fix for it. This is more than likely "
"to cause the an error when running the pipe. "
"To fix this, change the value of "
"`config.linking.weighted_average_function` "
"manually before using the CAT instance")
return waf
logging.warning("Fixing config.linking.weighted_average_function "
"since the value saved does not work properly. "
"This is usually due to having loaded a model "
"that was originally saved in older versions of "
"python and thus something has gone wrong when "
"loading the method. This fix should not affect "
"usage, but if you wish to avoid the warning "
"you may want to save the model pack again using "
"a newer version of python (3.11 or later).")
return partial(weighted_average, *waf.args, **waf.keywords)

def __setattr__(self, arg: str, val) -> None:
if isinstance(self, Linking) and arg == "weighted_average_function":
val = self._attempt_fix_weighted_average_function(val)
super().__setattr__(arg, val)

def __setitem__(self, arg: str, val) -> None:
setattr(self, arg, val)
Expand Down
2 changes: 2 additions & 0 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class General(MixingConfig, BaseModel):
"""How many characters are piped at once into the meta_cat class"""
ner_aggregation_strategy: str = 'simple'
"""Agg strategy for HF pipeline for NER"""
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
verbose_metrics: bool = False
Expand Down
6 changes: 4 additions & 2 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
else:
self.training_arguments = training_arguments


def create_eval_pipeline(self):
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer)

if self.config.general['chunking_overlap_window'] is None:
logger.warning("Chunking overlap window attribute in the config is set to None, hence chunking is disabled. Be cautious, PII data MAY BE REVEALED. To enable chunking, set the value to 0 or above.")
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer,stride=self.config.general['chunking_overlap_window'])
if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'):
# NOTE: this will fix the DeID model(s) created before medcat 1.9.3
# though this fix may very well be unstable
Expand Down
30 changes: 27 additions & 3 deletions medcat/utils/ner/deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
- config
- cdb
"""
from typing import Union, Tuple, Any, List, Iterable, Optional
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
import logging

from medcat.cat import CAT
from medcat.utils.ner.model import NerModel

from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text


logger = logging.getLogger(__name__)


class DeIdModel(NerModel):
"""The DeID model.

Expand Down Expand Up @@ -93,6 +97,25 @@ def deid_multi_texts(self,
Returns:
List[str]: List of deidentified documents.
"""
# NOTE: we assume we're using the 1st (and generally only)
# additional NER model.
# the same assumption is made in the `train` method
chunking_overlap_window = self.cat._addl_ner[0].config.general.chunking_overlap_window
if chunking_overlap_window is not None:
logger.warning("Chunking overlap window has been set to %s. "
"This may cause multiprocessing to stall in certain"
"environments and/or situations and has not been"
"fully tested.",
chunking_overlap_window)
logger.warning("If the following hangs forever (i.e doesn't finish) "
"but you still wish to run on multiple processes you can set "
"`cat._addl_ner[0].config.general.chunking_overlap_window = None` "
"and then either a) save the model on disk and load it back up, or "
" b) call `cat._addl_ner[0].create_eval_pipeline()` to recreate the pipe. "
"However, this will remove chunking from the input text, which means "
"only the first 512 tokens will be recognised and thus only the "
"first part of longer documents (those with more than 512) tokens"
"will be deidentified. ")
entities = self.cat.get_entities_multi_texts(texts, addl_info=addl_info,
n_process=n_process, batch_size=batch_size)
out = []
Expand All @@ -110,7 +133,7 @@ def deid_multi_texts(self,
return out

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
"""Load DeId model from model pack.

The method first loads the CAT instance.
Expand All @@ -119,6 +142,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
valid DeId model.

Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.

Raises:
Expand All @@ -127,7 +151,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
Returns:
DeIdModel: The resulting DeI model.
"""
ner_model = NerModel.load_model_pack(model_pack_path)
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
cat = ner_model.cat
if not cls._is_deid_model(cat):
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions medcat/utils/ner/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple, Union, Optional
from typing import Any, List, Tuple, Union, Optional, Dict

from spacy.tokens import Doc

Expand Down Expand Up @@ -94,16 +94,17 @@ def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel
return cls(cat)

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'NerModel':
def load_model_pack(cls, model_pack_path: str,config: Optional[Dict] = None) -> 'NerModel':
"""Load NER model from model pack.

The method first wraps the loaded CAT instance.

Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.

Returns:
NerModel: The resulting DeI model.
"""
cat = CAT.load_model_pack(model_pack_path)
cat = CAT.load_model_pack(model_pack_path,ner_config_dict=config)
return cls(cat)
37 changes: 36 additions & 1 deletion medcat/utils/preprocess_snomed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ def get_all_children(sctid, pt2ch):
return result


def get_direct_refset_mapping(in_dict: dict) -> dict:
"""This method uses the output from Snomed.map_snomed2icd10 or
Snomed.map_snomed2opcs4 and removes the metadata and maps each
SNOMED CUI to the prioritised list of the target ontology CUIs.

The input dict is expected to be in the following format:
- Keys are SnomedCT CUIs
- The values are lists of dictionaries, each list item (at least)
- Has a key 'code' that specifies the target onotlogy CUI
- Has a key 'mapPriority' that specifies the priority

Args:
in_dict (dict): The input dict.

Returns:
dict: The map from Snomed CUI to list of priorities list of target ontology CUIs.
"""
ret_dict = dict()
for k, vals in in_dict.items():
# sort such that highest priority values are first
svals = sorted(vals, key=lambda el: el['mapPriority'], reverse=True)
# only keep the code / CUI
ret_dict[k] = [v['code'] for v in svals]
return ret_dict


class Snomed:
"""
Pre-process SNOMED CT release files.
Expand All @@ -53,6 +79,15 @@ def __init__(self, data_path, uk_ext=False, uk_drug_ext=False):
self.release = data_path[-16:-8]
self.uk_ext = uk_ext
self.uk_drug_ext = uk_drug_ext
self.opcs_refset_id = "1126441000000105"
if ((self.uk_ext or self.uk_drug_ext) and
# using lexicographical comparison below
# e.g "20240101" > "20231122" results in True
# yet "20231121" > "20231122" reults in False
len(self.release) == len("20231122") and self.release >= "20231122"):
# NOTE for UK extensions starting from 20231122 the
# OPCS4 refset ID seems to be different
self.opcs_refset_id = '1382401000000109'

def to_concept_df(self):
"""
Expand Down Expand Up @@ -398,7 +433,7 @@ def _map_snomed2refset(self):
mapping_df = pd.concat(dfs2merge)
del dfs2merge
if self.uk_ext or self.uk_drug_ext:
opcs_df = mapping_df[mapping_df['refsetId'] == '1126441000000105']
opcs_df = mapping_df[mapping_df['refsetId'] == self.opcs_refset_id]
icd10_df = mapping_df[mapping_df['refsetId']
== '999002271000000101']
return icd10_df, opcs_df
Expand Down
Loading
Loading