Skip to content

Commit

Permalink
feature (model) : Adding PRGC model and fixing some structure bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangking committed Aug 17, 2021
1 parent ad941e6 commit 4fef0a9
Show file tree
Hide file tree
Showing 9 changed files with 1,123 additions and 28 deletions.
39 changes: 33 additions & 6 deletions ark_nlp/factory/task/base/_sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def _on_step_begin_record(
**kwargs
):
pass

def _get_train_loss(
self,
inputs,
logits,
verbose=True,
**kwargs
):
# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _compute_loss(
self,
Expand All @@ -118,9 +132,7 @@ def _compute_loss(
**kwargs
):
loss = self.loss_function(logits, inputs['label_ids'])

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)


return loss

def _compute_loss_record(
Expand Down Expand Up @@ -187,6 +199,9 @@ def _on_optimize(
def _on_step_end(
self,
step,
inputs,
logits,
loss,
verbose=True,
show_step=100,
**kwargs
Expand Down Expand Up @@ -250,12 +265,24 @@ def _on_evaluate_epoch_begin(self, **kwargs):
self.ema.copy_to(self.module.parameters())

self._on_evaluate_epoch_begin_record(**kwargs)

def _get_evaluate_loss(
self,
inputs,
logits,
verbose=True,
**kwargs
):
# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)

return loss

def _on_evaluate_step_end(self, inputs, logits, **kwargs):

with torch.no_grad():
# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

labels = inputs['label_ids'].cpu()
logits = logits.cpu()
Expand Down Expand Up @@ -365,7 +392,7 @@ def fit(
logits = self.module(**inputs)

# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_train_loss(inputs, logits, **kwargs)

# loss backword
loss = self._on_backward(inputs, logits, loss, **kwargs)
Expand All @@ -374,7 +401,7 @@ def fit(
step = self._on_optimize(step, **kwargs)

# setp evaluate
self._on_step_end(step, **kwargs)
self._on_step_end(step, inputs, logits, loss, **kwargs)

self._on_epoch_end(epoch, **kwargs)

Expand Down
9 changes: 5 additions & 4 deletions ark_nlp/factory/task/base/_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _compute_loss(
torch.tensor(self.loss_function.ignore_index).type_as(inputs['label_ids'])
)
loss = self.loss_function(active_logits, active_labels)

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)


return loss

def _compute_loss_record(
Expand All @@ -73,6 +71,9 @@ def _compute_loss_record(
def _on_step_end(
self,
step,
inputs,
logits,
loss,
verbose=True,
print_step=100,
**kwargs
Expand Down Expand Up @@ -112,7 +113,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs):

with torch.no_grad():
# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

self.evaluate_logs['labels'].append(inputs['label_ids'].cpu())
self.evaluate_logs['logits'].append(logits.cpu())
Expand Down
19 changes: 4 additions & 15 deletions ark_nlp/factory/task/named_entity_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,14 @@ def _compute_loss(
**kwargs
):
loss = -1 * self.module.crf(emissions = logits, tags=inputs['label_ids'], mask=inputs['attention_mask'])

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _on_evaluate_step_end(self, inputs, logits, **kwargs):

with torch.no_grad():
# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

tags = self.module.crf.decode(logits, inputs['attention_mask'])
tags = tags.squeeze(0)
Expand Down Expand Up @@ -178,17 +176,14 @@ def _compute_loss(

span_loss *= span_mask
loss = torch.sum(span_loss) / inputs['span_mask'].size()[0]

if self.logs:
self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _on_evaluate_step_end(self, inputs, logits, **kwargs):

with torch.no_grad():
# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

logits = torch.nn.functional.softmax(logits, dim=-1)

Expand Down Expand Up @@ -239,9 +234,6 @@ def _compute_loss(
**kwargs
):
loss = self.loss_function(logits, inputs['label_ids'])

if self.logs:
self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

Expand Down Expand Up @@ -274,7 +266,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs):
with torch.no_grad():

# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

numerate, denominator = conlleval.global_pointer_f1_score(inputs['label_ids'].cpu(), logits.cpu())
self.evaluate_logs['numerate'] += numerate
Expand Down Expand Up @@ -330,9 +322,6 @@ def _compute_loss(

loss = start_loss + end_loss

if self.logs:
self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _on_evaluate_epoch_begin(self, **kwargs):
Expand All @@ -349,7 +338,7 @@ def _on_evaluate_step_end(self, inputs, logits, **kwargs):

with torch.no_grad():
# compute loss
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_evaluate_loss(inputs, logits, **kwargs)

length = inputs['attention_mask'].cpu().numpy().sum() - 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ def _compute_loss(

loss = self.loss_function(logits, inputs)

if self.logs:
self._compute_loss_record(inputs, inputs['label_ids'], logits, loss, verbose, **kwargs)

return loss

def _compute_loss_record(
Expand Down
17 changes: 17 additions & 0 deletions ark_nlp/model/re/prgc_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_dataset import PRGCREDataset
from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_dataset import PRGCREDataset as Dataset

from ark_nlp.processor.tokenizer.transfomer import SpanTokenizer as Tokenizer
from ark_nlp.processor.tokenizer.transfomer import SpanTokenizer as PRGCRETokenizer

from ark_nlp.nn import BertConfig as PRGCBertConfig
from ark_nlp.model.re.prgc_bert.prgc_bert import PRGCBert

from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer
from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_prgc_bert_optimizer

from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_task import PRGCRETask as Task
from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_task import PRGCRETask as PRGCRETask

from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_predictor import PRGCREPredictor as Predictor
from ark_nlp.model.re.prgc_bert.prgc_relation_extraction_predictor import PRGCREPredictor as PRGCREPredictor
Loading

0 comments on commit 4fef0a9

Please sign in to comment.