Skip to content

Commit

Permalink
CU-8693qx9yp Deid chunking - hugging face pipeline approach (#405)
Browse files Browse the repository at this point in the history
* Pushing chunking update

* Update transformers_ner.py

* Pushing update to config

Added NER config in cat load function

* Update cat.py

* Updating chunking overlap

* CU-8693qx9yp: Add warning for deid multiprocessing with (potentially) non-functioning chunking window

* CU-8693qx9yp: Fix linting issue

---------

Co-authored-by: mart-r <mart.ratas@gmail.com>
  • Loading branch information
shubham-s-agarwal and mart-r authored Feb 28, 2024
1 parent 02afddb commit 67f1126
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 12 deletions.
11 changes: 8 additions & 3 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def attempt_unpack(cls, zip_path: str) -> str:
def load_model_pack(cls,
zip_path: str,
meta_cat_config_dict: Optional[Dict] = None,
ner_config_dict: Optional[Dict] = None,
load_meta_models: bool = True,
load_addl_ner: bool = True) -> "CAT":
"""Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models
Expand All @@ -346,6 +347,10 @@ def load_model_pack(cls,
A config dict that will overwrite existing configs in meta_cat.
e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
Defaults to None.
ner_config_dict (Optional[Dict]):
A config dict that will overwrite existing configs in transformers ner.
e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}.
Defaults to None.
load_meta_models (bool):
Whether to load MetaCAT models if present (Default value True).
load_addl_ner (bool):
Expand Down Expand Up @@ -381,15 +386,15 @@ def load_model_pack(cls,
else:
vocab = None

# Find meta models in the model_pack
# Find ner models in the model_pack
trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
addl_ner = []
for trf_path in trf_paths:
trf = TransformersNER.load(save_dir_path=trf_path)
trf = TransformersNER.load(save_dir_path=trf_path,config_dict=ner_config_dict)
trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model
addl_ner.append(trf)

# Find meta models in the model_pack
# Find metacat models in the model_pack
meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_cats = []
for meta_path in meta_paths:
Expand Down
2 changes: 2 additions & 0 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class General(MixingConfig, BaseModel):
"""How many characters are piped at once into the meta_cat class"""
ner_aggregation_strategy: str = 'simple'
"""Agg strategy for HF pipeline for NER"""
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
verbose_metrics: bool = False
Expand Down
6 changes: 4 additions & 2 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
else:
self.training_arguments = training_arguments


def create_eval_pipeline(self):
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer)

if self.config.general['chunking_overlap_window'] is None:
logger.warning("Chunking overlap window attribute in the config is set to None, hence chunking is disabled. Be cautious, PII data MAY BE REVEALED. To enable chunking, set the value to 0 or above.")
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer,stride=self.config.general['chunking_overlap_window'])
if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'):
# NOTE: this will fix the DeID model(s) created before medcat 1.9.3
# though this fix may very well be unstable
Expand Down
30 changes: 27 additions & 3 deletions medcat/utils/ner/deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
- config
- cdb
"""
from typing import Union, Tuple, Any, List, Iterable, Optional
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
import logging

from medcat.cat import CAT
from medcat.utils.ner.model import NerModel

from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text


logger = logging.getLogger(__name__)


class DeIdModel(NerModel):
"""The DeID model.
Expand Down Expand Up @@ -93,6 +97,25 @@ def deid_multi_texts(self,
Returns:
List[str]: List of deidentified documents.
"""
# NOTE: we assume we're using the 1st (and generally only)
# additional NER model.
# the same assumption is made in the `train` method
chunking_overlap_window = self.cat._addl_ner[0].config.general.chunking_overlap_window
if chunking_overlap_window is not None:
logger.warning("Chunking overlap window has been set to %s. "
"This may cause multiprocessing to stall in certain"
"environments and/or situations and has not been"
"fully tested.",
chunking_overlap_window)
logger.warning("If the following hangs forever (i.e doesn't finish) "
"but you still wish to run on multiple processes you can set "
"`cat._addl_ner[0].config.general.chunking_overlap_window = None` "
"and then either a) save the model on disk and load it back up, or "
" b) call `cat._addl_ner[0].create_eval_pipeline()` to recreate the pipe. "
"However, this will remove chunking from the input text, which means "
"only the first 512 tokens will be recognised and thus only the "
"first part of longer documents (those with more than 512) tokens"
"will be deidentified. ")
entities = self.cat.get_entities_multi_texts(texts, addl_info=addl_info,
n_process=n_process, batch_size=batch_size)
out = []
Expand All @@ -110,7 +133,7 @@ def deid_multi_texts(self,
return out

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
"""Load DeId model from model pack.
The method first loads the CAT instance.
Expand All @@ -119,6 +142,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
valid DeId model.
Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.
Raises:
Expand All @@ -127,7 +151,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
Returns:
DeIdModel: The resulting DeI model.
"""
ner_model = NerModel.load_model_pack(model_pack_path)
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
cat = ner_model.cat
if not cls._is_deid_model(cat):
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions medcat/utils/ner/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple, Union, Optional
from typing import Any, List, Tuple, Union, Optional, Dict

from spacy.tokens import Doc

Expand Down Expand Up @@ -94,16 +94,17 @@ def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel
return cls(cat)

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'NerModel':
def load_model_pack(cls, model_pack_path: str,config: Optional[Dict] = None) -> 'NerModel':
"""Load NER model from model pack.
The method first wraps the loaded CAT instance.
Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.
Returns:
NerModel: The resulting DeI model.
"""
cat = CAT.load_model_pack(model_pack_path)
cat = CAT.load_model_pack(model_pack_path,ner_config_dict=config)
return cls(cat)
2 changes: 1 addition & 1 deletion tests/utils/ner/test_deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_can_create_model(self):
deid_model = deid.DeIdModel.create(ner)
self.assertIsNotNone(deid_model)


def _add_model(cls):
cdb = make_or_update_cdb(TRAIN_DATA)
config = transformers_ner.ConfigTransformersNER()
config.general['test_size'] = 0.1 # Usually set this to 0.1-0.2
config.general['chunking_overlap_window'] = None
cls.ner = transformers_ner.TransformersNER(cdb=cdb, config=config)
cls.ner.training_arguments.num_train_epochs = 1 # Use 5-10 normally
# As we are NOT training on a GPU that can, we'll set it to 1
Expand Down

0 comments on commit 67f1126

Please sign in to comment.