diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fae1d5a572b..292caae3d2a 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -83,6 +83,7 @@ rtc_library("rnn_vad_lp_residual") { rtc_library("rnn_vad_pitch") { sources = [ + "pitch_info.h", "pitch_search.cc", "pitch_search.h", "pitch_search_internal.cc", @@ -93,7 +94,6 @@ rtc_library("rnn_vad_pitch") { ":rnn_vad_common", "../../../../api:array_view", "../../../../rtc_base:checks", - "../../../../rtc_base:gtest_prod", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", ] diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc index 431c01fab39..f6a4f42fd60 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc @@ -20,7 +20,7 @@ namespace { constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. static_assert(1 << kAutoCorrelationFftOrder > - kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, ""); } // namespace @@ -45,7 +45,7 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; // pitch period. void AutoCorrelationCalculator::ComputeOnPitchBuffer( rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr) { + rtc::ArrayView auto_corr) { RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; @@ -53,7 +53,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( static_assert(kConvolutionLength == kFrameSize20ms12kHz, "Mismatch between pitch buffer size, frame size and maximum " "pitch period."); - static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength, + static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength, "The FFT length is not sufficiently big to avoid cyclic " "convolution errors."); auto tmp = tmp_->GetView(); @@ -67,12 +67,13 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( // Compute the FFT for the sliding frames chunk. The sliding frames are // defined as pitch_buf[i:i+kConvolutionLength] where i in - // [0, kNumLags12kHz). The chunk includes all of them, hence it is - // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength]. + // [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is + // defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength]. std::copy(pitch_buf.begin(), - pitch_buf.begin() + kConvolutionLength + kNumLags12kHz, + pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz, tmp.begin()); - std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f); + std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(), + 0.f); fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); // Convolve in the frequency domain. @@ -83,7 +84,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( // Extract the auto-correlation coefficients. std::copy(tmp.begin() + kConvolutionLength - 1, - tmp.begin() + kConvolutionLength + kNumLags12kHz - 1, + tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1, auto_corr.begin()); } diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h index d58558ca2e9..de7f453bc7a 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h @@ -34,7 +34,7 @@ class AutoCorrelationCalculator { // |auto_corr| indexes are inverted lags. void ComputeOnPitchBuffer( rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr); + rtc::ArrayView auto_corr); private: Pffft fft_; diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 36b366ad1dd..d6deff15560 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -36,13 +36,7 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz; static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); -// Number of (inverted) lags during the initial pitch search phase at 24 kHz. -constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; -// Number of (inverted) lags during the pitch search refinement phase at 24 kHz. -constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1; -static_assert( - kRefineNumLags24kHz > kInitialNumLags24kHz, - "The refinement step must search the pitch in an extended pitch range."); +constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; // 12 kHz analysis. constexpr int kSampleRate12kHz = 12000; @@ -53,8 +47,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2; static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); // The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, -// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|]. -constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; +// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|]. +constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; // 48 kHz constants. constexpr int kMinPitch48kHz = kMinPitch24kHz * 2; diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc index cdbbbc311d5..c207baeec09 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -67,12 +67,13 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures( ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_); // Estimate pitch on the LP-residual and write the normalized pitch period // into the output vector (normalization based on training data stats). - pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); - feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300); + pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); + feature_vector[kFeatureVectorSize - 2] = + 0.01f * (pitch_info_48kHz_.period - 300); // Extract lagged frames (according to the estimated pitch period). - RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz); + RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz); auto lagged_frame = pitch_buf_24kHz_view_.subview( - kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz); + kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz); // Analyze reference and lagged frames checking if silence has been detected // and write the feature vector. return spectral_features_extractor_.CheckSilenceComputeFeatures( diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.h b/modules/audio_processing/agc2/rnn_vad/features_extraction.h index e2c77d2cf8e..ce5cce1857c 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.h +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.h @@ -16,6 +16,7 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/biquad_filter.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" #include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" #include "modules/audio_processing/agc2/rnn_vad/spectral_features.h" @@ -52,7 +53,7 @@ class FeaturesExtractor { PitchEstimator pitch_estimator_; rtc::ArrayView reference_frame_view_; SpectralFeaturesExtractor spectral_features_extractor_; - int pitch_period_48kHz_; + PitchInfo pitch_info_48kHz_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_info.h b/modules/audio_processing/agc2/rnn_vad/pitch_info.h new file mode 100644 index 00000000000..c9fdd182b04 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_info.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ + +namespace webrtc { +namespace rnn_vad { + +// Stores pitch period and gain information. The pitch gain measures the +// strength of the pitch (the higher, the stronger). +struct PitchInfo { + PitchInfo() : period(0), gain(0.f) {} + PitchInfo(int p, float g) : period(p), gain(g) {} + int period; + float gain; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 9d4c5a2d817..85f67377e42 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -21,22 +21,22 @@ namespace rnn_vad { PitchEstimator::PitchEstimator() : pitch_buf_decimated_(kBufSize12kHz), pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), - auto_corr_(kNumLags12kHz), - auto_corr_view_(auto_corr_.data(), kNumLags12kHz) { + auto_corr_(kNumInvertedLags12kHz), + auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) { RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size()); - RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size()); + RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size()); } PitchEstimator::~PitchEstimator() = default; -int PitchEstimator::Estimate( - rtc::ArrayView pitch_buffer) { +PitchInfo PitchEstimator::Estimate( + rtc::ArrayView pitch_buf) { // Perform the initial pitch search at 12 kHz. - Decimate2x(pitch_buffer, pitch_buf_decimated_view_); + Decimate2x(pitch_buf, pitch_buf_decimated_view_); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, auto_corr_view_); - CandidatePitchPeriods pitch_candidates_inverted_lags = - ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_); + CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods( + auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 @@ -44,14 +44,12 @@ int PitchEstimator::Estimate( pitch_candidates_inverted_lags.best *= 2; pitch_candidates_inverted_lags.second_best *= 2; const int pitch_inv_lag_48kHz = - ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags); + RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags); // Look for stronger harmonics to find the final pitch period and its gain. RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); - last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz( - pitch_buffer, - /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz, - last_pitch_48kHz_); - return last_pitch_48kHz_.period; + last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( + pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); + return last_pitch_48kHz_; } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 1e6b9ad7068..74133d07388 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -17,8 +17,8 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" -#include "rtc_base/gtest_prod_util.h" namespace webrtc { namespace rnn_vad { @@ -30,21 +30,17 @@ class PitchEstimator { PitchEstimator(const PitchEstimator&) = delete; PitchEstimator& operator=(const PitchEstimator&) = delete; ~PitchEstimator(); - // Returns the estimated pitch period at 48 kHz. - int Estimate(rtc::ArrayView pitch_buffer); + // Estimates the pitch period and gain. Returns the pitch estimation data for + // 48 kHz. + PitchInfo Estimate(rtc::ArrayView pitch_buf); private: - FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance); - float GetLastPitchStrengthForTesting() const { - return last_pitch_48kHz_.strength; - } - - PitchInfo last_pitch_48kHz_{}; + PitchInfo last_pitch_48kHz_; AutoCorrelationCalculator auto_corr_calculator_; std::vector pitch_buf_decimated_; rtc::ArrayView pitch_buf_decimated_view_; std::vector auto_corr_; - rtc::ArrayView auto_corr_view_; + rtc::ArrayView auto_corr_view_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 8179dbd965b..d782a18d2ff 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -26,88 +26,94 @@ namespace webrtc { namespace rnn_vad { namespace { -float ComputeAutoCorrelation( - int inverted_lag, - rtc::ArrayView pitch_buffer) { - RTC_DCHECK_LT(inverted_lag, kBufSize24kHz); - RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz); - static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); +// Converts a lag to an inverted lag (only for 24kHz). +int GetInvertedLag(int lag) { + RTC_DCHECK_LE(lag, kMaxPitch24kHz); + return kMaxPitch24kHz - lag; +} + +float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, + int inv_lag, + int max_pitch_period) { + RTC_DCHECK_LT(inv_lag, pitch_buf.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + RTC_DCHECK_LE(inv_lag, max_pitch_period); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product(pitch_buffer.begin() + kMaxPitch24kHz, - pitch_buffer.end(), - pitch_buffer.begin() + inverted_lag, 0.f); + return std::inner_product(pitch_buf.begin() + max_pitch_period, + pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); } -// Given an auto-correlation coefficient `curr_auto_correlation` and its -// neighboring values `prev_auto_correlation` and `next_auto_correlation` -// computes a pseudo-interpolation offset to be applied to the pitch period -// associated to `curr`. The output is a lag in {-1, 0, +1}. -// TODO(bugs.webrtc.org/9076): Consider removing this method. -// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral -// analysis works at a sample rate that is twice as that of the pitch buffer; -// In particular, it is not relevant for the estimated pitch period feature fed -// into the RNN. -int GetPitchPseudoInterpolationOffset(float prev_auto_correlation, - float curr_auto_correlation, - float next_auto_correlation) { - if ((next_auto_correlation - prev_auto_correlation) > - 0.7f * (curr_auto_correlation - prev_auto_correlation)) { - return 1; // |next_auto_correlation| is the largest auto-correlation - // coefficient. - } else if ((prev_auto_correlation - next_auto_correlation) > - 0.7f * (curr_auto_correlation - next_auto_correlation)) { - return -1; // |prev_auto_correlation| is the largest auto-correlation - // coefficient. +// Given the auto-correlation coefficients for a lag and its neighbors, computes +// a pseudo-interpolation offset to be applied to the pitch period associated to +// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag +// in {-1, 0, +1}. +// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it +// is relevant only if the spectral analysis works at a sample rate that is +// twice as that of the pitch buffer (not so important instead for the estimated +// pitch period feature fed into the RNN). +int GetPitchPseudoInterpolationOffset(float prev_auto_corr, + float lag_auto_corr, + float next_auto_corr) { + const float& a = prev_auto_corr; + const float& b = lag_auto_corr; + const float& c = next_auto_corr; + + int offset = 0; + if ((c - a) > 0.7f * (b - a)) { + offset = 1; // |c| is the largest auto-correlation coefficient. + } else if ((a - c) > 0.7f * (b - c)) { + offset = -1; // |a| is the largest auto-correlation coefficient. } - return 0; + return offset; } // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The // output sample rate is twice as that of |lag|. int PitchPseudoInterpolationLagPitchBuf( int lag, - rtc::ArrayView pitch_buffer) { + rtc::ArrayView pitch_buf) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. if (lag > 0 && lag < kMaxPitch24kHz) { - const int inverted_lag = kMaxPitch24kHz - lag; offset = GetPitchPseudoInterpolationOffset( - ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer), - ComputeAutoCorrelation(inverted_lag, pitch_buffer), - ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer)); + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), + kMaxPitch24kHz), + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), + kMaxPitch24kHz), + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), + kMaxPitch24kHz)); } return 2 * lag + offset; } -// Refines a pitch period |inverted_lag| encoded as inverted lag with +// Refines a pitch period |inv_lag| encoded as inverted lag with // pseudo-interpolation. The output sample rate is twice as that of -// |inverted_lag|. +// |inv_lag|. int PitchPseudoInterpolationInvLagAutoCorr( - int inverted_lag, - rtc::ArrayView auto_correlation) { + int inv_lag, + rtc::ArrayView auto_corr) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) { + if (inv_lag > 0 && inv_lag < rtc::dchecked_cast(auto_corr.size()) - 1) { offset = GetPitchPseudoInterpolationOffset( - auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag], - auto_correlation[inverted_lag - 1]); + auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]); } // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should - // be subtracted since |inverted_lag| is an inverted lag but offset is a lag. - return 2 * inverted_lag + offset; + // be subtracted since |inv_lag| is an inverted lag but offset is a lag. + return 2 * inv_lag + offset; } -// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when +// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when // looking for sub-harmonics. // The values have been chosen to serve the following algorithm. Given the // initial pitch period T, we examine whether one of its harmonics is the true // fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of -// these harmonics, in addition to the pitch strength of itself, we choose one +// these harmonics, in addition to the pitch gain of itself, we choose one // multiple of its pitch period, n*T/k, to validate it (by averaging their pitch -// strengths). The multiplier n is chosen so that n*T/k is used only one time -// over all k. When for example k = 4, we should also expect a peak at 3*T/4. -// When k = 8 instead we don't want to look at 2*T/8, since we have already -// checked T/4 before. Instead, we look at T*3/8. +// gains). The multiplier n is chosen so that n*T/k is used only one time over +// all k. When for example k = 4, we should also expect a peak at 3*T/4. When +// k = 8 instead we don't want to look at 2*T/8, since we have already checked +// T/4 before. Instead, we look at T*3/8. // The array can be generate in Python as follows: // from fractions import Fraction // # Smallest positive integer not in X. @@ -124,168 +130,92 @@ int PitchPseudoInterpolationInvLagAutoCorr( constexpr std::array kSubHarmonicMultipliers = { {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; -struct Range { - int min; - int max; -}; - -// Creates a pitch period interval centered in `inverted_lag` with hard-coded -// radius. Clipping is applied so that the interval is always valid for a 24 kHz -// pitch buffer. -Range CreateInvertedLagRange(int inverted_lag) { - constexpr int kRadius = 2; - return {std::max(inverted_lag - kRadius, 0), - std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)}; -} +// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for +// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. +constexpr std::array kInitialPitchPeriodThresholds = { + {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; -// Computes the auto correlation coefficients for the inverted lags in the -// closed interval `inverted_lags`. -void ComputeAutoCorrelation( - Range inverted_lags, - rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation) { - RTC_DCHECK_GE(inverted_lags.min, 0); - RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size()); - for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max; - ++inverted_lag) { - auto_correlation[inverted_lag] = - ComputeAutoCorrelation(inverted_lag, pitch_buffer); - } -} +} // namespace -int FindBestPitchPeriods24kHz( - rtc::ArrayView auto_correlation, - rtc::ArrayView pitch_buffer) { - static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); - static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); - // Initialize the sliding 20 ms frame energy. - // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - float denominator = std::inner_product( - pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1, - pitch_buffer.begin(), 1.f); - // Search best pitch by looking at the scaled auto-correlation. - int best_inverted_lag = 0; // Pitch period. - float best_numerator = -1.f; // Pitch strength numerator. - float best_denominator = 0.f; // Pitch strength denominator. - for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz; - ++inverted_lag) { - // A pitch candidate must have positive correlation. - if (auto_correlation[inverted_lag] > 0.f) { - const float numerator = - auto_correlation[inverted_lag] * auto_correlation[inverted_lag]; - // Compare numerator/denominator ratios without using divisions. - if (numerator * best_denominator > best_numerator * denominator) { - best_inverted_lag = inverted_lag; - best_numerator = numerator; - best_denominator = denominator; - } - } - // Update |denominator| for the next inverted lag. - static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz, - ""); - const float y_old = pitch_buffer[inverted_lag]; - const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; - denominator -= y_old * y_old; - denominator += y_new * y_new; - denominator = std::max(0.f, denominator); +void Decimate2x(rtc::ArrayView src, + rtc::ArrayView dst) { + // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. + static_assert(2 * dst.size() == src.size(), ""); + for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) { + dst[i] = src[2 * i]; } - return best_inverted_lag; } -// Returns an alternative pitch period for `pitch_period` given a `multiplier` -// and a `divisor` of the period. -constexpr int GetAlternativePitchPeriod(int pitch_period, - int multiplier, - int divisor) { - RTC_DCHECK_GT(divisor, 0); - // Same as `round(multiplier * pitch_period / divisor)`. - return (2 * multiplier * pitch_period + divisor) / (2 * divisor); -} - -// Returns true if the alternative pitch period is stronger than the initial one -// given the last estimated pitch and the value of `period_divisor` used to -// compute the alternative pitch period via `GetAlternativePitchPeriod()`. -bool IsAlternativePitchStrongerThanInitial(PitchInfo last, - PitchInfo initial, - PitchInfo alternative, - int period_divisor) { - // Initial pitch period candidate thresholds for a sample rate of 24 kHz. - // Computed as [5*k*k for k in range(16)]. - constexpr std::array kInitialPitchPeriodThresholds = { - {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; - static_assert( - kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(), - ""); - RTC_DCHECK_GE(last.period, 0); - RTC_DCHECK_GE(initial.period, 0); - RTC_DCHECK_GE(alternative.period, 0); - RTC_DCHECK_GE(period_divisor, 2); - // Compute a term that lowers the threshold when |alternative.period| is close - // to the last estimated period |last.period| - i.e., pitch tracking. - float lower_threshold_term = 0.f; - if (std::abs(alternative.period - last.period) <= 1) { - // The candidate pitch period is within 1 sample from the last one. - // Make the candidate at |alternative.period| very easy to be accepted. - lower_threshold_term = last.strength; - } else if (std::abs(alternative.period - last.period) == 2 && - initial.period > - kInitialPitchPeriodThresholds[period_divisor - 2]) { - // The candidate pitch period is 2 samples far from the last one and the - // period |initial.period| (from which |alternative.period| has been - // derived) is greater than a threshold. Make |alternative.period| easy to - // be accepted. - lower_threshold_term = 0.5f * last.strength; +float ComputePitchGainThreshold(int candidate_pitch_period, + int pitch_period_ratio, + int initial_pitch_period, + float initial_pitch_gain, + int prev_pitch_period, + float prev_pitch_gain) { + // Map arguments to more compact aliases. + const int& t1 = candidate_pitch_period; + const int& k = pitch_period_ratio; + const int& t0 = initial_pitch_period; + const float& g0 = initial_pitch_gain; + const int& t_prev = prev_pitch_period; + const float& g_prev = prev_pitch_gain; + + // Validate input. + RTC_DCHECK_GE(t1, 0); + RTC_DCHECK_GE(k, 2); + RTC_DCHECK_GE(t0, 0); + RTC_DCHECK_GE(t_prev, 0); + + // Compute a term that lowers the threshold when |t1| is close to the last + // estimated period |t_prev| - i.e., pitch tracking. + float lower_threshold_term = 0; + if (abs(t1 - t_prev) <= 1) { + // The candidate pitch period is within 1 sample from the previous one. + // Make the candidate at |t1| very easy to be accepted. + lower_threshold_term = g_prev; + } else if (abs(t1 - t_prev) == 2 && + t0 > kInitialPitchPeriodThresholds[k - 2]) { + // The candidate pitch period is 2 samples far from the previous one and the + // period |t0| (from which |t1| has been derived) is greater than a + // threshold. Make |t1| easy to be accepted. + lower_threshold_term = 0.5f * g_prev; } - // Set the threshold based on the strength of the initial estimate - // |initial.period|. Also reduce the chance of false positives caused by a - // bias towards high frequencies (originating from short-term correlations). - float threshold = - std::max(0.3f, 0.7f * initial.strength - lower_threshold_term); - if (alternative.period < 3 * kMinPitch24kHz) { + // Set the threshold based on the gain of the initial estimate |t0|. Also + // reduce the chance of false positives caused by a bias towards high + // frequencies (originating from short-term correlations). + float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); + if (t1 < 3 * kMinPitch24kHz) { // High frequency. - threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term); - } else if (alternative.period < 2 * kMinPitch24kHz) { + threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); + } else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. - threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term); + threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); } - return alternative.strength > threshold; + return threshold; } -} // namespace - -void Decimate2x(rtc::ArrayView src, - rtc::ArrayView dst) { - // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. - static_assert(2 * kBufSize12kHz == kBufSize24kHz, ""); - for (int i = 0; i < kBufSize12kHz; ++i) { - dst[i] = src[2 * i]; - } -} - -void ComputeSlidingFrameSquareEnergies24kHz( - rtc::ArrayView pitch_buffer, - rtc::ArrayView yy_values) { - float yy = ComputeAutoCorrelation(kMaxPitch24kHz, pitch_buffer); +void ComputeSlidingFrameSquareEnergies( + rtc::ArrayView pitch_buf, + rtc::ArrayView yy_values) { + float yy = + ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); yy_values[0] = yy; - static_assert(kMaxPitch24kHz - (kRefineNumLags24kHz - 1) >= 0, ""); - static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, ""); - for (int lag = 1; lag < kRefineNumLags24kHz; ++lag) { - const int inverted_lag = kMaxPitch24kHz - lag; - const float y_old = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; - const float y_new = pitch_buffer[inverted_lag]; - yy -= y_old * y_old; - yy += y_new * y_new; + for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) { + RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); + RTC_DCHECK_LE(i, kMaxPitch24kHz); + const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i]; + const float new_coeff = pitch_buf[kMaxPitch24kHz - i]; + yy -= old_coeff * old_coeff; + yy += new_coeff * new_coeff; yy = std::max(0.f, yy); - yy_values[lag] = yy; + yy_values[i] = yy; } } -CandidatePitchPeriods ComputePitchPeriod12kHz( - rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation) { - static_assert(kMaxPitch12kHz > kNumLags12kHz, ""); - static_assert(kMaxPitch12kHz < kBufSize12kHz, ""); - +CandidatePitchPeriods FindBestPitchPeriods( + rtc::ArrayView auto_corr, + rtc::ArrayView pitch_buf, + int max_pitch_period) { // Stores a pitch candidate period and strength information. struct PitchCandidate { // Pitch period encoded as inverted lag. @@ -301,22 +231,28 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( } }; + RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + const int frame_size = + rtc::dchecked_cast(pitch_buf.size()) - max_pitch_period; + RTC_DCHECK_GT(frame_size, 0); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - float denominator = std::inner_product( - pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms12kHz + 1, - pitch_buffer.begin(), 1.f); + float yy = + std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, + pitch_buf.begin(), 1.f); // Search best and second best pitches by looking at the scaled // auto-correlation. + PitchCandidate candidate; PitchCandidate best; PitchCandidate second_best; second_best.period_inverted_lag = 1; - for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) { + for (int inv_lag = 0; inv_lag < rtc::dchecked_cast(auto_corr.size()); + ++inv_lag) { // A pitch candidate must have positive correlation. - if (auto_correlation[inverted_lag] > 0.f) { - PitchCandidate candidate{ - inverted_lag, - auto_correlation[inverted_lag] * auto_correlation[inverted_lag], - denominator}; + if (auto_corr[inv_lag] > 0) { + candidate.period_inverted_lag = inv_lag; + candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag]; + candidate.strength_denominator = yy; if (candidate.HasStrongerPitchThan(second_best)) { if (candidate.HasStrongerPitchThan(best)) { second_best = best; @@ -327,144 +263,144 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( } } // Update |squared_energy_y| for the next inverted lag. - const float y_old = pitch_buffer[inverted_lag]; - const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz]; - denominator -= y_old * y_old; - denominator += y_new * y_new; - denominator = std::max(0.f, denominator); + const float old_coeff = pitch_buf[inv_lag]; + const float new_coeff = pitch_buf[inv_lag + frame_size]; + yy -= old_coeff * old_coeff; + yy += new_coeff * new_coeff; + yy = std::max(0.f, yy); } return {best.period_inverted_lag, second_best.period_inverted_lag}; } -int ComputePitchPeriod48kHz( - rtc::ArrayView pitch_buffer, - CandidatePitchPeriods pitch_candidates) { +int RefinePitchPeriod48kHz( + rtc::ArrayView pitch_buf, + CandidatePitchPeriods pitch_candidates_inverted_lags) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but // for a few lag values). - std::array auto_correlation{}; - const Range r1 = CreateInvertedLagRange(pitch_candidates.best); - const Range r2 = CreateInvertedLagRange(pitch_candidates.second_best); - RTC_DCHECK_LE(r1.min, r1.max); - RTC_DCHECK_LE(r2.min, r2.max); - if (r1.min <= r2.min && r1.max + 1 >= r2.min) { - // Overlapping or adjacent ranges (`r1` precedes `r2`). - RTC_DCHECK_LE(r1.max, r2.max); - ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation); - } else if (r1.min > r2.min && r2.max + 1 >= r1.min) { - // Overlapping or adjacent ranges (`r2` precedes `r1`). - RTC_DCHECK_LE(r2.max, r1.max); - ComputeAutoCorrelation({r2.min, r1.max}, pitch_buffer, auto_correlation); - } else { - // Disjoint ranges. - ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation); - ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation); + std::array auto_correlation; + auto_correlation.fill( + 0.f); // Zeros become ignored lags in FindBestPitchPeriods(). + auto is_neighbor = [](int i, int j) { + return ((i > j) ? (i - j) : (j - i)) <= 2; + }; + // TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop. + for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size()); + ++inverted_lag) { + if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) || + is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best)) + auto_correlation[inverted_lag] = + ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz); } // Find best pitch at 24 kHz. - const int pitch_candidate_24kHz = - FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer); + const CandidatePitchPeriods pitch_candidates_24kHz = + FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz); // Pseudo-interpolation. - return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz, + return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best, auto_correlation); } -PitchInfo ComputeExtendedPitchPeriod48kHz( - rtc::ArrayView pitch_buffer, +PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( + rtc::ArrayView pitch_buf, int initial_pitch_period_48kHz, - PitchInfo last_pitch_48kHz) { + PitchInfo prev_pitch_48kHz) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); - // Stores information for a refined pitch candidate. struct RefinedPitchCandidate { - int period; - float strength; - // Additional strength data used for the final estimation of the strength. + RefinedPitchCandidate() {} + RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy) + : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} + int period_24kHz; + // Pitch strength information. + float gain; + // Additional pitch strength information used for the final estimation of + // pitch gain. float xy; // Cross-correlation. float yy; // Auto-correlation. }; // Initialize. - std::array yy_values; - // TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz(). - ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values); + std::array yy_values; + ComputeSlidingFrameSquareEnergies(pitch_buf, + {yy_values.data(), yy_values.size()}); const float xx = yy_values[0]; - const auto pitch_strength = [](float xy, float yy, float xx) { - RTC_DCHECK_GE(xx * yy, 0.f); + // Helper lambdas. + const auto pitch_gain = [](float xy, float yy, float xx) { + RTC_DCHECK_LE(0.f, xx * yy); return xy / std::sqrt(1.f + xx * yy); }; - // Initial pitch candidate. + // Initial pitch candidate gain. RefinedPitchCandidate best_pitch; - best_pitch.period = + best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); - best_pitch.xy = - ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); - best_pitch.yy = yy_values[best_pitch.period]; - best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx); - - // 24 kHz version of the last estimated pitch and copy of the initial - // estimation. - const PitchInfo last_pitch{last_pitch_48kHz.period / 2, - last_pitch_48kHz.strength}; - const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; - - // Find `max_period_divisor` such that the result of - // `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)` - // equals `kMinPitch24kHz`. - const int max_period_divisor = - (2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1); - for (int period_divisor = 2; period_divisor <= max_period_divisor; - ++period_divisor) { - PitchInfo alternative_pitch; - alternative_pitch.period = GetAlternativePitchPeriod( - initial_pitch.period, /*multiplier=*/1, period_divisor); - RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz); - // When looking at |alternative_pitch.period|, we also look at one of its + best_pitch.xy = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); + best_pitch.yy = yy_values[best_pitch.period_24kHz]; + best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); + + // Store the initial pitch period information. + const int initial_pitch_period = best_pitch.period_24kHz; + const float initial_pitch_gain = best_pitch.gain; + + // Given the initial pitch estimation, check lower periods (i.e., harmonics). + const auto alternative_period = [](int period, int k, int n) -> int { + RTC_DCHECK_GT(k, 0); + return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). + }; + // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals + // kMinPitch24kHz. + const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1); + for (int k = 2; k <= max_k; ++k) { + int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); + RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz); + // When looking at |candidate_pitch_period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. - // |period_divisor| == 2 is a special case since |dual_alternative_period| - // might be greater than the maximum pitch period. - int dual_alternative_period = GetAlternativePitchPeriod( - initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2], - period_divisor); - RTC_DCHECK_GT(dual_alternative_period, 0); - if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) { - dual_alternative_period = initial_pitch.period; + // |k| == 2 is a special case since |candidate_pitch_secondary_period| might + // be greater than the maximum pitch period. + int candidate_pitch_secondary_period = alternative_period( + initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); + RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); + if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) { + candidate_pitch_secondary_period = initial_pitch_period; } - RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period) + RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) << "The lower pitch period and the additional sub-harmonic must not " "coincide."; // Compute an auto-correlation score for the primary pitch candidate - // |alternative_pitch.period| by also looking at its possible sub-harmonic - // |dual_alternative_period|. - float xy_primary_period = ComputeAutoCorrelation( - kMaxPitch24kHz - alternative_pitch.period, pitch_buffer); - float xy_secondary_period = ComputeAutoCorrelation( - kMaxPitch24kHz - dual_alternative_period, pitch_buffer); + // |candidate_pitch_period| by also looking at its possible sub-harmonic + // |candidate_pitch_secondary_period|. + float xy_primary_period = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); + float xy_secondary_period = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), + kMaxPitch24kHz); float xy = 0.5f * (xy_primary_period + xy_secondary_period); - float yy = 0.5f * (yy_values[alternative_pitch.period] + - yy_values[dual_alternative_period]); - alternative_pitch.strength = pitch_strength(xy, yy, xx); + float yy = 0.5f * (yy_values[candidate_pitch_period] + + yy_values[candidate_pitch_secondary_period]); + float candidate_pitch_gain = pitch_gain(xy, yy, xx); // Maybe update best period. - if (IsAlternativePitchStrongerThanInitial( - last_pitch, initial_pitch, alternative_pitch, period_divisor)) { - best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy, - yy}; + float threshold = ComputePitchGainThreshold( + candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain, + prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain); + if (candidate_pitch_gain > threshold) { + best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy}; } } - // Final pitch strength and period. + // Final pitch gain and period. best_pitch.xy = std::max(0.f, best_pitch.xy); RTC_DCHECK_LE(0.f, best_pitch.yy); - float final_pitch_strength = (best_pitch.yy <= best_pitch.xy) - ? 1.f - : best_pitch.xy / (best_pitch.yy + 1.f); - final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength); + float final_pitch_gain = (best_pitch.yy <= best_pitch.xy) + ? 1.f + : best_pitch.xy / (best_pitch.yy + 1.f); + final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); int final_pitch_period_48kHz = std::max( kMinPitch48kHz, - PitchPseudoInterpolationLagPitchBuf(best_pitch.period, pitch_buffer)); + PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); - return {final_pitch_period_48kHz, final_pitch_strength}; + return {final_pitch_period_48kHz, final_pitch_gain}; } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index b16a2f438da..cab62865235 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -18,6 +18,7 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" namespace webrtc { namespace rnn_vad { @@ -26,78 +27,56 @@ namespace rnn_vad { void Decimate2x(rtc::ArrayView src, rtc::ArrayView dst); -// Key concepts and keywords used below in this file. -// -// The pitch estimation relies on a pitch buffer, which is an array-like data -// structured designed as follows: -// -// |....A....|.....B.....| -// -// The part on the left, named `A` contains the oldest samples, whereas `B` -// contains the most recent ones. The size of `A` corresponds to the maximum -// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms -// respectively). -// -// Pitch estimation is essentially based on the analysis of two 20 ms frames -// extracted from the pitch buffer. One frame, called `x`, is kept fixed and -// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called -// `y`, is extracted from different parts of the buffer instead. -// -// The offset between `x` and `y` corresponds to a specific pitch period. -// For instance, if `y` is positioned at the beginning of the pitch buffer, then -// the cross-correlation between `x` and `y` can be used as an indication of the -// strength for the maximum pitch. -// -// Such an offset can be encoded in two ways: -// - As a lag, which is the index in the pitch buffer for the first item in `y` -// - As an inverted lag, which is the number of samples from the beginning of -// `x` and the end of `y` -// -// |---->| lag -// |....A....|.....B.....| -// |<--| inverted lag -// |.....y.....| `y` 20 ms frame -// -// The inverted lag has the advantage of being directly proportional to the -// corresponding pitch period. +// Computes a gain threshold for a candidate pitch period given the initial and +// the previous pitch period and gain estimates and the pitch period ratio used +// to derive the candidate pitch period from the initial period. +float ComputePitchGainThreshold(int candidate_pitch_period, + int pitch_period_ratio, + int initial_pitch_period, + float initial_pitch_gain, + int prev_pitch_period, + float prev_pitch_gain); -// Computes the sum of squared samples for every sliding frame `y` in the pitch -// buffer. The indexes of `yy_values` are lags. -void ComputeSlidingFrameSquareEnergies24kHz( - rtc::ArrayView pitch_buffer, - rtc::ArrayView yy_values); +// Computes the sum of squared samples for every sliding frame in the pitch +// buffer. |yy_values| indexes are lags. +// +// The pitch buffer is structured as depicted below: +// |.........|...........| +// a b +// The part on the left, named "a" contains the oldest samples, whereas "b" the +// most recent ones. The size of "a" corresponds to the maximum pitch period, +// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively). +void ComputeSlidingFrameSquareEnergies( + rtc::ArrayView pitch_buf, + rtc::ArrayView yy_values); -// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags. +// Top-2 pitch period candidates. struct CandidatePitchPeriods { int best; int second_best; }; -// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz -// pitch buffer and the auto-correlation values (having inverted lags as -// indexes). -CandidatePitchPeriods ComputePitchPeriod12kHz( - rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation); +// Computes the candidate pitch periods given the auto-correlation coefficients +// stored according to ComputePitchAutoCorrelation() (i.e., using inverted +// lags). The return periods are inverted lags. +CandidatePitchPeriods FindBestPitchPeriods( + rtc::ArrayView auto_corr, + rtc::ArrayView pitch_buf, + int max_pitch_period); -// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer -// and the pitch period candidates at 24 kHz (encoded as inverted lag). -int ComputePitchPeriod48kHz( - rtc::ArrayView pitch_buffer, - CandidatePitchPeriods pitch_candidates_24kHz); - -struct PitchInfo { - int period; - float strength; -}; +// Refines the pitch period estimation given the pitch buffer |pitch_buf| and +// the initial pitch period estimation |pitch_candidates_inverted_lags|. +// Returns an inverted lag at 48 kHz. +int RefinePitchPeriod48kHz( + rtc::ArrayView pitch_buf, + CandidatePitchPeriods pitch_candidates_inverted_lags); -// Computes the pitch period at 48 kHz searching in an extended pitch range -// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation -// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch. -PitchInfo ComputeExtendedPitchPeriod48kHz( - rtc::ArrayView pitch_buffer, +// Refines the pitch period estimation and compute the pitch gain. Returns the +// refined pitch estimation data at 48 kHz. +PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( + rtc::ArrayView pitch_buf, int initial_pitch_period_48kHz, - PitchInfo last_pitch_48kHz); + PitchInfo prev_pitch_48kHz); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 7acb046db14..fdbee68357e 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -31,77 +31,138 @@ constexpr float kTestPitchGainsHigh = 0.75f; } // namespace +class ComputePitchGainThresholdTest + : public ::testing::Test, + public ::testing::WithParamInterface> {}; + +// Checks that the computed pitch gain is within tolerance given test input +// data. +TEST_P(ComputePitchGainThresholdTest, WithinTolerance) { + const auto params = GetParam(); + const int candidate_pitch_period = std::get<0>(params); + const int pitch_period_ratio = std::get<1>(params); + const int initial_pitch_period = std::get<2>(params); + const float initial_pitch_gain = std::get<3>(params); + const int prev_pitch_period = std::get<4>(params); + const float prev_pitch_gain = std::get<5>(params); + const float threshold = std::get<6>(params); + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + EXPECT_NEAR( + threshold, + ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio, + initial_pitch_period, initial_pitch_gain, + prev_pitch_period, prev_pitch_gain), + 5e-7f); + } +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + ComputePitchGainThresholdTest, + ::testing::Values( + std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f), + std::make_tuple(113, + 2, + 226, + 0.20967799f, + 219, + 0.40392199f, + 0.30000001f), + std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f), + std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f), + std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f), + std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f), + std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); + // Checks that the frame-wise sliding square energy function produces output // within tolerance given test input data. -TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { +TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) { PitchTestData test_data; std::array computed_output; - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - computed_output); + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(), + computed_output); + } auto square_energies_view = test_data.GetPitchBufSquareEnergiesView(); ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()}, computed_output, 3e-2f); } // Checks that the estimated pitch period is bit-exact given test input data. -TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { +TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); CandidatePitchPeriods pitch_candidates; - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates = - ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view); + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); + pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated, + kMaxPitch12kHz); + } EXPECT_EQ(pitch_candidates.best, 140); EXPECT_EQ(pitch_candidates.second_best, 142); } // Checks that the refined pitch period is bit-exact given test input data. -TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { +TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { PitchTestData test_data; // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), - /*pitch_candidates=*/{280, 284}), + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{280, 284}), 560); - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), - /*pitch_candidates=*/{260, 284}), + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{260, 284}), 568); } -class ComputeExtendedPitchPeriod48kHzTest +class CheckLowerPitchPeriodsAndComputePitchGainTest : public ::testing::Test, - public ::testing::WithParamInterface< - std::tuple> { - protected: - int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); } - int GetLastPitchPeriod() const { return std::get<1>(GetParam()); } - float GetLastPitchStrength() const { return std::get<2>(GetParam()); } - int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); } - float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); } -}; + public ::testing::WithParamInterface> {}; // Checks that the computed pitch period is bit-exact and that the computed -// pitch strength is within tolerance given test input data. -TEST_P(ComputeExtendedPitchPeriod48kHzTest, +// pitch gain is within tolerance given test input data. +TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, PeriodBitExactnessGainWithinTolerance) { + const auto params = GetParam(); + const int initial_pitch_period = std::get<0>(params); + const int prev_pitch_period = std::get<1>(params); + const float prev_pitch_gain = std::get<2>(params); + const int expected_pitch_period = std::get<3>(params); + const float expected_pitch_gain = std::get<4>(params); PitchTestData test_data; - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - const auto computed_output = ComputeExtendedPitchPeriod48kHz( - test_data.GetPitchBufView(), GetInitialPitchPeriod(), - {GetLastPitchPeriod(), GetLastPitchStrength()}); - EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period); - EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f); + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain( + test_data.GetPitchBufView(), initial_pitch_period, + {prev_pitch_period, prev_pitch_gain}); + EXPECT_EQ(expected_pitch_period, computed_output.period); + EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f); + } } INSTANTIATE_TEST_SUITE_P( RnnVadTest, - ComputeExtendedPitchPeriod48kHzTest, + CheckLowerPitchPeriodsAndComputePitchGainTest, ::testing::Values(std::make_tuple(kTestPitchPeriodsLow, kTestPitchPeriodsLow, kTestPitchGainsLow, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index c57c8c24dbd..fdecb928079 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -21,14 +22,15 @@ namespace webrtc { namespace rnn_vad { +namespace test { // Checks that the computed pitch period is bit-exact and that the computed // pitch gain is within tolerance given test input data. TEST(RnnVadTest, PitchSearchWithinTolerance) { - auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader(); + auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); - float expected_pitch_period, expected_pitch_strength; + float expected_pitch_period, expected_pitch_gain; PitchEstimator pitch_estimator; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -37,15 +39,15 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { SCOPED_TRACE(i); lp_residual_reader.first->ReadChunk(lp_residual); lp_residual_reader.first->ReadValue(&expected_pitch_period); - lp_residual_reader.first->ReadValue(&expected_pitch_strength); - int pitch_period = + lp_residual_reader.first->ReadValue(&expected_pitch_gain); + PitchInfo pitch_info = pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); - EXPECT_EQ(expected_pitch_period, pitch_period); - EXPECT_NEAR(expected_pitch_strength, - pitch_estimator.GetLastPitchStrengthForTesting(), 1e-5f); + EXPECT_EQ(expected_pitch_period, pitch_info.period); + EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); } } } +} // namespace test } // namespace rnn_vad } // namespace webrtc