Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Bert-style model for MetaCAT #419

Merged
merged 22 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4d36f8a
Pushing changes for bert-style models for MetaCAT
shubham-s-agarwal Apr 19, 2024
da9ab06
Pushing fix for LSTM
shubham-s-agarwal Apr 19, 2024
cb65fc3
Pushing changes for flake8 and type fixes
shubham-s-agarwal Apr 19, 2024
869eeae
Pushing type fixes
shubham-s-agarwal Apr 19, 2024
3e02eed
Fixing type issue
shubham-s-agarwal Apr 19, 2024
c899c9c
Pushing changes
shubham-s-agarwal Apr 22, 2024
d1321b8
Pushing change and type fixes
shubham-s-agarwal Apr 22, 2024
9091d9b
Fixing flake8 issues
shubham-s-agarwal Apr 22, 2024
c57dcfe
Pushing flake8 fixes
shubham-s-agarwal Apr 23, 2024
364fdd4
Pushing fixes for flake8
shubham-s-agarwal Apr 23, 2024
7272168
Pushing flake8 fix
shubham-s-agarwal Apr 23, 2024
619c565
Adding peft to list of libraries
shubham-s-agarwal Apr 23, 2024
2a546c3
Pushing changes with load and train workflow and type fixes
shubham-s-agarwal Apr 30, 2024
8efd2a9
Pushing changes with type hints and new documentation
shubham-s-agarwal May 7, 2024
aa5044e
Pushing type fix
shubham-s-agarwal May 7, 2024
fcdc867
Fixing type issue
shubham-s-agarwal May 7, 2024
88ee8e7
Adding test case for BERT and reverting config changes
shubham-s-agarwal May 7, 2024
917dca2
Merging changes from master to metacat_bert branch (#431)
shubham-s-agarwal May 8, 2024
563c3d4
Merge branch 'master' into metacat_bert
mart-r May 8, 2024
decfbfb
Pushing changed tests and removing empty change
shubham-s-agarwal May 8, 2024
fbcdb70
Pushing change for logging
shubham-s-agarwal May 9, 2024
2657515
Revert "Pushing change for logging"
shubham-s-agarwal May 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict, Any

from medcat.config import MixingConfig, BaseModel, Optional, Extra


Expand Down Expand Up @@ -49,10 +48,20 @@ class Config:
class Model(MixingConfig, BaseModel):
"""The model part of the metaCAT config"""
model_name: str = 'lstm'
"""NOTE: When changing model, make sure to change the tokenizer as well"""
model_variant: str = 'bert-base-uncased'
model_freeze_layers: bool = True
num_layers: int = 2
input_size: int = 300
hidden_size: int = 300
dropout: float = 0.5
phase_number: int = 0
"""Indicates whether or not two phase learning is being performed.
1: Phase 1 - Train model on undersampled data
2: Phase 2 - Continue training on full data
0: None - 2 phase learning is not performed"""
category_undersample: str = ''
model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True}
num_directions: int = 2
"""2 - bidirectional model, 1 - unidirectional"""
nclasses: int = 2
Expand All @@ -61,7 +70,7 @@ class Model(MixingConfig, BaseModel):
emb_grad: bool = True
"""If True the embeddings will also be trained"""
ignore_cpos: bool = False
"""If set to True center positions will be ignored when calculating represenation"""
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
Expand All @@ -77,6 +86,8 @@ class Train(MixingConfig, BaseModel):
shuffle_data: bool = True
"""Used only during training, if set the dataset will be shuffled before train/test split"""
class_weights: Optional[Any] = None
compute_class_weights: bool = False
"""If true and if class weights are not provided, the class weights will be calculated based on the data"""
score_average: str = 'weighted'
"""What to use for averaging F1/P/R across labels"""
prerequisites: dict = {}
Expand All @@ -88,6 +99,10 @@ class Train(MixingConfig, BaseModel):
"""When was the last training run"""
metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
"""What metric should be used for choosing the best model"""
loss_funct: str = 'cross_entropy'
"""Loss function for the model"""
gamma: int = 2
"""Focal Loss - how much the loss focuses on hard-to-classify examples."""

class Config:
extra = Extra.allow
Expand Down
123 changes: 92 additions & 31 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
from medcat.utils.hasher import Hasher
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values, prepare_for_oversampled_data
from medcat.pipeline.pipe_runner import PipeRunner
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
from medcat.utils.meta_cat.data_utils import Doc as FakeDoc
from medcat.utils.decorators import deprecated
from peft import get_peft_model, LoraConfig, TaskType

# It should be safe to do this always, as all other multiprocessing
# will be finished before data comes to meta_cat
os.environ["TOKENIZERS_PARALLELISM"] = "true"


logger = logging.getLogger(__name__) # separate logger from the package-level one
logger = logging.getLogger(__name__) # separate logger from the package-level one


class MetaCAT(PipeRunner):
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
The embedding densor

