Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relation extraction #173

Merged
merged 119 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
b20e7c8
Added files.
vladd-bit Aug 24, 2021
eec6c59
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Aug 31, 2021
56220aa
More additions to rel extraction.
vladd-bit Sep 1, 2021
7ad88f5
Rel base.
vladd-bit Sep 3, 2021
233ce36
Update.
vladd-bit Sep 6, 2021
85a7015
Updates.
vladd-bit Sep 10, 2021
5003548
Dependency parsing.
vladd-bit Oct 1, 2021
541b47d
Updates.
vladd-bit Oct 13, 2021
c042b0d
Added pre-training steps.
vladd-bit Oct 15, 2021
87d0c0c
Added training & model utils.
vladd-bit Oct 18, 2021
4f42696
Cleanup & fixes.
vladd-bit Oct 19, 2021
018d811
Update.
vladd-bit Oct 21, 2021
f3d3f44
Evaluation updates for pretraining.
vladd-bit Oct 27, 2021
e5f354e
Removed duplicate relation storage.
vladd-bit Nov 9, 2021
c69de67
Merged master.
vladd-bit Nov 9, 2021
031d256
Moved RE model file location.
vladd-bit Nov 12, 2021
2259a6b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Nov 16, 2021
1c469e9
Structure revisions.
vladd-bit Nov 22, 2021
423b4e1
Added custom config for RE.
vladd-bit Dec 13, 2021
8ae9abb
Implemented custom dataset loader for RE.
vladd-bit Dec 13, 2021
186416c
More changes.
vladd-bit Dec 13, 2021
451e33f
Small fix.
vladd-bit Dec 13, 2021
8b36413
Latest additions to RelCAT (pipe + predictions)
vladd-bit Jan 19, 2022
2fb8fc9
Setup.py fix.
vladd-bit Jan 19, 2022
930dd11
RE utils update.
vladd-bit Jan 19, 2022
24b2841
rel model update.
vladd-bit Jan 19, 2022
193ecb1
rel dataset + tokenizer improvements.
vladd-bit Jan 19, 2022
03111a7
RelCAT updates.
vladd-bit Jan 19, 2022
7ab60f4
RelCAT saving/loading improvements.
vladd-bit Jan 21, 2022
40875f3
RelCAT saving/loading improvements.
vladd-bit Jan 21, 2022
810d1dc
RelCAT model fixes.
vladd-bit Jan 21, 2022
11dcb32
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 21, 2022
72187f6
Attempted gpu learning fix. Dataset label generation fixes.
vladd-bit Jan 24, 2022
5f67a4c
Minor train dataset gen fix.
vladd-bit Jan 24, 2022
cfc0e91
Minor train dataset gen fix No.2.
vladd-bit Jan 24, 2022
9f4b220
Config updates.
vladd-bit Jan 25, 2022
19afa81
Gpu support fixes. Added label stats.
vladd-bit Jan 25, 2022
8eb1665
Evaluation stat fixes.
vladd-bit Jan 26, 2022
6e86fa2
Cleaned stat output mode during training.
vladd-bit Jan 26, 2022
5cee8cf
Build fix.
vladd-bit Jan 26, 2022
223ac9a
removed unused dependencies and fixed code formatting
vladd-bit Jan 26, 2022
ea7d68c
Mypy compliance.
vladd-bit Jan 26, 2022
1ea9738
Fixed linting.
vladd-bit Jan 27, 2022
9f6609e
More Gpu mode train fixes.
vladd-bit Jan 28, 2022
1782c0b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 28, 2022
fb86869
Fixed model saving/loading issues when using other baes models.
vladd-bit Jan 31, 2022
df21543
More fixes to stat evaluation. Added proper CAT integration of RelCAT.
vladd-bit Feb 3, 2022
92a5e08
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Feb 3, 2022
87d1a9c
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 11, 2022
ced1627
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 14, 2022
7b69710
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 28, 2022
37fd212
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 4, 2022
f0eda2b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 8, 2022
10269b9
Setup.py typo fix.
vladd-bit Apr 8, 2022
b8a45b2
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 8, 2022
20203ac
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit May 10, 2022
f057139
RelCAT loading fix.
vladd-bit May 10, 2022
197a27a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jul 21, 2022
86fd509
RelCAT Config changes.
vladd-bit Aug 1, 2022
79dc069
Type fix. Minor additions to RelCAT model.
vladd-bit Aug 1, 2022
323c895
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Aug 1, 2022
f1c56bf
Type fixes.
vladd-bit Aug 1, 2022
a78ff86
Type corrections.
vladd-bit Aug 2, 2022
f09ceb2
RelCAT update.
vladd-bit Mar 21, 2023
32574f2
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 21, 2023
c081c3e
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit May 22, 2023
e2e48b5
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 11, 2023
4ce5ba3
Type fixes.
vladd-bit Dec 12, 2023
21c09ff
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Dec 13, 2023
8123689
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 13, 2023
57ab0c5
Fixed type issue.
vladd-bit Dec 13, 2023
9da5aa6
RelCATConfig: added seed param.
vladd-bit Dec 13, 2023
009e832
Adaptations to the new codebase + type fixes..
vladd-bit Dec 15, 2023
1a7d130
Doc/type fixes.
vladd-bit Dec 19, 2023
53dba6a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 20, 2023
92613ed
Fixed input size issue for model.
vladd-bit Jan 8, 2024
a49a44a
Fixed issue(s) with model size and config.
vladd-bit Jan 16, 2024
6456e6e
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 16, 2024
5aac9ab
RelCAT: updated configs to new style.
vladd-bit Jan 19, 2024
9c50b30
RelCAT: removed old refs to logging.
vladd-bit Jan 19, 2024
b071607
Merge branches 'relation_extraction' and 'master' of https://github.c…
vladd-bit Jan 29, 2024
89d9128
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Feb 7, 2024
e6e99cb
Fixed GPU training + added extra stat print for train set.
vladd-bit Feb 7, 2024
307d194
Type fixes.
vladd-bit Feb 7, 2024
fb7efe3
Updated dev requirements.
vladd-bit Feb 7, 2024
c235daf
Linting.
vladd-bit Feb 7, 2024
fcdf2e3
Merge branches 'relation_extraction' and 'master' of https://github.c…
vladd-bit Feb 9, 2024
aad0a73
Fixed pin_memory issue when training on CPU.
vladd-bit Feb 9, 2024
8a9026b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 8, 2024
f94e349
Updated RelCAT dataset get + default config.
vladd-bit Mar 21, 2024
0770356
Updated RelDS generator + default config
vladd-bit Mar 25, 2024
bdf20f5
Linting.
vladd-bit Mar 25, 2024
f7b5aaf
Updated RelDatset + config.
vladd-bit Apr 3, 2024
3e827cf
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 3, 2024
aaf6533
Pushing updates to model
shubham-s-agarwal Apr 8, 2024
18f9bb8
Fixing formatting
shubham-s-agarwal Apr 8, 2024
503513c
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
040821b
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
ed7c8d5
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
8d0bfe4
RelCAT: added test resource files.
vladd-bit Apr 9, 2024
3f3a780
RelCAT: Fixed model load/checkpointing.
vladd-bit Apr 10, 2024
3f56824
RelCAT: updated to pipe spacy doc call.
vladd-bit Apr 12, 2024
b7a4987
RelCAT: added tests.
vladd-bit Apr 12, 2024
77d27b0
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 12, 2024
a9258a2
Fixed lint/type issues & added rel tag to test DS.
vladd-bit Apr 15, 2024
0ed70fb
Fixed ann id to token issue.
vladd-bit Apr 15, 2024
8db2e76
RelCAT: updated test dataset + tests.
vladd-bit Apr 18, 2024
6eea6b7
RelCAT: updates to requested changes + dataset improvements.
vladd-bit Apr 18, 2024
6972310
RelCAT: updated docs/logs according to commends.
vladd-bit Apr 18, 2024
d03316c
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 18, 2024
8cb12a4
RelCAT: type fix.
vladd-bit Apr 18, 2024
d10318a
RelCAT: mct export dataset updates.
vladd-bit Apr 19, 2024
12acaeb
RelCAT: test updates + requested changes p2.
vladd-bit Apr 19, 2024
4c14a3a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 19, 2024
382cefc
RelCAT: log for MCT export train.
vladd-bit Apr 19, 2024
35b0913
Updated docs + split train_test & dataset for benchmarks.
vladd-bit Apr 26, 2024
d48bc41
type fixes.
vladd-bit Apr 26, 2024
3068516
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 26, 2024
72643fc
Merge branch 'master' into relation_extraction
mart-r Apr 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ venv
db.sqlite3
.ipynb_checkpoints

