Skip to content

Commit

Permalink
Merge pull request #419 from CogStack/metacat_bert
Browse files Browse the repository at this point in the history
Adding Bert-style model for MetaCAT
  • Loading branch information
shubham-s-agarwal authored May 16, 2024
2 parents 1d78bd0 + 2657515 commit fbe9745
Show file tree
Hide file tree
Showing 8 changed files with 562 additions and 187 deletions.
19 changes: 17 additions & 2 deletions medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict, Any

from medcat.config import MixingConfig, BaseModel, Optional, Extra


Expand Down Expand Up @@ -49,10 +48,20 @@ class Config:
class Model(MixingConfig, BaseModel):
"""The model part of the metaCAT config"""
model_name: str = 'lstm'
"""NOTE: When changing model, make sure to change the tokenizer as well"""
model_variant: str = 'bert-base-uncased'
model_freeze_layers: bool = True
num_layers: int = 2
input_size: int = 300
hidden_size: int = 300
dropout: float = 0.5
phase_number: int = 0
"""Indicates whether or not two phase learning is being performed.
1: Phase 1 - Train model on undersampled data
2: Phase 2 - Continue training on full data
0: None - 2 phase learning is not performed"""
category_undersample: str = ''
model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True}
num_directions: int = 2
"""2 - bidirectional model, 1 - unidirectional"""
nclasses: int = 2
Expand All @@ -61,7 +70,7 @@ class Model(MixingConfig, BaseModel):
emb_grad: bool = True
"""If True the embeddings will also be trained"""
ignore_cpos: bool = False
"""If set to True center positions will be ignored when calculating represenation"""
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
Expand All @@ -77,6 +86,8 @@ class Train(MixingConfig, BaseModel):
shuffle_data: bool = True
"""Used only during training, if set the dataset will be shuffled before train/test split"""
class_weights: Optional[Any] = None
compute_class_weights: bool = False
"""If true and if class weights are not provided, the class weights will be calculated based on the data"""
score_average: str = 'weighted'
"""What to use for averaging F1/P/R across labels"""
prerequisites: dict = {}
Expand All @@ -88,6 +99,10 @@ class Train(MixingConfig, BaseModel):
"""When was the last training run"""
metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
"""What metric should be used for choosing the best model"""
loss_funct: str = 'cross_entropy'
"""Loss function for the model"""
gamma: int = 2
"""Focal Loss - how much the loss focuses on hard-to-classify examples."""

class Config:
extra = Extra.allow
Expand Down
123 changes: 92 additions & 31 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
from medcat.utils.hasher import Hasher
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values, prepare_for_oversampled_data
from medcat.pipeline.pipe_runner import PipeRunner
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
from medcat.utils.meta_cat.data_utils import Doc as FakeDoc
from medcat.utils.decorators import deprecated
from peft import get_peft_model, LoraConfig, TaskType

# It should be safe to do this always, as all other multiprocessing
# will be finished before data comes to meta_cat
os.environ["TOKENIZERS_PARALLELISM"] = "true"


logger = logging.getLogger(__name__) # separate logger from the package-level one
logger = logging.getLogger(__name__) # separate logger from the package-level one


class MetaCAT(PipeRunner):
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
The embedding densor
Raises:
ValueError: If the meta model is not LSTM
ValueError: If the meta model is not LSTM or BERT
Returns:
nn.Module:
Expand All @@ -86,7 +86,22 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
config = self.config
if config.model['model_name'] == 'lstm':
from medcat.utils.meta_cat.models import LSTM
model = LSTM(embeddings, config)
model: nn.Module = LSTM(embeddings, config)
logger.info("LSTM model used for classification")

elif config.model['model_name'] == 'bert':
from medcat.utils.meta_cat.models import BertForMetaAnnotation
model = BertForMetaAnnotation(config)

if not config.model.model_freeze_layers:
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
target_modules=["query", "value"], lora_dropout=0.2)

model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

logger.info("BERT model used for classification")

else:
raise ValueError("Unknown model name %s" % config.model['model_name'])

Expand All @@ -107,7 +122,7 @@ def get_hash(self) -> str:
return hasher.hexdigest()

