-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Bert-style model for MetaCAT #419
Conversation
medcat/meta_cat.py
Outdated
|
||
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict: | ||
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None,data_=None) -> Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a confusing method sig now, why is data_ needed alongside data_loaded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_loaded is the actual dataset that we've loaded (from medcat export), data_ is the oversampled additional data that will be added to it the actual data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay - _data
should be renamed to something more descriptive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've replaced it with data_oversampled
medcat/meta_cat.py
Outdated
category_name = g_config['category_name'] | ||
if category_name not in data: | ||
raise Exception( | ||
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format( | ||
category_name, " | ".join(list(data.keys())))) | ||
|
||
data = data[category_name] | ||
if data_: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're running tokenzier here on the sample[0]
, but in prepare_from_json
, the tokenzier is passed in there and ran for the LSTM model.
Can we reuse prepare_from_json ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.tokenizer
is used for both:
prepare_from_json
which processes the actual data from the medcat export (used by bert and lstm)- processing
data_
(oversampled data)
I can try and re-use prepare_from_json
, however the structure for the oversampled data would need to match the medcat export one (which is complex) and would entail re-processing some bits (calculating cpos and others)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-structured to include this in data_utils under prepare_for_oversampled_data
function
medcat/meta_cat.py
Outdated
|
||
self.model = self.get_model(embeddings=self.embeddings) | ||
|
||
if self.config.model.load_model_dict_: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to be more specific here that we're loading the bert model impl of MetaCAT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we've added the info (in the logger) in the get_model
function
`
def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
"""Get the model
Args:
embeddings (Optional[Tensor]):
The embedding densor
Raises:
ValueError: If the meta model is not LSTM or BERT
Returns:
nn.Module:
The module
"""
config = self.config
from medcat.utils.meta_cat.models import LSTM
from medcat.utils.meta_cat.models import BertForMetaAnnotation
if config.model['model_name'] == 'lstm':
model: Union[LSTM, BertForMetaAnnotation] = LSTM(embeddings, config)
**logger.info("LSTM model used for classification")**
elif config.model['model_name'] == 'bert':
model = BertForMetaAnnotation(config)
if not config.model.model_freeze_layers:
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
target_modules=["query", "value"], lora_dropout=0.2)
model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()
**logger.info("BERT model used for classification")**
`
medcat/meta_cat.py
Outdated
@@ -319,7 +355,7 @@ def save(self, save_dir_path: str) -> None: | |||
self.tokenizer.save(save_dir_path) | |||
|
|||
# Save config | |||
self.config.save(os.path.join(save_dir_path, 'config.json')) | |||
# self.config.save(os.path.join(save_dir_path, 'config.json')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we're saving config? This change shouldn't affect the existing Bi-LSTM models, or if a user wants to use the bi-LSTM implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we aren't saving the config as of now, the reason is the version of the config that is being saved is drastically different to the one we are passing, causing issues when the config is again being read and modified. Here is the version being saved:
{ "general": { "py/object": "medcat.config_meta_cat.General", "py/state": { "__dict__": { "device": "cpu", "disable_component_lock": false, "seed": 13, "description": "No description", "category_name": "Status", "category_value2id": {"Other": 1, "Confirmed": 0}, "vocab_size": 30522, "lowercase": true, "cntx_left": 20, "cntx_right": 10, "replace_center": null, "batch_size_eval": 5000, "annotate_overlapping": false, "tokenizer_name": "bert-tokenizer", "save_and_reuse_tokens": false, "pipe_batch_size_in_chars": 20000000, "span_group": null }, "__fields_set__": { "py/set": [ "device" , "lowercase" , "save_and_reuse_tokens" , "seed" , "category_value2id" , "replace_center" , "pipe_batch_size_in_chars", "annotate_overlapping" , "disable_component_lock" , "batch_size_eval" , "cntx_right" , "cntx_left" , "tokenizer_name" , "vocab_size" , "category_name" ] }, "__private_attribute_values__": {} } }, "model": { "py/object": "medcat.config_meta_cat.Model", "py/state": { "__dict__": { "model_name": "bert", "model_variant": "bert-base-uncased", "model_freeze_layers": false, "num_layers": 3, "input_size": 1536, "hidden_size": 32, "dropout": 0.25, "category_undersample": "Other", "model_architecture_config": { "fc2" : true , "fc3" : false, "lr_scheduler": true }, "num_directions": 2, "nclasses": 2, "padding_idx": 0, "emb_grad": true, "ignore_cpos": false, "load_model_dict_": false, "train_on_full_data": true }, "__fields_set__": { "py/set": [ "dropout" , "ignore_cpos" , "padding_idx" , "hidden_size" , "train_on_full_data" , "model_name" , "emb_grad" , "input_size" , "num_directions" , "nclasses" , "model_architecture_config", "load_model_dict_" , "model_variant" , "model_freeze_layers" , "num_layers" , "category_undersample" ] }, "__private_attribute_values__": {} } }, "train": { "py/object": "medcat.config_meta_cat.Train", "py/state": { "__dict__": { "batch_size": 64, "nepochs": 20, "lr": 0.0005, "test_size": 0.2, "shuffle_data": true, "class_weights": [0.3, 0.65], "score_average": "weighted", "prerequisites": {}, "cui_filter": null, "auto_save_model": true, "last_train_on": null, "metric": {"base": "macro avg", "score": "f1-score"}, "loss_funct": "cross_entropy", "gamma": 3, "loss_function": "focal_loss" }, "__fields_set__": { "py/set": [ "batch_size" , "lr" , "auto_save_model", "class_weights" , "nepochs" , "test_size" , "gamma" , "prerequisites" , "score_average" , "loss_function" , "cui_filter" , "shuffle_data" ] }, "__private_attribute_values__": {} } } }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I don't understand your point of why the config isn't being saved? The modified config class needs to be backwards compatible with prev models. If we need a Bert specific implementation config there should be a new class that is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this, config is being saved now
medcat/meta_cat.py
Outdated
|
||
try: | ||
meta_cat.model.load_state_dict(torch.load(model_save_path, map_location=device)) | ||
except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we be more specific on the exception caught, and pass the exception into the log message at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made the change to display a warning message (this version is outdated)
return x, cpos, y | ||
# cpos = torch.tensor(cpos, dtype=torch.long).to(device) | ||
attention_masks = (x != 0).type(torch.int) | ||
return x, cpos, attention_masks, y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to update doc string, why is cpos data movement movement commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is due to earlier, cpos contained a single index, now it contains dynamic number of indices for each data point. Due to this, it cannot be converted to a tensor
medcat/utils/meta_cat/ml_utils.py
Outdated
@@ -110,8 +113,8 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]], | |||
return predictions, confidences | |||
|
|||
|
|||
def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> Tuple: | |||
"""Shuffle and randomply split data | |||
def split_list_train_test(data: List, test_size: int, shuffle: bool = True) -> Tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
method sig has now been changed, just use a new param here, test_size_int: int, and keep the float which is a % size for test. If both are set raise a ValueError( 'test_size and test_size_int both set only use one' )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
optimizer.step() | ||
if config.model.model_architecture_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh - I don't see an optim.zero_grad()
so gradients accumulate over all batches?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth going with the hugginface Trainer
here, as that includes a battle tested training loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added zero_grad()
from each batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs flake8 fixes but also some method sig updates to ensure backwards compatibility with current models / training code.
medcat/meta_cat.py
Outdated
@@ -319,7 +355,7 @@ def save(self, save_dir_path: str) -> None: | |||
self.tokenizer.save(save_dir_path) | |||
|
|||
# Save config | |||
self.config.save(os.path.join(save_dir_path, 'config.json')) | |||
# self.config.save(os.path.join(save_dir_path, 'config.json')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I don't understand your point of why the config isn't being saved? The modified config class needs to be backwards compatible with prev models. If we need a Bert specific implementation config there should be a new class that is used
1) Added model.zero_grad to clear accumulated gradients 2) Fixed config save issue 3) Re-structured data preparation for oversampled data
Pushing ml_utils file which was missed in the last commit
The workflow for inference is: load() and inference For training: init() and train() Train will always not load the model dict, except when the phase_number is set to 2 for 2 phase learning's second phase
I've made the changes to the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of my comments might come from a point of ignorance. I don't have a good enough knowledge how these things work under the hood.
But what I do want to make sure we retain is previous behaviour. And I've got a feeling quite a few of these changes change previous (default) behaviour.
Ideally, we'd have regression testing for this. Which would be able to fish out issues that would rise from some of these changes. But alas, we don't have that set up (at least not yet).
Furthermore, please add at least some tests for the new code that has been added. Especially for BERT-based MetaCATs. Otherwise someone might come with another PR, change a few parameters, and all of this might be broken the next release.
I didn't check, but some of the MetaCAT tests may actually now be running with BERT because of the changes in defaults. But we don't really want that. New code generally requires new tests. And old tests should keep testing what they tested before. Otherwise we end up thinking something is automatically tested when it's not.
medcat/config.py
Outdated
@@ -103,6 +103,7 @@ def save(self, save_path: str) -> None: | |||
save_path(str): Where to save the created json file | |||
""" | |||
# We want to save the dict here, not the whole class | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add "changes" where there are none.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't been removed.
medcat/config_meta_cat.py
Outdated
@@ -28,7 +27,7 @@ class General(MixingConfig, BaseModel): | |||
"""Number of annotations to be meta-annotated at once in eval""" | |||
annotate_overlapping: bool = False | |||
"""If set meta_anns will be calcualted for doc._.ents, otherwise for doc.ents""" | |||
tokenizer_name: str = 'bbpe' | |||
tokenizer_name: str = 'bert-tokenizer' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the default here could change existing behaviour. We want to avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reverted the change, I made bert as the default one because I wanted to use it for meta_cat_tests.
medcat/config_meta_cat.py
Outdated
@@ -48,11 +47,19 @@ class Config: | |||
|
|||
class Model(MixingConfig, BaseModel): | |||
"""The model part of the metaCAT config""" | |||
model_name: str = 'lstm' | |||
model_name: str = 'bert' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the default here could change existing behaviour. We want to avoid that.
Perhaps a separate method to get a default config suitable for BERT rather than LSTM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reverted the change
medcat/config_meta_cat.py
Outdated
num_layers: int = 2 | ||
input_size: int = 300 | ||
input_size: int = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the default here could change existing behaviour. We want to avoid that.
If this is relevant to BERT but not LSTM, perhaps a separate method to get a default config suitable for BERT rather than LSTM?
If this should change in general, we can leave it be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reverted the change, input_size has to be 768 for bert, this is only for LSTM.
medcat/config_meta_cat.py
Outdated
@@ -70,13 +77,15 @@ class Config: | |||
|
|||
class Train(MixingConfig, BaseModel): | |||
"""The train part of the metaCAT config""" | |||
batch_size: int = 100 | |||
nepochs: int = 50 | |||
batch_size: int = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the default here could change existing behaviour. We want to avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reverted the change
medcat/meta_cat.py
Outdated
@@ -421,18 +473,37 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe | |||
start = ent.start_char | |||
end = ent.end_char | |||
|
|||
flag = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change in behaviour seems to affect both LSTM and BERT approaches.
Is that deliberate? If so, it would require some documentation since it looks like new/changed behaviour.
If no, then we'd need to make sure this only runs for BERT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it changes the way we extract the tokens for the medical entity in question which should affect LSTM and BERT.
I've optimised the implementation further and added comments
last_ind = ind | ||
|
||
_start = max(0, ind - cntx_left) | ||
# _start = max(0, ind - cntx_left) | ||
_start = max(0, ctoken_idx[0] - cntx_left) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a bug fix?
Again, seems to affect everything, so change would need to be documented somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its the same change as above
medcat/utils/meta_cat/data_utils.py
Outdated
@@ -33,6 +34,8 @@ def prepare_from_json(data: Dict, | |||
{'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data. Defaults to `{}`. | |||
lowercase (bool): | |||
Should the text be lowercased before tokenization. Defaults to True. | |||
cui_filter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this moved around and given less documentation? Doesn't look like there would be a reason for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've fixed this
medcat/utils/meta_cat/data_utils.py
Outdated
start = ann['start'] | ||
end = ann['end'] | ||
|
||
# Get the index of the center token | ||
flag = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, looks like some kind of new behaviour. Why is this added? Will this retain old behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its the same change: it changes the way we extract the tokens for the medical entity in question.
medcat/utils/meta_cat/data_utils.py
Outdated
|
||
Returns: | ||
dict: | ||
New data with integeres inplace of strings for categry values. | ||
New data with integers inplace of strings for categry values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this now returns a tuple of 3 things. Would be great to document that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added it in the docstrings
BERT test cases: Testing for BERT model along with 2 phase learning
I've added tests to metacat_test for BERT. It performs all the tests for BERT, and also checks the 2 phase learning by running phase 1 and phase 2. |
* Small addition to contribution guidelines (#420) * CU-8694cbcpu: Allow specifying an AU Snomed when preprocessing (#421) * CU-8694dpy1c: Return empty generator upon empty stream (#423) * CU-8694dpy1c: Return empty generator upon empty stream * CU-8694dpy1c: Fix empty generator returns * CU-8694dpy1c: Simplify empty generator returns * Relation extraction (#173) * Added files. * More additions to rel extraction. * Rel base. * Update. * Updates. * Dependency parsing. * Updates. * Added pre-training steps. * Added training & model utils. * Cleanup & fixes. * Update. * Evaluation updates for pretraining. * Removed duplicate relation storage. * Moved RE model file location. * Structure revisions. * Added custom config for RE. * Implemented custom dataset loader for RE. * More changes. * Small fix. * Latest additions to RelCAT (pipe + predictions) * Setup.py fix. * RE utils update. * rel model update. * rel dataset + tokenizer improvements. * RelCAT updates. * RelCAT saving/loading improvements. * RelCAT saving/loading improvements. * RelCAT model fixes. * Attempted gpu learning fix. Dataset label generation fixes. * Minor train dataset gen fix. * Minor train dataset gen fix No.2. * Config updates. * Gpu support fixes. Added label stats. * Evaluation stat fixes. * Cleaned stat output mode during training. * Build fix. * removed unused dependencies and fixed code formatting * Mypy compliance. * Fixed linting. * More Gpu mode train fixes. * Fixed model saving/loading issues when using other baes models. * More fixes to stat evaluation. Added proper CAT integration of RelCAT. * Setup.py typo fix. * RelCAT loading fix. * RelCAT Config changes. * Type fix. Minor additions to RelCAT model. * Type fixes. * Type corrections. * RelCAT update. * Type fixes. * Fixed type issue. * RelCATConfig: added seed param. * Adaptations to the new codebase + type fixes.. * Doc/type fixes. * Fixed input size issue for model. * Fixed issue(s) with model size and config. * RelCAT: updated configs to new style. * RelCAT: removed old refs to logging. * Fixed GPU training + added extra stat print for train set. * Type fixes. * Updated dev requirements. * Linting. * Fixed pin_memory issue when training on CPU. * Updated RelCAT dataset get + default config. * Updated RelDS generator + default config * Linting. * Updated RelDatset + config. * Pushing updates to model Made changes to: 1) Extracting given number of context tokens left and right of the entities 2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them * Fixing formatting * Update rel_dataset.py * Update rel_dataset.py * Update rel_dataset.py * RelCAT: added test resource files. * RelCAT: Fixed model load/checkpointing. * RelCAT: updated to pipe spacy doc call. * RelCAT: added tests. * Fixed lint/type issues & added rel tag to test DS. * Fixed ann id to token issue. * RelCAT: updated test dataset + tests. * RelCAT: updates to requested changes + dataset improvements. * RelCAT: updated docs/logs according to commends. * RelCAT: type fix. * RelCAT: mct export dataset updates. * RelCAT: test updates + requested changes p2. * RelCAT: log for MCT export train. * Updated docs + split train_test & dataset for benchmarks. * type fixes. --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: mart-r <mart.ratas@gmail.com> * CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases (#424) * CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases * CU-8694fae3r: Fix pre-releases tagging * CU-8694fae3r: Allow actions to run on release edit --------- Co-authored-by: Mart Ratas <mart.ratas@gmail.com> Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the whitespace change you did in config.py
- it pollutes the changes since you didn't actually change any functionality in that module.
Please fix the tests. Create another class in the same module.
Change of state like you've done now can create a situation where the tests may run in one environment (or even on one execution) but not another.
As for the optimised algorithm. The only reason I'm concerned about that is because I haven't dove into it close enough to really understand what's happening. And because of that I'm afraid of approving changes that could change existing behaviour.
However, if you're certain it retains existing behaviour, this can be left as is.
medcat/config.py
Outdated
@@ -103,6 +103,7 @@ def save(self, save_path: str) -> None: | |||
save_path(str): Where to save the created json file | |||
""" | |||
# We want to save the dict here, not the whole class | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't been removed.
medcat/meta_cat.py
Outdated
|
||
Returns: | ||
nn.Module: | ||
The module | ||
""" | ||
config = self.config | ||
from medcat.utils.meta_cat.models import LSTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to just use the common base class (nn.Module
) for type annotation.
tests/test_meta_cat.py
Outdated
@@ -89,6 +93,30 @@ def test_predict_spangroup(self): | |||
|
|||
n_meta_cat.config.general.span_group = None | |||
|
|||
def test_z_bert_meta_cat(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a good way to test your changes.
You're changing the state of the class (changing the value of MetaCATTests .meta_cat
). The order that the tests are run is not tied to the order in which the tests are written in the class. As such, this change to the state of the class could make the other tests in the class be run on the MetaCAT you've set in this test rather than the one that was set up in the setUpClass
method.
I would recommend making a new class that extends unittest.TestCase
that has its own state and deals with this new type only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
This reverts commit fbcdb70.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
The following changes have been made: