Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Bert-style model for MetaCAT #419

Merged
merged 22 commits into from
May 16, 2024
Merged

Adding Bert-style model for MetaCAT #419

merged 22 commits into from
May 16, 2024

Conversation

shubham-s-agarwal
Copy link
Collaborator

The following changes have been made:

  1. Adding BERT model support for MetaCAT
  2. Including 2-phase learning for BERT
  3. Using LoRA for training BERT's parameters
  4. Support for adding in oversampled data
  5. Retrieving the hidden state of all tokens of the medical entity
  6. Using Focal loss to train the models
  7. Using stratified splitting for train test dataset

@shubham-s-agarwal shubham-s-agarwal added the enhancement New feature or request label Apr 19, 2024
@shubham-s-agarwal shubham-s-agarwal self-assigned this Apr 19, 2024

def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict:
def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None,data_=None) -> Dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a confusing method sig now, why is data_ needed alongside data_loaded

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_loaded is the actual dataset that we've loaded (from medcat export), data_ is the oversampled additional data that will be added to it the actual data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay - _data should be renamed to something more descriptive

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced it with data_oversampled

category_name = g_config['category_name']
if category_name not in data:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
category_name, " | ".join(list(data.keys()))))

data = data[category_name]
if data_:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're running tokenzier here on the sample[0], but in prepare_from_json, the tokenzier is passed in there and ran for the LSTM model.

Can we reuse prepare_from_json ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.tokenizer is used for both:

  • prepare_from_json which processes the actual data from the medcat export (used by bert and lstm)
  • processing data_ (oversampled data)

I can try and re-use prepare_from_json, however the structure for the oversampled data would need to match the medcat export one (which is complex) and would entail re-processing some bits (calculating cpos and others)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-structured to include this in data_utils under prepare_for_oversampled_data function


self.model = self.get_model(embeddings=self.embeddings)

if self.config.model.load_model_dict_:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to be more specific here that we're loading the bert model impl of MetaCAT?

Copy link
Collaborator Author

@shubham-s-agarwal shubham-s-agarwal Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we've added the info (in the logger) in the get_model function

`
def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
"""Get the model

    Args:
        embeddings (Optional[Tensor]):
            The embedding densor

    Raises:
        ValueError: If the meta model is not LSTM or BERT

    Returns:
        nn.Module:
            The module
    """
    config = self.config
    from medcat.utils.meta_cat.models import LSTM
    from medcat.utils.meta_cat.models import BertForMetaAnnotation
    if config.model['model_name'] == 'lstm':
        model: Union[LSTM, BertForMetaAnnotation] = LSTM(embeddings, config)
        **logger.info("LSTM model used for classification")**

    elif config.model['model_name'] == 'bert':
        model = BertForMetaAnnotation(config)

        if not config.model.model_freeze_layers:

            peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
                                     target_modules=["query", "value"], lora_dropout=0.2)

            model = get_peft_model(model, peft_config)
            # model.print_trainable_parameters()

        **logger.info("BERT model used for classification")**

`

@@ -319,7 +355,7 @@ def save(self, save_dir_path: str) -> None:
self.tokenizer.save(save_dir_path)

# Save config
self.config.save(os.path.join(save_dir_path, 'config.json'))
# self.config.save(os.path.join(save_dir_path, 'config.json'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we're saving config? This change shouldn't affect the existing Bi-LSTM models, or if a user wants to use the bi-LSTM implementation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we aren't saving the config as of now, the reason is the version of the config that is being saved is drastically different to the one we are passing, causing issues when the config is again being read and modified. Here is the version being saved:

{ "general": { "py/object": "medcat.config_meta_cat.General", "py/state": { "__dict__": { "device": "cpu", "disable_component_lock": false, "seed": 13, "description": "No description", "category_name": "Status", "category_value2id": {"Other": 1, "Confirmed": 0}, "vocab_size": 30522, "lowercase": true, "cntx_left": 20, "cntx_right": 10, "replace_center": null, "batch_size_eval": 5000, "annotate_overlapping": false, "tokenizer_name": "bert-tokenizer", "save_and_reuse_tokens": false, "pipe_batch_size_in_chars": 20000000, "span_group": null }, "__fields_set__": { "py/set": [ "device" , "lowercase" , "save_and_reuse_tokens" , "seed" , "category_value2id" , "replace_center" , "pipe_batch_size_in_chars", "annotate_overlapping" , "disable_component_lock" , "batch_size_eval" , "cntx_right" , "cntx_left" , "tokenizer_name" , "vocab_size" , "category_name" ] }, "__private_attribute_values__": {} } }, "model": { "py/object": "medcat.config_meta_cat.Model", "py/state": { "__dict__": { "model_name": "bert", "model_variant": "bert-base-uncased", "model_freeze_layers": false, "num_layers": 3, "input_size": 1536, "hidden_size": 32, "dropout": 0.25, "category_undersample": "Other", "model_architecture_config": { "fc2" : true , "fc3" : false, "lr_scheduler": true }, "num_directions": 2, "nclasses": 2, "padding_idx": 0, "emb_grad": true, "ignore_cpos": false, "load_model_dict_": false, "train_on_full_data": true }, "__fields_set__": { "py/set": [ "dropout" , "ignore_cpos" , "padding_idx" , "hidden_size" , "train_on_full_data" , "model_name" , "emb_grad" , "input_size" , "num_directions" , "nclasses" , "model_architecture_config", "load_model_dict_" , "model_variant" , "model_freeze_layers" , "num_layers" , "category_undersample" ] }, "__private_attribute_values__": {} } }, "train": { "py/object": "medcat.config_meta_cat.Train", "py/state": { "__dict__": { "batch_size": 64, "nepochs": 20, "lr": 0.0005, "test_size": 0.2, "shuffle_data": true, "class_weights": [0.3, 0.65], "score_average": "weighted", "prerequisites": {}, "cui_filter": null, "auto_save_model": true, "last_train_on": null, "metric": {"base": "macro avg", "score": "f1-score"}, "loss_funct": "cross_entropy", "gamma": 3, "loss_function": "focal_loss" }, "__fields_set__": { "py/set": [ "batch_size" , "lr" , "auto_save_model", "class_weights" , "nepochs" , "test_size" , "gamma" , "prerequisites" , "score_average" , "loss_function" , "cui_filter" , "shuffle_data" ] }, "__private_attribute_values__": {} } } }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I don't understand your point of why the config isn't being saved? The modified config class needs to be backwards compatible with prev models. If we need a Bert specific implementation config there should be a new class that is used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this, config is being saved now


try:
meta_cat.model.load_state_dict(torch.load(model_save_path, map_location=device))
except:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more specific on the exception caught, and pass the exception into the log message at least.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the change to display a warning message (this version is outdated)

return x, cpos, y
# cpos = torch.tensor(cpos, dtype=torch.long).to(device)
attention_masks = (x != 0).type(torch.int)
return x, cpos, attention_masks, y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to update doc string, why is cpos data movement movement commented out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to earlier, cpos contained a single index, now it contains dynamic number of indices for each data point. Due to this, it cannot be converted to a tensor

@@ -110,8 +113,8 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]],
return predictions, confidences


def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> Tuple:
"""Shuffle and randomply split data
def split_list_train_test(data: List, test_size: int, shuffle: bool = True) -> Tuple:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

method sig has now been changed, just use a new param here, test_size_int: int, and keep the float which is a % size for test. If both are set raise a ValueError( 'test_size and test_size_int both set only use one' )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

optimizer.step()
if config.model.model_architecture_config is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh - I don't see an optim.zero_grad()
so gradients accumulate over all batches?!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth going with the hugginface Trainer here, as that includes a battle tested training loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added zero_grad() from each batch

Copy link
Member

@tomolopolis tomolopolis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs flake8 fixes but also some method sig updates to ensure backwards compatibility with current models / training code.

@@ -319,7 +355,7 @@ def save(self, save_dir_path: str) -> None:
self.tokenizer.save(save_dir_path)

# Save config
self.config.save(os.path.join(save_dir_path, 'config.json'))
# self.config.save(os.path.join(save_dir_path, 'config.json'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I don't understand your point of why the config isn't being saved? The modified config class needs to be backwards compatible with prev models. If we need a Bert specific implementation config there should be a new class that is used

1) Added model.zero_grad to clear accumulated gradients
2) Fixed config save issue
3) Re-structured data preparation for oversampled data
Pushing ml_utils file which was missed in the last commit
The workflow for inference is: load() and inference
For training: init() and train()
Train will always not load the model dict, except when the phase_number is set to 2 for 2 phase learning's second phase
@shubham-s-agarwal
Copy link
Collaborator Author

I've made the changes to the PR.
The workflow for inference is: load() and inference
For training: init() and train()
Train will always not load the model dict, except when the phase_number is set to 2 for 2 phase learning's second phase

Copy link
Collaborator

@mart-r mart-r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of my comments might come from a point of ignorance. I don't have a good enough knowledge how these things work under the hood.

But what I do want to make sure we retain is previous behaviour. And I've got a feeling quite a few of these changes change previous (default) behaviour.
Ideally, we'd have regression testing for this. Which would be able to fish out issues that would rise from some of these changes. But alas, we don't have that set up (at least not yet).

Furthermore, please add at least some tests for the new code that has been added. Especially for BERT-based MetaCATs. Otherwise someone might come with another PR, change a few parameters, and all of this might be broken the next release.
I didn't check, but some of the MetaCAT tests may actually now be running with BERT because of the changes in defaults. But we don't really want that. New code generally requires new tests. And old tests should keep testing what they tested before. Otherwise we end up thinking something is automatically tested when it's not.

medcat/config.py Outdated
@@ -103,6 +103,7 @@ def save(self, save_path: str) -> None:
save_path(str): Where to save the created json file
"""
# We want to save the dict here, not the whole class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add "changes" where there are none.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been removed.