@deprecated(message="Use `train_from_json` or `train_raw` instead")
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
continue training if an existing model is loaded or start new training if the model is blank/new.
Expand All @@ -117,13 +132,16 @@ def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter
Returns:
Dict: The resulting report.
"""
return self.train_from_json(json_path, save_dir_path)
return self.train_from_json(json_path, save_dir_path, data_oversampled=data_oversampled)

def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None,
data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
continue training if an existing model is loaded or start new training if the model is blank/new.
Expand All @@ -133,6 +151,8 @@ def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[s
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter
Returns:
Dict: The resulting report.
Expand All @@ -157,9 +177,9 @@ def merge_data_loaded(base, other):
for path in json_path:
with open(path, 'r') as f:
data_loaded = merge_data_loaded(data_loaded, json.load(f))
return self.train_raw(data_loaded, save_dir_path)
return self.train_raw(data_loaded, save_dir_path, data_oversampled=data_oversampled)

def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict:
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model given raw data. It will
continue training if an existing model is loaded or start new training if the model is blank/new.
Expand Down Expand Up @@ -187,13 +207,19 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter
The format of which is expected: [[['text','of','the','document'], [index of medical entity], "label" ],
['text','of','the','document'], [index of medical entity], "label" ]]
Returns:
Dict: The resulting report.
Raises:
Exception: If no save path is specified, or category name not in data.
AssertionError: If no tokeniser is set
FileNotFoundError: If phase_number is set to 2 and model.dat file is not found
KeyError: If phase_number is set to 2 and model.dat file contains mismatched architecture
"""
g_config = self.config.general
t_config = self.config.train
Expand All @@ -212,32 +238,60 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'],
lowercase=g_config['lowercase'])

# Check is the name there
# Check is the name present
category_name = g_config['category_name']
if category_name not in data:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
category_name, " | ".join(list(data.keys()))))

data = data[category_name]
if data_oversampled:
data_sampled = prepare_for_oversampled_data(data_oversampled, self.tokenizer)
data = data + data_sampled

category_value2id = g_config['category_value2id']
if not category_value2id:
# Encode the category values
data, category_value2id = encode_category_values(data)
data_undersampled, full_data, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
else:
# We already have everything, just get the data
data, _ = encode_category_values(data, existing_category_value2id=category_value2id)

data_undersampled, full_data, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: {} vs {}".format(
self.config.model['nclasses'], len(category_value2id)))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)
self.model = self.get_model(embeddings=self.embeddings)

if self.config.model.phase_number == 2 and save_dir_path is not None:
model_save_path = os.path.join(save_dir_path, 'model.dat')
device = torch.device(g_config['device'])
try:
self.model.load_state_dict(torch.load(model_save_path, map_location=device))
logger.info("Model state loaded from dict for 2 phase learning")

except FileNotFoundError:
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")

except KeyError:
raise KeyError("\nError: Missing key in loaded state dictionary. \nThis might be due to a mismatch between the model architecture and the saved state.")

except Exception as e:
raise Exception(f"\nError: Model state cannot be loaded from dict. {e}")

data = full_data
if self.config.model.phase_number == 1:
data = data_undersampled
if not t_config['auto_save_model']:
logger.info("For phase 1, model state has to be saved. Saving model...")
t_config['auto_save_model'] = True

report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)

Expand Down Expand Up @@ -293,7 +347,7 @@ def eval(self, json_path: str) -> Dict:

# We already have everything, just get the data
category_value2id = g_config['category_value2id']
data, _ = encode_category_values(data, existing_category_value2id=category_value2id)
data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id)

# Run evaluation
assert self.tokenizer is not None
Expand All @@ -317,8 +371,8 @@ def save(self, save_dir_path: str) -> None:
# Save tokenizer
assert self.tokenizer is not None
self.tokenizer.save(save_dir_path)

# Save config

self.config.save(os.path.join(save_dir_path, 'config.json'))

# Save the model
Expand Down Expand Up @@ -347,7 +401,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
# Load config
config = cast(ConfigMetaCAT, ConfigMetaCAT.load(os.path.join(save_dir_path, 'config.json')))

