Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Revert "RNN VAD: pitch search optimizations (part 1)"
Browse files Browse the repository at this point in the history
This reverts commit 9da3e17.

Reason for revert: bug in ComputePitchPeriod48kHz()

Original change's description:
> RNN VAD: pitch search optimizations (part 1)
>
> TL;DR this CL improves efficiency and includes several code
> readability improvements mainly triggered by the comments to
> patch set #10.
>
> Highlights:
> - Split `FindBestPitchPeriods()` into 12 and 24 kHz versions
>   to hard-code the input size and simplify the 24 kHz version
> - Loop in `ComputePitchPeriod48kHz()` (new name for
>   `RefinePitchPeriod48kHz()`) removed since the lags for which
>   we need to compute the auto correlation are a few
> - `ComputePitchGainThreshold()` was only used in unit tests; it's been
>   moved into the anon ns and the test removed
>
> This CL makes `ComputePitchPeriod48kHz()` is about 10% faster (measured
> with https://webrtc-review.googlesource.com/c/src/+/191320/4/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc).
> The realtime factor has improved by about +14%.
>
> Benchmarked as follows:
> ```
> out/release/modules_unittests \
>   --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
>   --gtest_also_run_disabled_tests --logs
> ```
>
> Results:
>
>       | baseline             | this CL
> ------+----------------------+------------------------
> run 1 | 24.0231 +/- 0.591016 | 23.568 +/- 0.990788
>       | 370.06x              | 377.207x
> ------+----------------------+------------------------
> run 2 | 24.0485 +/- 0.957498 | 23.3714 +/- 0.857523
>       | 369.67x              | 380.379x
> ------+----------------------+------------------------
> run 2 | 25.4091 +/- 2.6123   | 23.709 +/- 1.04477
>       | 349.875x             | 374.963x
>
> Bug: webrtc:10480
> Change-Id: I9a3e9164b2442114b928de506c92a547c273882f
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191320
> Reviewed-by: Per Åhgren <peah@webrtc.org>
> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#32568}

TBR=alessiob@webrtc.org,peah@webrtc.org

No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: webrtc:10480
Change-Id: I2a91f4f29566f872a7dfa220b31c6c625ed075db
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192660
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32581}
  • Loading branch information
alebzk authored and Commit Bot committed Nov 10, 2020
1 parent e6a731f commit 1b6b958
Show file tree
Hide file tree
Showing 13 changed files with 450 additions and 452 deletions.
2 changes: 1 addition & 1 deletion modules/audio_processing/agc2/rnn_vad/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ rtc_library("rnn_vad_lp_residual") {

rtc_library("rnn_vad_pitch") {
sources = [
"pitch_info.h",
"pitch_search.cc",
"pitch_search.h",
"pitch_search_internal.cc",
Expand All @@ -93,7 +94,6 @@ rtc_library("rnn_vad_pitch") {
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:gtest_prod",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
]
Expand Down
17 changes: 9 additions & 8 deletions modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace {

constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");

} // namespace
Expand All @@ -45,15 +45,15 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
Expand All @@ -67,12 +67,13 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(

// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);

// Convolve in the frequency domain.
Expand All @@ -83,7 +84,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(

// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
auto_corr.begin());
}

Expand Down
2 changes: 1 addition & 1 deletion modules/audio_processing/agc2/rnn_vad/auto_correlation.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AutoCorrelationCalculator {
// |auto_corr| indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);

private:
Pffft fft_;
Expand Down
12 changes: 3 additions & 9 deletions modules/audio_processing/agc2/rnn_vad/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,7 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
static_assert(
kRefineNumLags24kHz > kInitialNumLags24kHz,
"The refinement step must search the pitch in an extended pitch range.");
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;

// 12 kHz analysis.
constexpr int kSampleRate12kHz = 12000;
Expand All @@ -53,8 +47,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|].
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;

// 48 kHz constants.
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
Expand Down
9 changes: 5 additions & 4 deletions modules/audio_processing/agc2/rnn_vad/features_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (pitch_info_48kHz_.period - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(
Expand Down
3 changes: 2 additions & 1 deletion modules/audio_processing/agc2/rnn_vad/features_extraction.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ class FeaturesExtractor {
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
int pitch_period_48kHz_;
PitchInfo pitch_info_48kHz_;
};

} // namespace rnn_vad
Expand Down
29 changes: 29 additions & 0 deletions modules/audio_processing/agc2/rnn_vad/pitch_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/

#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_

namespace webrtc {
namespace rnn_vad {

// Stores pitch period and gain information. The pitch gain measures the
// strength of the pitch (the higher, the stronger).
struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {}
PitchInfo(int p, float g) : period(p), gain(g) {}
int period;
float gain;
};

} // namespace rnn_vad
} // namespace webrtc

#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
26 changes: 12 additions & 14 deletions modules/audio_processing/agc2/rnn_vad/pitch_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,35 @@ namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
}

PitchEstimator::~PitchEstimator() = default;

int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
PitchInfo PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags =
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
pitch_candidates_inverted_lags.best *= 2;
pitch_candidates_inverted_lags.second_best *= 2;
const int pitch_inv_lag_48kHz =
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
last_pitch_48kHz_);
return last_pitch_48kHz_.period;
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
}

} // namespace rnn_vad
Expand Down
16 changes: 6 additions & 10 deletions modules/audio_processing/agc2/rnn_vad/pitch_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "rtc_base/gtest_prod_util.h"

namespace webrtc {
namespace rnn_vad {
Expand All @@ -30,21 +30,17 @@ class PitchEstimator {
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Returns the estimated pitch period at 48 kHz.
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
// Estimates the pitch period and gain. Returns the pitch estimation data for
// 48 kHz.
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);

private:
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
float GetLastPitchStrengthForTesting() const {
return last_pitch_48kHz_.strength;
}

PitchInfo last_pitch_48kHz_{};
PitchInfo last_pitch_48kHz_;
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
};

} // namespace rnn_vad
Expand Down
Loading

0 comments on commit 1b6b958

Please sign in to comment.