# vscode
.vscode

#tmp and similar files
.nfs*
*.log
Expand All @@ -42,4 +45,9 @@ tmp.py

# models files
*.dat
!examples/*.dat
# others
/medcat/config/*
/medcat/models/*
/tutorial/version_control/docker_ssh_config

!examples/*.dat
44 changes: 44 additions & 0 deletions medcat/config_re.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Dict, Any
from medcat.config import ConfigMixin


class ConfigRE(ConfigMixin):
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:

self.general: Dict[str, Any] = {
'device': 'cpu',
'seed': 13,
# these are the pairs of relations that are to be predicted , form of : [("Disease", "Symptom"), ("entity1_type", "entity2_type") ...]
'relation_type_filter_pairs': [],
'category_value2id': {}, # Map from category values to ID, if empty it will be autocalculated during training
'vocab_size': None, # Will be set automatically if the tokenizer is provided during meta_cat init
'lowercase': True, # If true all input text will be lowercased
'ent_context_left': 2, # Number of entities to take from the left of the concept
'ent_context_right': 2, # Number of entities to take from the right of the concept
'window_size' : 300, # max acceptable dinstance between entities (in characters)
'batch_size_eval': 5000, # Number of annotations to be meta-annotated at once in eval
'tokenizer_name': 'BERT', # Tokenizer name used with of MetaCAT
'pipe_batch_size_in_chars': 20000000, # How many characters are piped at once into the meta_cat class
}
self.model: Dict[str, Any] = {
'model_name': 'BERT',
'input_size': 300,
'hidden_size': 300,
'dropout': 0.5,
'nclasses': 2, # Number of classes that this model will output
'padding_idx': -1,
'emb_grad': True, # If True the embeddings will also be trained
}

self.train: Dict[str, Any] = {
'batch_size': 100,
'nepochs': 50,
'lr': 0.001,
'test_size': 0.1,
'shuffle_data': True, # Used only during training, if set the dataset will be shuffled before train/test split
'class_weights': None,
'score_average': 'weighted', # What to use for averaging F1/P/R across labels
'prerequisites': {},
'auto_save_model': True, # Should do model be saved during training for best results
}
293 changes: 293 additions & 0 deletions medcat/relation_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import json
import logging
import os
import numpy
import logging
from numpy.core.fromnumeric import take
import torch

import torch.nn
import pickle
import dill
import torch.optim
import torch
from torch.utils.data import dataloader
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
import tqdm
import torch.nn as nn
from torch import Tensor
from datetime import date, datetime
from torch.nn.modules.module import T
from transformers import BertConfig
from ast import literal_eval
from itertools import permutations
from pandas.core.series import Series
from medcat.cdb import CDB
from medcat.config_re import ConfigRE
from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT

from spacy.tokens import Doc
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from medcat.utils.meta_cat.ml_utils import split_list_train_test


from medcat.utils.relation_extraction.eval import Two_Headed_Loss
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
from medcat.utils.relation_extraction.models import BertModel_RelationExtraction
from medcat.utils.relation_extraction.pad_seq import Pad_Sequence

from medcat.utils.relation_extraction.utils import batch_split, create_tokenizer_pretrain, load_bin_file, load_results, load_state, put_blanks, save_bin_file, save_results, tokenize

from medcat.utils.relation_extraction.rel_dataset import RelData
from seqeval.metrics import precision_score, recall_score, f1_score
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

class RelationExtraction(object):
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To Zeljko's prior comment. This module should be similar to MetaCAT so, i.e. subclasses PipeRunner, so that pipe and call etc. are available. We want ultimately want API like:

CAT(cdb, meta_cats=[ MetaCAT .. ] , rel_cats=[ RelCAT(... ), ... ])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partially done, probably need to add this to the pipe file as well now


name : str = "rel"
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, cdb : CDB, config: ConfigRE = ConfigRE(), re_model_path : str = "", tokenizer: Optional[TokenizerWrapperBERT] = None, task="train"):
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

self.config = config
self.tokenizer = tokenizer
self.cdb = cdb

self.learning_rate = config.train["lr"]
self.batch_size = config.train["batch_size"]
self.n_classes = config.model["nclasses"]

self.is_cuda_available = torch.cuda.is_available()

self.device = torch.device("cuda:0" if self.is_cuda_available else "cpu")
self.hf_model_name = "bert-large-uncased"

self.model_config = BertConfig.from_pretrained(self.hf_model_name)

if self.is_cuda_available:
self.model = self.model.to(self.device)

if self.tokenizer is None:
tokenizer_path = os.path.join(re_model_path, "BERT_tokenizer_relation_extraction")
if os.path.exists(tokenizer_path):
self.tokenizer = TokenizerWrapperBERT.load(tokenizer_path)
else:
self.tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-large-uncased"),
max_seq_length=self.model_config.max_position_embeddings)
create_tokenizer_pretrain(self.tokenizer)

self.model_config.vocab_size = len(self.tokenizer.hf_tokenizers)

self.model = BertModel_RelationExtraction.from_pretrained(pretrained_model_name_or_path=self.hf_model_name,
model_size=self.hf_model_name,
config=self.model_config,
task=task,
n_classes=self.n_classes)

self.model.resize_token_embeddings(self.model_config.vocab_size)

unfrozen_layers = ["classifier", "pooler", "encoder.layer.11", \
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
"classification_layer", "blanks_linear", "lm_linear", "cls"]

for name, param in self.model.named_parameters():
if not any([layer in name for layer in unfrozen_layers]):
param.requires_grad = False
else:
param.requires_grad = True

self.pad_id = self.tokenizer.hf_tokenizers.pad_token_id
self.padding_seq = Pad_Sequence(seq_pad_value=self.pad_id,\
label_pad_value=self.pad_id,\
label2_pad_value=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this init method is only indented 3 spaces, can we consistently use 4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another style thing - can we add a space before the line continuation backslash, i.e. check medcat_ner.py l86:88, actually as this is a list they're not needed at all


def create_test_train_datasets(self, data):
train_data, test_data = {}, {}
train_data["output_relations"], test_data["output_relations"] = split_list_train_test(data["output_relations"],
test_size=self.config.train["test_size"], shuffle=False)
for k,v in data.items():
if k != "output_relations":
train_data[k] = v
test_data[k] = v

return train_data, test_data

def train(self, export_data_path = "", csv_path = "", docs = None, checkpoint_path="./", num_epoch=1, gradient_acc_steps=1, multistep_lr_gamma=0.8, max_grad_norm=1.0):
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

train_rel_data = RelData(cdb=self.cdb, config=self.config, tokenizer=self.tokenizer)
test_rel_data = RelData(cdb=CDB(self.cdb.config), config=self.config, tokenizer=None)

if csv_path != "":
train_rel_data.dataset, test_rel_data.dataset = self.create_test_train_datasets(train_rel_data.create_base_relations_from_csv(csv_path))

#print(train_rel_data.create_base_relations_from_csv(csv_path))
elif export_data_path != "":
export_data = {}
with open(export_data_path) as f:
export_data = json.load(f)

#print(train_rel_data.create_relations_from_export(export_data))
train_rel_data.dataset, test_rel_data.dataset = self.create_test_train_datasets(train_rel_data.create_relations_from_export(export_data))
else:
logging.error("NO DATA HAS BEEN PROVIDED (JSON/CSV/spacy_DOCS)")
return

train_dataset_size = len(train_rel_data)
batch_size = train_dataset_size if train_dataset_size < self.batch_size else self.batch_size
train_dataloader = DataLoader(train_rel_data, batch_size=batch_size, shuffle=True, \
num_workers=0, collate_fn=self.padding_seq, pin_memory=False)

test_dataset_size = len(test_rel_data)
test_batch_size = test_dataset_size if test_dataset_size < self.batch_size else self.batch_size
test_dataloader = DataLoader(test_rel_data, batch_size=test_batch_size, shuffle=True, \
num_workers=0, collate_fn=self.padding_seq, pin_memory=False)

criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam([{"params": self.model.parameters(), "lr": self.learning_rate}])

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
24,26,30], gamma=multistep_lr_gamma)

start_epoch, best_pred = load_state(self.model, optimizer, scheduler, load_best=False)

logging.info("Starting training process...")

losses_per_epoch, accuracy_per_epoch, test_f1_per_epoch = load_results(path=checkpoint_path)

# update_size = 1 if len(train_dataloader) // 10 > 0

for epoch in range(start_epoch, num_epoch):
start_time = datetime.now().time()
self.model.train()
# self.model.zero_grad()
total_loss = 0.0

losses_per_batch = []
total_acc = 0.0
accuracy_per_batch = []

for i, data in enumerate(train_dataloader, 0):
token_ids, e1_e2_start, labels, _, _, _ = data

attention_mask = (token_ids != self.pad_id).float()
token_type_ids = torch.zeros((token_ids.shape[0], token_ids.shape[1])).long()

classification_logits = self.model(input_ids=token_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
e1_e2_start=e1_e2_start)
loss = criterion(classification_logits, labels.squeeze(1))
loss = loss/gradient_acc_steps

loss.backward()

torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

if (i % gradient_acc_steps) == 0:
optimizer.step()
optimizer.zero_grad()

total_loss += loss.item()
total_acc += self.evaluate_(classification_logits, labels, ignore_idx=-1)[0]
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

if (i % batch_size) == (batch_size - 1):
losses_per_batch.append(gradient_acc_steps*total_loss/batch_size)
accuracy_per_batch.append(total_acc/batch_size)

print('[Epoch: %d, %5d/ %d points] total loss, accuracy per batch: %.3f, %.3f' %
(epoch + 1, (i + 1)*self.batch_size, train_dataset_size, losses_per_batch[-1], accuracy_per_batch[-1]))
total_loss = 0.0; total_acc = 0.0
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved

end_time = datetime.now().time()
scheduler.step()
results = self.evaluate_results(test_dataloader, self.pad_id)
if len(losses_per_batch) > 0:
losses_per_epoch.append(sum(losses_per_batch)/len(losses_per_batch))
print("Losses at Epoch %d: %.7f" % (epoch + 1, losses_per_epoch[-1]))
if len(accuracy_per_batch) > 0:
accuracy_per_epoch.append(sum(accuracy_per_batch)/len(accuracy_per_batch))
print("Train accuracy at Epoch %d: %.7f" % (epoch + 1, accuracy_per_epoch[-1]))

test_f1_per_epoch.append(results['f1'])

print("Epoch finished, took " + str(datetime.combine(date.today(), end_time) - datetime.combine(date.today(), start_time) ) + " seconds")
print("Test f1 at Epoch %d: %.7f" % (epoch + 1, test_f1_per_epoch[-1]))

if len(accuracy_per_epoch) > 0 and accuracy_per_epoch[-1] > best_pred:
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
best_pred = accuracy_per_epoch[-1]
torch.save({
'epoch': epoch + 1,\
'state_dict': self.model.state_dict(),\
'best_acc': accuracy_per_epoch[-1],\
'optimizer' : optimizer.state_dict(),\
'scheduler' : scheduler.state_dict(),\
}, os.path.join("./data/" , "training_model_best_BERT.dat"))

if (epoch % 1) == 0:
save_results(losses_per_epoch, accuracy_per_epoch, test_f1_per_epoch, file_prefix="train")
#accuracy_per_epoch[-1],\
torch.save({
'epoch': epoch + 1,\
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
'state_dict': self.model.state_dict(),
'best_acc': 0,
'optimizer' : optimizer.state_dict(),
'scheduler' : scheduler.state_dict()
}, os.path.join("./" , "training_checkpoint_BERT.dat"))

def evaluate_(self, output, labels, ignore_idx):
### ignore index 0 (padding) when calculating accuracy
idxs = (labels != ignore_idx).squeeze()
out_labels = torch.softmax(output, dim=1).max(1)[1]
l = labels.squeeze()[idxs];
vladd-bit marked this conversation as resolved.
Show resolved Hide resolved
o = out_labels[idxs]

if len(idxs) > 1:
acc = (l == o).sum().item()/len(idxs)
else:
acc = (l == o).sum().item()

l = l.cpu().numpy().tolist() if l.is_cuda else l.numpy().tolist()
o = o.cpu().numpy().tolist() if o.is_cuda else o.numpy().tolist()

return acc, (o, l)

def evaluate_results(self, dataset, pad_id):
logging.info("Evaluating test samples...")
acc = 0
out_labels = []
true_labels = []
self.model.eval()

with torch.no_grad():
for i, data in enumerate(dataset):
logging.info(data)

token_ids, e1_e2_start, labels, _,_,_ = data
attention_mask = (token_ids != pad_id).float()
token_type_ids = torch.zeros((token_ids.shape[0], token_ids.shape[1])).long()

if self.is_cuda_available:
token_ids = token_ids.cuda()
labels = labels.cuda()
attention_mask = attention_mask.cuda()
token_type_ids = token_type_ids.cuda()

classification_logits = self.model(token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None,\
e1_e2_start=e1_e2_start)

accuracy, (o, l) = self.evaluate_(classification_logits, labels, ignore_idx=-1)

out_labels.append([str(i) for i in o])
true_labels.append([str(i) for i in l])
acc += accuracy

accuracy = acc/(i + 1)
results = {
"accuracy": accuracy,
"precision": precision_score(true_labels, out_labels),
"recall": recall_score(true_labels, out_labels),
"f1": f1_score(true_labels, out_labels)
}

logging.info("***** Eval results *****")
for key in sorted(results.keys()):
logging.info(" %s = %s", key, str(results[key]))

return results
Loading