# Overwrite loaded paramters with something new
# Overwrite loaded parameters with something new
if config_dict is not None:
config.merge_config(config_dict)

Expand All @@ -358,7 +412,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
tokenizer = TokenizerWrapperBPE.load(save_dir_path)
elif config.general['tokenizer_name'] == 'bert-tokenizer':
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
tokenizer = TokenizerWrapperBERT.load(save_dir_path)
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant'])

# Create meta_cat
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)
Expand All @@ -380,7 +434,8 @@ def get_ents(self, doc: Doc) -> Iterable[Span]:
try:
return doc.spans[spangroup_name]
except KeyError:
raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")
raise Exception(
f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")

# Should we annotate overlapping entities
if self.config.general['annotate_overlapping']:
Expand Down Expand Up @@ -421,18 +476,26 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
start = ent.start_char
end = ent.end_char

ind = 0
# Start where the last ent was found, cannot be before it as we've sorted
# Updated implementation to extract all the tokens for the medical entity (rather than the one)
ctoken_idx = []
for ind, pair in enumerate(offset_mapping[last_ind:]):
if start >= pair[0] and start < pair[1]:
break
ind = last_ind + ind # If we did not start from 0 in the for loop
last_ind = ind
# Checking if we've reached at the start of the entity
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind) # End reached
break
else:
ctoken_idx.append(ind) # Keep going

# Start where the last ent was found, cannot be before it as we've sorted
last_ind += ind # If we did not start from 0 in the for loop

_start = max(0, ctoken_idx[0] - cntx_left)
_end = min(len(input_ids), ctoken_idx[-1] + 1 + cntx_right)

_start = max(0, ind - cntx_left)
_end = min(len(input_ids), ind + 1 + cntx_right)
tkns = input_ids[_start:_end]
cpos = cntx_left + min(0, ind - cntx_left)
cpos_new = [x - _start for x in ctoken_idx]

if replace_center is not None:
if lowercase:
Expand All @@ -447,8 +510,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
ln = e_ind - s_ind # Length of the concept in tokens
assert self.tokenizer is not None
tkns = tkns[:cpos] + self.tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:]

samples.append([tkns, cpos])
samples.append([tkns, cpos_new])
ent_id2ind[ent._.id] = len(samples) - 1

return ent_id2ind, samples
Expand Down Expand Up @@ -544,7 +606,6 @@ def _set_meta_anns(self,
for i, doc in enumerate(docs):
data.extend(doc._.share_tokens[0])
doc_ind2positions[i] = doc._.share_tokens[1]

all_predictions, all_confidences = predict(self.model, data, config)
for i, doc in enumerate(docs):
start_ind, end_ind, ent_id2ind = doc_ind2positions[i]
Expand Down
13 changes: 9 additions & 4 deletions medcat/tokenizers/meta_cat_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union, overload
Expand Down Expand Up @@ -26,7 +27,7 @@ def save(self, dir_path: str) -> None: ...

@classmethod
@abstractmethod
def load(cls, dir_path: str, **kwargs) -> Tokenizer: ...
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> Tokenizer: ...

@abstractmethod
def get_size(self) -> int: ...
Expand Down Expand Up @@ -112,7 +113,7 @@ def save(self, dir_path: str) -> None:
self.hf_tokenizers.save_model(dir_path, prefix=self.name)

@classmethod
def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBPE":
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBPE":
tokenizer = cls()
vocab_file = os.path.join(dir_path, f'{tokenizer.name}-vocab.json')
merges_file = os.path.join(dir_path, f'{tokenizer.name}-merges.txt')
Expand Down Expand Up @@ -186,10 +187,14 @@ def save(self, dir_path: str) -> None:
self.hf_tokenizers.save_pretrained(path)

@classmethod
def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBERT":
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBERT":
tokenizer = cls()
path = os.path.join(dir_path, cls.name)
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
try:
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
except Exception as e:
logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant))
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant)

return tokenizer

Expand Down
Loading

0 comments on commit fbe9745

Please sign in to comment.