Skip to content

Commit

Permalink
fix (casrel bert) : Fixing loss log bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangking committed Aug 17, 2021
1 parent c47a8f9 commit b14b29c
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ def _on_epoch_begin_record(self, **kwargs):
self.logs['epoch_loss'] = 0
self.logs['epoch_example'] = 0
self.logs['epoch_step'] = 0

def _get_train_loss(
self,
inputs,
logits,
verbose=True,
**kwargs
):
# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _compute_loss(
self,
Expand All @@ -176,15 +190,11 @@ def _compute_loss(

loss = self.loss_function(logits, inputs)

if self.logs:
self._compute_loss_record(inputs, inputs['label_ids'], logits, loss, verbose, **kwargs)

return loss

def _compute_loss_record(
self,
inputs,
lables,
logits,
loss,
verbose,
Expand Down Expand Up @@ -283,7 +293,7 @@ def fit(
logits = self.module(**inputs)

# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_train_loss(inputs, logits, **kwargs)

loss = self._on_backward(inputs, logits, loss, **kwargs)

Expand Down

0 comments on commit b14b29c

Please sign in to comment.