From ed0c7256be1e31481d51d75ac16ff1fcc47c8bac Mon Sep 17 00:00:00 2001 From: cschuldt Date: Tue, 7 Dec 2021 09:11:52 +0100 Subject: Optimize MatchedFilter. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changing to an index for-loop (instead of using std::max_element & std::distance) tracking even & odd elements separately allows the compiler to produce code with less pipeline stall. Bug: None Change-Id: Iaa3e820a3a3b61e2eb276f0dac9106c848db1891 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240061 Reviewed-by: Per Ã…hgren Commit-Queue: Christian Schuldt Cr-Commit-Position: refs/heads/main@{#35729} --- modules/audio_processing/aec3/matched_filter.cc | 47 ++++++++++++++++++---- modules/audio_processing/aec3/matched_filter.h | 3 ++ .../aec3/matched_filter_unittest.cc | 22 ++++++++++ 3 files changed, 65 insertions(+), 7 deletions(-) (limited to 'modules/audio_processing/aec3') diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index 794381cc8a..faca933856 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -308,6 +308,41 @@ void MatchedFilterCore(size_t x_start_index, } } +size_t MaxSquarePeakIndex(rtc::ArrayView h) { + if (h.size() < 2) { + return 0; + } + float max_element1 = h[0] * h[0]; + float max_element2 = h[1] * h[1]; + size_t lag_estimate1 = 0; + size_t lag_estimate2 = 1; + const size_t last_index = h.size() - 1; + // Keeping track of even & odd max elements separately typically allows the + // compiler to produce more efficient code. + for (size_t k = 2; k < last_index; k += 2) { + float element1 = h[k] * h[k]; + float element2 = h[k + 1] * h[k + 1]; + if (element1 > max_element1) { + max_element1 = element1; + lag_estimate1 = k; + } + if (element2 > max_element2) { + max_element2 = element2; + lag_estimate2 = k + 1; + } + } + if (max_element2 > max_element1) { + max_element1 = max_element2; + lag_estimate1 = lag_estimate2; + } + // In case of odd h size, we have not yet checked the last element. + float last_element = h[last_index] * h[last_index]; + if (last_element > max_element1) { + return last_index; + } + return lag_estimate1; +} + } // namespace aec3 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, @@ -400,17 +435,15 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, } // Compute anchor for the matched filter error. - const float error_sum_anchor = - std::inner_product(y.begin(), y.end(), y.begin(), 0.f); + float error_sum_anchor = 0.0f; + for (size_t k = 0; k < y.size(); ++k) { + error_sum_anchor += y[k] * y[k]; + } // Estimate the lag in the matched filter as the distance to the portion in // the filter that contributes the most to the matched filter output. This // is detected as the peak of the matched filter. - const size_t lag_estimate = std::distance( - filters_[n].begin(), - std::max_element( - filters_[n].begin(), filters_[n].end(), - [](float a, float b) -> bool { return a * a < b * b; })); + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); // Update the lag estimates for the matched filter. lag_estimates_[n] = LagEstimate( diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index c6410ab4ee..dd4a678394 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -74,6 +74,9 @@ void MatchedFilterCore(size_t x_start_index, bool* filters_updated, float* error_sum); +// Find largest peak of squared values in array. +size_t MaxSquarePeakIndex(rtc::ArrayView h); + } // namespace aec3 // Produces recursively updated cross-correlation estimates for several signal diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc index 37b51fa624..8abfb69a7a 100644 --- a/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -176,6 +176,28 @@ TEST(MatchedFilter, TestAvx2Optimizations) { #endif +// Verifies that the (optimized) function MaxSquarePeakIndex() produces output +// equal to the corresponding std-functions. +TEST(MatchedFilter, MaxSquarePeakIndex) { + Random random_generator(42U); + constexpr int kMaxLength = 128; + constexpr int kNumIterationsPerLength = 256; + for (int length = 1; length < kMaxLength; ++length) { + std::vector y(length); + for (int i = 0; i < kNumIterationsPerLength; ++i) { + RandomizeSampleVector(&random_generator, y); + + size_t lag_from_function = MaxSquarePeakIndex(y); + size_t lag_from_std = std::distance( + y.begin(), + std::max_element(y.begin(), y.end(), [](float a, float b) -> bool { + return a * a < b * b; + })); + EXPECT_EQ(lag_from_function, lag_from_std); + } + } +} + // Verifies that the matched filter produces proper lag estimates for // artificially // delayed signals. -- cgit v1.2.3