From d46ff7e82804f0c7dc83ad23788a5da8e5d4c3c9 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 27 Dec 2017 03:22:39 +0530 Subject: [PATCH 1/3] [src] Ensure num-tree-leaves is a multiple of 8 (RE NVidia Volta architecture) (#2097) --- src/bin/build-tree.cc | 7 +++- src/tree/build-tree-utils.cc | 5 +-- src/tree/build-tree-utils.h | 4 ++- src/tree/build-tree.cc | 67 +++++++++++++++++++++++++++++++++--- src/tree/build-tree.h | 8 ++++- 5 files changed, 82 insertions(+), 9 deletions(-) diff --git a/src/bin/build-tree.cc b/src/bin/build-tree.cc index 69a94a54611..72774900d61 100644 --- a/src/bin/build-tree.cc +++ b/src/bin/build-tree.cc @@ -45,6 +45,7 @@ int main(int argc, char *argv[]) { BaseFloat thresh = 300.0; BaseFloat cluster_thresh = -1.0; // negative means use smallest split in splitting phase as thresh. int32 max_leaves = 0; + bool round_num_leaves = true; std::string occs_out_filename; ParseOptions po(usage); @@ -61,6 +62,9 @@ int main(int argc, char *argv[]) { "threshold for clustering after tree-building. 0 means " "no clustering; -1 means use as a clustering threshold the " "likelihood change of the final split."); + po.Register("round-num-leaves", &round_num_leaves, + "If true, then the number of leaves will be reduced to a " + "multiple of 8 by clustering."); po.Read(argc, argv); @@ -127,7 +131,8 @@ int main(int argc, char *argv[]) { thresh, max_leaves, cluster_thresh, - P); + P, + round_num_leaves); { // This block is to warn about low counts. std::vector split_stats; diff --git a/src/tree/build-tree-utils.cc b/src/tree/build-tree-utils.cc index 8e0b6db4204..93ae74deb88 100644 --- a/src/tree/build-tree-utils.cc +++ b/src/tree/build-tree-utils.cc @@ -599,7 +599,8 @@ EventMap *SplitDecisionTree(const EventMap &input_map, int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, BaseFloat thresh, - std::vector *mapping) { + std::vector *mapping, + int32 min_clusters) { // First map stats KALDI_ASSERT(stats.size() != 0); std::vector split_stats; @@ -627,7 +628,7 @@ int ClusterEventMapGetMapping(const EventMap &e_in, change; change = ClusterBottomUp(summed_stats_contiguous, thresh, - 0, // no min-clust: use threshold for now. + min_clusters, // usually 0 NULL, // don't need clusters out. &assignments); // this algorithm is quadratic, so might be quite slow. diff --git a/src/tree/build-tree-utils.h b/src/tree/build-tree-utils.h index 464fc6b14a3..8f3b1476330 100644 --- a/src/tree/build-tree-utils.h +++ b/src/tree/build-tree-utils.h @@ -192,8 +192,10 @@ EventMap *DoTableSplitMultiple(const EventMap &orig, // a particular phone, do this by providing a set of "stats" that correspond to just // this subset of leaves*. Leaves with no stats will not be clustered. // See build-tree.cc for an example of usage. +// min_clusters can be used with thresh "inf" to get exactly that many clusters. int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, - BaseFloat thresh, std::vector *mapping); + BaseFloat thresh, std::vector *mapping, + int32 min_clusters = 0); /// This is as ClusterEventMapGetMapping but a more convenient interface /// that exposes less of the internals. It uses a bottom-up clustering to diff --git a/src/tree/build-tree.cc b/src/tree/build-tree.cc index 62735c55421..ad944c441e7 100644 --- a/src/tree/build-tree.cc +++ b/src/tree/build-tree.cc @@ -22,6 +22,7 @@ #include "util/stl-utils.h" #include "tree/build-tree-utils.h" #include "tree/clusterable-classes.h" +#include "tree/build-tree.h" namespace kaldi { @@ -141,7 +142,8 @@ EventMap *BuildTree(Questions &qopts, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. - int32 P) { + int32 P, + bool round_num_leaves) { KALDI_ASSERT(thresh > 0 || max_leaves > 0); KALDI_ASSERT(stats.size() != 0); KALDI_ASSERT(!phone_sets.empty() @@ -212,8 +214,29 @@ EventMap *BuildTree(Questions &qopts, &num_removed); KALDI_LOG << "BuildTree: removed "<< num_removed << " leaves."; - int32 num_leaves = 0; - EventMap *tree_renumbered = RenumberEventMap(*tree_clustered, &num_leaves); + int32 num_leaves_out = 0; + EventMap *tree_renumbered; + if (round_num_leaves) { + // Round the number of leaves to a multiple of 8 by clustering the leaves + // and merging them within each cluster. + // The final number of leaves will be 'num_leaves_required'. + int32 num_leaves_required = ((num_leaves - num_removed) / 8) * 8; + std::vector leaf_mapping; + + int32 num_actually_removed = ClusterEventMapGetMapping( + *tree_clustered, stats, std::numeric_limits::infinity(), + &leaf_mapping, num_leaves_required); + KALDI_ASSERT(num_leaves - num_removed + - num_actually_removed == num_leaves_required); + + EventMap* tree_rounded = tree_clustered->Copy(leaf_mapping); + DeletePointers(&leaf_mapping); + tree_renumbered = RenumberEventMap(*tree_rounded, &num_leaves_out); + + delete tree_rounded; + } else { + tree_renumbered = RenumberEventMap(*tree_clustered, &num_leaves_out); + } BaseFloat objf_after_cluster = ObjfGivenMap(stats, *tree_renumbered); @@ -223,13 +246,49 @@ EventMap *BuildTree(Questions &qopts, KALDI_VLOG(1) << "Normalizing over only split phones, this is: " << ((objf_after_cluster-objf_before_cluster) / normalizer_filt) << " per frame."; - KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves; + KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves_out; delete tree_clustered; delete tree_split; delete tree_stub; return tree_renumbered; } else { + if (round_num_leaves) { + // Round the number of leaves to a multiple of 8 by clustering the leaves + // and merging them within each cluster. + // The final number of leaves will be 'num_leaves_required'. + BaseFloat objf_before_cluster = ObjfGivenMap(stats, *tree_split); + + int32 num_leaves_required = (num_leaves / 8) * 8; + std::vector leaf_mapping; + + int32 num_actually_removed = ClusterEventMapGetMapping( + *tree_split, stats, std::numeric_limits::infinity(), + &leaf_mapping, num_leaves_required); + + KALDI_ASSERT(num_actually_removed < 8); + + EventMap* tree_rounded = tree_split->Copy(leaf_mapping); + DeletePointers(&leaf_mapping); + + int32 num_leaves_out; + EventMap* tree_renumbered = RenumberEventMap(*tree_rounded, &num_leaves_out); + + BaseFloat objf_after_cluster = ObjfGivenMap(stats, *tree_renumbered); + + KALDI_VLOG(1) << "Objf change due to clustering " + << ((objf_after_cluster-objf_before_cluster) / normalizer) + << " per frame."; + KALDI_VLOG(1) << "Normalizing over only split phones, this is: " + << ((objf_after_cluster-objf_before_cluster) / normalizer_filt) + << " per frame."; + KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves_out; + + delete tree_stub; + delete tree_rounded; + return tree_renumbered; + } + delete tree_stub; return tree_split; } diff --git a/src/tree/build-tree.h b/src/tree/build-tree.h index 37bb1081ffa..498ac5a8e19 100644 --- a/src/tree/build-tree.h +++ b/src/tree/build-tree.h @@ -75,6 +75,11 @@ namespace kaldi { * @param P [in] The central position of the phone context window, e.g. 1 for a * triphone system. + * @param round_num_leaves [in] If true, then the number of leaves in the + * final tree is made a multiple of 8. This is done by + * further clustering the leaves after they are first + * clustered based on log-likelihood change. + * (See cluster_thresh above) (default: true) * @return Returns a pointer to an EventMap object that is the tree. */ @@ -88,7 +93,8 @@ EventMap *BuildTree(Questions &qopts, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. - int32 P); + int32 P, + bool round_num_leaves = true); /** From 48656c314815a42486b23f08722bc05bdc844452 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Dec 2017 13:56:58 -0800 Subject: [PATCH 2/3] Revert "[src] Ensure num-tree-leaves is a multiple of 8 (RE NVidia Volta architecture) (#2097)" (#2105) This reverts commit d46ff7e82804f0c7dc83ad23788a5da8e5d4c3c9. --- src/bin/build-tree.cc | 7 +--- src/tree/build-tree-utils.cc | 5 ++- src/tree/build-tree-utils.h | 4 +-- src/tree/build-tree.cc | 67 +++--------------------------------- src/tree/build-tree.h | 8 +---- 5 files changed, 9 insertions(+), 82 deletions(-) diff --git a/src/bin/build-tree.cc b/src/bin/build-tree.cc index 72774900d61..69a94a54611 100644 --- a/src/bin/build-tree.cc +++ b/src/bin/build-tree.cc @@ -45,7 +45,6 @@ int main(int argc, char *argv[]) { BaseFloat thresh = 300.0; BaseFloat cluster_thresh = -1.0; // negative means use smallest split in splitting phase as thresh. int32 max_leaves = 0; - bool round_num_leaves = true; std::string occs_out_filename; ParseOptions po(usage); @@ -62,9 +61,6 @@ int main(int argc, char *argv[]) { "threshold for clustering after tree-building. 0 means " "no clustering; -1 means use as a clustering threshold the " "likelihood change of the final split."); - po.Register("round-num-leaves", &round_num_leaves, - "If true, then the number of leaves will be reduced to a " - "multiple of 8 by clustering."); po.Read(argc, argv); @@ -131,8 +127,7 @@ int main(int argc, char *argv[]) { thresh, max_leaves, cluster_thresh, - P, - round_num_leaves); + P); { // This block is to warn about low counts. std::vector split_stats; diff --git a/src/tree/build-tree-utils.cc b/src/tree/build-tree-utils.cc index 93ae74deb88..8e0b6db4204 100644 --- a/src/tree/build-tree-utils.cc +++ b/src/tree/build-tree-utils.cc @@ -599,8 +599,7 @@ EventMap *SplitDecisionTree(const EventMap &input_map, int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, BaseFloat thresh, - std::vector *mapping, - int32 min_clusters) { + std::vector *mapping) { // First map stats KALDI_ASSERT(stats.size() != 0); std::vector split_stats; @@ -628,7 +627,7 @@ int ClusterEventMapGetMapping(const EventMap &e_in, change; change = ClusterBottomUp(summed_stats_contiguous, thresh, - min_clusters, // usually 0 + 0, // no min-clust: use threshold for now. NULL, // don't need clusters out. &assignments); // this algorithm is quadratic, so might be quite slow. diff --git a/src/tree/build-tree-utils.h b/src/tree/build-tree-utils.h index 8f3b1476330..464fc6b14a3 100644 --- a/src/tree/build-tree-utils.h +++ b/src/tree/build-tree-utils.h @@ -192,10 +192,8 @@ EventMap *DoTableSplitMultiple(const EventMap &orig, // a particular phone, do this by providing a set of "stats" that correspond to just // this subset of leaves*. Leaves with no stats will not be clustered. // See build-tree.cc for an example of usage. -// min_clusters can be used with thresh "inf" to get exactly that many clusters. int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, - BaseFloat thresh, std::vector *mapping, - int32 min_clusters = 0); + BaseFloat thresh, std::vector *mapping); /// This is as ClusterEventMapGetMapping but a more convenient interface /// that exposes less of the internals. It uses a bottom-up clustering to diff --git a/src/tree/build-tree.cc b/src/tree/build-tree.cc index ad944c441e7..62735c55421 100644 --- a/src/tree/build-tree.cc +++ b/src/tree/build-tree.cc @@ -22,7 +22,6 @@ #include "util/stl-utils.h" #include "tree/build-tree-utils.h" #include "tree/clusterable-classes.h" -#include "tree/build-tree.h" namespace kaldi { @@ -142,8 +141,7 @@ EventMap *BuildTree(Questions &qopts, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. - int32 P, - bool round_num_leaves) { + int32 P) { KALDI_ASSERT(thresh > 0 || max_leaves > 0); KALDI_ASSERT(stats.size() != 0); KALDI_ASSERT(!phone_sets.empty() @@ -214,29 +212,8 @@ EventMap *BuildTree(Questions &qopts, &num_removed); KALDI_LOG << "BuildTree: removed "<< num_removed << " leaves."; - int32 num_leaves_out = 0; - EventMap *tree_renumbered; - if (round_num_leaves) { - // Round the number of leaves to a multiple of 8 by clustering the leaves - // and merging them within each cluster. - // The final number of leaves will be 'num_leaves_required'. - int32 num_leaves_required = ((num_leaves - num_removed) / 8) * 8; - std::vector leaf_mapping; - - int32 num_actually_removed = ClusterEventMapGetMapping( - *tree_clustered, stats, std::numeric_limits::infinity(), - &leaf_mapping, num_leaves_required); - KALDI_ASSERT(num_leaves - num_removed - - num_actually_removed == num_leaves_required); - - EventMap* tree_rounded = tree_clustered->Copy(leaf_mapping); - DeletePointers(&leaf_mapping); - tree_renumbered = RenumberEventMap(*tree_rounded, &num_leaves_out); - - delete tree_rounded; - } else { - tree_renumbered = RenumberEventMap(*tree_clustered, &num_leaves_out); - } + int32 num_leaves = 0; + EventMap *tree_renumbered = RenumberEventMap(*tree_clustered, &num_leaves); BaseFloat objf_after_cluster = ObjfGivenMap(stats, *tree_renumbered); @@ -246,49 +223,13 @@ EventMap *BuildTree(Questions &qopts, KALDI_VLOG(1) << "Normalizing over only split phones, this is: " << ((objf_after_cluster-objf_before_cluster) / normalizer_filt) << " per frame."; - KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves_out; + KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves; delete tree_clustered; delete tree_split; delete tree_stub; return tree_renumbered; } else { - if (round_num_leaves) { - // Round the number of leaves to a multiple of 8 by clustering the leaves - // and merging them within each cluster. - // The final number of leaves will be 'num_leaves_required'. - BaseFloat objf_before_cluster = ObjfGivenMap(stats, *tree_split); - - int32 num_leaves_required = (num_leaves / 8) * 8; - std::vector leaf_mapping; - - int32 num_actually_removed = ClusterEventMapGetMapping( - *tree_split, stats, std::numeric_limits::infinity(), - &leaf_mapping, num_leaves_required); - - KALDI_ASSERT(num_actually_removed < 8); - - EventMap* tree_rounded = tree_split->Copy(leaf_mapping); - DeletePointers(&leaf_mapping); - - int32 num_leaves_out; - EventMap* tree_renumbered = RenumberEventMap(*tree_rounded, &num_leaves_out); - - BaseFloat objf_after_cluster = ObjfGivenMap(stats, *tree_renumbered); - - KALDI_VLOG(1) << "Objf change due to clustering " - << ((objf_after_cluster-objf_before_cluster) / normalizer) - << " per frame."; - KALDI_VLOG(1) << "Normalizing over only split phones, this is: " - << ((objf_after_cluster-objf_before_cluster) / normalizer_filt) - << " per frame."; - KALDI_VLOG(1) << "Num-leaves is now "<< num_leaves_out; - - delete tree_stub; - delete tree_rounded; - return tree_renumbered; - } - delete tree_stub; return tree_split; } diff --git a/src/tree/build-tree.h b/src/tree/build-tree.h index 498ac5a8e19..37bb1081ffa 100644 --- a/src/tree/build-tree.h +++ b/src/tree/build-tree.h @@ -75,11 +75,6 @@ namespace kaldi { * @param P [in] The central position of the phone context window, e.g. 1 for a * triphone system. - * @param round_num_leaves [in] If true, then the number of leaves in the - * final tree is made a multiple of 8. This is done by - * further clustering the leaves after they are first - * clustered based on log-likelihood change. - * (See cluster_thresh above) (default: true) * @return Returns a pointer to an EventMap object that is the tree. */ @@ -93,8 +88,7 @@ EventMap *BuildTree(Questions &qopts, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. - int32 P, - bool round_num_leaves = true); + int32 P); /** From a5561c3cfc370da2c4c99bbe0af0f537b1ce6578 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Tue, 26 Dec 2017 23:52:50 -0500 Subject: [PATCH 3/3] [src,scripts] Simplify model combination: do simple average over last n models (#2067) --- egs/wsj/s5/steps/info/chain_dir_info.pl | 3 + egs/wsj/s5/steps/info/nnet3_dir_info.pl | 3 + .../nnet3/train/chain_objf/acoustic_model.py | 17 +- egs/wsj/s5/steps/libs/nnet3/train/common.py | 14 +- .../nnet3/train/frame_level_objf/common.py | 11 +- egs/wsj/s5/steps/nnet3/chain/train.py | 2 +- egs/wsj/s5/steps/nnet3/train_dnn.py | 2 +- egs/wsj/s5/steps/nnet3/train_raw_dnn.py | 2 +- egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 2 +- egs/wsj/s5/steps/nnet3/train_rnn.py | 2 +- src/chainbin/nnet3-chain-combine.cc | 132 +++- src/nnet3/Makefile | 4 +- src/nnet3/nnet-chain-combine.cc | 610 ------------------ src/nnet3/nnet-chain-combine.h | 205 ------ src/nnet3/nnet-combine.cc | 606 ----------------- src/nnet3/nnet-combine.h | 251 ------- src/nnet3bin/nnet3-combine.cc | 137 +++- 17 files changed, 246 insertions(+), 1757 deletions(-) delete mode 100644 src/nnet3/nnet-chain-combine.cc delete mode 100644 src/nnet3/nnet-chain-combine.h delete mode 100644 src/nnet3/nnet-combine.cc delete mode 100644 src/nnet3/nnet-combine.h diff --git a/egs/wsj/s5/steps/info/chain_dir_info.pl b/egs/wsj/s5/steps/info/chain_dir_info.pl index b0adb7e498c..d0fac5292c6 100755 --- a/egs/wsj/s5/steps/info/chain_dir_info.pl +++ b/egs/wsj/s5/steps/info/chain_dir_info.pl @@ -137,6 +137,9 @@ sub get_combine_info { if (m/Combining nnets, objective function changed from (\S+) to (\S+)/) { close(F); return sprintf(" combine=%.3f->%.3f", $1, $2); + } elsif (m/Combining (\S+) nnets, objective function changed from (\S+) to (\S+)/) { + close(F); + return sprintf(" combine=%.3f->%.3f (over %d)", $2, $3, $1); } } } diff --git a/egs/wsj/s5/steps/info/nnet3_dir_info.pl b/egs/wsj/s5/steps/info/nnet3_dir_info.pl index 06d07a63755..4b0e774a592 100755 --- a/egs/wsj/s5/steps/info/nnet3_dir_info.pl +++ b/egs/wsj/s5/steps/info/nnet3_dir_info.pl @@ -137,6 +137,9 @@ sub get_combine_info { if (m/Combining nnets, objective function changed from (\S+) to (\S+)/) { close(F); return sprintf(" combine=%.2f->%.2f", $1, $2); + } elsif (m/Combining (\S+) nnets, objective function changed from (\S+) to (\S+)/) { + close(F); + return sprintf(" combine=%.2f->%.2f (over %d)", $2, $3, $1); } } } diff --git a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py index 02a3b4c75d5..5b640510ea1 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py @@ -492,7 +492,7 @@ def compute_progress(dir, iter, run_opts): def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_str, egs_dir, leaky_hmm_coefficient, l2_regularize, xent_regularize, run_opts, - sum_to_one_penalty=0.0): + max_objective_evaluations=30): """ Function to do model combination In the nnet3 setup, the logic @@ -505,9 +505,6 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st models_to_combine.add(num_iters) - # TODO: if it turns out the sum-to-one-penalty code is not useful, - # remove support for it. - for iter in sorted(models_to_combine): model_file = '{0}/{1}.mdl'.format(dir, iter) if os.path.exists(model_file): @@ -528,12 +525,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st common_lib.execute_command( """{command} {combine_queue_opt} {dir}/log/combine.log \ - nnet3-chain-combine --num-iters={opt_iters} \ + nnet3-chain-combine \ + --max-objective-evaluations={max_objective_evaluations} \ --l2-regularize={l2} --leaky-hmm-coefficient={leaky} \ - --separate-weights-per-component={separate_weights} \ - --enforce-sum-to-one={hard_enforce} \ - --sum-to-one-penalty={penalty} \ - --enforce-positive-weights=true \ --verbose=3 {dir}/den.fst {raw_models} \ "ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/combine.cegs ark:- | \ nnet3-chain-merge-egs --minibatch-size={num_chunk_per_mb} \ @@ -542,12 +536,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st {dir}/final.mdl""".format( command=run_opts.command, combine_queue_opt=run_opts.combine_queue_opt, - opt_iters=(20 if sum_to_one_penalty <= 0 else 80), - separate_weights=(sum_to_one_penalty > 0), + max_objective_evaluations=max_objective_evaluations, l2=l2_regularize, leaky=leaky_hmm_coefficient, dir=dir, raw_models=" ".join(raw_model_strings), - hard_enforce=(sum_to_one_penalty <= 0), - penalty=sum_to_one_penalty, num_chunk_per_mb=num_chunk_per_minibatch_str, num_iters=num_iters, egs_dir=egs_dir)) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index b3b443ceb4c..2b4fdd92cec 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -852,6 +852,16 @@ def __init__(self, the final model combination stage. These models will themselves be averages of iteration-number ranges""") + self.parser.add_argument("--trainer.optimization.max-objective-evaluations", + "--trainer.max-objective-evaluations", + type=int, dest='max_objective_evaluations', + default=30, + help="""The maximum number of objective + evaluations in order to figure out the + best number of models to combine. It helps to + speedup if the number of models provided to the + model combination binary is quite large (e.g. + several hundred).""") self.parser.add_argument("--trainer.optimization.do-final-combination", dest='do_final_combination', type=str, action=common_lib.StrToBoolAction, @@ -861,9 +871,7 @@ def __init__(self, last-numbered model as the final.mdl).""") self.parser.add_argument("--trainer.optimization.combine-sum-to-one-penalty", type=float, dest='combine_sum_to_one_penalty', default=0.0, - help="""If > 0, activates 'soft' enforcement of the - sum-to-one penalty in combination (may be helpful - if using dropout). E.g. 1.0e-03.""") + help="""This option is deprecated and does nothing.""") self.parser.add_argument("--trainer.optimization.momentum", type=float, dest='momentum', default=0.0, help="""Momentum used in update computation. diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index f8a69c5ad84..46eec2e3b87 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -452,7 +452,7 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, minibatch_size_str, run_opts, chunk_width=None, get_raw_nnet_from_am=True, - sum_to_one_penalty=0.0, + max_objective_evaluations=30, use_multitask_egs=False, compute_per_dim_accuracy=False): """ Function to do model combination @@ -501,10 +501,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, use_multitask_egs=use_multitask_egs) common_lib.execute_command( """{command} {combine_queue_opt} {dir}/log/combine.log \ - nnet3-combine --num-iters=80 \ - --enforce-sum-to-one={hard_enforce} \ - --sum-to-one-penalty={penalty} \ - --enforce-positive-weights=true \ + nnet3-combine \ + --max-objective-evaluations={max_objective_evaluations} \ --verbose=3 {raw_models} \ "ark,bg:nnet3-copy-egs {multitask_egs_opts} \ {egs_rspecifier} ark:- | \ @@ -513,9 +511,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, """.format(command=run_opts.command, combine_queue_opt=run_opts.combine_queue_opt, dir=dir, raw_models=" ".join(raw_model_strings), + max_objective_evaluations=max_objective_evaluations, egs_rspecifier=egs_rspecifier, - hard_enforce=(sum_to_one_penalty <= 0), - penalty=sum_to_one_penalty, mbsize=minibatch_size_str, out_model=out_model, multitask_egs_opts=multitask_egs_opts)) diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index d23c379e104..b62f5510e3c 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -554,7 +554,7 @@ def train(args, run_opts): l2_regularize=args.l2_regularize, xent_regularize=args.xent_regularize, run_opts=run_opts, - sum_to_one_penalty=args.combine_sum_to_one_penalty) + max_objective_evaluations=args.max_objective_evaluations) else: logger.info("Copying the last-numbered model to final.mdl") common_lib.force_symlink("{0}.mdl".format(num_iters), diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index 87a1fd5afed..073ad3e7d7a 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -364,7 +364,7 @@ def train(args, run_opts): models_to_combine=models_to_combine, egs_dir=egs_dir, minibatch_size_str=args.minibatch_size, run_opts=run_opts, - sum_to_one_penalty=args.combine_sum_to_one_penalty) + max_objective_evaluations=args.max_objective_evaluations) if args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index 38396f0b4e7..2d092ceebc7 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -398,7 +398,7 @@ def train(args, run_opts): models_to_combine=models_to_combine, egs_dir=egs_dir, minibatch_size_str=args.minibatch_size, run_opts=run_opts, get_raw_nnet_from_am=False, - sum_to_one_penalty=args.combine_sum_to_one_penalty, + max_objective_evaluations=args.max_objective_evaluations, use_multitask_egs=use_multitask_egs) else: common_lib.force_symlink("{0}.raw".format(num_iters), diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index c9ffcf7ff2c..b51632e7d2c 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -475,7 +475,7 @@ def train(args, run_opts): run_opts=run_opts, chunk_width=args.chunk_width, get_raw_nnet_from_am=False, compute_per_dim_accuracy=args.compute_per_dim_accuracy, - sum_to_one_penalty=args.combine_sum_to_one_penalty) + max_objective_evaluations=args.max_objective_evaluations) else: common_lib.force_symlink("{0}.raw".format(num_iters), "{0}/final.raw".format(args.dir)) diff --git a/egs/wsj/s5/steps/nnet3/train_rnn.py b/egs/wsj/s5/steps/nnet3/train_rnn.py index e6f81b03c3b..005e751cae0 100755 --- a/egs/wsj/s5/steps/nnet3/train_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_rnn.py @@ -451,7 +451,7 @@ def train(args, run_opts): run_opts=run_opts, minibatch_size_str=args.num_chunk_per_minibatch, chunk_width=args.chunk_width, - sum_to_one_penalty=args.combine_sum_to_one_penalty, + max_objective_evaluations=args.max_objective_evaluations, compute_per_dim_accuracy=args.compute_per_dim_accuracy) if args.stage <= num_iters + 1: diff --git a/src/chainbin/nnet3-chain-combine.cc b/src/chainbin/nnet3-chain-combine.cc index 3c44e6b904c..ca0428553c1 100644 --- a/src/chainbin/nnet3-chain-combine.cc +++ b/src/chainbin/nnet3-chain-combine.cc @@ -1,6 +1,7 @@ // chainbin/nnet3-chain-combine.cc // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang // See ../../COPYING for clarification regarding multiple authors // @@ -19,7 +20,65 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "nnet3/nnet-chain-combine.h" +#include "nnet3/nnet-utils.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-chain-diagnostics.h" + + +namespace kaldi { +namespace nnet3 { + +// Computes and returns the objective function for the examples in 'egs' given +// the model in 'nnet'. If either of batchnorm/dropout test modes is true, we +// make a copy of 'nnet', set test modes on that and evaluate its objective. +// Note: the object that prob_computer->nnet_ refers to should be 'nnet'. +double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, + const std::vector &egs, const Nnet &nnet, + const chain::ChainTrainingOptions &chain_config, + const fst::StdVectorFst &den_fst, + NnetChainComputeProb *prob_computer) { + if (batchnorm_test_mode || dropout_test_mode) { + Nnet nnet_copy(nnet); + if (batchnorm_test_mode) + SetBatchnormTestMode(true, &nnet_copy); + if (dropout_test_mode) + SetDropoutTestMode(true, &nnet_copy); + NnetComputeProbOptions compute_prob_opts; + NnetChainComputeProb prob_computer_test(compute_prob_opts, chain_config, + den_fst, nnet_copy); + return ComputeObjf(false, false, egs, nnet_copy, + chain_config, den_fst, &prob_computer_test); + } else { + prob_computer->Reset(); + std::vector::const_iterator iter = egs.begin(), + end = egs.end(); + for (; iter != end; ++iter) + prob_computer->Compute(*iter); + const ChainObjectiveInfo *objf_info = + prob_computer->GetObjective("output"); + if (objf_info == NULL) + KALDI_ERR << "Error getting objective info (unsuitable egs?)"; + KALDI_ASSERT(objf_info->tot_weight > 0.0); + // inf/nan tot_objf->return -inf objective. + double tot_objf = objf_info->tot_like + objf_info->tot_l2_term; + if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0)) + return -std::numeric_limits::infinity(); + // we prefer to deal with normalized objective functions. + return tot_objf / objf_info->tot_weight; + } +} + +// Updates moving average over num_models nnets, given the average over +// previous (num_models - 1) nnets, and the new nnet. +void UpdateNnetMovingAverage(int32 num_models, + const Nnet &nnet, Nnet *moving_average_nnet) { + KALDI_ASSERT(NumParameters(nnet) == NumParameters(*moving_average_nnet)); + ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet); + AddNnet(nnet, 1.0 / num_models, moving_average_nnet); +} + +} +} int main(int argc, char *argv[]) { @@ -30,9 +89,11 @@ int main(int argc, char *argv[]) { typedef kaldi::int64 int64; const char *usage = - "Using a subset of training or held-out nnet3+chain examples, compute an\n" - "optimal combination of anumber of nnet3 neural nets by maximizing the\n" - "'chain' objective function. See documentation of options for more details.\n" + "Using a subset of training or held-out nnet3+chain examples, compute\n" + "the average over the first n nnet models where we maximize the\n" + "'chain' objective function for n. Note that the order of models has\n" + "been reversed before feeding into this binary. So we are actually\n" + "combining last n models.\n" "Inputs and outputs are nnet3 raw nnets.\n" "\n" "Usage: nnet3-chain-combine [options] ... \n" @@ -41,23 +102,28 @@ int main(int argc, char *argv[]) { " nnet3-combine den.fst 35.raw 36.raw 37.raw 38.raw ark:valid.cegs final.raw\n"; bool binary_write = true; + int32 max_objective_evaluations = 30; bool batchnorm_test_mode = false, dropout_test_mode = true; std::string use_gpu = "yes"; - NnetCombineConfig combine_config; chain::ChainTrainingOptions chain_config; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("max-objective-evaluations", &max_objective_evaluations, "The " + "maximum number of objective evaluations in order to figure " + "out the best number of models to combine. It helps to speedup " + "if the number of models provided to this binary is quite " + "large (e.g. several hundred)."); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); po.Register("batchnorm-test-mode", &batchnorm_test_mode, - "If true, set test-mode to true on any BatchNormComponents."); + "If true, set test-mode to true on any BatchNormComponents " + "while evaluating objectives."); po.Register("dropout-test-mode", &dropout_test_mode, "If true, set test-mode to true on any DropoutComponents and " - "DropoutMaskComponents."); + "DropoutMaskComponents while evaluating objectives."); - combine_config.Register(&po); chain_config.Register(&po); po.Read(argc, argv); @@ -83,11 +149,10 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); - - if (batchnorm_test_mode) - SetBatchnormTestMode(true, &nnet); - if (dropout_test_mode) - SetDropoutTestMode(true, &nnet); + Nnet moving_average_nnet(nnet), best_nnet(nnet); + NnetComputeProbOptions compute_prob_opts; + NnetChainComputeProb prob_computer(compute_prob_opts, chain_config, + den_fst, moving_average_nnet); std::vector egs; egs.reserve(10000); // reserve a lot of space to minimize the chance of @@ -102,29 +167,50 @@ int main(int argc, char *argv[]) { KALDI_ASSERT(!egs.empty()); } + // first evaluates the objective using the last model. + int32 best_num_to_combine = 1; + double + init_objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, chain_config, den_fst, &prob_computer), + best_objf = init_objf; + KALDI_LOG << "objective function using the last model is " << init_objf; int32 num_nnets = po.NumArgs() - 3; - NnetChainCombiner combiner(combine_config, chain_config, - num_nnets, egs, den_fst, nnet); - + // then each time before we re-evaluate the objective function, we will add + // num_to_add models to the moving average. + int32 num_to_add = (num_nnets + max_objective_evaluations - 1) / + max_objective_evaluations; for (int32 n = 1; n < num_nnets; n++) { std::string this_nnet_rxfilename = po.GetArg(n + 2); ReadKaldiObject(this_nnet_rxfilename, &nnet); - combiner.AcceptNnet(nnet); + // updates the moving average + UpdateNnetMovingAverage(n + 1, nnet, &moving_average_nnet); + // evaluates the objective everytime after adding num_to_add model or + // all the models to the moving average. + if ((n - 1) % num_to_add == num_to_add - 1 || n == num_nnets - 1) { + double objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, chain_config, den_fst, &prob_computer); + KALDI_LOG << "Combining last " << n + 1 + << " models, objective function is " << objf; + if (objf > best_objf) { + best_objf = objf; + best_nnet = moving_average_nnet; + best_num_to_combine = n + 1; + } + } } + KALDI_LOG << "Combining " << best_num_to_combine + << " nnets, objective function changed from " << init_objf + << " to " << best_objf; - combiner.Combine(); - - nnet = combiner.GetNnet(); if (HasBatchnorm(nnet)) - RecomputeStats(egs, chain_config, den_fst, &nnet); + RecomputeStats(egs, chain_config, den_fst, &best_nnet); #if HAVE_CUDA==1 CuDevice::Instantiate().PrintProfile(); #endif - WriteKaldiObject(nnet, nnet_wxfilename, binary_write); - + WriteKaldiObject(best_nnet, nnet_wxfilename, binary_write); KALDI_LOG << "Finished combining neural nets, wrote model to " << nnet_wxfilename; } catch(const std::exception &e) { diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 3236c52d60f..8ddba56b0e0 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -22,9 +22,9 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-example.o nnet-nnet.o nnet-compile-utils.o \ nnet-utils.o nnet-compute.o nnet-test-utils.o nnet-analyze.o \ nnet-example-utils.o nnet-training.o \ - nnet-diagnostics.o nnet-combine.o nnet-am-decodable-simple.o \ + nnet-diagnostics.o nnet-am-decodable-simple.o \ nnet-optimize-utils.o nnet-chain-example.o \ - nnet-chain-training.o nnet-chain-diagnostics.o nnet-chain-combine.o \ + nnet-chain-training.o nnet-chain-diagnostics.o \ discriminative-supervision.o nnet-discriminative-example.o \ nnet-discriminative-diagnostics.o \ discriminative-training.o nnet-discriminative-training.o \ diff --git a/src/nnet3/nnet-chain-combine.cc b/src/nnet3/nnet-chain-combine.cc deleted file mode 100644 index c93858fb06e..00000000000 --- a/src/nnet3/nnet-chain-combine.cc +++ /dev/null @@ -1,610 +0,0 @@ -// nnet3/nnet-chain-combine.cc - -// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "nnet3/nnet-chain-combine.h" -#include "nnet3/nnet-utils.h" - -namespace kaldi { -namespace nnet3 { - -NnetChainCombiner::NnetChainCombiner(const NnetCombineConfig &combine_config, - const chain::ChainTrainingOptions &chain_config, - int32 num_nnets, - const std::vector &egs, - const fst::StdVectorFst &den_fst, - const Nnet &first_nnet): - combine_config_(combine_config), - chain_config_(chain_config), - egs_(egs), - den_fst_(den_fst), - nnet_(first_nnet), - num_real_input_nnets_(num_nnets), - nnet_params_(std::min(num_nnets, combine_config_.max_effective_inputs), - NumParameters(first_nnet)), - tot_input_weighting_(nnet_params_.NumRows()) { - - if (combine_config_.sum_to_one_penalty != 0.0 && - combine_config_.enforce_sum_to_one) { - KALDI_WARN << "--sum-to-one-penalty=" << combine_config_.sum_to_one_penalty - << " is nonzero, so setting --enforce-sum-to-one=false."; - combine_config_.enforce_sum_to_one = false; - } - SubVector first_params(nnet_params_, 0); - VectorizeNnet(nnet_, &first_params); - tot_input_weighting_(0) += 1.0; - num_nnets_provided_ = 1; - ComputeUpdatableComponentDims(); - NnetComputeProbOptions compute_prob_opts; - compute_prob_opts.compute_deriv = true; - prob_computer_ = new NnetChainComputeProb(compute_prob_opts, chain_config_, den_fst_, nnet_); -} - -void NnetChainCombiner::ComputeUpdatableComponentDims(){ - updatable_component_dims_.clear(); - for (int32 c = 0; c < nnet_.NumComponents(); c++) { - Component *comp = nnet_.GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - updatable_component_dims_.push_back(uc->NumParameters()); - } - } -} - -void NnetChainCombiner::AcceptNnet(const Nnet &nnet) { - KALDI_ASSERT(num_nnets_provided_ < num_real_input_nnets_ && - "You called AcceptNnet too many times."); - int32 num_effective_nnets = nnet_params_.NumRows(); - if (num_effective_nnets == num_real_input_nnets_) { - SubVector this_params(nnet_params_, num_nnets_provided_); - VectorizeNnet(nnet, &this_params); - tot_input_weighting_(num_nnets_provided_) += 1.0; - } else { - // this_index is a kind of warped index, mapping the range - // 0 ... num_real_inputs_nnets_ - 1 onto the range - // 0 ... num_effective_nnets - 1. View the index as falling in - // between two integer indexes and determining weighting factors. - // we could view this as triangular bins. - BaseFloat this_index = num_nnets_provided_ * (num_effective_nnets - 1) - / static_cast(num_real_input_nnets_ - 1); - int32 lower_index = std::floor(this_index), - upper_index = lower_index + 1; - BaseFloat remaining_part = this_index - lower_index, - lower_weight = 1.0 - remaining_part, - upper_weight = remaining_part; - KALDI_ASSERT(lower_index >= 0 && upper_index <= num_effective_nnets && - lower_weight >= 0.0 && upper_weight >= 0.0 && - lower_weight <= 1.0 && upper_weight <= 1.0); - Vector vec(nnet_params_.NumCols(), kUndefined); - VectorizeNnet(nnet, &vec); - nnet_params_.Row(lower_index).AddVec(lower_weight, vec); - tot_input_weighting_(lower_index) += lower_weight; - if (upper_index == num_effective_nnets) { - KALDI_ASSERT(upper_weight < 0.1); - } else { - nnet_params_.Row(upper_index).AddVec(upper_weight, vec); - tot_input_weighting_(upper_index) += upper_weight; - } - } - num_nnets_provided_++; -} - -void NnetChainCombiner::FinishPreprocessingInput() { - KALDI_ASSERT(num_nnets_provided_ == num_real_input_nnets_ && - "You did not call AcceptInput() enough times."); - int32 num_effective_nnets = nnet_params_.NumRows(); - for (int32 i = 0; i < num_effective_nnets; i++) { - BaseFloat tot_weight = tot_input_weighting_(i); - KALDI_ASSERT(tot_weight > 0.0); // Or would be a coding error. - // Rescale so this row is like a weighted average instead of - // a weighted sum. - if (tot_weight != 1.0) - nnet_params_.Row(i).Scale(1.0 / tot_weight); - } -} - -void NnetChainCombiner::Combine() { - FinishPreprocessingInput(); - - if (!SelfTestDerivatives()) { - KALDI_LOG << "Self-testing model derivatives since parameter-derivatives " - "self-test failed."; - SelfTestModelDerivatives(); - } - - int32 dim = ParameterDim(); - LbfgsOptions lbfgs_options; - lbfgs_options.minimize = false; // We're maximizing. - lbfgs_options.m = dim; // Store the same number of vectors as the dimension - // itself, so this is BFGS. - lbfgs_options.first_step_impr = combine_config_.initial_impr; - - Vector params(dim), deriv(dim); - double objf, initial_objf; - GetInitialParameters(¶ms); - - - OptimizeLbfgs lbfgs(params, lbfgs_options); - - for (int32 i = 0; i < combine_config_.num_iters; i++) { - params.CopyFromVec(lbfgs.GetProposedValue()); - objf = ComputeObjfAndDerivFromParameters(params, &deriv); - KALDI_VLOG(2) << "Iteration " << i << " params = " << params - << ", objf = " << objf << ", deriv = " << deriv; - if (i == 0) initial_objf = objf; - lbfgs.DoStep(objf, deriv); - } - - if (!combine_config_.sum_to_one_penalty) { - KALDI_LOG << "Combining nnets, objective function changed from " - << initial_objf << " to " << objf; - } else { - Vector weights(WeightDim()); - GetWeights(params, &weights); - bool print_weights = true; - double penalty = GetSumToOnePenalty(weights, NULL, print_weights); - // note: initial_objf has no penalty term because it summed exactly - // to one. - KALDI_LOG << "Combining nnets, objective function changed from " - << initial_objf << " to " << objf << " = " - << (objf - penalty) << " + " << penalty; - } - - - // must recompute nnet_ if "params" is not exactly equal to the - // final params that LB - Vector final_params(dim); - final_params.CopyFromVec(lbfgs.GetValue(&objf)); - if (!params.ApproxEqual(final_params, 0.0)) { - // the following call makes sure that nnet_ corresponds to the parameters - // in "params". - ComputeObjfAndDerivFromParameters(final_params, &deriv); - } - PrintParams(final_params); -} - - -void NnetChainCombiner::PrintParams(const VectorBase ¶ms) const { - Vector weights(WeightDim()), normalized_weights(WeightDim()); - GetWeights(params, &weights); - GetNormalizedWeights(weights, &normalized_weights); - int32 num_models = nnet_params_.NumRows(), - num_uc = NumUpdatableComponents(); - - if (combine_config_.separate_weights_per_component) { - std::vector updatable_component_names; - for (int32 c = 0; c < nnet_.NumComponents(); c++) { - const Component *comp = nnet_.GetComponent(c); - if (comp->Properties() & kUpdatableComponent) - updatable_component_names.push_back(nnet_.GetComponentName(c)); - } - KALDI_ASSERT(static_cast(updatable_component_names.size()) == - NumUpdatableComponents()); - for (int32 uc = 0; uc < num_uc; uc++) { - std::ostringstream os; - os.width(20); - os << std::left << updatable_component_names[uc] << ": "; - os.width(9); - os.precision(4); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + uc; - os << " " << std::left << normalized_weights(index); - } - KALDI_LOG << "Weights for " << os.str(); - } - } else { - int32 c = 0; // arbitrarily chosen; they'll all be the same. - std::ostringstream os; - os.width(9); - os.precision(4); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - os << " " << std::left << normalized_weights(index); - } - KALDI_LOG << "Model weights are: " << os.str(); - } - int32 num_effective_nnets = nnet_params_.NumRows(); - if (num_effective_nnets != num_real_input_nnets_) - KALDI_LOG << "Above, only " << num_effective_nnets << " weights were " - "printed due to the the --num-effective-nnets option; " - "there were " << num_real_input_nnets_ << " actual input nnets. " - "Each weight corresponds to a weighted average over a range of " - "nnets in the sequence (with triangular bins)"; -} - -bool NnetChainCombiner::SelfTestDerivatives() { - int32 num_tests = 2; // more properly, this is the number of dimensions in a - // single test. - double delta = 0.001; - int32 dim = ParameterDim(); - - Vector params(dim), deriv(dim); - Vector predicted_changes(num_tests), - observed_changes(num_tests); - - GetInitialParameters(¶ms); - double initial_objf = ComputeObjfAndDerivFromParameters(params, - &deriv); - for (int32 i = 0; i < num_tests; i++) { - Vector new_deriv(dim), offset(dim), new_params(params); - offset.SetRandn(); - new_params.AddVec(delta, offset); - double new_objf = ComputeObjfAndDerivFromParameters(new_params, - &new_deriv); - // for predicted changes, interpolate old and new derivs. - predicted_changes(i) = - 0.5 * VecVec(new_params, deriv) - 0.5 * VecVec(params, deriv) + - 0.5 * VecVec(new_params, new_deriv) - 0.5 * VecVec(params, new_deriv); - observed_changes(i) = new_objf - initial_objf; - } - double threshold = 0.1; - KALDI_LOG << "predicted_changes = " << predicted_changes; - KALDI_LOG << "observed_changes = " << observed_changes; - if (!ApproxEqual(predicted_changes, observed_changes, threshold)) { - KALDI_WARN << "Derivatives self-test failed."; - return false; - } else { - return true; - } -} - - -void NnetChainCombiner::SelfTestModelDerivatives() { - int32 num_tests = 3; // more properly, this is the number of dimensions in a - // single test. - int32 dim = ParameterDim(); - - Vector params(dim), deriv(dim); - Vector predicted_changes(num_tests), - observed_changes(num_tests); - - GetInitialParameters(¶ms); - Vector weights(WeightDim()), normalized_weights(WeightDim()); - Vector nnet_params(NnetParameterDim(), kUndefined), - nnet_deriv(NnetParameterDim(), kUndefined); - GetWeights(params, &weights); - GetNormalizedWeights(weights, &normalized_weights); - GetNnetParameters(normalized_weights, &nnet_params); - - double initial_objf = ComputeObjfAndDerivFromNnet(nnet_params, - &nnet_deriv); - - double delta = 0.002 * std::sqrt(VecVec(nnet_params, nnet_params) / - NnetParameterDim()); - - - for (int32 i = 0; i < num_tests; i++) { - Vector new_nnet_deriv(NnetParameterDim()), - offset(NnetParameterDim()), new_nnet_params(nnet_params); - offset.SetRandn(); - new_nnet_params.AddVec(delta, offset); - double new_objf = ComputeObjfAndDerivFromNnet(new_nnet_params, - &new_nnet_deriv); - // for predicted changes, interpolate old and new derivs. - predicted_changes(i) = - 0.5 * VecVec(new_nnet_params, nnet_deriv) - - 0.5 * VecVec(nnet_params, nnet_deriv) + - 0.5 * VecVec(new_nnet_params, new_nnet_deriv) - - 0.5 * VecVec(nnet_params, new_nnet_deriv); - observed_changes(i) = new_objf - initial_objf; - } - double threshold = 0.1; - KALDI_LOG << "model-derivatives: predicted_changes = " << predicted_changes; - KALDI_LOG << "model-derivatives: observed_changes = " << observed_changes; - if (!ApproxEqual(predicted_changes, observed_changes, threshold)) - KALDI_WARN << "Model derivatives self-test failed."; -} - - - - -int32 NnetChainCombiner::ParameterDim() const { - if (combine_config_.separate_weights_per_component) - return NumUpdatableComponents() * nnet_params_.NumRows(); - else - return nnet_params_.NumRows(); -} - - -void NnetChainCombiner::GetInitialParameters(VectorBase *params) const { - KALDI_ASSERT(params->Dim() == ParameterDim()); - params->Set(1.0 / nnet_params_.NumRows()); - if (combine_config_.enforce_positive_weights) { - // we enforce positive weights by treating the params as the log of the - // actual weight. - params->ApplyLog(); - } -} - -void NnetChainCombiner::GetWeights(const VectorBase ¶ms, - VectorBase *weights) const { - KALDI_ASSERT(weights->Dim() == WeightDim()); - if (combine_config_.separate_weights_per_component) { - weights->CopyFromVec(params); - } else { - int32 nc = NumUpdatableComponents(); - // have one parameter per row of nnet_params_, and need to repeat - // the weight for the different components. - for (int32 n = 0; n < nnet_params_.NumRows(); n++) { - for (int32 c = 0; c < nc; c++) - (*weights)(n * nc + c) = params(n); - } - } - // we enforce positive weights by having the weights be the exponential of the - // corresponding parameters. - if (combine_config_.enforce_positive_weights) - weights->ApplyExp(); -} - - -void NnetChainCombiner::GetParamsDeriv(const VectorBase &weights, - const VectorBase &weights_deriv, - VectorBase *param_deriv) { - KALDI_ASSERT(weights.Dim() == WeightDim() && - param_deriv->Dim() == ParameterDim()); - Vector preexp_weights_deriv(weights_deriv); - if (combine_config_.enforce_positive_weights) { - // to enforce positive weights we first compute weights (call these - // preexp_weights) and then take exponential. Note, d/dx exp(x) = exp(x). - // So the derivative w.r.t. the preexp_weights equals the derivative - // w.r.t. the weights, times the weights. - preexp_weights_deriv.MulElements(weights); - } - if (combine_config_.separate_weights_per_component) { - param_deriv->CopyFromVec(preexp_weights_deriv); - } else { - int32 nc = NumUpdatableComponents(); - param_deriv->SetZero(); - for (int32 n = 0; n < nnet_params_.NumRows(); n++) - for (int32 c = 0; c < nc; c++) - (*param_deriv)(n) += preexp_weights_deriv(n * nc + c); - } -} - -double NnetChainCombiner::GetSumToOnePenalty( - const VectorBase &weights, - VectorBase *weights_penalty_deriv, - bool print_weights) const { - - KALDI_ASSERT(combine_config_.sum_to_one_penalty >= 0.0); - double penalty = combine_config_.sum_to_one_penalty; - if (penalty == 0.0) { - weights_penalty_deriv->SetZero(); - return 0.0; - } - double ans = 0.0; - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - Vector tot_weights(num_uc); - std::ostringstream tot_weight_info; - for (int32 c = 0; c < num_uc; c++) { - double this_total_weight = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - double this_weight = weights(index); - this_total_weight += this_weight; - } - tot_weights(c) = this_total_weight; - // this_total_weight_deriv is the derivative of the penalty - // term w.r.t. this component's total weight. - double this_total_weight_deriv; - if (combine_config_.enforce_positive_weights) { - // if combine_config_.enforce_positive_weights is true, then we choose to - // formulate the penalty in a slightly different way.. this solves the - // problem that with the formulation in the 'else' below, if for some - // reason the total weight is << 1.0, the deriv w.r.t. the actual - // parameters gets tiny [because weight = exp(params)]. - double log_total = log(this_total_weight); - ans += -0.5 * penalty * log_total * log_total; - double log_total_deriv = -1.0 * penalty * log_total; - this_total_weight_deriv = log_total_deriv / this_total_weight; - } else { - ans += -0.5 * penalty * - (this_total_weight - 1.0) * (this_total_weight - 1.0); - this_total_weight_deriv = penalty * (1.0 - this_total_weight); - - } - if (weights_penalty_deriv != NULL) { - KALDI_ASSERT(weights.Dim() == weights_penalty_deriv->Dim()); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*weights_penalty_deriv)(index) = this_total_weight_deriv; - } - } - } - if (print_weights) { - Vector tot_weights_float(tot_weights); - KALDI_LOG << "Total weights per component: " - << PrintVectorPerUpdatableComponent(nnet_, - tot_weights_float); - } - return ans; -} - -void NnetChainCombiner::GetNnetParameters(const Vector &weights, - VectorBase *nnet_params) const { - KALDI_ASSERT(nnet_params->Dim() == nnet_params_.NumCols()); - nnet_params->SetZero(); - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 m = 0; m < num_models; m++) { - const SubVector src_params(nnet_params_, m); - int32 dim_offset = 0; - for (int32 c = 0; c < num_uc; c++) { - int32 index = m * num_uc + c; - BaseFloat weight = weights(index); - int32 dim = updatable_component_dims_[c]; - const SubVector src_component_params(src_params, dim_offset, - dim); - SubVector dest_component_params(*nnet_params, dim_offset, dim); - dest_component_params.AddVec(weight, src_component_params); - dim_offset += dim; - } - KALDI_ASSERT(dim_offset == nnet_params_.NumCols()); - } -} - -// compare GetNnetParameters. -void NnetChainCombiner::GetWeightsDeriv( - const VectorBase &nnet_params_deriv, - VectorBase *weights_deriv) { - KALDI_ASSERT(nnet_params_deriv.Dim() == nnet_params_.NumCols() && - weights_deriv->Dim() == WeightDim()); - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 m = 0; m < num_models; m++) { - const SubVector src_params(nnet_params_, m); - int32 dim_offset = 0; - for (int32 c = 0; c < num_uc; c++) { - int32 index = m * num_uc + c; - int32 dim = updatable_component_dims_[c]; - const SubVector src_component_params(src_params, dim_offset, - dim); - const SubVector component_params_deriv(nnet_params_deriv, - dim_offset, dim); - (*weights_deriv)(index) = VecVec(src_component_params, - component_params_deriv); - dim_offset += dim; - } - KALDI_ASSERT(dim_offset == nnet_params_.NumCols()); - } -} - -double NnetChainCombiner::ComputeObjfAndDerivFromNnet( - VectorBase &nnet_params, - VectorBase *nnet_params_deriv) { - BaseFloat sum = nnet_params.Sum(); - // inf/nan parameters->return -inf objective. - if (!(sum == sum && sum - sum == 0)) - return -std::numeric_limits::infinity(); - // Set nnet to have these params. - UnVectorizeNnet(nnet_params, &nnet_); - - prob_computer_->Reset(); - std::vector::const_iterator iter = egs_.begin(), - end = egs_.end(); - for (; iter != end; ++iter) - prob_computer_->Compute(*iter); - const ChainObjectiveInfo *objf_info = - prob_computer_->GetObjective("output"); - if (objf_info == NULL) - KALDI_ERR << "Error getting objective info (unsuitable egs?)"; - KALDI_ASSERT(objf_info->tot_weight > 0.0); - const Nnet &deriv = prob_computer_->GetDeriv(); - VectorizeNnet(deriv, nnet_params_deriv); - // we prefer to deal with normalized objective functions. - nnet_params_deriv->Scale(1.0 / objf_info->tot_weight); - return (objf_info->tot_like + objf_info->tot_l2_term) / objf_info->tot_weight; -} - - -double NnetChainCombiner::ComputeObjfAndDerivFromParameters( - VectorBase ¶ms, - VectorBase *params_deriv) { - Vector weights(WeightDim()), normalized_weights(WeightDim()), - weights_sum_to_one_penalty_deriv(WeightDim()), - normalized_weights_deriv(WeightDim()), weights_deriv(WeightDim()); - Vector - nnet_params(NnetParameterDim(), kUndefined), - nnet_params_deriv(NnetParameterDim(), kUndefined); - GetWeights(params, &weights); - double ans = GetSumToOnePenalty(weights, &weights_sum_to_one_penalty_deriv); - GetNormalizedWeights(weights, &normalized_weights); - GetNnetParameters(normalized_weights, &nnet_params); - ans += ComputeObjfAndDerivFromNnet(nnet_params, &nnet_params_deriv); - if (ans != ans || ans - ans != 0) // NaN or inf - return ans; // No point computing derivative - GetWeightsDeriv(nnet_params_deriv, &normalized_weights_deriv); - GetUnnormalizedWeightsDeriv(weights, normalized_weights_deriv, - &weights_deriv); - weights_deriv.AddVec(1.0, weights_sum_to_one_penalty_deriv); - GetParamsDeriv(weights, weights_deriv, params_deriv); - return ans; -} - - -// enforces the constraint that the weights for each component must sum to one, -// if necessary. -void NnetChainCombiner::GetNormalizedWeights( - const VectorBase &unnorm_weights, - VectorBase *norm_weights) const { - if (!combine_config_.enforce_sum_to_one) { - norm_weights->CopyFromVec(unnorm_weights); - return; - } - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 c = 0; c < num_uc; c++) { - double sum = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - sum += unnorm_weights(index); - } - double inv_sum = 1.0 / sum; // if it's NaN then it's OK, we'll get NaN - // weights and eventually -inf objective. - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*norm_weights)(index) = unnorm_weights(index) * inv_sum; - } - } -} - -void NnetChainCombiner::GetUnnormalizedWeightsDeriv( - const VectorBase &unnorm_weights, - const VectorBase &norm_weights_deriv, - VectorBase *unnorm_weights_deriv) { - if (!combine_config_.enforce_sum_to_one) { - unnorm_weights_deriv->CopyFromVec(norm_weights_deriv); - return; - } - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 c = 0; c < num_uc; c++) { - double sum = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - sum += unnorm_weights(index); - } - double inv_sum = 1.0 / sum; - double inv_sum_deriv = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - // in the forward direction, we'd do: - // (*norm_weights)(index) = unnorm_weights(index) * inv_sum; - (*unnorm_weights_deriv)(index) = inv_sum * norm_weights_deriv(index); - inv_sum_deriv += norm_weights_deriv(index) * unnorm_weights(index); - } - // note: d/dx (1/x) = -1/x^2 - double sum_deriv = -1.0 * inv_sum_deriv * inv_sum * inv_sum; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*unnorm_weights_deriv)(index) += sum_deriv; - } - } -} - - - - -} // namespace nnet3 -} // namespace kaldi diff --git a/src/nnet3/nnet-chain-combine.h b/src/nnet3/nnet-chain-combine.h deleted file mode 100644 index 3aeb3882650..00000000000 --- a/src/nnet3/nnet-chain-combine.h +++ /dev/null @@ -1,205 +0,0 @@ -// nnet3/nnet-chain-combine.h - -// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_NNET3_NNET_CHAIN_COMBINE_H_ -#define KALDI_NNET3_NNET_CHAIN_COMBINE_H_ - -#include "nnet3/nnet-utils.h" -#include "nnet3/nnet-compute.h" -#include "util/parse-options.h" -#include "itf/options-itf.h" -#include "nnet3/nnet-combine.h" -#include "nnet3/nnet-chain-diagnostics.h" - - -namespace kaldi { -namespace nnet3 { - -// we re-use NnetCombineConfig from nnet-combine.h - -/* - You should use this class as follows: - - Call the constructor, giving it the egs and the first nnet. - - Call AcceptNnet to provide all the other nnets. (the nnets will - be stored in a matrix in CPU memory, to avoid filing up GPU memory). - - Call Combine() - - Get the resultant nnet with GetNnet(). - */ -class NnetChainCombiner { - public: - /// Caution: this object retains a const reference to the "egs", so don't - /// delete them until it goes out of scope. - NnetChainCombiner(const NnetCombineConfig &nnet_config, - const chain::ChainTrainingOptions &chain_config, - int32 num_nnets, - const std::vector &egs, - const fst::StdVectorFst &den_fst, - const Nnet &first_nnet); - - /// You should call this function num_nnets-1 times after calling - /// the constructor, to provide the remaining nnets. - void AcceptNnet(const Nnet &nnet); - - void Combine(); - - const Nnet &GetNnet() const { return nnet_; } - - ~NnetChainCombiner() { delete prob_computer_; } - private: - NnetCombineConfig combine_config_; - const chain::ChainTrainingOptions &chain_config_; - - const std::vector &egs_; - - const fst::StdVectorFst &den_fst_; - - Nnet nnet_; // The current neural network. - - NnetChainComputeProb *prob_computer_; - - std::vector updatable_component_dims_; // dimension of each updatable - // component. - - int32 num_real_input_nnets_; // number of actual nnet inputs. - - int32 num_nnets_provided_; // keeps track of the number of calls to AcceptNnet(). - - // nnet_params_ are the parameters of the "effective input" - // neural nets; they will often be the same as the real inputs, - // but if num_real_input_nnets_ > config_.num_effective_nnets, they - // will be weighted combinations. - Matrix nnet_params_; - - // This vector has the same dimension as nnet_params_.NumRows(), - // and helps us normalize so each row of nnet_params corresponds to - // a weighted average of its inputs (will be all ones if - // config_.max_effective_inputs >= the number of nnets provided). - Vector tot_input_weighting_; - - // returns the parameter dimension, i.e. the dimension of the parameters that - // we are optimizing. This depends on the config, the number of updatable - // components and nnet_params_.NumRows(); it will never exceed the number of - // updatable components times nnet_params_.NumRows(). - int32 ParameterDim() const; - - int32 NumUpdatableComponents() const { - return updatable_component_dims_.size(); - } - // returns the weight dimension. - int32 WeightDim() const { - return nnet_params_.NumRows() * NumUpdatableComponents(); - } - - int32 NnetParameterDim() const { return nnet_params_.NumCols(); } - - // Computes the initial parameters. The parameters are the underlying thing - // that we optimize; their dimension equals ParameterDim(). They are not the same - // thing as the nnet parameters. - void GetInitialParameters(VectorBase *params) const; - - // Tests that derivatives are accurate. Prints warning and returns false if not. - bool SelfTestDerivatives(); - - // Tests that model derivatives are accurate. Just prints warning if not. - void SelfTestModelDerivatives(); - - - // prints the parameters via logging statements. - void PrintParams(const VectorBase ¶ms) const; - - // This function computes the objective function (and its derivative, if the objective - // function is finite) at the given value of the parameters (the parameters we're optimizing, - // i.e. the combination weights; not the nnet parameters. This function calls most of the - // functions below. - double ComputeObjfAndDerivFromParameters( - VectorBase ¶ms, - VectorBase *params_deriv); - - - // Computes the weights from the parameters in a config-dependent way. The - // weight dimension is always (the number of updatable components times - // nnet_params_.NumRows()). - void GetWeights(const VectorBase ¶ms, - VectorBase *weights) const; - - // Given the raw weights: if config_.enforce_sum_to_one, then compute weights - // with sum-to-one constrint per component included; else just copy input to - // output. - void GetNormalizedWeights(const VectorBase &unnorm_weights, - VectorBase *norm_weights) const; - - // if config_.sum_to_one_penalty is 0.0, returns 0.0 and sets - // weights_penalty_deriv to 0.0; else it computes, for each - // updatable component u the total weight w_u, returns the value - // -0.5 * config_.sum_to_one_penalty * sum_u (w_u - 1.0)^2; - // and sets 'weights_penalty_deriv' to the derivative w.r.t. - // the result. - // Note: config_.sum_to_one_penalty is exclusive with - // config_.enforce_sum_to_one, so there is really no distinction between - // normalized and unnormalized weights here (since normalization would be a - // no-op). - double GetSumToOnePenalty(const VectorBase &weights, - VectorBase *weights_penalty_deriv, - bool print_weights = false) const; - - - // Computes the nnet-parameter vector from the normalized weights and - // nnet_params_, as a vector. (See the functions Vectorize() and - // UnVectorize() for how they relate to the nnet's components' parameters). - void GetNnetParameters(const Vector &normalized_weights, - VectorBase *nnet_params) const; - - // This function computes the objective function (and its derivative, if the objective - // function is finite) at the given value of nnet parameters. This involves the - // nnet computation. - double ComputeObjfAndDerivFromNnet(VectorBase &nnet_params, - VectorBase *nnet_params_deriv); - - // Given an objective-function derivative with respect to the nnet parameters, - // computes the derivative with respect to the (normalized) weights. - void GetWeightsDeriv(const VectorBase &nnet_params_deriv, - VectorBase *normalized_weights_deriv); - - - // Computes the derivative w.r.t. the unnormalized weights, by propagating - // through the normalization operation. - // If config_.enforce_sum_to_one == false, just copies norm_weights_deriv to - // unnorm_weights_deriv. - void GetUnnormalizedWeightsDeriv(const VectorBase &unnorm_weights, - const VectorBase &norm_weights_deriv, - VectorBase *unnorm_weights_deriv); - - - // Given a derivative w.r.t. the weights, outputs a derivative w.r.t. - // the params - void GetParamsDeriv(const VectorBase &weights, - const VectorBase &weight_deriv, - VectorBase *param_deriv); - - void ComputeUpdatableComponentDims(); - void FinishPreprocessingInput(); - -}; - - - -} // namespace nnet3 -} // namespace kaldi - -#endif diff --git a/src/nnet3/nnet-combine.cc b/src/nnet3/nnet-combine.cc deleted file mode 100644 index fa570ec96a3..00000000000 --- a/src/nnet3/nnet-combine.cc +++ /dev/null @@ -1,606 +0,0 @@ -// nnet3/nnet-combine.cc - -// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "nnet3/nnet-combine.h" -#include "nnet3/nnet-utils.h" - -namespace kaldi { -namespace nnet3 { - -NnetCombiner::NnetCombiner(const NnetCombineConfig &config, - int32 num_nnets, - const std::vector &egs, - const Nnet &first_nnet): - config_(config), - egs_(egs), - nnet_(first_nnet), - num_real_input_nnets_(num_nnets), - nnet_params_(std::min(num_nnets, config_.max_effective_inputs), - NumParameters(first_nnet)), - tot_input_weighting_(nnet_params_.NumRows()) { - - if (config_.sum_to_one_penalty != 0.0 && - config_.enforce_sum_to_one) { - KALDI_WARN << "--sum-to-one-penalty=" << config_.sum_to_one_penalty - << " is nonzero, so setting --enforce-sum-to-one=false."; - config_.enforce_sum_to_one = false; - } - SubVector first_params(nnet_params_, 0); - VectorizeNnet(nnet_, &first_params); - tot_input_weighting_(0) += 1.0; - num_nnets_provided_ = 1; - ComputeUpdatableComponentDims(); - NnetComputeProbOptions compute_prob_opts; - compute_prob_opts.compute_deriv = true; - prob_computer_ = new NnetComputeProb(compute_prob_opts, nnet_); -} - -void NnetCombiner::ComputeUpdatableComponentDims(){ - updatable_component_dims_.clear(); - for (int32 c = 0; c < nnet_.NumComponents(); c++) { - Component *comp = nnet_.GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - updatable_component_dims_.push_back(uc->NumParameters()); - } - } -} - -void NnetCombiner::AcceptNnet(const Nnet &nnet) { - KALDI_ASSERT(num_nnets_provided_ < num_real_input_nnets_ && - "You called AcceptNnet too many times."); - int32 num_effective_nnets = nnet_params_.NumRows(); - if (num_effective_nnets == num_real_input_nnets_) { - SubVector this_params(nnet_params_, num_nnets_provided_); - VectorizeNnet(nnet, &this_params); - tot_input_weighting_(num_nnets_provided_) += 1.0; - } else { - // this_index is a kind of warped index, mapping the range - // 0 ... num_real_inputs_nnets_ - 1 onto the range - // 0 ... num_effective_nnets - 1. View the index as falling in - // between two integer indexes and determining weighting factors. - // we could view this as triangular bins. - BaseFloat this_index = num_nnets_provided_ * (num_effective_nnets - 1) - / static_cast(num_real_input_nnets_ - 1); - int32 lower_index = std::floor(this_index), - upper_index = lower_index + 1; - BaseFloat remaining_part = this_index - lower_index, - lower_weight = 1.0 - remaining_part, - upper_weight = remaining_part; - KALDI_ASSERT(lower_index >= 0 && upper_index <= num_effective_nnets && - lower_weight >= 0.0 && upper_weight >= 0.0 && - lower_weight <= 1.0 && upper_weight <= 1.0); - Vector vec(nnet_params_.NumCols(), kUndefined); - VectorizeNnet(nnet, &vec); - nnet_params_.Row(lower_index).AddVec(lower_weight, vec); - tot_input_weighting_(lower_index) += lower_weight; - if (upper_index == num_effective_nnets) { - KALDI_ASSERT(upper_weight < 0.1); - } else { - nnet_params_.Row(upper_index).AddVec(upper_weight, vec); - tot_input_weighting_(upper_index) += upper_weight; - } - } - num_nnets_provided_++; -} - -void NnetCombiner::FinishPreprocessingInput() { - KALDI_ASSERT(num_nnets_provided_ == num_real_input_nnets_ && - "You did not call AcceptInput() enough times."); - int32 num_effective_nnets = nnet_params_.NumRows(); - for (int32 i = 0; i < num_effective_nnets; i++) { - BaseFloat tot_weight = tot_input_weighting_(i); - KALDI_ASSERT(tot_weight > 0.0); // Or would be a coding error. - // Rescale so this row is like a weighted average instead of - // a weighted sum. - if (tot_weight != 1.0) - nnet_params_.Row(i).Scale(1.0 / tot_weight); - } -} - -void NnetCombiner::Combine() { - FinishPreprocessingInput(); - - if (!SelfTestDerivatives()) { - KALDI_LOG << "Self-testing model derivatives since parameter-derivatives " - "self-test failed."; - SelfTestModelDerivatives(); - } - - int32 dim = ParameterDim(); - LbfgsOptions lbfgs_options; - lbfgs_options.minimize = false; // We're maximizing. - lbfgs_options.m = dim; // Store the same number of vectors as the dimension - // itself, so this is BFGS. - lbfgs_options.first_step_impr = config_.initial_impr; - - Vector params(dim), deriv(dim); - double objf, initial_objf; - GetInitialParameters(¶ms); - - - OptimizeLbfgs lbfgs(params, lbfgs_options); - - for (int32 i = 0; i < config_.num_iters; i++) { - params.CopyFromVec(lbfgs.GetProposedValue()); - objf = ComputeObjfAndDerivFromParameters(params, &deriv); - KALDI_VLOG(2) << "Iteration " << i << " params = " << params - << ", objf = " << objf << ", deriv = " << deriv; - if (i == 0) initial_objf = objf; - lbfgs.DoStep(objf, deriv); - } - - if (!config_.sum_to_one_penalty) { - KALDI_LOG << "Combining nnets, objective function changed from " - << initial_objf << " to " << objf; - } else { - Vector weights(WeightDim()); - GetWeights(params, &weights); - bool print_weights = true; - double penalty = GetSumToOnePenalty(weights, NULL, print_weights); - // note: initial_objf has no penalty term because it summed exactly - // to one. - KALDI_LOG << "Combining nnets, objective function changed from " - << initial_objf << " to " << objf << " = " - << (objf - penalty) << " + " << penalty; - } - - - // must recompute nnet_ if "params" is not exactly equal to the - // final params that LB - Vector final_params(dim); - final_params.CopyFromVec(lbfgs.GetValue(&objf)); - if (!params.ApproxEqual(final_params, 0.0)) { - // the following call makes sure that nnet_ corresponds to the parameters - // in "params". - ComputeObjfAndDerivFromParameters(final_params, &deriv); - } - PrintParams(final_params); - -} - -void NnetCombiner::PrintParams(const VectorBase ¶ms) const { - Vector weights(WeightDim()), normalized_weights(WeightDim()); - GetWeights(params, &weights); - GetNormalizedWeights(weights, &normalized_weights); - int32 num_models = nnet_params_.NumRows(), - num_uc = NumUpdatableComponents(); - - if (config_.separate_weights_per_component) { - std::vector updatable_component_names; - for (int32 c = 0; c < nnet_.NumComponents(); c++) { - const Component *comp = nnet_.GetComponent(c); - if (comp->Properties() & kUpdatableComponent) - updatable_component_names.push_back(nnet_.GetComponentName(c)); - } - KALDI_ASSERT(static_cast(updatable_component_names.size()) == - NumUpdatableComponents()); - for (int32 uc = 0; uc < num_uc; uc++) { - std::ostringstream os; - os.width(20); - os << std::left << updatable_component_names[uc] << ": "; - os.width(9); - os.precision(4); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + uc; - os << " " << std::left << normalized_weights(index); - } - KALDI_LOG << "Weights for " << os.str(); - } - } else { - int32 c = 0; // arbitrarily chosen; they'll all be the same. - std::ostringstream os; - os.width(9); - os.precision(4); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - os << " " << std::left << normalized_weights(index); - } - KALDI_LOG << "Model weights are: " << os.str(); - } - int32 num_effective_nnets = nnet_params_.NumRows(); - if (num_effective_nnets != num_real_input_nnets_) - KALDI_LOG << "Above, only " << num_effective_nnets << " weights were " - "printed due to the the --max-effective-inputs option; " - "there were " << num_real_input_nnets_ << " actual input nnets. " - "Each weight corresponds to a weighted average over a range of " - "nnets in the sequence (with triangular bins)"; -} - -bool NnetCombiner::SelfTestDerivatives() { - int32 num_tests = 2; // more properly, this is the number of dimensions in a - // single test. - double delta = 0.001; - int32 dim = ParameterDim(); - - Vector params(dim), deriv(dim); - Vector predicted_changes(num_tests), - observed_changes(num_tests); - - GetInitialParameters(¶ms); - double initial_objf = ComputeObjfAndDerivFromParameters(params, - &deriv); - for (int32 i = 0; i < num_tests; i++) { - Vector new_deriv(dim), offset(dim), new_params(params); - offset.SetRandn(); - new_params.AddVec(delta, offset); - double new_objf = ComputeObjfAndDerivFromParameters(new_params, - &new_deriv); - // for predicted changes, interpolate old and new derivs. - predicted_changes(i) = - 0.5 * VecVec(new_params, deriv) - 0.5 * VecVec(params, deriv) + - 0.5 * VecVec(new_params, new_deriv) - 0.5 * VecVec(params, new_deriv); - observed_changes(i) = new_objf - initial_objf; - } - double threshold = 0.1; - KALDI_LOG << "predicted_changes = " << predicted_changes; - KALDI_LOG << "observed_changes = " << observed_changes; - if (!ApproxEqual(predicted_changes, observed_changes, threshold)) { - KALDI_WARN << "Derivatives self-test failed."; - return false; - } else { - return true; - } -} - - -void NnetCombiner::SelfTestModelDerivatives() { - int32 num_tests = 3; // more properly, this is the number of dimensions in a - // single test. - int32 dim = ParameterDim(); - - Vector params(dim), deriv(dim); - Vector predicted_changes(num_tests), - observed_changes(num_tests); - - GetInitialParameters(¶ms); - Vector weights(WeightDim()), normalized_weights(WeightDim()); - Vector nnet_params(NnetParameterDim(), kUndefined), - nnet_deriv(NnetParameterDim(), kUndefined); - GetWeights(params, &weights); - GetNormalizedWeights(weights, &normalized_weights); - GetNnetParameters(normalized_weights, &nnet_params); - - double initial_objf = ComputeObjfAndDerivFromNnet(nnet_params, - &nnet_deriv); - - double delta = 0.002 * std::sqrt(VecVec(nnet_params, nnet_params) / - NnetParameterDim()); - - - for (int32 i = 0; i < num_tests; i++) { - Vector new_nnet_deriv(NnetParameterDim()), - offset(NnetParameterDim()), new_nnet_params(nnet_params); - offset.SetRandn(); - new_nnet_params.AddVec(delta, offset); - double new_objf = ComputeObjfAndDerivFromNnet(new_nnet_params, - &new_nnet_deriv); - // for predicted changes, interpolate old and new derivs. - predicted_changes(i) = - 0.5 * VecVec(new_nnet_params, nnet_deriv) - - 0.5 * VecVec(nnet_params, nnet_deriv) + - 0.5 * VecVec(new_nnet_params, new_nnet_deriv) - - 0.5 * VecVec(nnet_params, new_nnet_deriv); - observed_changes(i) = new_objf - initial_objf; - } - double threshold = 0.1; - KALDI_LOG << "model-derivatives: predicted_changes = " << predicted_changes; - KALDI_LOG << "model-derivatives: observed_changes = " << observed_changes; - if (!ApproxEqual(predicted_changes, observed_changes, threshold)) - KALDI_WARN << "Model derivatives self-test failed."; -} - - - - -int32 NnetCombiner::ParameterDim() const { - if (config_.separate_weights_per_component) - return NumUpdatableComponents() * nnet_params_.NumRows(); - else - return nnet_params_.NumRows(); -} - - -void NnetCombiner::GetInitialParameters(VectorBase *params) const { - KALDI_ASSERT(params->Dim() == ParameterDim()); - params->Set(1.0 / nnet_params_.NumRows()); - if (config_.enforce_positive_weights) { - // we enforce positive weights by treating the params as the log of the - // actual weight. - params->ApplyLog(); - } -} - -void NnetCombiner::GetWeights(const VectorBase ¶ms, - VectorBase *weights) const { - KALDI_ASSERT(weights->Dim() == WeightDim()); - if (config_.separate_weights_per_component) { - weights->CopyFromVec(params); - } else { - int32 nc = NumUpdatableComponents(); - // have one parameter per row of nnet_params_, and need to repeat - // the weight for the different components. - for (int32 n = 0; n < nnet_params_.NumRows(); n++) { - for (int32 c = 0; c < nc; c++) - (*weights)(n * nc + c) = params(n); - } - } - // we enforce positive weights by having the weights be the exponential of the - // corresponding parameters. - if (config_.enforce_positive_weights) - weights->ApplyExp(); -} - - -void NnetCombiner::GetParamsDeriv(const VectorBase &weights, - const VectorBase &weights_deriv, - VectorBase *param_deriv) { - KALDI_ASSERT(weights.Dim() == WeightDim() && - param_deriv->Dim() == ParameterDim()); - Vector preexp_weights_deriv(weights_deriv); - if (config_.enforce_positive_weights) { - // to enforce positive weights we first compute weights (call these - // preexp_weights) and then take exponential. Note, d/dx exp(x) = exp(x). - // So the derivative w.r.t. the preexp_weights equals the derivative - // w.r.t. the weights, times the weights. - preexp_weights_deriv.MulElements(weights); - } - if (config_.separate_weights_per_component) { - param_deriv->CopyFromVec(preexp_weights_deriv); - } else { - int32 nc = NumUpdatableComponents(); - param_deriv->SetZero(); - for (int32 n = 0; n < nnet_params_.NumRows(); n++) - for (int32 c = 0; c < nc; c++) - (*param_deriv)(n) += preexp_weights_deriv(n * nc + c); - } -} - - -double NnetCombiner::GetSumToOnePenalty( - const VectorBase &weights, - VectorBase *weights_penalty_deriv, - bool print_weights) const { - - KALDI_ASSERT(config_.sum_to_one_penalty >= 0.0); - double penalty = config_.sum_to_one_penalty; - if (penalty == 0.0) { - weights_penalty_deriv->SetZero(); - return 0.0; - } - double ans = 0.0; - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - Vector tot_weights(num_uc); - std::ostringstream tot_weight_info; - for (int32 c = 0; c < num_uc; c++) { - double this_total_weight = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - double this_weight = weights(index); - this_total_weight += this_weight; - } - tot_weights(c) = this_total_weight; - // this_total_weight_deriv is the derivative of the penalty - // term w.r.t. this component's total weight. - double this_total_weight_deriv; - if (config_.enforce_positive_weights) { - // if config_.enforce_positive_weights is true, then we choose to - // formulate the penalty in a slightly different way.. this solves the - // problem that with the formulation in the 'else' below, if for some - // reason the total weight is << 1.0, the deriv w.r.t. the actual - // parameters gets tiny [because weight = exp(params)]. - double log_total = log(this_total_weight); - ans += -0.5 * penalty * log_total * log_total; - double log_total_deriv = -1.0 * penalty * log_total; - this_total_weight_deriv = log_total_deriv / this_total_weight; - } else { - ans += -0.5 * penalty * - (this_total_weight - 1.0) * (this_total_weight - 1.0); - this_total_weight_deriv = penalty * (1.0 - this_total_weight); - - } - if (weights_penalty_deriv != NULL) { - KALDI_ASSERT(weights.Dim() == weights_penalty_deriv->Dim()); - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*weights_penalty_deriv)(index) = this_total_weight_deriv; - } - } - } - if (print_weights) { - Vector tot_weights_float(tot_weights); - KALDI_LOG << "Total weights per component: " - << PrintVectorPerUpdatableComponent(nnet_, - tot_weights_float); - } - return ans; -} - - -void NnetCombiner::GetNnetParameters(const Vector &weights, - VectorBase *nnet_params) const { - KALDI_ASSERT(nnet_params->Dim() == nnet_params_.NumCols()); - nnet_params->SetZero(); - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 m = 0; m < num_models; m++) { - const SubVector src_params(nnet_params_, m); - int32 dim_offset = 0; - for (int32 c = 0; c < num_uc; c++) { - int32 index = m * num_uc + c; - BaseFloat weight = weights(index); - int32 dim = updatable_component_dims_[c]; - const SubVector src_component_params(src_params, dim_offset, - dim); - SubVector dest_component_params(*nnet_params, dim_offset, dim); - dest_component_params.AddVec(weight, src_component_params); - dim_offset += dim; - } - KALDI_ASSERT(dim_offset == nnet_params_.NumCols()); - } -} - -// compare GetNnetParameters. -void NnetCombiner::GetWeightsDeriv( - const VectorBase &nnet_params_deriv, - VectorBase *weights_deriv) { - KALDI_ASSERT(nnet_params_deriv.Dim() == nnet_params_.NumCols() && - weights_deriv->Dim() == WeightDim()); - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 m = 0; m < num_models; m++) { - const SubVector src_params(nnet_params_, m); - int32 dim_offset = 0; - for (int32 c = 0; c < num_uc; c++) { - int32 index = m * num_uc + c; - int32 dim = updatable_component_dims_[c]; - const SubVector src_component_params(src_params, dim_offset, - dim); - const SubVector component_params_deriv(nnet_params_deriv, - dim_offset, dim); - (*weights_deriv)(index) = VecVec(src_component_params, - component_params_deriv); - dim_offset += dim; - } - KALDI_ASSERT(dim_offset == nnet_params_.NumCols()); - } -} - -double NnetCombiner::ComputeObjfAndDerivFromNnet( - VectorBase &nnet_params, - VectorBase *nnet_params_deriv) { - BaseFloat sum = nnet_params.Sum(); - // inf/nan parameters->return -inf objective. - if (!(sum == sum && sum - sum == 0)) - return -std::numeric_limits::infinity(); - // Set nnet to have these params. - UnVectorizeNnet(nnet_params, &nnet_); - - prob_computer_->Reset(); - std::vector::const_iterator iter = egs_.begin(), - end = egs_.end(); - for (; iter != end; ++iter) - prob_computer_->Compute(*iter); - double tot_weights, - tot_objf = prob_computer_->GetTotalObjective(&tot_weights); - KALDI_ASSERT(tot_weights > 0.0); - const Nnet &deriv = prob_computer_->GetDeriv(); - VectorizeNnet(deriv, nnet_params_deriv); - // we prefer to deal with normalized objective functions. - nnet_params_deriv->Scale(1.0 / tot_weights); - return tot_objf / tot_weights; -} - - -double NnetCombiner::ComputeObjfAndDerivFromParameters( - VectorBase ¶ms, - VectorBase *params_deriv) { - Vector weights(WeightDim()), normalized_weights(WeightDim()), - weights_sum_to_one_penalty_deriv(WeightDim()), - normalized_weights_deriv(WeightDim()), weights_deriv(WeightDim()); - Vector - nnet_params(NnetParameterDim(), kUndefined), - nnet_params_deriv(NnetParameterDim(), kUndefined); - GetWeights(params, &weights); - double ans = GetSumToOnePenalty(weights, &weights_sum_to_one_penalty_deriv); - GetNormalizedWeights(weights, &normalized_weights); - GetNnetParameters(normalized_weights, &nnet_params); - ans += ComputeObjfAndDerivFromNnet(nnet_params, &nnet_params_deriv); - if (ans != ans || ans - ans != 0) // NaN or inf - return ans; // No point computing derivative - GetWeightsDeriv(nnet_params_deriv, &normalized_weights_deriv); - GetUnnormalizedWeightsDeriv(weights, normalized_weights_deriv, - &weights_deriv); - weights_deriv.AddVec(1.0, weights_sum_to_one_penalty_deriv); - GetParamsDeriv(weights, weights_deriv, params_deriv); - return ans; -} - - -// enforces the constraint that the weights for each component must sum to one, -// if necessary. -void NnetCombiner::GetNormalizedWeights( - const VectorBase &unnorm_weights, - VectorBase *norm_weights) const { - if (!config_.enforce_sum_to_one) { - norm_weights->CopyFromVec(unnorm_weights); - return; - } - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 c = 0; c < num_uc; c++) { - double sum = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - sum += unnorm_weights(index); - } - double inv_sum = 1.0 / sum; // if it's NaN then it's OK, we'll get NaN - // weights and eventually -inf objective. - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*norm_weights)(index) = unnorm_weights(index) * inv_sum; - } - } -} - -void NnetCombiner::GetUnnormalizedWeightsDeriv( - const VectorBase &unnorm_weights, - const VectorBase &norm_weights_deriv, - VectorBase *unnorm_weights_deriv) { - if (!config_.enforce_sum_to_one) { - unnorm_weights_deriv->CopyFromVec(norm_weights_deriv); - return; - } - int32 num_uc = NumUpdatableComponents(), - num_models = nnet_params_.NumRows(); - for (int32 c = 0; c < num_uc; c++) { - double sum = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - sum += unnorm_weights(index); - } - double inv_sum = 1.0 / sum; - double inv_sum_deriv = 0.0; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - // in the forward direction, we'd do: - // (*norm_weights)(index) = unnorm_weights(index) * inv_sum; - (*unnorm_weights_deriv)(index) = inv_sum * norm_weights_deriv(index); - inv_sum_deriv += norm_weights_deriv(index) * unnorm_weights(index); - } - // note: d/dx (1/x) = -1/x^2 - double sum_deriv = -1.0 * inv_sum_deriv * inv_sum * inv_sum; - for (int32 m = 0; m < num_models; m++) { - int32 index = m * num_uc + c; - (*unnorm_weights_deriv)(index) += sum_deriv; - } - } -} - - - - -} // namespace nnet3 -} // namespace kaldi diff --git a/src/nnet3/nnet-combine.h b/src/nnet3/nnet-combine.h deleted file mode 100644 index 5b60d30b8ed..00000000000 --- a/src/nnet3/nnet-combine.h +++ /dev/null @@ -1,251 +0,0 @@ -// nnet3/nnet-combine.h - -// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_NNET3_NNET_COMBINE_H_ -#define KALDI_NNET3_NNET_COMBINE_H_ - -#include "nnet3/nnet-utils.h" -#include "nnet3/nnet-compute.h" -#include "util/parse-options.h" -#include "itf/options-itf.h" -#include "nnet3/nnet-diagnostics.h" - - -namespace kaldi { -namespace nnet3 { - -/** Configuration class that controls neural net combination, where we combine a - number of neural nets. -*/ -struct NnetCombineConfig { - int32 num_iters; // The dimension of the space we are optimizing in is fairly - // small (equal to the number of components times the number - // of neural nets we were given), so we optimize with BFGS - // (internally the code uses L-BFGS, but we set the the - // number of vectors to be the same as the dimension of the - // space, so it actually is regular BFGS. The num-iters - // corresponds to the number of function evaluations. - - - BaseFloat initial_impr; - int32 max_effective_inputs; - bool test_gradient; - bool enforce_positive_weights; - bool enforce_sum_to_one; - BaseFloat sum_to_one_penalty; - bool separate_weights_per_component; - NnetCombineConfig(): num_iters(60), - initial_impr(0.01), - max_effective_inputs(15), - test_gradient(false), - enforce_positive_weights(false), - enforce_sum_to_one(false), - sum_to_one_penalty(0.0), - separate_weights_per_component(true) { } - - void Register(OptionsItf *po) { - po->Register("num-iters", &num_iters, "Maximum number of function " - "evaluations for BFGS to use when optimizing combination weights"); - po->Register("max-effective-inputs", &max_effective_inputs, "Limits the number of " - "parameters that have to be learn to be equivalent to the number of " - "parameters we'd have to learn if the number of inputs nnets equalled " - "this number. Does this by using averages of nnets at close points " - "in the sequence of inputs, as the actual inputs to the computation."); - po->Register("initial-impr", &initial_impr, "Amount of objective-function change " - "we aim for on the first iteration (controls the initial step size)."); - po->Register("test-gradient", &test_gradient, "If true, activate code that " - "tests the gradient is accurate."); - po->Register("enforce-positive-weights", &enforce_positive_weights, - "If true, enforce that all weights are positive."); - po->Register("enforce-sum-to-one", &enforce_sum_to_one, "If true, enforce that " - "the model weights for each component should sum to one."); - po->Register("sum-to-one-penalty", &sum_to_one_penalty, "If >0, a penalty term " - "on the squared difference between sum(weights) for one component," - " and 1.0. This is like --enforce-sum-to-one, but done in a 'soft' " - "way (e.g. maybe useful with dropout). We suggest small values " - "like 10e-3 (for regular nnets) or 1.0e-04 (for chain models)."); - po->Register("separate-weights-per-component", &separate_weights_per_component, - "If true, have a separate weight for each updatable component in " - "the nnet."); - } -}; - - -/* - You should use this class as follows: - - Call the constructor, giving it the egs and the first nnet. - - Call AcceptNnet to provide all the other nnets. (the nnets will - be stored in a matrix in CPU memory, to avoid filing up GPU memory). - - Call Combine() - - Get the resultant nnet with GetNnet(). - */ -class NnetCombiner { - public: - /// Caution: this object retains a const reference to the "egs", so don't - /// delete them until it goes out of scope. - NnetCombiner(const NnetCombineConfig &config, - int32 num_nnets, - const std::vector &egs, - const Nnet &first_nnet); - /// You should call this function num_nnets-1 times after calling - /// the constructor, to provide the remaining nnets. - void AcceptNnet(const Nnet &nnet); - void Combine(); - const Nnet &GetNnet() const { return nnet_; } - - ~NnetCombiner() { delete prob_computer_; } - private: - NnetCombineConfig config_; - - const std::vector &egs_; - - Nnet nnet_; // The current neural network. - - NnetComputeProb *prob_computer_; - - std::vector updatable_component_dims_; // dimension of each updatable - // component. - - int32 num_real_input_nnets_; // number of actual nnet inputs. - - int32 num_nnets_provided_; // keeps track of the number of calls to AcceptNnet(). - - // nnet_params_ are the parameters of the "effective input" - // neural nets; they will often be the same as the real inputs, - // but if num_real_input_nnets_ > config_.num_effective_nnets, they - // will be weighted combinations. - Matrix nnet_params_; - - // This vector has the same dimension as nnet_params_.NumRows(), - // and helps us normalize so each row of nnet_params corresponds to - // a weighted average of its inputs (will be all ones if - // config_.max_effective_inputs >= the number of nnets provided). - Vector tot_input_weighting_; - - // returns the parameter dimension, i.e. the dimension of the parameters that - // we are optimizing. This depends on the config, the number of updatable - // components and nnet_params_.NumRows(); it will never exceed the number of - // updatable components times nnet_params_.NumRows(). - int32 ParameterDim() const; - - int32 NumUpdatableComponents() const { - return updatable_component_dims_.size(); - } - // returns the weight dimension. - int32 WeightDim() const { - return nnet_params_.NumRows() * NumUpdatableComponents(); - } - - int32 NnetParameterDim() const { return nnet_params_.NumCols(); } - - // Computes the initial parameters. The parameters are the underlying thing - // that we optimize; their dimension equals ParameterDim(). They are not the same - // thing as the nnet parameters. - void GetInitialParameters(VectorBase *params) const; - - // Tests that derivatives are accurate. Prints warning and returns false if not. - bool SelfTestDerivatives(); - - // Tests that model derivatives are accurate. Just prints warning if not. - void SelfTestModelDerivatives(); - - - // prints the parameters via logging statements. - void PrintParams(const VectorBase ¶ms) const; - - // This function computes the objective function (and its derivative, if the objective - // function is finite) at the given value of the parameters (the parameters we're optimizing, - // i.e. the combination weights; not the nnet parameters. This function calls most of the - // functions below. - double ComputeObjfAndDerivFromParameters( - VectorBase ¶ms, - VectorBase *params_deriv); - - - // Computes the weights from the parameters in a config-dependent way. The - // weight dimension is always (the number of updatable components times - // nnet_params_.NumRows()). - void GetWeights(const VectorBase ¶ms, - VectorBase *weights) const; - - // Given the raw weights: if config_.enforce_sum_to_one, then compute weights - // with sum-to-one constrint per component included; else just copy input to - // output. - void GetNormalizedWeights(const VectorBase &unnorm_weights, - VectorBase *norm_weights) const; - - // if config_.sum_to_one_penalty is 0.0, returns 0.0 and sets - // weights_penalty_deriv to 0.0; else it computes, for each - // updatable component u the total weight w_u, returns the value - // -0.5 * config_.sum_to_one_penalty * sum_u (w_u - 1.0)^2; - // and sets 'weights_penalty_deriv' to the derivative w.r.t. - // the result. - // Note: config_.sum_to_one_penalty is exclusive with - // config_.enforce_sum_to_one, so there is really no distinction between - // normalized and unnormalized weights here (since normalization would be a - // no-op). - double GetSumToOnePenalty(const VectorBase &weights, - VectorBase *weights_penalty_deriv, - bool print_weights = false) const; - - - // Computes the nnet-parameter vector from the normalized weights and - // nnet_params_, as a vector. (See the functions Vectorize() and - // UnVectorize() for how they relate to the nnet's components' parameters). - void GetNnetParameters(const Vector &normalized_weights, - VectorBase *nnet_params) const; - - // This function computes the objective function (and its derivative, if the objective - // function is finite) at the given value of nnet parameters. This involves the - // nnet computation. - double ComputeObjfAndDerivFromNnet(VectorBase &nnet_params, - VectorBase *nnet_params_deriv); - - // Given an objective-function derivative with respect to the nnet parameters, - // computes the derivative with respect to the (normalized) weights. - void GetWeightsDeriv(const VectorBase &nnet_params_deriv, - VectorBase *normalized_weights_deriv); - - - // Computes the derivative w.r.t. the unnormalized weights, by propagating - // through the normalization operation. - // If config_.enforce_sum_to_one == false, just copies norm_weights_deriv to - // unnorm_weights_deriv. - void GetUnnormalizedWeightsDeriv(const VectorBase &unnorm_weights, - const VectorBase &norm_weights_deriv, - VectorBase *unnorm_weights_deriv); - - - // Given a derivative w.r.t. the weights, outputs a derivative w.r.t. - // the params - void GetParamsDeriv(const VectorBase &weights, - const VectorBase &weight_deriv, - VectorBase *param_deriv); - - void ComputeUpdatableComponentDims(); - void FinishPreprocessingInput(); - -}; - - - -} // namespace nnet3 -} // namespace kaldi - -#endif diff --git a/src/nnet3bin/nnet3-combine.cc b/src/nnet3bin/nnet3-combine.cc index 128a9642ec4..4bcf4cdfb6d 100644 --- a/src/nnet3bin/nnet3-combine.cc +++ b/src/nnet3bin/nnet3-combine.cc @@ -1,6 +1,7 @@ // nnet3bin/nnet3-combine.cc // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang // See ../../COPYING for clarification regarding multiple authors // @@ -19,8 +20,58 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "nnet3/nnet-combine.h" +#include "nnet3/nnet-utils.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-diagnostics.h" + + +namespace kaldi { +namespace nnet3 { + +// Computes and returns the objective function for the examples in 'egs' given +// the model in 'nnet'. If either of batchnorm/dropout test modes is true, we +// make a copy of 'nnet', set test modes on that and evaluate its objective. +// Note: the object that prob_computer->nnet_ refers to should be 'nnet'. +double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, + const std::vector &egs, const Nnet &nnet, + NnetComputeProb *prob_computer) { + if (batchnorm_test_mode || dropout_test_mode) { + Nnet nnet_copy(nnet); + if (batchnorm_test_mode) + SetBatchnormTestMode(true, &nnet_copy); + if (dropout_test_mode) + SetDropoutTestMode(true, &nnet_copy); + NnetComputeProbOptions compute_prob_opts; + NnetComputeProb prob_computer_test(compute_prob_opts, nnet_copy); + return ComputeObjf(false, false, egs, nnet_copy, &prob_computer_test); + } else { + prob_computer->Reset(); + std::vector::const_iterator iter = egs.begin(), + end = egs.end(); + for (; iter != end; ++iter) + prob_computer->Compute(*iter); + double tot_weights, + tot_objf = prob_computer->GetTotalObjective(&tot_weights); + KALDI_ASSERT(tot_weights > 0.0); + // inf/nan tot_objf->return -inf objective. + if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0)) + return -std::numeric_limits::infinity(); + // we prefer to deal with normalized objective functions. + return tot_objf / tot_weights; + } +} +// Updates moving average over num_models nnets, given the average over +// previous (num_models - 1) nnets, and the new nnet. +void UpdateNnetMovingAverage(int32 num_models, + const Nnet &nnet, Nnet *moving_average_nnet) { + KALDI_ASSERT(NumParameters(nnet) == NumParameters(*moving_average_nnet)); + ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet); + AddNnet(nnet, 1.0 / num_models, moving_average_nnet); +} + +} +} int main(int argc, char *argv[]) { try { @@ -30,9 +81,11 @@ int main(int argc, char *argv[]) { typedef kaldi::int64 int64; const char *usage = - "Using a subset of training or held-out examples, compute an optimal combination of a\n" - "number of nnet3 neural nets by maximizing the objective function. See documentation of\n" - "options for more details. Inputs and outputs are 'raw' nnets.\n" + "Using a subset of training or held-out examples, compute the average\n" + "over the first n nnet3 models where we maxize the objective function\n" + "for n. Note that the order of models has been reversed before\n" + "being fed into this binary. So we are actually combining last n models.\n" + "Inputs and outputs are 'raw' nnets.\n" "\n" "Usage: nnet3-combine [options] ... \n" "\n" @@ -40,23 +93,27 @@ int main(int argc, char *argv[]) { " nnet3-combine 1.1.raw 1.2.raw 1.3.raw ark:valid.egs 2.raw\n"; bool binary_write = true; + int32 max_objective_evaluations = 30; bool batchnorm_test_mode = false, dropout_test_mode = true; std::string use_gpu = "yes"; - NnetCombineConfig combine_config; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("max-objective-evaluations", &max_objective_evaluations, "The " + "maximum number of objective evaluations in order to figure " + "out the best number of models to combine. It helps to speedup " + "if the number of models provided to this binary is quite " + "large (e.g. several hundred)."); po.Register("batchnorm-test-mode", &batchnorm_test_mode, - "If true, set test-mode to true on any BatchNormComponents."); + "If true, set test-mode to true on any BatchNormComponents " + "while evaluating objectives."); po.Register("dropout-test-mode", &dropout_test_mode, "If true, set test-mode to true on any DropoutComponents and " - "DropoutMaskComponents."); + "DropoutMaskComponents while evaluating objectives."); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); - combine_config.Register(&po); - po.Read(argc, argv); if (po.NumArgs() < 3) { @@ -75,11 +132,9 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(nnet_rxfilename, &nnet); - - if (batchnorm_test_mode) - SetBatchnormTestMode(true, &nnet); - if (dropout_test_mode) - SetDropoutTestMode(true, &nnet); + Nnet moving_average_nnet(nnet), best_nnet(nnet); + NnetComputeProbOptions compute_prob_opts; + NnetComputeProb prob_computer(compute_prob_opts, moving_average_nnet); std::vector egs; egs.reserve(10000); // reserve a lot of space to minimize the chance of @@ -94,31 +149,49 @@ int main(int argc, char *argv[]) { KALDI_ASSERT(!egs.empty()); } + // first evaluates the objective using the last model. + int32 best_num_to_combine = 1; + double + init_objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, &prob_computer), + best_objf = init_objf; + KALDI_LOG << "objective function using the last model is " << init_objf; int32 num_nnets = po.NumArgs() - 2; - if (num_nnets > 1 || !combine_config.enforce_sum_to_one) { - NnetCombiner combiner(combine_config, num_nnets, egs, nnet); - - for (int32 n = 1; n < num_nnets; n++) { - ReadKaldiObject(po.GetArg(1 + n), &nnet); - combiner.AcceptNnet(nnet); + // then each time before we re-evaluate the objective function, we will add + // num_to_add models to the moving average. + int32 num_to_add = (num_nnets + max_objective_evaluations - 1) / + max_objective_evaluations; + for (int32 n = 1; n < num_nnets; n++) { + ReadKaldiObject(po.GetArg(1 + n), &nnet); + // updates the moving average + UpdateNnetMovingAverage(n + 1, nnet, &moving_average_nnet); + // evaluates the objective everytime after adding num_to_add model or + // all the models to the moving average. + if ((n - 1) % num_to_add == num_to_add - 1 || n == num_nnets - 1) { + double objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, &prob_computer); + KALDI_LOG << "Combining last " << n + 1 + << " models, objective function is " << objf; + if (objf > best_objf) { + best_objf = objf; + best_nnet = moving_average_nnet; + best_num_to_combine = n + 1; + } } - combiner.Combine(); + } + KALDI_LOG << "Combining " << best_num_to_combine + << " nnets, objective function changed from " << init_objf + << " to " << best_objf; + + if (HasBatchnorm(nnet)) + RecomputeStats(egs, &best_nnet); #if HAVE_CUDA==1 CuDevice::Instantiate().PrintProfile(); #endif - nnet = combiner.GetNnet(); - if (HasBatchnorm(nnet)) - RecomputeStats(egs, &nnet); - WriteKaldiObject(nnet, nnet_wxfilename, binary_write); - } else { - KALDI_LOG << "Copying the single input model directly to the output, " - << "without any combination."; - if (HasBatchnorm(nnet)) - RecomputeStats(egs, &nnet); - WriteKaldiObject(nnet, nnet_wxfilename, binary_write); - } + + WriteKaldiObject(best_nnet, nnet_wxfilename, binary_write); KALDI_LOG << "Finished combining neural nets, wrote model to " << nnet_wxfilename; } catch(const std::exception &e) {