Skip to content

Commit

Permalink
Merge pull request #2 from xiangking/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
XiangWang authored Aug 17, 2021
2 parents 4fef0a9 + b14b29c commit 4349728
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ark_nlp/factory/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_default_optimizer(module, module_name='bert', **kwargs):
if module_name == 'bert':
return get_default_bert_optimizer(module, **kwargs)
elif module_name == 'crf_bert':
return get_default_bert_crf_optimizer(module, **kwargs)
return get_default_crf_bert_optimizer(module, **kwargs)
else:
raise ValueError("The default optimizer does not exist")

Expand Down
18 changes: 16 additions & 2 deletions ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ def _on_epoch_begin_record(self, **kwargs):
self.logs['epoch_loss'] = 0
self.logs['epoch_example'] = 0
self.logs['epoch_step'] = 0

def _get_train_loss(
self,
inputs,
logits,
verbose=True,
**kwargs
):
# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)

self._compute_loss_record(inputs, logits, loss, verbose, **kwargs)

return loss

def _compute_loss(
self,
Expand All @@ -180,7 +194,7 @@ def _compute_loss(

def _compute_loss_record(
self,
inputs,
inputs,
logits,
loss,
verbose,
Expand Down Expand Up @@ -279,7 +293,7 @@ def fit(
logits = self.module(**inputs)

# 计算损失
loss = self._compute_loss(inputs, logits, **kwargs)
loss = self._get_train_loss(inputs, logits, **kwargs)

loss = self._on_backward(inputs, logits, loss, **kwargs)

Expand Down

0 comments on commit 4349728

Please sign in to comment.