Raises:
ValueError: If the meta model is not LSTM
ValueError: If the meta model is not LSTM or BERT

Returns:
nn.Module:
Expand All @@ -86,7 +86,22 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
config = self.config
if config.model['model_name'] == 'lstm':
from medcat.utils.meta_cat.models import LSTM
model = LSTM(embeddings, config)
model: nn.Module = LSTM(embeddings, config)
logger.info("LSTM model used for classification")

elif config.model['model_name'] == 'bert':
from medcat.utils.meta_cat.models import BertForMetaAnnotation
model = BertForMetaAnnotation(config)

if not config.model.model_freeze_layers:
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
target_modules=["query", "value"], lora_dropout=0.2)

model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

logger.info("BERT model used for classification")

else:
raise ValueError("Unknown model name %s" % config.model['model_name'])

Expand All @@ -107,7 +122,7 @@ def get_hash(self) -> str:
return hasher.hexdigest()

@deprecated(message="Use `train_from_json` or `train_raw` instead")
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
continue training if an existing model is loaded or start new training if the model is blank/new.

Expand All @@ -117,13 +132,16 @@ def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter

Returns:
Dict: The resulting report.
"""
return self.train_from_json(json_path, save_dir_path)
return self.train_from_json(json_path, save_dir_path, data_oversampled=data_oversampled)

def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None,
data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
continue training if an existing model is loaded or start new training if the model is blank/new.

Expand All @@ -133,6 +151,8 @@ def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[s
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter

Returns:
Dict: The resulting report.
Expand All @@ -157,9 +177,9 @@ def merge_data_loaded(base, other):
for path in json_path:
with open(path, 'r') as f:
data_loaded = merge_data_loaded(data_loaded, json.load(f))
return self.train_raw(data_loaded, save_dir_path)
return self.train_raw(data_loaded, save_dir_path, data_oversampled=data_oversampled)

def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict:
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict:
"""Train or continue training a model given raw data. It will
continue training if an existing model is loaded or start new training if the model is blank/new.

Expand Down Expand Up @@ -187,13 +207,19 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
save_dir_path (Optional[str]):
In case we have aut_save_model (meaning during the training the best model will be saved)
we need to set a save path. Defaults to `None`.
data_oversampled (Optional[list]):
In case of oversampling being performed, the data will be passed in the parameter
The format of which is expected: [[['text','of','the','document'], [index of medical entity], "label" ],
['text','of','the','document'], [index of medical entity], "label" ]]

Returns:
Dict: The resulting report.

