diff --git a/ark_nlp/factory/task/base/_sequence_classification.py b/ark_nlp/factory/task/base/_sequence_classification.py index c3c3277..0955e59 100644 --- a/ark_nlp/factory/task/base/_sequence_classification.py +++ b/ark_nlp/factory/task/base/_sequence_classification.py @@ -109,6 +109,20 @@ def _on_step_begin_record( **kwargs ): pass + + def _get_train_loss( + self, + inputs, + logits, + verbose=True, + **kwargs + ): + # 计算损失 + loss = self._compute_loss(inputs, logits, **kwargs) + + self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) + + return loss def _compute_loss( self, @@ -118,9 +132,7 @@ def _compute_loss( **kwargs ): loss = self.loss_function(logits, inputs['label_ids']) - - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) - + return loss def _compute_loss_record( @@ -187,6 +199,9 @@ def _on_optimize( def _on_step_end( self, step, + inputs, + logits, + loss, verbose=True, show_step=100, **kwargs @@ -250,12 +265,24 @@ def _on_evaluate_epoch_begin(self, **kwargs): self.ema.copy_to(self.module.parameters()) self._on_evaluate_epoch_begin_record(**kwargs) + + def _get_evaluate_loss( + self, + inputs, + logits, + verbose=True, + **kwargs + ): + # 计算损失 + loss = self._compute_loss(inputs, logits, **kwargs) + + return loss def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) labels = inputs['label_ids'].cpu() logits = logits.cpu() @@ -365,7 +392,7 @@ def fit( logits = self.module(**inputs) # 计算损失 - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_train_loss(inputs, logits, **kwargs) # loss backword loss = self._on_backward(inputs, logits, loss, **kwargs) @@ -374,7 +401,7 @@ def fit( step = self._on_optimize(step, **kwargs) # setp evaluate - self._on_step_end(step, **kwargs) + self._on_step_end(step, inputs, logits, loss, **kwargs) self._on_epoch_end(epoch, **kwargs) diff --git a/ark_nlp/factory/task/base/_token_classification.py b/ark_nlp/factory/task/base/_token_classification.py index 3aa9888..ef43939 100644 --- a/ark_nlp/factory/task/base/_token_classification.py +++ b/ark_nlp/factory/task/base/_token_classification.py @@ -54,9 +54,7 @@ def _compute_loss( torch.tensor(self.loss_function.ignore_index).type_as(inputs['label_ids']) ) loss = self.loss_function(active_logits, active_labels) - - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) - + return loss def _compute_loss_record( @@ -73,6 +71,9 @@ def _compute_loss_record( def _on_step_end( self, step, + inputs, + logits, + loss, verbose=True, print_step=100, **kwargs @@ -112,7 +113,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) self.evaluate_logs['labels'].append(inputs['label_ids'].cpu()) self.evaluate_logs['logits'].append(logits.cpu()) diff --git a/ark_nlp/factory/task/named_entity_recognition.py b/ark_nlp/factory/task/named_entity_recognition.py index 7ac4723..10d7e61 100644 --- a/ark_nlp/factory/task/named_entity_recognition.py +++ b/ark_nlp/factory/task/named_entity_recognition.py @@ -96,8 +96,6 @@ def _compute_loss( **kwargs ): loss = -1 * self.module.crf(emissions = logits, tags=inputs['label_ids'], mask=inputs['attention_mask']) - - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) return loss @@ -105,7 +103,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) tags = self.module.crf.decode(logits, inputs['attention_mask']) tags = tags.squeeze(0) @@ -178,9 +176,6 @@ def _compute_loss( span_loss *= span_mask loss = torch.sum(span_loss) / inputs['span_mask'].size()[0] - - if self.logs: - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) return loss @@ -188,7 +183,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) logits = torch.nn.functional.softmax(logits, dim=-1) @@ -239,9 +234,6 @@ def _compute_loss( **kwargs ): loss = self.loss_function(logits, inputs['label_ids']) - - if self.logs: - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) return loss @@ -274,7 +266,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) numerate, denominator = conlleval.global_pointer_f1_score(inputs['label_ids'].cpu(), logits.cpu()) self.evaluate_logs['numerate'] += numerate @@ -330,9 +322,6 @@ def _compute_loss( loss = start_loss + end_loss - if self.logs: - self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) - return loss def _on_evaluate_epoch_begin(self, **kwargs): @@ -349,7 +338,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs): with torch.no_grad(): # compute loss - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_evaluate_loss(inputs, logits, **kwargs) length = inputs['attention_mask'].cpu().numpy().sum() - 2 diff --git a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py index 3897cf9..183c898 100644 --- a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py +++ b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py @@ -176,9 +176,6 @@ def _compute_loss( loss = self.loss_function(logits, inputs) - if self.logs: - self._compute_loss_record(inputs, inputs['label_ids'], logits, loss, verbose, **kwargs) - return loss def _compute_loss_record( diff --git a/ark_nlp/model/re/prgc_bert/__init__.py b/ark_nlp/model/re/prgc_bert/__init__.py new file mode 100644 index 0000000..5a06829 --- /dev/null +++ b/ark_nlp/model/re/prgc_bert/__init__.py @@ -0,0 +1,17 @@ +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_dataset import PRGCREDataset +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_dataset import PRGCREDataset as Dataset + +from ark_nlp.processor.tokenizer.transfomer import SpanTokenizer as Tokenizer +from ark_nlp.processor.tokenizer.transfomer import SpanTokenizer as PRGCRETokenizer + +from ark_nlp.nn import BertConfig as PRGCBertConfig +from ark_nlp.model.re.prgc_bert.prgc_bert import PRGCBert + +from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer +from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_prgc_bert_optimizer + +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_task import PRGCRETask as Task +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_task import PRGCRETask as PRGCRETask + +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_predictor import PRGCREPredictor as Predictor +from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_predictor import PRGCREPredictor as PRGCREPredictor \ No newline at end of file diff --git a/ark_nlp/model/re/prgc_bert/prgc_bert.py b/ark_nlp/model/re/prgc_bert/prgc_bert.py new file mode 100644 index 0000000..dae0d9d --- /dev/null +++ b/ark_nlp/model/re/prgc_bert/prgc_bert.py @@ -0,0 +1,211 @@ +import time +import torch +import math +import torch.nn.functional as F + +from torch import nn +from transformers import BertModel +from transformers import BertPreTrainedModel +from collections import Counter + +import torch +import torch.nn as nn + +from transformers import BertPreTrainedModel, BertModel + + +class MultiNonLinearClassifier(nn.Module): + def __init__(self, hidden_size, tag_size, dropout_rate): + super(MultiNonLinearClassifier, self).__init__() + self.tag_size = tag_size + self.linear = nn.Linear(hidden_size, int(hidden_size / 2)) + self.hidden2tag = nn.Linear(int(hidden_size / 2), self.tag_size) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input_features): + features_tmp = self.linear(input_features) + features_tmp = nn.ReLU()(features_tmp) + features_tmp = self.dropout(features_tmp) + features_output = self.hidden2tag(features_tmp) + return features_output + + +class SequenceLabelForSO(nn.Module): + def __init__(self, hidden_size, tag_size, dropout_rate): + super(SequenceLabelForSO, self).__init__() + self.tag_size = tag_size + self.linear = nn.Linear(hidden_size, int(hidden_size / 2)) + self.hidden2tag_sub = nn.Linear(int(hidden_size / 2), self.tag_size) + self.hidden2tag_obj = nn.Linear(int(hidden_size / 2), self.tag_size) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input_features): + """ + Args: + input_features: (bs, seq_len, h) + """ + features_tmp = self.linear(input_features) + features_tmp = nn.ReLU()(features_tmp) + features_tmp = self.dropout(features_tmp) + sub_output = self.hidden2tag_sub(features_tmp) + obj_output = self.hidden2tag_obj(features_tmp) + return sub_output, obj_output + + +class PRGCBert(BertPreTrainedModel): + def __init__( + self, + config, + seq_tag_size=3, + drop_prob=0.3, + emb_fusion='concat', + corres_mode=None, + biaffine_hidden_size=128, + ): + super().__init__(config) + self.seq_tag_size = seq_tag_size + self.rel_num = config.num_labels + self.emb_fusion = emb_fusion + + # pretrain model + self.bert = BertModel(config) + # sequence tagging + self.sequence_tagging_sub = MultiNonLinearClassifier(config.hidden_size * 2, self.seq_tag_size, drop_prob) + self.sequence_tagging_obj = MultiNonLinearClassifier(config.hidden_size * 2, self.seq_tag_size, drop_prob) + self.sequence_tagging_sum = SequenceLabelForSO(config.hidden_size, self.seq_tag_size, drop_prob) + + # relation judgement + self.rel_judgement = MultiNonLinearClassifier(config.hidden_size, self.rel_num, drop_prob) + self.rel_embedding = nn.Embedding(self.rel_num, config.hidden_size) + + self.corres_mode = corres_mode + if self.corres_mode == 'biaffine': + self.U = torch.nn.Parameter(torch.randn(biaffine_hidden_size, 1, biaffine_hidden_size)) + self.start_encoder = torch.nn.Sequential(torch.nn.Linear(in_features=config.hidden_size, + out_features=biaffine_hidden_size), + torch.nn.ReLU()) + self.end_encoder = torch.nn.Sequential(torch.nn.Linear(in_features=config.hidden_size, + out_features=biaffine_hidden_size), + torch.nn.ReLU()) + else: + # global correspondence + self.global_corres = MultiNonLinearClassifier(config.hidden_size * 2, 1, drop_prob) + + self.init_weights() + + @staticmethod + def masked_avgpool(sent, mask): + mask_ = mask.masked_fill(mask == 0, -1e9).float() + score = torch.softmax(mask_, -1) + return torch.matmul(score.unsqueeze(1), sent).squeeze(1) + + def forward( + self, + input_ids=None, + attention_mask=None, + seq_tags=None, + potential_rels=None, + rel_threshold=0.1, + **kwargs + + ): + """ + Args: + input_ids: (batch_size, seq_len) + attention_mask: (batch_size, seq_len) + rel_tags: (bs, rel_num) + potential_rels: (bs,), only in train stage. + seq_tags: (bs, 2, seq_len) + corres_tags: (bs, seq_len, seq_len) + ex_params: experiment parameters + """ + + # pre-train model + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True + ) # sequence_output, pooled_output, (hidden_states), (attentions) + + sequence_output = outputs[0] + bs, seq_len, h = sequence_output.size() + + # (bs, h) + h_k_avg = self.masked_avgpool(sequence_output, attention_mask) + # (bs, rel_num) + rel_pred = self.rel_judgement(h_k_avg) + + if self.corres_mode == 'biaffine': + sub_extend = self.start_encoder(sequence_output) + obj_extend = self.end_encoder(sequence_output) + + corres_pred = torch.einsum('bxi,ioj,byj->bxyo', sub_extend, self.U, obj_extend).squeeze(-1) + else: + sub_extend = sequence_output.unsqueeze(2).expand(-1, -1, seq_len, -1) # (bs, s, s, h) + obj_extend = sequence_output.unsqueeze(1).expand(-1, seq_len, -1, -1) # (bs, s, s, h) + # batch x seq_len x seq_len x 2*hidden + corres_pred = torch.cat([sub_extend, obj_extend], 3) + # (bs, seq_len, seq_len) + corres_pred = self.global_corres(corres_pred).squeeze(-1) + + # relation predict and data construction in inference stage + xi, pred_rels = None, None + if seq_tags is None: + # (bs, rel_num) + rel_pred_onehot = torch.where(torch.sigmoid(rel_pred) > rel_threshold, + torch.ones(rel_pred.size(), device=rel_pred.device), + torch.zeros(rel_pred.size(), device=rel_pred.device)) + + # if potential relation is null + for idx, sample in enumerate(rel_pred_onehot): + if 1 not in sample: + # (rel_num,) + max_index = torch.argmax(rel_pred[idx]) + sample[max_index] = 1 + rel_pred_onehot[idx] = sample + + # 2*(sum(x_i),) + bs_idxs, pred_rels = torch.nonzero(rel_pred_onehot, as_tuple=True) + # get x_i + xi_dict = Counter(bs_idxs.tolist()) + xi = [xi_dict[idx] for idx in range(bs)] + + pos_seq_output = [] + pos_potential_rel = [] + pos_attention_mask = [] + for bs_idx, rel_idx in zip(bs_idxs, pred_rels): + # (seq_len, h) + pos_seq_output.append(sequence_output[bs_idx]) + pos_attention_mask.append(attention_mask[bs_idx]) + pos_potential_rel.append(rel_idx) + # (sum(x_i), seq_len, h) + sequence_output = torch.stack(pos_seq_output, dim=0) + # (sum(x_i), seq_len) + attention_mask = torch.stack(pos_attention_mask, dim=0) + # (sum(x_i),) + potential_rels = torch.stack(pos_potential_rel, dim=0) + + # (bs/sum(x_i), h) + rel_emb = self.rel_embedding(potential_rels) + + # relation embedding vector fusion + rel_emb = rel_emb.unsqueeze(1).expand(-1, seq_len, h) + + if self.emb_fusion == 'concat': + # (bs/sum(x_i), seq_len, 2*h) + decode_input = torch.cat([sequence_output, rel_emb], dim=-1) + # (bs/sum(x_i), seq_len, tag_size) + output_sub = self.sequence_tagging_sub(decode_input) + output_obj = self.sequence_tagging_obj(decode_input) + + elif self.emb_fusion == 'sum': + # (bs/sum(x_i), seq_len, h) + decode_input = sequence_output + rel_emb + # (bs/sum(x_i), seq_len, tag_size) + output_sub, output_obj = self.sequence_tagging_sum(decode_input) + + if xi is None: + return output_sub, output_obj, corres_pred, rel_pred + else: + + return output_sub, output_obj, corres_pred, pred_rels, xi \ No newline at end of file diff --git a/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_dataset.py b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_dataset.py new file mode 100644 index 0000000..284f90d --- /dev/null +++ b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_dataset.py @@ -0,0 +1,168 @@ +""" +# Copyright Xiang Wang, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 + +Author: Xiang Wang, xiangking1995@163.com +Status: Active +""" + +import copy +import json +import torch +import random +import codecs +import numpy as np +import pandas as pd + +from collections import defaultdict +from functools import lru_cache +from torch.utils.data import Dataset +from ark_nlp.dataset.base._dataset import BaseDataset + + +class PRGCREDataset(BaseDataset): + + def __init__(self, *args, **kwargs): + super(PRGCREDataset, self).__init__(*args, **kwargs) + self.sublabel2id = {"B-H": 1, "I-H": 2, "O": 0} + self.oblabel2id = {"B-T": 1, "I-T": 2, "O": 0} + + def _get_categories(self): + return sorted(list(set([triple_[3] for data_ in self.dataset for triple_ in data_['label']]))) + + def _convert_to_dataset(self, data_df): + + dataset = [] + + data_df['text'] = data_df['text'].apply(lambda x: x.strip()) + if not self.is_test: + data_df['label'] = data_df['label'].apply(lambda x: eval(x)) + + feature_names = list(data_df.columns) + for index_, row_ in enumerate(data_df.itertuples()): + + dataset.append({feature_name_: getattr(row_, feature_name_) + for feature_name_ in feature_names}) + return dataset + + def _convert_to_transfomer_ids(self, tokenizer): + """ + 将文本转化成id的形式 + + :param tokenizer: + + ToDo: 将__getitem__部分ID化代码迁移到这部分 + + """ + self.tokenizer = tokenizer + + if self.is_retain_dataset: + self.retain_dataset = copy.deepcopy(self.dataset) + + features = [] + for (index_, row_) in enumerate(self.dataset): + + text = row_['text'] + + if len(text) > self.tokenizer.max_seq_len - 2: + text = text[:self.tokenizer.max_seq_len - 2] + + tokens = self.tokenizer.tokenize(text) + token_mapping = self.tokenizer.get_token_mapping(text, tokens, is_mapping_index=False) + index_token_mapping = self.tokenizer.get_token_mapping(text, tokens) + + start_mapping = {j[0]: i for i, j in enumerate(index_token_mapping) if j} + end_mapping = {j[-1]: i for i, j in enumerate(index_token_mapping) if j} + + input_ids, input_mask, segment_ids = self.tokenizer.sequence_to_ids(tokens) + + if not self.is_train: + triples = [] + + for triple in row_['label']: + sub_head_idx = triple[1] + sub_end_idx = triple[2] + obj_head_idx = triple[5] + obj_end_idx = triple[6] + + if sub_head_idx in start_mapping and obj_head_idx in start_mapping and sub_end_idx in end_mapping and obj_end_idx in end_mapping: + sub_head_idx = start_mapping[sub_head_idx] + obj_head_idx = start_mapping[obj_head_idx] + + triples.append((('H', sub_head_idx + 1, end_mapping[sub_end_idx] + 1 + 1), + ('T', obj_head_idx + 1, end_mapping[obj_end_idx] + 1 + 1), + self.cat2id[triple[3]])) + + feature = { + 'input_ids': input_ids, + 'attention_mask': input_mask, + 'triples': triples, + 'token_mapping': token_mapping + } + + features.append(feature) + + else: + corres_tag = np.zeros((self.tokenizer.max_seq_len, self.tokenizer.max_seq_len)) + rel_tag = len(self.cat2id) * [0] + rel_entities = defaultdict(set) + + for triple in row_['label']: + sub_head_idx = triple[1] + sub_end_idx = triple[2] + obj_head_idx = triple[5] + obj_end_idx = triple[6] + + # construct relation tag + rel_tag[self.cat2id[triple[3]]] = 1 + + if sub_head_idx in start_mapping and obj_head_idx in start_mapping and sub_end_idx in end_mapping and obj_end_idx in end_mapping: + sub_head_idx = start_mapping[sub_head_idx] + obj_head_idx = start_mapping[obj_head_idx] + + corres_tag[sub_head_idx+1][obj_head_idx+1] = 1 + rel_entities[self.cat2id[triple[3]]].add((sub_head_idx, end_mapping[sub_end_idx], obj_head_idx, end_mapping[obj_end_idx])) + + sub_feats = [] + + for rel, en_ll in rel_entities.items(): + # init + tags_sub = self.tokenizer.max_seq_len * [self.sublabel2id['O']] + tags_obj = self.tokenizer.max_seq_len * [self.oblabel2id['O']] + + for en in en_ll: + # get sub and obj head + sub_head_idx, sub_end_idx, obj_head_idx, obj_end_idx = en + + tags_sub[sub_head_idx + 1] = self.sublabel2id['B-H'] + tags_sub[sub_head_idx + 1 + 1 : sub_end_idx + 1 + 1] = (sub_end_idx - sub_head_idx) * [self.sublabel2id['I-H']] + + tags_obj[obj_head_idx + 1] = self.oblabel2id['B-T'] + tags_obj[obj_head_idx + 1 + 1 : obj_end_idx + 1 + 1] = (obj_end_idx - obj_head_idx) * [self.oblabel2id['I-T']] + + seq_tag = [tags_sub, tags_obj] + + feature = { + 'input_ids': input_ids, + 'attention_mask': input_mask, + 'corres_tags': corres_tag, + 'seq_tags': seq_tag, + 'potential_rels': rel, + 'rel_tags': rel_tag, + 'token_mapping': token_mapping + } + + features.append(feature) + + return features + + @property + def to_device_cols(self): + if self.is_train: + return ['input_ids', 'attention_mask', 'corres_tags', 'seq_tags', 'potential_rels', 'rel_tags'] + else: + return ['input_ids', 'attention_mask'] \ No newline at end of file diff --git a/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_predictor.py b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_predictor.py new file mode 100644 index 0000000..7eb0e0a --- /dev/null +++ b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_predictor.py @@ -0,0 +1,246 @@ +""" +# Copyright 2021 Xiang Wang, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 + +Author: Xiang Wang, xiangking1995@163.com +Status: Active +""" + +import tqdm +import torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import sklearn.metrics as sklearn_metrics + +from tqdm import tqdm +from torch.autograd import grad +from torch.autograd import Variable +from torch.optim import lr_scheduler +from torch.utils.data import Dataset +from torch.utils.data import DataLoader + + +def get_chunk_type(tok, idx_to_tag): + """ + Args: + tok: id of token, ex 4 + idx_to_tag: dictionary {4: "B-PER", ...} + Returns: + tuple: "B", "PER" + """ + tag_name = idx_to_tag[tok] + content = tag_name.split('-') + tag_class = content[0] + if len(content) == 1: + return tag_class + ht = content[-1] + return tag_class, ht + + +def get_chunks(seq, tags): + """Given a sequence of tags, group entities and their position + Args: + seq: np.array[4, 4, 0, 0, ...] sequence of labels + tags: dict["O"] = 4 + Returns: + list of (chunk_type, chunk_start, chunk_end) + Example: + seq = [4, 5, 0, 3] + tags = {"B-PER": 4, "I-PER": 5, "B-LOC": 3} + result = [("PER", 0, 2), ("LOC", 3, 4)] + """ + default1 = tags['O'] + idx_to_tag = {idx: tag for tag, idx in tags.items()} + chunks = [] + chunk_type, chunk_start = None, None + for i, tok in enumerate(seq): + # End of a chunk 1 + if tok == default1 and chunk_type is not None: + # Add a chunk. + chunk = (chunk_type, chunk_start, i) + chunks.append(chunk) + chunk_type, chunk_start = None, None + + # End of a chunk + start of a chunk! + elif tok != default1: + res = get_chunk_type(tok, idx_to_tag) + if len(res) == 1: + continue + tok_chunk_class, ht = get_chunk_type(tok, idx_to_tag) + tok_chunk_type = ht + if chunk_type is None: + chunk_type, chunk_start = tok_chunk_type, i + elif tok_chunk_type != chunk_type or tok_chunk_class == "B": + chunk = (chunk_type, chunk_start, i) + chunks.append(chunk) + chunk_type, chunk_start = tok_chunk_type, i + else: + pass + + # end condition + if chunk_type is not None: + chunk = (chunk_type, chunk_start, len(seq)) + chunks.append(chunk) + + return chunks + + +def tag_mapping_corres(predict_tags, pre_corres, pre_rels=None, label2idx_sub=None, label2idx_obj=None): + """ + Args: + predict_tags: np.array, (xi, 2, max_sen_len) + pre_corres: (seq_len, seq_len) + pre_rels: (xi,) + """ + rel_num = predict_tags.shape[0] + pre_triples = [] + for idx in range(rel_num): + heads, tails = [], [] + pred_chunks_sub = get_chunks(predict_tags[idx][0], label2idx_sub) + pred_chunks_obj = get_chunks(predict_tags[idx][1], label2idx_obj) + pred_chunks = pred_chunks_sub + pred_chunks_obj + for ch in pred_chunks: + if ch[0] == 'H': + heads.append(ch) + elif ch[0] == 'T': + tails.append(ch) + retain_hts = [(h, t) for h in heads for t in tails if pre_corres[h[1]][t[1]] == 1] + for h_t in retain_hts: + if pre_rels is not None: + triple = list(h_t) + [pre_rels[idx]] + else: + triple = list(h_t) + [idx] + pre_triples.append(tuple(triple)) + return pre_triples + + +class PRGCREPredictor(object): + def __init__( + self, + module, + tokernizer, + cat2id, + corres_threshold=0.5 + ): + self.module = module + self.module.task = 'TokenLevel' + + self.corres_threshold = corres_threshold + + self.cat2id = cat2id + self.tokenizer = tokernizer + self.device = list(self.module.parameters())[0].device + + self.id2cat = {} + for cat_, idx_ in self.cat2id.items(): + self.id2cat[idx_] = cat_ + + self.sublabel2id = {"B-H": 1, "I-H": 2, "O": 0} + self.oblabel2id = {"B-T": 1, "I-T": 2, "O": 0} + + def _convert_to_transfomer_ids( + self, + text + ): + if len(text) > self.tokenizer.max_seq_len - 2: + text = text[:self.tokenizer.max_seq_len - 2] + + tokens = self.tokenizer.tokenize(text) + token_mapping = self.tokenizer.get_token_mapping(text, tokens, is_mapping_index=False) + index_token_mapping = self.tokenizer.get_token_mapping(text, tokens) + + start_mapping = {j[0]: i for i, j in enumerate(index_token_mapping) if j} + end_mapping = {j[-1]: i for i, j in enumerate(index_token_mapping) if j} + + input_ids, input_mask, segment_ids = self.tokenizer.sequence_to_ids(tokens) + + features = { + 'input_ids': input_ids, + 'attention_mask': input_mask, + 'token_mapping': token_mapping + } + + return features + + def _get_input_ids( + self, + text + ): + if self.tokenizer.tokenizer_type == 'transfomer': + return self._convert_to_transfomer_ids(text) + else: + raise ValueError("The tokenizer type does not exist") + + def _get_module_one_sample_inputs( + self, + features + ): + inputs = {} + for col in features: + if isinstance(features[col], np.ndarray): + inputs[col] = torch.Tensor(features[col]).type(torch.long).unsqueeze(0).to(self.device) + else: + inputs[col] = features[col] + + return inputs + + def predict_one_sample( + self, + text='', + + ): + features = self._get_input_ids(text) + self.module.eval() + + with torch.no_grad(): + + inputs = self._get_module_one_sample_inputs(features) + + logits = self.module(**inputs) + + token_mapping = inputs['token_mapping'] + + output_sub, output_obj, corres_pred, pred_rels, xi = logits + + pred_seq_sub = torch.argmax(torch.softmax(output_sub, dim=-1), dim=-1) + pred_seq_obj = torch.argmax(torch.softmax(output_obj, dim=-1), dim=-1) + pred_seqs = torch.cat([pred_seq_sub.unsqueeze(1), pred_seq_obj.unsqueeze(1)], dim=1) + + mask_tmp1 = inputs['attention_mask'].unsqueeze(-1) + mask_tmp2 = inputs['attention_mask'].unsqueeze(1) + corres_mask = mask_tmp1 * mask_tmp2 + + corres_pred = torch.sigmoid(corres_pred) * corres_mask + pre_corres = torch.where(corres_pred > self.corres_threshold, + torch.ones(corres_pred.size(), device=corres_pred.device), + torch.zeros(corres_pred.size(), device=corres_pred.device)) + + pred_seqs = pred_seqs.detach().cpu().numpy() + pre_corres = pre_corres.detach().cpu().numpy() + + xi = np.array(xi) + pred_rels = pred_rels.detach().cpu().numpy() + xi_index = np.cumsum(xi).tolist() + xi_index.insert(0, 0) + + pre_triples = tag_mapping_corres(predict_tags=pred_seqs[xi_index[0]:xi_index[1]], + pre_corres=pre_corres[0], + pre_rels=pred_rels[xi_index[0]:xi_index[1]], + label2idx_sub=self.sublabel2id, + label2idx_obj=self.oblabel2id) + + triple_set = set() + for _pre_triple in pre_triples: + sub = ''.join([token_mapping[index_] for index_ in range(_pre_triple[0][1]-1, _pre_triple[0][2]-1)]) + obj = ''.join([token_mapping[index_] for index_ in range(_pre_triple[1][1]-1, _pre_triple[1][2]-1)]) + rel = self.id2cat[_pre_triple[2]] + + triple_set.add((sub, rel, obj)) + + return list(triple_set) \ No newline at end of file diff --git a/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_task.py b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_task.py new file mode 100644 index 0000000..15355a0 --- /dev/null +++ b/ark_nlp/model/re/prgc_bert/prgc_relation_extraction_task.py @@ -0,0 +1,439 @@ +""" +# Copyright 2020 Xiang Wang, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 + +Author: Xiang Wang, xiangking1995@163.com +Status: Active +""" + +import time +import tqdm +import torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import sklearn.metrics as sklearn_metrics + +from tqdm import tqdm +from torch.optim import lr_scheduler +from torch.autograd import Variable +from torch.autograd import grad +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +from ark_nlp.factory.loss_function import get_loss +from ark_nlp.factory.optimizer import get_optimizer +from ark_nlp.factory.task.base._task import Task +from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask + + +def get_chunk_type(tok, idx_to_tag): + """ + Args: + tok: id of token, ex 4 + idx_to_tag: dictionary {4: "B-PER", ...} + Returns: + tuple: "B", "PER" + """ + tag_name = idx_to_tag[tok] + content = tag_name.split('-') + tag_class = content[0] + if len(content) == 1: + return tag_class + ht = content[-1] + return tag_class, ht + + +def get_chunks(seq, tags): + """Given a sequence of tags, group entities and their position + Args: + seq: np.array[4, 4, 0, 0, ...] sequence of labels + tags: dict["O"] = 4 + Returns: + list of (chunk_type, chunk_start, chunk_end) + Example: + seq = [4, 5, 0, 3] + tags = {"B-PER": 4, "I-PER": 5, "B-LOC": 3} + result = [("PER", 0, 2), ("LOC", 3, 4)] + """ + default1 = tags['O'] + idx_to_tag = {idx: tag for tag, idx in tags.items()} + chunks = [] + chunk_type, chunk_start = None, None + for i, tok in enumerate(seq): + # End of a chunk 1 + if tok == default1 and chunk_type is not None: + # Add a chunk. + chunk = (chunk_type, chunk_start, i) + chunks.append(chunk) + chunk_type, chunk_start = None, None + + # End of a chunk + start of a chunk! + elif tok != default1: + res = get_chunk_type(tok, idx_to_tag) + if len(res) == 1: + continue + tok_chunk_class, ht = get_chunk_type(tok, idx_to_tag) + tok_chunk_type = ht + if chunk_type is None: + chunk_type, chunk_start = tok_chunk_type, i + elif tok_chunk_type != chunk_type or tok_chunk_class == "B": + chunk = (chunk_type, chunk_start, i) + chunks.append(chunk) + chunk_type, chunk_start = tok_chunk_type, i + else: + pass + + # end condition + if chunk_type is not None: + chunk = (chunk_type, chunk_start, len(seq)) + chunks.append(chunk) + + return chunks + + +def tag_mapping_corres(predict_tags, pre_corres, pre_rels=None, label2idx_sub=None, label2idx_obj=None): + """ + Args: + predict_tags: np.array, (xi, 2, max_sen_len) + pre_corres: (seq_len, seq_len) + pre_rels: (xi,) + """ + rel_num = predict_tags.shape[0] + pre_triples = [] + for idx in range(rel_num): + heads, tails = [], [] + pred_chunks_sub = get_chunks(predict_tags[idx][0], label2idx_sub) + pred_chunks_obj = get_chunks(predict_tags[idx][1], label2idx_obj) + pred_chunks = pred_chunks_sub + pred_chunks_obj + for ch in pred_chunks: + if ch[0] == 'H': + heads.append(ch) + elif ch[0] == 'T': + tails.append(ch) + retain_hts = [(h, t) for h in heads for t in tails if pre_corres[h[1]][t[1]] == 1] + for h_t in retain_hts: + if pre_rels is not None: + triple = list(h_t) + [pre_rels[idx]] + else: + triple = list(h_t) + [idx] + pre_triples.append(tuple(triple)) + return pre_triples + + +def get_metrics(correct_num, predict_num, gold_num): + p = correct_num / predict_num if predict_num > 0 else 0 + r = correct_num / gold_num if gold_num > 0 else 0 + f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0 + return { + 'correct_num': correct_num, + 'predict_num': predict_num, + 'gold_num': gold_num, + 'precision': p, + 'recall': r, + 'f1': f1 + } + + +class PRGCRETask(SequenceClassificationTask): + + def __init__(self, *args, **kwargs): + + super(PRGCRETask, self).__init__(*args, **kwargs) + if hasattr(self.module, 'task') is False: + self.module.task = 'TokenLevel' + + def _collate_fn_train(self, features): + """将InputFeatures转换为Tensor""" + + input_ids = torch.tensor([f['input_ids'] for f in features], dtype=torch.long) + attention_mask = torch.tensor([f['attention_mask'] for f in features], dtype=torch.long) + seq_tags = torch.tensor([f['seq_tags'] for f in features], dtype=torch.long) + poten_relations = torch.tensor([f['potential_rels'] for f in features], dtype=torch.long) + corres_tags = torch.tensor([f['corres_tags'] for f in features], dtype=torch.long) + rel_tags = torch.tensor([f['rel_tags'] for f in features], dtype=torch.long) + token_mapping = [f['token_mapping'] for f in features] + + tensors = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'seq_tags': seq_tags, + 'potential_rels': poten_relations, + 'corres_tags': corres_tags, + 'rel_tags': rel_tags, + 'token_mapping': token_mapping + } + + return tensors + + def _collate_fn_evaluate(self, features): + """将InputFeatures转换为Tensor""" + + input_ids = torch.tensor([f['input_ids'] for f in features], dtype=torch.long) + attention_mask = torch.tensor([f['attention_mask'] for f in features], dtype=torch.long) + triples = [f['triples'] for f in features] + token_mapping = [f['token_mapping'] for f in features] + + tensors = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'triples': triples, + 'token_mapping': token_mapping + } + + return tensors + + def _on_train_begin( + self, + train_data, + validation_data, + batch_size, + lr, + params, + shuffle, + train_to_device_cols=None, + **kwargs + ): + + if self.class_num == None: + self.class_num = train_data.class_num + + if train_to_device_cols == None: + self.train_to_device_cols = train_data.to_device_cols + else: + self.train_to_device_cols = train_to_device_cols + + train_generator = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=self._collate_fn_train) + self.train_generator_lenth = len(train_generator) + + self.optimizer = get_optimizer(self.optimizer, self.module, lr, params) + self.optimizer.zero_grad() + + self.module.train() + + self._on_train_begin_record(**kwargs) + + return train_generator + + def _compute_loss( + self, + inputs, + logits, + verbose=True, + **kwargs + ): + batch_size, _ = inputs['input_ids'].size() + + output_sub, output_obj, corres_pred, rel_pred = logits + + mask_tmp1 = inputs['attention_mask'].unsqueeze(-1) + mask_tmp2 = inputs['attention_mask'].unsqueeze(1) + corres_mask = mask_tmp1 * mask_tmp2 + + attention_mask = inputs['attention_mask'].view(-1) + # sequence label loss + loss_func = nn.CrossEntropyLoss(reduction='none') + loss_seq_sub = (loss_func(output_sub.view(-1, self.module.seq_tag_size), + inputs['seq_tags'][:, 0, :].reshape(-1)) * attention_mask).sum() / attention_mask.sum() + loss_seq_obj = (loss_func(output_obj.view(-1, self.module.seq_tag_size), + inputs['seq_tags'][:, 1, :].reshape(-1)) * attention_mask).sum() / attention_mask.sum() + loss_seq = (loss_seq_sub + loss_seq_obj) / 2 + # init + loss_matrix, loss_rel = torch.tensor(0), torch.tensor(0) + + corres_pred = corres_pred.view(batch_size, -1) + corres_mask = corres_mask.view(batch_size, -1) + corres_tags = inputs['corres_tags'].view(batch_size, -1) + + loss_func = nn.BCEWithLogitsLoss(reduction='none') + + loss_matrix = (loss_func(corres_pred, + corres_tags.float()) * corres_mask).sum() / corres_mask.sum() + + loss_func = nn.BCEWithLogitsLoss(reduction='mean') + loss_rel = loss_func(rel_pred, inputs['rel_tags'].float()) + + loss = loss_seq + loss_matrix + loss_rel + + return loss + + def _compute_loss_record( + self, + inputs, + logits, + loss, + verbose, + **kwargs + ): + + self.logs['epoch_loss'] += loss.item() + self.logs['epoch_step'] += 1 + self.logs['global_step'] += 1 + + def fit( + self, + train_data=None, + validation_data=None, + lr=False, + params=None, + batch_size=32, + epochs=1, + **kwargs + ): + self.logs = dict() + + self.id2cat = train_data.id2cat + + train_generator = self._on_train_begin(train_data, validation_data, batch_size, lr, params, shuffle=True, **kwargs) + + for epoch in range(epochs): + + self._on_epoch_begin(**kwargs) + + for step, inputs in enumerate(tqdm(train_generator)): + + self._on_step_begin(epoch, step, inputs, **kwargs) + + inputs = self._get_module_inputs_on_train(inputs, **kwargs) + + # forward + logits = self.module(**inputs) + + # 计算损失 + loss = self._get_train_loss(inputs, logits, **kwargs) + + # loss backword + loss = self._on_backward(inputs, logits, loss, **kwargs) + + # optimize + step = self._on_optimize(step, **kwargs) + + # setp evaluate + self._on_step_end(step, inputs, logits, loss, **kwargs) + + self._on_epoch_end(epoch, **kwargs) + + if validation_data is not None: + self.evaluate(validation_data, **kwargs) + + self._on_train_end(**kwargs) + + def _on_evaluate_begin( + self, + validation_data, + batch_size, + shuffle, + evaluate_to_device_cols=None, + **kwargs + ): + + self.evaluate_id2sublabel = validation_data.sublabel2id + self.evaluate_id2oblabel = validation_data.oblabel2id + + if evaluate_to_device_cols == None: + self.evaluate_to_device_cols = validation_data.to_device_cols + else: + self.evaluate_to_device_cols = evaluate_to_device_cols + + generator = DataLoader(validation_data, batch_size=batch_size, shuffle=False, collate_fn=self._collate_fn_evaluate) + + self.module.eval() + + self._on_evaluate_begin_record(**kwargs) + + return generator + + def _on_evaluate_begin_record(self, **kwargs): + + self.evaluate_logs['correct_num'] = 0 + self.evaluate_logs['predict_num'] = 0 + self.evaluate_logs['gold_num'] = 0 + self.evaluate_logs['eval_step'] = 0 + self.evaluate_logs['eval_loss'] = 0 + self.evaluate_logs['eval_example'] = 0 + + + def _on_evaluate_step_end(self, inputs, logits, corres_threshold=0.5, **kwargs): + + batch_size, _ = inputs['input_ids'].size() + token_mappings = inputs['token_mapping'] + + output_sub, output_obj, corres_pred, pred_rels, xi = logits + + pred_seq_sub = torch.argmax(torch.softmax(output_sub, dim=-1), dim=-1) + pred_seq_obj = torch.argmax(torch.softmax(output_obj, dim=-1), dim=-1) + pred_seqs = torch.cat([pred_seq_sub.unsqueeze(1), pred_seq_obj.unsqueeze(1)], dim=1) + + mask_tmp1 = inputs['attention_mask'].unsqueeze(-1) + mask_tmp2 = inputs['attention_mask'].unsqueeze(1) + corres_mask = mask_tmp1 * mask_tmp2 + + corres_pred = torch.sigmoid(corres_pred) * corres_mask + pre_corres = torch.where(corres_pred > corres_threshold, + torch.ones(corres_pred.size(), device=corres_pred.device), + torch.zeros(corres_pred.size(), device=corres_pred.device)) + + pred_seqs = pred_seqs.detach().cpu().numpy() + pre_corres = pre_corres.detach().cpu().numpy() + + xi = np.array(xi) + pred_rels = pred_rels.detach().cpu().numpy() + xi_index = np.cumsum(xi).tolist() + xi_index.insert(0, 0) + + for idx in range(batch_size): + pre_triples = tag_mapping_corres(predict_tags=pred_seqs[xi_index[idx]:xi_index[idx + 1]], + pre_corres=pre_corres[idx], + pre_rels=pred_rels[xi_index[idx]:xi_index[idx + 1]], + label2idx_sub=self.evaluate_id2sublabel, + label2idx_obj=self.evaluate_id2oblabel) + + self.evaluate_logs['correct_num'] += len(set(pre_triples) & set(inputs['triples'][idx])) + self.evaluate_logs['predict_num'] += len(set(pre_triples)) + self.evaluate_logs['gold_num'] += len(set(inputs['triples'][idx])) + + def _on_evaluate_epoch_end( + self, + validation_data, + epoch=1, + is_evaluate_print=True, + **kwargs + ): + + metrics = get_metrics(self.evaluate_logs['correct_num'], self.evaluate_logs['predict_num'] , self.evaluate_logs['gold_num']) + metrics_str = "; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics.items()) + + print(metrics_str) + + def evaluate( + self, + validation_data, + evaluate_batch_size=16, + return_pred=False, + **kwargs + ): + self.evaluate_logs = dict() + + generator = self._on_evaluate_begin(validation_data, evaluate_batch_size, shuffle=False, **kwargs) + + with torch.no_grad(): + + self._on_evaluate_epoch_begin(**kwargs) + + for step, inputs in enumerate(generator): + + inputs = self._get_module_inputs_on_eval(inputs, **kwargs) + + # forward + logits = self.module(**inputs) + + self._on_evaluate_step_end(inputs, logits, **kwargs) + + self._on_evaluate_epoch_end(validation_data, **kwargs) + + self._on_evaluate_end(**kwargs) \ No newline at end of file