-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Bert-style model for MetaCAT #419
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
4d36f8a
Pushing changes for bert-style models for MetaCAT
shubham-s-agarwal da9ab06
Pushing fix for LSTM
shubham-s-agarwal cb65fc3
Pushing changes for flake8 and type fixes
shubham-s-agarwal 869eeae
Pushing type fixes
shubham-s-agarwal 3e02eed
Fixing type issue
shubham-s-agarwal c899c9c
Pushing changes
shubham-s-agarwal d1321b8
Pushing change and type fixes
shubham-s-agarwal 9091d9b
Fixing flake8 issues
shubham-s-agarwal c57dcfe
Pushing flake8 fixes
shubham-s-agarwal 364fdd4
Pushing fixes for flake8
shubham-s-agarwal 7272168
Pushing flake8 fix
shubham-s-agarwal 619c565
Adding peft to list of libraries
shubham-s-agarwal 2a546c3
Pushing changes with load and train workflow and type fixes
shubham-s-agarwal 8efd2a9
Pushing changes with type hints and new documentation
shubham-s-agarwal aa5044e
Pushing type fix
shubham-s-agarwal fcdc867
Fixing type issue
shubham-s-agarwal 88ee8e7
Adding test case for BERT and reverting config changes
shubham-s-agarwal 917dca2
Merging changes from master to metacat_bert branch (#431)
shubham-s-agarwal 563c3d4
Merge branch 'master' into metacat_bert
mart-r decfbfb
Pushing changed tests and removing empty change
shubham-s-agarwal fbcdb70
Pushing change for logging
shubham-s-agarwal 2657515
Revert "Pushing change for logging"
shubham-s-agarwal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,18 +11,18 @@ | |
from medcat.utils.hasher import Hasher | ||
from medcat.config_meta_cat import ConfigMetaCAT | ||
from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model | ||
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values | ||
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values, prepare_for_oversampled_data | ||
from medcat.pipeline.pipe_runner import PipeRunner | ||
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase | ||
from medcat.utils.meta_cat.data_utils import Doc as FakeDoc | ||
from medcat.utils.decorators import deprecated | ||
from peft import get_peft_model, LoraConfig, TaskType | ||
|
||
# It should be safe to do this always, as all other multiprocessing | ||
# will be finished before data comes to meta_cat | ||
os.environ["TOKENIZERS_PARALLELISM"] = "true" | ||
|
||
|
||
logger = logging.getLogger(__name__) # separate logger from the package-level one | ||
logger = logging.getLogger(__name__) # separate logger from the package-level one | ||
|
||
|
||
class MetaCAT(PipeRunner): | ||
|
@@ -77,7 +77,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: | |
The embedding densor | ||
|
||
Raises: | ||
ValueError: If the meta model is not LSTM | ||
ValueError: If the meta model is not LSTM or BERT | ||
|
||
Returns: | ||
nn.Module: | ||
|
@@ -86,7 +86,22 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: | |
config = self.config | ||
if config.model['model_name'] == 'lstm': | ||
from medcat.utils.meta_cat.models import LSTM | ||
model = LSTM(embeddings, config) | ||
model: nn.Module = LSTM(embeddings, config) | ||
logger.info("LSTM model used for classification") | ||
|
||
elif config.model['model_name'] == 'bert': | ||
from medcat.utils.meta_cat.models import BertForMetaAnnotation | ||
model = BertForMetaAnnotation(config) | ||
|
||
if not config.model.model_freeze_layers: | ||
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, | ||
target_modules=["query", "value"], lora_dropout=0.2) | ||
|
||
model = get_peft_model(model, peft_config) | ||
# model.print_trainable_parameters() | ||
|
||
logger.info("BERT model used for classification") | ||
|
||
else: | ||
raise ValueError("Unknown model name %s" % config.model['model_name']) | ||
|
||
|
@@ -107,7 +122,7 @@ def get_hash(self) -> str: | |
return hasher.hexdigest() | ||
|
||
@deprecated(message="Use `train_from_json` or `train_raw` instead") | ||
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: | ||
def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict: | ||
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will | ||
continue training if an existing model is loaded or start new training if the model is blank/new. | ||
|
||
|
@@ -117,13 +132,16 @@ def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None | |
save_dir_path (Optional[str]): | ||
In case we have aut_save_model (meaning during the training the best model will be saved) | ||
we need to set a save path. Defaults to `None`. | ||
data_oversampled (Optional[list]): | ||
In case of oversampling being performed, the data will be passed in the parameter | ||
|
||
Returns: | ||
Dict: The resulting report. | ||
""" | ||
return self.train_from_json(json_path, save_dir_path) | ||
return self.train_from_json(json_path, save_dir_path, data_oversampled=data_oversampled) | ||
|
||
def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: | ||
def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None, | ||
data_oversampled: Optional[list] = None) -> Dict: | ||
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will | ||
continue training if an existing model is loaded or start new training if the model is blank/new. | ||
|
||
|
@@ -133,6 +151,8 @@ def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[s | |
save_dir_path (Optional[str]): | ||
In case we have aut_save_model (meaning during the training the best model will be saved) | ||
we need to set a save path. Defaults to `None`. | ||
data_oversampled (Optional[list]): | ||
In case of oversampling being performed, the data will be passed in the parameter | ||
|
||
Returns: | ||
Dict: The resulting report. | ||
|
@@ -157,9 +177,9 @@ def merge_data_loaded(base, other): | |
for path in json_path: | ||
with open(path, 'r') as f: | ||
data_loaded = merge_data_loaded(data_loaded, json.load(f)) | ||
return self.train_raw(data_loaded, save_dir_path) | ||
return self.train_raw(data_loaded, save_dir_path, data_oversampled=data_oversampled) | ||
|
||
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict: | ||
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict: | ||
"""Train or continue training a model given raw data. It will | ||
continue training if an existing model is loaded or start new training if the model is blank/new. | ||
|
||
|
@@ -187,13 +207,19 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D | |
save_dir_path (Optional[str]): | ||
In case we have aut_save_model (meaning during the training the best model will be saved) | ||
we need to set a save path. Defaults to `None`. | ||
data_oversampled (Optional[list]): | ||
In case of oversampling being performed, the data will be passed in the parameter | ||
The format of which is expected: [[['text','of','the','document'], [index of medical entity], "label" ], | ||
['text','of','the','document'], [index of medical entity], "label" ]] | ||
|
||
Returns: | ||
Dict: The resulting report. | ||
|
||
Raises: | ||
Exception: If no save path is specified, or category name not in data. | ||
AssertionError: If no tokeniser is set | ||
FileNotFoundError: If phase_number is set to 2 and model.dat file is not found | ||
KeyError: If phase_number is set to 2 and model.dat file contains mismatched architecture | ||
""" | ||
g_config = self.config.general | ||
t_config = self.config.train | ||
|
@@ -212,32 +238,60 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D | |
replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'], | ||
lowercase=g_config['lowercase']) | ||
|
||
# Check is the name there | ||
# Check is the name present | ||
category_name = g_config['category_name'] | ||
if category_name not in data: | ||
raise Exception( | ||
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format( | ||
category_name, " | ".join(list(data.keys())))) | ||
|
||
data = data[category_name] | ||
if data_oversampled: | ||
data_sampled = prepare_for_oversampled_data(data_oversampled, self.tokenizer) | ||
data = data + data_sampled | ||
|
||
category_value2id = g_config['category_value2id'] | ||
if not category_value2id: | ||
# Encode the category values | ||
data, category_value2id = encode_category_values(data) | ||
data_undersampled, full_data, category_value2id = encode_category_values(data, | ||
category_undersample=self.config.model.category_undersample) | ||
g_config['category_value2id'] = category_value2id | ||
else: | ||
# We already have everything, just get the data | ||
data, _ = encode_category_values(data, existing_category_value2id=category_value2id) | ||
|
||
data_undersampled, full_data, category_value2id = encode_category_values(data, | ||
existing_category_value2id=category_value2id, | ||
category_undersample=self.config.model.category_undersample) | ||
g_config['category_value2id'] = category_value2id | ||
# Make sure the config number of classes is the same as the one found in the data | ||
if len(category_value2id) != self.config.model['nclasses']: | ||
logger.warning( | ||
"The number of classes set in the config is not the same as the one found in the data: {} vs {}".format( | ||
self.config.model['nclasses'], len(category_value2id))) | ||
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") | ||
self.config.model['nclasses'] = len(category_value2id) | ||
self.model = self.get_model(embeddings=self.embeddings) | ||
|
||
if self.config.model.phase_number == 2 and save_dir_path is not None: | ||
model_save_path = os.path.join(save_dir_path, 'model.dat') | ||
device = torch.device(g_config['device']) | ||
try: | ||
self.model.load_state_dict(torch.load(model_save_path, map_location=device)) | ||
logger.info("Model state loaded from dict for 2 phase learning") | ||
|
||
except FileNotFoundError: | ||
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.") | ||
|
||
except KeyError: | ||
raise KeyError("\nError: Missing key in loaded state dictionary. \nThis might be due to a mismatch between the model architecture and the saved state.") | ||
|
||
except Exception as e: | ||
raise Exception(f"\nError: Model state cannot be loaded from dict. {e}") | ||
|
||
data = full_data | ||
if self.config.model.phase_number == 1: | ||
data = data_undersampled | ||
if not t_config['auto_save_model']: | ||
logger.info("For phase 1, model state has to be saved. Saving model...") | ||
t_config['auto_save_model'] = True | ||
|
||
report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path) | ||
|
||
|
@@ -293,7 +347,7 @@ def eval(self, json_path: str) -> Dict: | |
|
||
# We already have everything, just get the data | ||
category_value2id = g_config['category_value2id'] | ||
data, _ = encode_category_values(data, existing_category_value2id=category_value2id) | ||
data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id) | ||
|
||
# Run evaluation | ||
assert self.tokenizer is not None | ||
|
@@ -317,8 +371,8 @@ def save(self, save_dir_path: str) -> None: | |
# Save tokenizer | ||
assert self.tokenizer is not None | ||
self.tokenizer.save(save_dir_path) | ||
|
||
# Save config | ||
|
||
mart-r marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.config.save(os.path.join(save_dir_path, 'config.json')) | ||
|
||
# Save the model | ||
|
@@ -347,7 +401,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA | |
# Load config | ||
config = cast(ConfigMetaCAT, ConfigMetaCAT.load(os.path.join(save_dir_path, 'config.json'))) | ||
|
||
# Overwrite loaded paramters with something new | ||
# Overwrite loaded parameters with something new | ||
if config_dict is not None: | ||
config.merge_config(config_dict) | ||
|
||
|
@@ -358,7 +412,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA | |
tokenizer = TokenizerWrapperBPE.load(save_dir_path) | ||
elif config.general['tokenizer_name'] == 'bert-tokenizer': | ||
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT | ||
tokenizer = TokenizerWrapperBERT.load(save_dir_path) | ||
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant']) | ||
|
||
# Create meta_cat | ||
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) | ||
|
@@ -380,7 +434,8 @@ def get_ents(self, doc: Doc) -> Iterable[Span]: | |
try: | ||
return doc.spans[spangroup_name] | ||
except KeyError: | ||
raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.") | ||
raise Exception( | ||
f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.") | ||
|
||
# Should we annotate overlapping entities | ||
if self.config.general['annotate_overlapping']: | ||
|
@@ -421,18 +476,26 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe | |
start = ent.start_char | ||
end = ent.end_char | ||
|
||
ind = 0 | ||
# Start where the last ent was found, cannot be before it as we've sorted | ||
# Updated implementation to extract all the tokens for the medical entity (rather than the one) | ||
ctoken_idx = [] | ||
for ind, pair in enumerate(offset_mapping[last_ind:]): | ||
if start >= pair[0] and start < pair[1]: | ||
break | ||
ind = last_ind + ind # If we did not start from 0 in the for loop | ||
last_ind = ind | ||
# Checking if we've reached at the start of the entity | ||
if start <= pair[0] or start <= pair[1]: | ||
if end <= pair[1]: | ||
ctoken_idx.append(ind) # End reached | ||
break | ||
else: | ||
ctoken_idx.append(ind) # Keep going | ||
|
||
# Start where the last ent was found, cannot be before it as we've sorted | ||
last_ind += ind # If we did not start from 0 in the for loop | ||
|
||
_start = max(0, ctoken_idx[0] - cntx_left) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a bug fix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its the same change as above |
||
_end = min(len(input_ids), ctoken_idx[-1] + 1 + cntx_right) | ||
|
||
_start = max(0, ind - cntx_left) | ||
_end = min(len(input_ids), ind + 1 + cntx_right) | ||
tkns = input_ids[_start:_end] | ||
cpos = cntx_left + min(0, ind - cntx_left) | ||
cpos_new = [x - _start for x in ctoken_idx] | ||
|
||
if replace_center is not None: | ||
if lowercase: | ||
|
@@ -447,8 +510,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe | |
ln = e_ind - s_ind # Length of the concept in tokens | ||
assert self.tokenizer is not None | ||
tkns = tkns[:cpos] + self.tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:] | ||
|
||
samples.append([tkns, cpos]) | ||
samples.append([tkns, cpos_new]) | ||
ent_id2ind[ent._.id] = len(samples) - 1 | ||
|
||
return ent_id2ind, samples | ||
|
@@ -544,7 +606,6 @@ def _set_meta_anns(self, | |
for i, doc in enumerate(docs): | ||
data.extend(doc._.share_tokens[0]) | ||
doc_ind2positions[i] = doc._.share_tokens[1] | ||
|
||
all_predictions, all_confidences = predict(self.model, data, config) | ||
for i, doc in enumerate(docs): | ||
start_ind, end_ind, ent_id2ind = doc_ind2positions[i] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phase_number
is a new config entry. As such, I'm guessing it's only applicable to BERT-based MetaCATs.If that is the case, it would be better to check to make sure the
model_name
isbert
before doing the following. Otherwise you might unexpectedly end up running this code when using LSTM.If this is still also relevant to LSTM, it would need to be documented since it would still correspond to new/different behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 phase learning (relating to phase_number) can be used with BERT and LSTM.
I've added more description in the config where we are defining the parameter