-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added files. * More additions to rel extraction. * Rel base. * Update. * Updates. * Dependency parsing. * Updates. * Added pre-training steps. * Added training & model utils. * Cleanup & fixes. * Update. * Evaluation updates for pretraining. * Removed duplicate relation storage. * Moved RE model file location. * Structure revisions. * Added custom config for RE. * Implemented custom dataset loader for RE. * More changes. * Small fix. * Latest additions to RelCAT (pipe + predictions) * Setup.py fix. * RE utils update. * rel model update. * rel dataset + tokenizer improvements. * RelCAT updates. * RelCAT saving/loading improvements. * RelCAT saving/loading improvements. * RelCAT model fixes. * Attempted gpu learning fix. Dataset label generation fixes. * Minor train dataset gen fix. * Minor train dataset gen fix No.2. * Config updates. * Gpu support fixes. Added label stats. * Evaluation stat fixes. * Cleaned stat output mode during training. * Build fix. * removed unused dependencies and fixed code formatting * Mypy compliance. * Fixed linting. * More Gpu mode train fixes. * Fixed model saving/loading issues when using other baes models. * More fixes to stat evaluation. Added proper CAT integration of RelCAT. * Setup.py typo fix. * RelCAT loading fix. * RelCAT Config changes. * Type fix. Minor additions to RelCAT model. * Type fixes. * Type corrections. * RelCAT update. * Type fixes. * Fixed type issue. * RelCATConfig: added seed param. * Adaptations to the new codebase + type fixes.. * Doc/type fixes. * Fixed input size issue for model. * Fixed issue(s) with model size and config. * RelCAT: updated configs to new style. * RelCAT: removed old refs to logging. * Fixed GPU training + added extra stat print for train set. * Type fixes. * Updated dev requirements. * Linting. * Fixed pin_memory issue when training on CPU. * Updated RelCAT dataset get + default config. * Updated RelDS generator + default config * Linting. * Updated RelDatset + config. * Pushing updates to model Made changes to: 1) Extracting given number of context tokens left and right of the entities 2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them * Fixing formatting * Update rel_dataset.py * Update rel_dataset.py * Update rel_dataset.py * RelCAT: added test resource files. * RelCAT: Fixed model load/checkpointing. * RelCAT: updated to pipe spacy doc call. * RelCAT: added tests. * Fixed lint/type issues & added rel tag to test DS. * Fixed ann id to token issue. * RelCAT: updated test dataset + tests. * RelCAT: updates to requested changes + dataset improvements. * RelCAT: updated docs/logs according to commends. * RelCAT: type fix. * RelCAT: mct export dataset updates. * RelCAT: test updates + requested changes p2. * RelCAT: log for MCT export train. * Updated docs + split train_test & dataset for benchmarks. * type fixes. --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: mart-r <mart.ratas@gmail.com>
- Loading branch information
1 parent
1caa187
commit abc97fb
Showing
17 changed files
with
6,776 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,9 @@ venv | |
db.sqlite3 | ||
.ipynb_checkpoints | ||
|
||
# vscode | ||
.vscode | ||
|
||
#tmp and similar files | ||
.nfs* | ||
*.log | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import logging | ||
from typing import Dict, Any, List | ||
from medcat.config import MixingConfig, BaseModel, Optional, Extra | ||
|
||
|
||
class General(MixingConfig, BaseModel): | ||
"""The General part of the RelCAT config""" | ||
device: str = "cpu" | ||
relation_type_filter_pairs: List = [] | ||
"""Map from category values to ID, if empty it will be autocalculated during training""" | ||
vocab_size: Optional[int] = None | ||
lowercase: bool = True | ||
"""If true all input text will be lowercased""" | ||
cntx_left: int = 15 | ||
"""Number of tokens to take from the left of the concept""" | ||
cntx_right: int = 15 | ||
"""Number of tokens to take from the right of the concept""" | ||
window_size: int = 300 | ||
"""Max acceptable dinstance between entities (in characters), care when using this as it can produce sentences that are over 512 tokens (limit is given by tokenizer)""" | ||
|
||
mct_export_max_non_rel_sample_size:int = 200 | ||
"""Limit the number of 'Other' samples selected for training/test. This is applied per encountered medcat project, sample_size/num_projects. """ | ||
mct_export_create_addl_rels: bool = False | ||
"""When processing relations from a MedCAT export, relations labeled as 'Other' are created from all the annotations pairs available""" | ||
|
||
tokenizer_name: str = "bert" | ||
model_name: str = "bert-base-uncased" | ||
log_level: int = logging.INFO | ||
max_seq_length: int = 512 | ||
tokenizer_special_tokens: bool = False | ||
annotation_schema_tag_ids: List = [] | ||
"""If a foreign non-MCAT trainer dataset is used, you can insert your own Rel entity token delimiters into the tokenizer, \ | ||
copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce""" | ||
labels2idx: Dict = {} | ||
idx2labels: Dict = {} | ||
pin_memory: bool = True | ||
seed: int = 13 | ||
task: str = "train" | ||
|
||
|
||
class Model(MixingConfig, BaseModel): | ||
"""The model part of the RelCAT config""" | ||
input_size: int = 300 | ||
hidden_size: int = 768 | ||
hidden_layers: int = 3 | ||
""" hidden_size * 5, 5 being the number of tokens, default (s1,s2,e1,e2+CLS)""" | ||
model_size: int = 5120 | ||
dropout: float = 0.2 | ||
num_directions: int = 2 | ||
"""2 - bidirectional model, 1 - unidirectional""" | ||
|
||
padding_idx: int = -1 | ||
emb_grad: bool = True | ||
"""If True the embeddings will also be trained""" | ||
ignore_cpos: bool = False | ||
"""If set to True center positions will be ignored when calculating represenation""" | ||
|
||
class Config: | ||
extra = Extra.allow | ||
validate_assignment = True | ||
|
||
|
||
class Train(MixingConfig, BaseModel): | ||
"""The train part of the RelCAT config""" | ||
nclasses: int = 2 | ||
"""Number of classes that this model will output""" | ||
batch_size: int = 25 | ||
nepochs: int = 1 | ||
lr: float = 1e-4 | ||
adam_epsilon: float = 1e-4 | ||
test_size: float = 0.2 | ||
gradient_acc_steps: int = 1 | ||
multistep_milestones: List[int] = [ | ||
2, 4, 6, 8, 12, 15, 18, 20, 22, 24, 26, 30] | ||
multistep_lr_gamma: float = 0.8 | ||
max_grad_norm: float = 1.0 | ||
shuffle_data: bool = True | ||
"""Used only during training, if set the dataset will be shuffled before train/test split""" | ||
class_weights: Optional[Any] = None | ||
score_average: str = "weighted" | ||
"""What to use for averaging F1/P/R across labels""" | ||
auto_save_model: bool = True | ||
"""Should the model be saved during training for best results""" | ||
|
||
class Config: | ||
extra = Extra.allow | ||
validate_assignment = True | ||
|
||
|
||
class ConfigRelCAT(MixingConfig, BaseModel): | ||
"""The RelCAT part of the config""" | ||
general: General = General() | ||
model: Model = Model() | ||
train: Train = Train() | ||
|
||
class Config: | ||
extra = Extra.allow | ||
validate_assignment = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.