Raises:
Exception: If no save path is specified, or category name not in data.
AssertionError: If no tokeniser is set
FileNotFoundError: If phase_number is set to 2 and model.dat file is not found
KeyError: If phase_number is set to 2 and model.dat file contains mismatched architecture
"""
g_config = self.config.general
t_config = self.config.train
Expand All @@ -212,32 +238,60 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'],
lowercase=g_config['lowercase'])

# Check is the name there
# Check is the name present
category_name = g_config['category_name']
if category_name not in data:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
category_name, " | ".join(list(data.keys()))))

data = data[category_name]
if data_oversampled:
data_sampled = prepare_for_oversampled_data(data_oversampled, self.tokenizer)
data = data + data_sampled

category_value2id = g_config['category_value2id']
if not category_value2id:
# Encode the category values
data, category_value2id = encode_category_values(data)
data_undersampled, full_data, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
else:
# We already have everything, just get the data
data, _ = encode_category_values(data, existing_category_value2id=category_value2id)

data_undersampled, full_data, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: {} vs {}".format(
self.config.model['nclasses'], len(category_value2id)))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)
self.model = self.get_model(embeddings=self.embeddings)

if self.config.model.phase_number == 2 and save_dir_path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phase_number is a new config entry. As such, I'm guessing it's only applicable to BERT-based MetaCATs.
If that is the case, it would be better to check to make sure the model_name is bert before doing the following. Otherwise you might unexpectedly end up running this code when using LSTM.

If this is still also relevant to LSTM, it would need to be documented since it would still correspond to new/different behaviour.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 phase learning (relating to phase_number) can be used with BERT and LSTM.
I've added more description in the config where we are defining the parameter

model_save_path = os.path.join(save_dir_path, 'model.dat')
device = torch.device(g_config['device'])
try:
self.model.load_state_dict(torch.load(model_save_path, map_location=device))
logger.info("Model state loaded from dict for 2 phase learning")

except FileNotFoundError:
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")

except KeyError:
raise KeyError("\nError: Missing key in loaded state dictionary. \nThis might be due to a mismatch between the model architecture and the saved state.")

except Exception as e:
raise Exception(f"\nError: Model state cannot be loaded from dict. {e}")

data = full_data
if self.config.model.phase_number == 1:
data = data_undersampled
if not t_config['auto_save_model']:
logger.info("For phase 1, model state has to be saved. Saving model...")
t_config['auto_save_model'] = True

report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)

Expand Down Expand Up @@ -293,7 +347,7 @@ def eval(self, json_path: str) -> Dict:

# We already have everything, just get the data
category_value2id = g_config['category_value2id']
data, _ = encode_category_values(data, existing_category_value2id=category_value2id)
data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id)

# Run evaluation
assert self.tokenizer is not None
Expand All @@ -317,8 +371,8 @@ def save(self, save_dir_path: str) -> None:
# Save tokenizer
assert self.tokenizer is not None
self.tokenizer.save(save_dir_path)

# Save config

mart-r marked this conversation as resolved.
Show resolved Hide resolved
self.config.save(os.path.join(save_dir_path, 'config.json'))

# Save the model
Expand Down Expand Up @@ -347,7 +401,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
# Load config
config = cast(ConfigMetaCAT, ConfigMetaCAT.load(os.path.join(save_dir_path, 'config.json')))

# Overwrite loaded paramters with something new
# Overwrite loaded parameters with something new
if config_dict is not None:
config.merge_config(config_dict)

Expand All @@ -358,7 +412,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
tokenizer = TokenizerWrapperBPE.load(save_dir_path)
elif config.general['tokenizer_name'] == 'bert-tokenizer':
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
tokenizer = TokenizerWrapperBERT.load(save_dir_path)
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant'])

# Create meta_cat
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)
Expand All @@ -380,7 +434,8 @@ def get_ents(self, doc: Doc) -> Iterable[Span]:
try:
return doc.spans[spangroup_name]
except KeyError:
raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")
raise Exception(
f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")

# Should we annotate overlapping entities
if self.config.general['annotate_overlapping']:
Expand Down Expand Up @@ -421,18 +476,26 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
start = ent.start_char
end = ent.end_char

ind = 0
# Start where the last ent was found, cannot be before it as we've sorted
# Updated implementation to extract all the tokens for the medical entity (rather than the one)
ctoken_idx = []
for ind, pair in enumerate(offset_mapping[last_ind:]):
if start >= pair[0] and start < pair[1]:
break
ind = last_ind + ind # If we did not start from 0 in the for loop
last_ind = ind
# Checking if we've reached at the start of the entity
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind) # End reached
break
else:
ctoken_idx.append(ind) # Keep going

# Start where the last ent was found, cannot be before it as we've sorted
last_ind += ind # If we did not start from 0 in the for loop

_start = max(0, ctoken_idx[0] - cntx_left)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug fix?
Again, seems to affect everything, so change would need to be documented somewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the same change as above

_end = min(len(input_ids), ctoken_idx[-1] + 1 + cntx_right)

_start = max(0, ind - cntx_left)
_end = min(len(input_ids), ind + 1 + cntx_right)
tkns = input_ids[_start:_end]
cpos = cntx_left + min(0, ind - cntx_left)
cpos_new = [x - _start for x in ctoken_idx]

if replace_center is not None:
if lowercase:
Expand All @@ -447,8 +510,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
ln = e_ind - s_ind # Length of the concept in tokens
assert self.tokenizer is not None
tkns = tkns[:cpos] + self.tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:]

samples.append([tkns, cpos])
samples.append([tkns, cpos_new])
ent_id2ind[ent._.id] = len(samples) - 1

return ent_id2ind, samples
Expand Down Expand Up @@ -544,7 +606,6 @@ def _set_meta_anns(self,
for i, doc in enumerate(docs):
data.extend(doc._.share_tokens[0])
doc_ind2positions[i] = doc._.share_tokens[1]

all_predictions, all_confidences = predict(self.model, data, config)
for i, doc in enumerate(docs):
start_ind, end_ind, ent_id2ind = doc_ind2positions[i]
Expand Down
13 changes: 9 additions & 4 deletions medcat/tokenizers/meta_cat_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union, overload
Expand Down Expand Up @@ -26,7 +27,7 @@ def save(self, dir_path: str) -> None: ...

@classmethod
@abstractmethod
def load(cls, dir_path: str, **kwargs) -> Tokenizer: ...
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> Tokenizer: ...

@abstractmethod
def get_size(self) -> int: ...
Expand Down Expand Up @@ -112,7 +113,7 @@ def save(self, dir_path: str) -> None:
self.hf_tokenizers.save_model(dir_path, prefix=self.name)

@classmethod
def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBPE":
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBPE":
tokenizer = cls()
vocab_file = os.path.join(dir_path, f'{tokenizer.name}-vocab.json')
merges_file = os.path.join(dir_path, f'{tokenizer.name}-merges.txt')
Expand Down Expand Up @@ -186,10 +187,14 @@ def save(self, dir_path: str) -> None:
self.hf_tokenizers.save_pretrained(path)

@classmethod
def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBERT":
def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBERT":
tokenizer = cls()
path = os.path.join(dir_path, cls.name)
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
try:
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
except Exception as e:
logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant))
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant)

return tokenizer

Expand Down
Loading
Loading