Skip to content

Commit

Permalink
[src] nnet1: fixing issue in multi-task training (kaldi-asr#1491)
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelVesely84 authored and danpovey committed Mar 13, 2017
1 parent bd23a10 commit 1a4dbf6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
36 changes: 30 additions & 6 deletions src/nnet/nnet-loss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,41 @@ void MultiTaskLoss::Eval(const VectorBase<BaseFloat> &frame_weights,
// allocate diff matrix,
diff->Resize(num_frames, num_output);

/// One vector of frame_weights per loss-function,
/// The original frame weights are multiplied with
/// a mask of `defined targets' according to the 'Posterior'.
std::vector<Vector<BaseFloat> > frmwei_have_tgt;
for (int32 l = 0; l < loss_vec_.size(); l++) {
// copy original weights,
frmwei_have_tgt.push_back(Vector<BaseFloat>(frame_weights));
// We need to mask-out the frames for which the 'posterior' is not defined (= is empty):
int32 loss_beg = loss_dim_offset_[l]; // first column of loss target,
int32 loss_end = loss_dim_offset_[l+1]; // (last+1) column of loss target,
for (int32 f = 0; f < num_frames; f++) {
bool tgt_defined = false;
for (int32 p = 0; p < post[f].size(); p++) {
if (post[f][p].first >= loss_beg && post[f][p].first < loss_end) {
tgt_defined = true;
break;
}
}
if (!tgt_defined) {
frmwei_have_tgt[l](f) = 0.0; // set zero_weight for the frame with no targets!
}
}
}

// call the vector of loss functions,
CuMatrix<BaseFloat> diff_aux;
for (int32 i = 0; i < loss_vec_.size(); i++) {
loss_vec_[i]->Eval(frame_weights,
net_out.ColRange(loss_dim_offset_[i], loss_dim_[i]),
tgt_mat_.ColRange(loss_dim_offset_[i], loss_dim_[i]),
for (int32 l = 0; l < loss_vec_.size(); l++) {
loss_vec_[l]->Eval(frmwei_have_tgt[l],
net_out.ColRange(loss_dim_offset_[l], loss_dim_[l]),
tgt_mat_.ColRange(loss_dim_offset_[l], loss_dim_[l]),
&diff_aux);
// Scale the gradients,
diff_aux.Scale(loss_weights_[i]);
diff_aux.Scale(loss_weights_[l]);
// Copy to diff,
diff->ColRange(loss_dim_offset_[i], loss_dim_[i]).CopyFromMat(diff_aux);
diff->ColRange(loss_dim_offset_[l], loss_dim_[l]).CopyFromMat(diff_aux);
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/nnet/nnet-loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class Xent : public LossItf {

/// Get loss value (frame average),
BaseFloat AvgLoss() {
if (frames_.Sum() == 0) return 0.0;
return (xentropy_.Sum() - entropy_.Sum()) / frames_.Sum();
}

Expand Down Expand Up @@ -151,6 +152,7 @@ class Mse : public LossItf {

/// Get loss value (frame average),
BaseFloat AvgLoss() {
if (frames_ == 0) return 0.0;
return loss_ / frames_;
}

Expand Down

0 comments on commit 1a4dbf6

Please sign in to comment.