diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index f05623c65b3..8bbe76840da 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -2798,11 +2798,19 @@ void NaturalGradientAffineComponent::ZeroStats() { } void NaturalGradientAffineComponent::Scale(BaseFloat scale) { - update_count_ *= scale; - max_change_scale_stats_ *= scale; - active_scaling_count_ *= scale; - linear_params_.Scale(scale); - bias_params_.Scale(scale); + if (scale == 0.0) { + update_count_ = 0.0; + max_change_scale_stats_ = 0.0; + active_scaling_count_ = 0.0; + linear_params_.SetZero(); + bias_params_.SetZero(); + } else { + update_count_ *= scale; + max_change_scale_stats_ *= scale; + active_scaling_count_ *= scale; + linear_params_.Scale(scale); + bias_params_.Scale(scale); + } } void NaturalGradientAffineComponent::Add(BaseFloat alpha, const Component &other_in) {