From 7d79572dac78642a3eb01b2c3e6e3e971646c9d4 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Mon, 13 Mar 2017 19:31:04 +0100 Subject: [PATCH] [src] nnet1: fixing issue in multi-task training (#1491) --- src/nnet/nnet-loss.cc | 36 ++++++++++++++++++++++++++++++------ src/nnet/nnet-loss.h | 2 ++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/nnet/nnet-loss.cc b/src/nnet/nnet-loss.cc index ba529fcb556..0c1bcfbe4b7 100644 --- a/src/nnet/nnet-loss.cc +++ b/src/nnet/nnet-loss.cc @@ -367,17 +367,41 @@ void MultiTaskLoss::Eval(const VectorBase &frame_weights, // allocate diff matrix, diff->Resize(num_frames, num_output); + /// One vector of frame_weights per loss-function, + /// The original frame weights are multiplied with + /// a mask of `defined targets' according to the 'Posterior'. + std::vector > frmwei_have_tgt; + for (int32 l = 0; l < loss_vec_.size(); l++) { + // copy original weights, + frmwei_have_tgt.push_back(Vector(frame_weights)); + // We need to mask-out the frames for which the 'posterior' is not defined (= is empty): + int32 loss_beg = loss_dim_offset_[l]; // first column of loss target, + int32 loss_end = loss_dim_offset_[l+1]; // (last+1) column of loss target, + for (int32 f = 0; f < num_frames; f++) { + bool tgt_defined = false; + for (int32 p = 0; p < post[f].size(); p++) { + if (post[f][p].first >= loss_beg && post[f][p].first < loss_end) { + tgt_defined = true; + break; + } + } + if (!tgt_defined) { + frmwei_have_tgt[l](f) = 0.0; // set zero_weight for the frame with no targets! + } + } + } + // call the vector of loss functions, CuMatrix diff_aux; - for (int32 i = 0; i < loss_vec_.size(); i++) { - loss_vec_[i]->Eval(frame_weights, - net_out.ColRange(loss_dim_offset_[i], loss_dim_[i]), - tgt_mat_.ColRange(loss_dim_offset_[i], loss_dim_[i]), + for (int32 l = 0; l < loss_vec_.size(); l++) { + loss_vec_[l]->Eval(frmwei_have_tgt[l], + net_out.ColRange(loss_dim_offset_[l], loss_dim_[l]), + tgt_mat_.ColRange(loss_dim_offset_[l], loss_dim_[l]), &diff_aux); // Scale the gradients, - diff_aux.Scale(loss_weights_[i]); + diff_aux.Scale(loss_weights_[l]); // Copy to diff, - diff->ColRange(loss_dim_offset_[i], loss_dim_[i]).CopyFromMat(diff_aux); + diff->ColRange(loss_dim_offset_[l], loss_dim_[l]).CopyFromMat(diff_aux); } } diff --git a/src/nnet/nnet-loss.h b/src/nnet/nnet-loss.h index 1e0558f1b39..56bd9ac0222 100644 --- a/src/nnet/nnet-loss.h +++ b/src/nnet/nnet-loss.h @@ -90,6 +90,7 @@ class Xent : public LossItf { /// Get loss value (frame average), BaseFloat AvgLoss() { + if (frames_.Sum() == 0) return 0.0; return (xentropy_.Sum() - entropy_.Sum()) / frames_.Sum(); } @@ -151,6 +152,7 @@ class Mse : public LossItf { /// Get loss value (frame average), BaseFloat AvgLoss() { + if (frames_ == 0) return 0.0; return loss_ / frames_; }