@@ -28,7 +27,7 @@ class General(MixingConfig, BaseModel):
"""Number of annotations to be meta-annotated at once in eval"""
annotate_overlapping: bool = False
"""If set meta_anns will be calcualted for doc._.ents, otherwise for doc.ents"""
tokenizer_name: str = 'bbpe'
tokenizer_name: str = 'bert-tokenizer'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the default here could change existing behaviour. We want to avoid that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted the change, I made bert as the default one because I wanted to use it for meta_cat_tests.

@@ -48,11 +47,19 @@ class Config:

class Model(MixingConfig, BaseModel):
"""The model part of the metaCAT config"""
model_name: str = 'lstm'
model_name: str = 'bert'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the default here could change existing behaviour. We want to avoid that.

Perhaps a separate method to get a default config suitable for BERT rather than LSTM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted the change

num_layers: int = 2
input_size: int = 300
input_size: int = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the default here could change existing behaviour. We want to avoid that.

If this is relevant to BERT but not LSTM, perhaps a separate method to get a default config suitable for BERT rather than LSTM?
If this should change in general, we can leave it be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted the change, input_size has to be 768 for bert, this is only for LSTM.

@@ -70,13 +77,15 @@ class Config:

class Train(MixingConfig, BaseModel):
"""The train part of the metaCAT config"""
batch_size: int = 100
nepochs: int = 50
batch_size: int = 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the default here could change existing behaviour. We want to avoid that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted the change

@@ -421,18 +473,37 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
start = ent.start_char
end = ent.end_char

flag = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change in behaviour seems to affect both LSTM and BERT approaches.
Is that deliberate? If so, it would require some documentation since it looks like new/changed behaviour.
If no, then we'd need to make sure this only runs for BERT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it changes the way we extract the tokens for the medical entity in question which should affect LSTM and BERT.
I've optimised the implementation further and added comments

last_ind = ind

_start = max(0, ind - cntx_left)
# _start = max(0, ind - cntx_left)
_start = max(0, ctoken_idx[0] - cntx_left)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug fix?
Again, seems to affect everything, so change would need to be documented somewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the same change as above

@@ -33,6 +34,8 @@ def prepare_from_json(data: Dict,
{'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data. Defaults to `{}`.
lowercase (bool):
Should the text be lowercased before tokenization. Defaults to True.
cui_filter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this moved around and given less documentation? Doesn't look like there would be a reason for that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed this

start = ann['start']
end = ann['end']

# Get the index of the center token
flag = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, looks like some kind of new behaviour. Why is this added? Will this retain old behaviour?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the same change: it changes the way we extract the tokens for the medical entity in question.


Returns:
dict:
New data with integeres inplace of strings for categry values.
New data with integers inplace of strings for categry values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this now returns a tuple of 3 things. Would be great to document that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it in the docstrings

@shubham-s-agarwal
Copy link
Collaborator Author

I've added tests to metacat_test for BERT. It performs all the tests for BERT, and also checks the 2 phase learning by running phase 1 and phase 2.

shubham-s-agarwal and others added 2 commits May 8, 2024 08:35
* Small addition to contribution guidelines (#420)

* CU-8694cbcpu: Allow specifying an AU Snomed when preprocessing (#421)

* CU-8694dpy1c: Return empty generator upon empty stream (#423)

* CU-8694dpy1c: Return empty generator upon empty stream

* CU-8694dpy1c: Fix empty generator returns

* CU-8694dpy1c: Simplify empty generator returns

* Relation extraction (#173)

* Added files.

* More additions to rel extraction.

* Rel base.

* Update.

* Updates.

* Dependency parsing.

* Updates.

* Added pre-training steps.

* Added training & model utils.

* Cleanup & fixes.

* Update.

* Evaluation updates for pretraining.

* Removed duplicate relation storage.

* Moved RE model file location.

* Structure revisions.

* Added custom config for RE.

* Implemented custom dataset loader for RE.

* More changes.

* Small fix.

* Latest additions to RelCAT (pipe + predictions)

* Setup.py fix.

* RE utils update.

* rel model update.

* rel dataset + tokenizer improvements.

* RelCAT updates.

* RelCAT saving/loading improvements.

* RelCAT saving/loading improvements.

* RelCAT model fixes.

* Attempted gpu learning fix. Dataset label generation fixes.

* Minor train dataset gen fix.

* Minor train dataset gen fix No.2.

* Config updates.

* Gpu support fixes. Added label stats.

* Evaluation stat fixes.

* Cleaned stat output mode during training.

* Build fix.

* removed unused dependencies and fixed code formatting

* Mypy compliance.

* Fixed linting.

* More Gpu mode train fixes.

* Fixed model saving/loading issues when using other baes models.

* More fixes to stat evaluation. Added proper CAT integration of RelCAT.

* Setup.py typo fix.

* RelCAT loading fix.

* RelCAT Config changes.

* Type fix. Minor additions to RelCAT model.

* Type fixes.

* Type corrections.

* RelCAT update.

* Type fixes.

* Fixed type issue.

* RelCATConfig: added seed param.

* Adaptations to the new codebase + type fixes..

* Doc/type fixes.

* Fixed input size issue for model.

* Fixed issue(s) with model size and config.

* RelCAT: updated configs to new style.

* RelCAT: removed old refs to logging.

* Fixed GPU training + added extra stat print for train set.

* Type fixes.

* Updated dev requirements.

* Linting.

* Fixed pin_memory issue when training on CPU.

* Updated RelCAT dataset get + default config.

* Updated RelDS generator + default config

* Linting.

* Updated RelDatset + config.

* Pushing updates to model

Made changes to:
1) Extracting given number of context tokens left and right of the entities
2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them

* Fixing formatting

* Update rel_dataset.py

* Update rel_dataset.py

* Update rel_dataset.py

* RelCAT: added test resource files.

* RelCAT: Fixed model load/checkpointing.

* RelCAT: updated to pipe spacy doc call.

* RelCAT: added tests.

* Fixed lint/type issues & added rel tag to test DS.

* Fixed ann id to token issue.

* RelCAT: updated test dataset + tests.

* RelCAT: updates to requested changes + dataset improvements.

* RelCAT: updated docs/logs according to commends.

* RelCAT: type fix.

* RelCAT: mct export dataset updates.

* RelCAT: test updates + requested changes p2.

* RelCAT: log for MCT export train.

* Updated docs + split train_test & dataset for benchmarks.

* type fixes.

---------

Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com>
Co-authored-by: mart-r <mart.ratas@gmail.com>

* CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases (#424)

* CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases

* CU-8694fae3r: Fix pre-releases tagging

* CU-8694fae3r: Allow actions to run on release edit

---------

Co-authored-by: Mart Ratas <mart.ratas@gmail.com>
Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com>
Copy link
Collaborator

@mart-r mart-r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the whitespace change you did in config.py - it pollutes the changes since you didn't actually change any functionality in that module.

Please fix the tests. Create another class in the same module.
Change of state like you've done now can create a situation where the tests may run in one environment (or even on one execution) but not another.

As for the optimised algorithm. The only reason I'm concerned about that is because I haven't dove into it close enough to really understand what's happening. And because of that I'm afraid of approving changes that could change existing behaviour.
However, if you're certain it retains existing behaviour, this can be left as is.

medcat/config.py Outdated
@@ -103,6 +103,7 @@ def save(self, save_path: str) -> None:
save_path(str): Where to save the created json file
"""
# We want to save the dict here, not the whole class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been removed.


Returns:
nn.Module:
The module
"""
config = self.config
from medcat.utils.meta_cat.models import LSTM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to just use the common base class (nn.Module) for type annotation.

@@ -89,6 +93,30 @@ def test_predict_spangroup(self):

n_meta_cat.config.general.span_group = None

def test_z_bert_meta_cat(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a good way to test your changes.
You're changing the state of the class (changing the value of MetaCATTests .meta_cat). The order that the tests are run is not tied to the order in which the tests are written in the class. As such, this change to the state of the class could make the other tests in the class be run on the MetaCAT you've set in this test rather than the one that was set up in the setUpClass method.

I would recommend making a new class that extends unittest.TestCase that has its own state and deals with this new type only.

Copy link
Collaborator

@mart-r mart-r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

Copy link
Member

@tomolopolis tomolopolis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@shubham-s-agarwal shubham-s-agarwal merged commit fbe9745 into master May 16, 2024
5 checks passed
@shubham-s-agarwal shubham-s-agarwal deleted the metacat_bert branch May 16, 2024 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants