diff options
author | Lionel Koenig <lionelk@webrtc.org> | 2022-06-28 15:37:13 +0200 |
---|---|---|
committer | WebRTC LUCI CQ <webrtc-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2022-06-28 15:16:03 +0000 |
commit | 8783c678a5cb74dda890e76092e1d767bb179d8c (patch) | |
tree | 2c4764a6f0b3d725bd3b5461215059c8a33d1494 /modules/audio_processing/aec3 | |
parent | 7534ebd2bf59212cce5e010dd6ed9b3bc944818e (diff) | |
download | webrtc-8783c678a5cb74dda890e76092e1d767bb179d8c.tar.gz |
delay estimator: Look for early reverberation
Look for first echo (and not only the strongest one) on the same matched
filter.
This change is bit exact with previous version when `pre_echo` is false.
Author: Jesús de Vicente Peña <devicentepena@webrtc.org>
Bug: webrtc:14205
Change-Id: I6782eaa1d690b0df78d00f6d425a85c951b2ca9d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/266321
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Lionel Koenig <lionelk@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37360}
Diffstat (limited to 'modules/audio_processing/aec3')
12 files changed, 950 insertions, 392 deletions
diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index 70d049549a..4de2d00f3e 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -226,6 +226,7 @@ rtc_source_set("matched_filter") { "../../../api:array_view", "../../../rtc_base/system:arch", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_source_set("vector_math") { diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc index 8e306aca56..1404a9987d 100644 --- a/modules/audio_processing/aec3/echo_canceller3.cc +++ b/modules/audio_processing/aec3/echo_canceller3.cc @@ -377,6 +377,14 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { false; } + if (field_trial::IsEnabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) { + adjusted_cfg.delay.detect_pre_echo = true; + } + + if (field_trial::IsDisabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) { + adjusted_cfg.delay.detect_pre_echo = false; + } + if (field_trial::IsEnabled("WebRTC-Aec3SensitiveDominantNearendActivation")) { adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.5f; } else if (field_trial::IsEnabled( diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/modules/audio_processing/aec3/echo_path_delay_estimator.cc index e64c4493f6..fc83ca2f89 100644 --- a/modules/audio_processing/aec3/echo_path_delay_estimator.cc +++ b/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -43,10 +43,11 @@ EchoPathDelayEstimator::EchoPathDelayEstimator( : config.render_levels.poor_excitation_render_limit, config.delay.delay_estimate_smoothing, config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold), + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo), matched_filter_lag_aggregator_(data_dumper_, matched_filter_.GetMaxFilterLag(), - config.delay.delay_selection_thresholds) { + config.delay) { RTC_DCHECK(data_dumper); RTC_DCHECK(down_sampling_factor_ > 0); } @@ -75,13 +76,14 @@ absl::optional<DelayEstimate> EchoPathDelayEstimator::EstimateDelay( absl::optional<DelayEstimate> aggregated_matched_filter_lag = matched_filter_lag_aggregator_.Aggregate( - matched_filter_.GetLagEstimates()); + matched_filter_.GetBestLagEstimate()); // Run clockdrift detection. if (aggregated_matched_filter_lag && (*aggregated_matched_filter_lag).quality == DelayEstimate::Quality::kRefined) - clockdrift_detector_.Update((*aggregated_matched_filter_lag).delay); + clockdrift_detector_.Update( + matched_filter_lag_aggregator_.GetDelayAtHighestPeak()); // TODO(peah): Move this logging outside of this class once EchoCanceller3 // development is done. diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc index 13c9c1122e..810b0ae185 100644 --- a/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc +++ b/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc @@ -78,6 +78,7 @@ TEST(EchoPathDelayEstimator, DelayEstimation) { constexpr size_t kDownSamplingFactors[] = {2, 4, 8}; for (auto down_sampling_factor : kDownSamplingFactors) { EchoCanceller3Config config; + config.delay.delay_headroom_samples = 0; config.delay.down_sampling_factor = down_sampling_factor; config.delay.num_filters = 10; for (size_t delay_samples : {30, 64, 150, 200, 800, 4000}) { @@ -111,12 +112,13 @@ TEST(EchoPathDelayEstimator, DelayEstimation) { } if (estimated_delay_samples) { - // Allow estimated delay to be off by one sample in the down-sampled - // domain. + // Allow estimated delay to be off by a block as internally the delay is + // quantized with an error up to a block. size_t delay_ds = delay_samples / down_sampling_factor; size_t estimated_delay_ds = estimated_delay_samples->delay / down_sampling_factor; - EXPECT_NEAR(delay_ds, estimated_delay_ds, 1); + EXPECT_NEAR(delay_ds, estimated_delay_ds, + kBlockSize / down_sampling_factor); } else { ADD_FAILURE(); } diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index faca933856..c5e394ad2f 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -24,16 +24,147 @@ #include <iterator> #include <numeric> +#include "absl/types/optional.h" +#include "api/array_view.h" #include "modules/audio_processing/aec3/downsampled_render_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +namespace { + +// Subsample rate used for computing the accumulated error. +// The implementation of some core functions depends on this constant being +// equal to 4. +constexpr int kAccumulatedErrorSubSampleRate = 4; + +void UpdateAccumulatedError( + const rtc::ArrayView<const float> instantaneous_accumulated_error, + const rtc::ArrayView<float> accumulated_error, + float one_over_error_sum_anchor) { + for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) { + float error_norm = + instantaneous_accumulated_error[k] * one_over_error_sum_anchor; + if (error_norm < accumulated_error[k]) { + accumulated_error[k] = error_norm; + } else { + accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]); + } + } +} + +size_t ComputePreEchoLag(const rtc::ArrayView<float> accumulated_error, + size_t lag, + size_t alignment_shift_winner) { + size_t pre_echo_lag_estimate = lag - alignment_shift_winner; + size_t maximum_pre_echo_lag = + std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate, + accumulated_error.size()); + for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < 0.5f * accumulated_error[k - 1] && + accumulated_error[k] < 0.5f) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + return pre_echo_lag_estimate + alignment_shift_winner; +} + +} // namespace + namespace webrtc { namespace aec3 { #if defined(WEBRTC_HAS_NEON) +inline float SumAllElements(float32x4_t elements) { + float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements)); + sum = vpadd_f32(sum, sum); + return vget_lane_f32(sum, 0); +} + +void MatchedFilterCoreWithAccumulatedError_NEON( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* accumulated_error_p = &accumulated_error[0]; + // Initialize values for the accumulation. + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; + --k, h_p += 4, x_p += 4, accumulated_error_p++) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + // Compute x * h + float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k); + s += SumAllElements(hk_xk_128); + const float e = s - y[i]; + accumulated_error_p[0] += e * e; + } + // Combine the accumulated vector and scalar values. + x2_sum += SumAllElements(x2_sum_128); + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + // Store the result. + vst1q_f32(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + void MatchedFilterCore_NEON(size_t x_start_index, float x2_sum_threshold, float smoothing, @@ -41,11 +172,20 @@ void MatchedFilterCore_NEON(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum) { + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { const int h_size = static_cast<int>(h.size()); const int x_size = static_cast<int>(x.size()); RTC_DCHECK_EQ(0, h_size % 4); + if (compute_accumulated_error) { + return MatchedFilterCoreWithAccumulatedError_NEON( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. @@ -90,10 +230,8 @@ void MatchedFilterCore_NEON(size_t x_start_index, } // Combine the accumulated vector and scalar values. - float* v = reinterpret_cast<float*>(&x2_sum_128); - x2_sum += v[0] + v[1] + v[2] + v[3]; - v = reinterpret_cast<float*>(&s_128); - s += v[0] + v[1] + v[2] + v[3]; + s += SumAllElements(s_128); + x2_sum += SumAllElements(x2_sum_128); // Compute the matched filter error. float e = y[i] - s; @@ -144,6 +282,103 @@ void MatchedFilterCore_NEON(size_t x_start_index, #if defined(WEBRTC_ARCH_X86_FAMILY) +void MatchedFilterCore_AccumulatedError_SSE2( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 8); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* a_p = &accumulated_error[0]; + __m128 s_inst_128; + __m128 s_inst_128_4; + __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); + __m128 e_128; + float* const s_p = reinterpret_cast<float*>(&s_inst_128); + float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4); + float* const e_p = reinterpret_cast<float*>(&e_128); + float x2_sum = 0.0f; + float s_acum = 0; + // Perform 128 bit vector operations. + const int limit_by_8 = h_size >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) { + // Load the data into 128 bit vectors. + const __m128 x_k = _mm_loadu_ps(x_p); + const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); + const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); + // Compute and accumulate x * x and h * x. + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); + s_inst_128 = _mm_mul_ps(h_k, x_k); + s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4); + s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3]; + e_p[0] = s_acum - y[i]; + s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3]; + e_p[1] = s_acum - y[i]; + a_p[0] += e_p[0] * e_p[0]; + a_p[1] += e_p[1] * e_p[1]; + } + // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + // Compute the matched filter error. + float e = y[i] - s_acum; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m128 alpha_128 = _mm_set1_ps(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k = _mm_loadu_ps(x_p); + // Compute h = h + alpha * x. + const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); + h_k = _mm_add_ps(h_k, alpha_x); + // Store the result. + _mm_storeu_ps(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + void MatchedFilterCore_SSE2(size_t x_start_index, float x2_sum_threshold, float smoothing, @@ -151,19 +386,24 @@ void MatchedFilterCore_SSE2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum) { + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + if (compute_accumulated_error) { + return MatchedFilterCore_AccumulatedError_SSE2( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } const int h_size = static_cast<int>(h.size()); const int x_size = static_cast<int>(x.size()); RTC_DCHECK_EQ(0, h_size % 4); - // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. - RTC_DCHECK_GT(x_size, x_start_index); const float* x_p = &x[x_start_index]; const float* h_p = &h[0]; - // Initialize values for the accumulation. __m128 s_128 = _mm_set1_ps(0); __m128 s_128_4 = _mm_set1_ps(0); @@ -171,12 +411,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index, __m128 x2_sum_128_4 = _mm_set1_ps(0); float x2_sum = 0.f; float s = 0; - // Compute loop chunk sizes until, and after, the wraparound of the circular // buffer for x. const int chunk1 = std::min(h_size, static_cast<int>(x_size - x_start_index)); - // Perform the loop in two chunks. const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { @@ -198,17 +436,14 @@ void MatchedFilterCore_SSE2(size_t x_start_index, s_128 = _mm_add_ps(s_128, hx); s_128_4 = _mm_add_ps(s_128_4, hx_4); } - // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; } - x_p = &x[0]; } - // Combine the accumulated vector and scalar values. x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); float* v = reinterpret_cast<float*>(&x2_sum_128); @@ -216,22 +451,18 @@ void MatchedFilterCore_SSE2(size_t x_start_index, s_128 = _mm_add_ps(s_128, s_128_4); v = reinterpret_cast<float*>(&s_128); s += v[0] + v[1] + v[2] + v[3]; - // Compute the matched filter error. float e = y[i] - s; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; - // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; const __m128 alpha_128 = _mm_set1_ps(alpha); - // filter = filter + smoothing * (y - filter * x) * x / x * x. float* h_p = &h[0]; x_p = &x[x_start_index]; - // Perform the loop in two chunks. for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. @@ -244,22 +475,17 @@ void MatchedFilterCore_SSE2(size_t x_start_index, // Compute h = h + alpha * x. const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); h_k = _mm_add_ps(h_k, alpha_x); - // Store the result. _mm_storeu_ps(h_p, h_k); } - // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { *h_p += alpha * *x_p; } - x_p = &x[0]; } - *filters_updated = true; } - x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; } } @@ -272,17 +498,35 @@ void MatchedFilterCore(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum) { + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error) { + if (compute_accumulated_error) { + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + } + // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. float x2_sum = 0.f; float s = 0; size_t x_index = x_start_index; - for (size_t k = 0; k < h.size(); ++k) { - x2_sum += x[x_index] * x[x_index]; - s += h[k] * x[x_index]; - x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + if (compute_accumulated_error) { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + if ((k + 1 & 0b11) == 0) { + int idx = k >> 2; + accumulated_error[idx] += (y[i] - s) * (y[i] - s); + } + } + } else { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } } // Compute the matched filter error. @@ -354,7 +598,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, float excitation_limit, float smoothing_fast, float smoothing_slow, - float matching_filter_threshold) + float matching_filter_threshold, + bool detect_pre_echo) : data_dumper_(data_dumper), optimization_(optimization), sub_block_size_(sub_block_size), @@ -362,16 +607,31 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, filters_( num_matched_filters, std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)), - lag_estimates_(num_matched_filters), filters_offsets_(num_matched_filters, 0), excitation_limit_(excitation_limit), smoothing_fast_(smoothing_fast), smoothing_slow_(smoothing_slow), - matching_filter_threshold_(matching_filter_threshold) { + matching_filter_threshold_(matching_filter_threshold), + detect_pre_echo_(detect_pre_echo) { RTC_DCHECK(data_dumper); RTC_DCHECK_LT(0, window_size_sub_blocks); RTC_DCHECK((kBlockSize % sub_block_size) == 0); RTC_DCHECK((sub_block_size % 4) == 0); + static_assert(kAccumulatedErrorSubSampleRate == 4); + if (detect_pre_echo_) { + accumulated_error_ = std::vector<std::vector<float>>( + num_matched_filters, + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 1.0f)); + + instantaneous_accumulated_error_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 0.0f); + scratch_memory_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_); + } } MatchedFilter::~MatchedFilter() = default; @@ -381,9 +641,12 @@ void MatchedFilter::Reset() { std::fill(f.begin(), f.end(), 0.f); } - for (auto& l : lag_estimates_) { - l = MatchedFilter::LagEstimate(); + for (auto& e : accumulated_error_) { + std::fill(e.begin(), e.end(), 1.0f); } + + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; } void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, @@ -398,11 +661,25 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, const float x2_sum_threshold = filters_[0].size() * excitation_limit_ * excitation_limit_; + // Compute anchor for the matched filter error. + float error_sum_anchor = 0.0f; + for (size_t k = 0; k < y.size(); ++k) { + error_sum_anchor += y[k] * y[k]; + } + // Apply all matched filters. + float winner_error_sum = error_sum_anchor; + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; size_t alignment_shift = 0; - for (size_t n = 0; n < filters_.size(); ++n) { + absl::optional<size_t> previous_lag_estimate; + const int num_filters = static_cast<int>(filters_.size()); + int winner_index = -1; + for (int n = 0; n < num_filters; ++n) { float error_sum = 0.f; bool filters_updated = false; + const bool compute_pre_echo = + detect_pre_echo_ && n == last_detected_best_lag_filter_; size_t x_start_index = (render_buffer.read + alignment_shift + sub_block_size_ - 1) % @@ -411,85 +688,79 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: - aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing, - render_buffer.buffer, y, filters_[n], - &filters_updated, &error_sum); + aec3::MatchedFilterCore_SSE2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); break; case Aec3Optimization::kAvx2: - aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing, - render_buffer.buffer, y, filters_[n], - &filters_updated, &error_sum); + aec3::MatchedFilterCore_AVX2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing, - render_buffer.buffer, y, filters_[n], - &filters_updated, &error_sum); + aec3::MatchedFilterCore_NEON( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); break; #endif default: aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], - &filters_updated, &error_sum); - } - - // Compute anchor for the matched filter error. - float error_sum_anchor = 0.0f; - for (size_t k = 0; k < y.size(); ++k) { - error_sum_anchor += y[k] * y[k]; + &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_); } // Estimate the lag in the matched filter as the distance to the portion in // the filter that contributes the most to the matched filter output. This // is detected as the peak of the matched filter. const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); - - // Update the lag estimates for the matched filter. - lag_estimates_[n] = LagEstimate( - error_sum_anchor - error_sum, - (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && - error_sum < matching_filter_threshold_ * error_sum_anchor), - lag_estimate + alignment_shift, filters_updated); - - RTC_DCHECK_GE(10, filters_.size()); - switch (n) { - case 0: - data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); - break; - case 1: - data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); - break; - case 2: - data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); - break; - case 3: - data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); - break; - case 4: - data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]); - break; - case 5: - data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]); - break; - case 6: - data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]); - break; - case 7: - data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]); - break; - case 8: - data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]); - break; - case 9: - data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]); - break; - default: - RTC_DCHECK_NOTREACHED(); + const bool reliable = + lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && + error_sum < matching_filter_threshold_ * error_sum_anchor; + + // Find the best estimate + const size_t lag = lag_estimate + alignment_shift; + if (filters_updated && reliable && error_sum < winner_error_sum) { + winner_error_sum = error_sum; + winner_index = n; + // In case that 2 matched filters return the same winner candidate + // (overlap region), the one with the smaller index is chosen in order + // to search for pre-echoes. + if (previous_lag_estimate && previous_lag_estimate == lag) { + winner_lag_ = previous_lag_estimate; + winner_index = n - 1; + } else { + winner_lag_ = lag; + } } - + previous_lag_estimate = lag; alignment_shift += filter_intra_lag_shift_; } + + if (winner_index != -1) { + RTC_DCHECK(winner_lag_.has_value()); + reported_lag_estimate_ = + LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value()); + if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) { + if (error_sum_anchor > 30.0f * 30.0f * y.size()) { + UpdateAccumulatedError(instantaneous_accumulated_error_, + accumulated_error_[winner_index], + 1.0f / error_sum_anchor); + } + reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag( + accumulated_error_[winner_index], winner_lag_.value(), + winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/); + } + last_detected_best_lag_filter_ = winner_index; + } + if (ApmDataDumper::IsAvailable()) { + Dump(); + } } void MatchedFilter::LogFilterProperties(int sample_rate_hz, @@ -510,4 +781,27 @@ void MatchedFilter::LogFilterProperties(int sample_rate_hz, } } +void MatchedFilter::Dump() { + for (size_t n = 0; n < filters_.size(); ++n) { + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); + std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]); + std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_lag.c_str(), + lag_estimate + n * filter_intra_lag_shift_); + if (detect_pre_echo_) { + std::string dumper_error = + "aec3_correlator_error_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]); + + size_t pre_echo_lag = ComputePreEchoLag( + accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_, + n * filter_intra_lag_shift_); + std::string dumper_pre_lag = + "aec3_correlator_pre_echo_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag); + } + } +} + } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index dd4a678394..760d5e39fd 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -15,6 +15,7 @@ #include <vector> +#include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" #include "rtc_base/system/arch.h" @@ -36,7 +37,10 @@ void MatchedFilterCore_NEON(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); #endif @@ -50,7 +54,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); // Filter core for the matched filter that is optimized for AVX2. void MatchedFilterCore_AVX2(size_t x_start_index, @@ -60,7 +67,10 @@ void MatchedFilterCore_AVX2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); #endif @@ -72,7 +82,9 @@ void MatchedFilterCore(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error); // Find largest peak of squared values in array. size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h); @@ -87,13 +99,10 @@ class MatchedFilter { // shift. struct LagEstimate { LagEstimate() = default; - LagEstimate(float accuracy, bool reliable, size_t lag, bool updated) - : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {} - - float accuracy = 0.f; - bool reliable = false; + LagEstimate(size_t lag, size_t pre_echo_lag) + : lag(lag), pre_echo_lag(pre_echo_lag) {} size_t lag = 0; - bool updated = false; + size_t pre_echo_lag = 0; }; MatchedFilter(ApmDataDumper* data_dumper, @@ -105,7 +114,8 @@ class MatchedFilter { float excitation_limit, float smoothing_fast, float smoothing_slow, - float matching_filter_threshold); + float matching_filter_threshold, + bool detect_pre_echo); MatchedFilter() = delete; MatchedFilter(const MatchedFilter&) = delete; @@ -122,8 +132,8 @@ class MatchedFilter { void Reset(); // Returns the current lag estimates. - rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const { - return lag_estimates_; + absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const { + return reported_lag_estimate_; } // Returns the maximum filter lag. @@ -137,17 +147,25 @@ class MatchedFilter { size_t downsampling_factor) const; private: + void Dump(); + ApmDataDumper* const data_dumper_; const Aec3Optimization optimization_; const size_t sub_block_size_; const size_t filter_intra_lag_shift_; std::vector<std::vector<float>> filters_; - std::vector<LagEstimate> lag_estimates_; + std::vector<std::vector<float>> accumulated_error_; + std::vector<float> instantaneous_accumulated_error_; + std::vector<float> scratch_memory_; + absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_; + absl::optional<size_t> winner_lag_; + int last_detected_best_lag_filter_ = -1; std::vector<size_t> filters_offsets_; const float excitation_limit_; const float smoothing_fast_; const float smoothing_slow_; const float matching_filter_threshold_; + const bool detect_pre_echo_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter_avx2.cc b/modules/audio_processing/aec3/matched_filter_avx2.cc index 8b7010f1dc..8c2ffcbd1e 100644 --- a/modules/audio_processing/aec3/matched_filter_avx2.cc +++ b/modules/audio_processing/aec3/matched_filter_avx2.cc @@ -8,15 +8,134 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "modules/audio_processing/aec3/matched_filter.h" - #include <immintrin.h> +#include "modules/audio_processing/aec3/matched_filter.h" #include "rtc_base/checks.h" namespace webrtc { namespace aec3 { +// Let ha denote the horizontal of a, and hb the horizontal sum of b +// returns [ha, hb, ha, hb] +inline __m128 hsum_ab(__m256 a, __m256 b) { + __m256 s_256 = _mm256_hadd_ps(a, b); + const __m256i mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + s_256 = _mm256_permutevar8x32_ps(s_256, mask); + __m128 s = _mm_hadd_ps(_mm256_extractf128_ps(s_256, 0), + _mm256_extractf128_ps(s_256, 1)); + s = _mm_hadd_ps(s, s); + return s; +} + +void MatchedFilterCore_AccumulatedError_AVX2( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 16); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* a_p = &accumulated_error[0]; + __m256 s_inst_hadd_256; + __m256 s_inst_256; + __m256 s_inst_256_8; + __m256 x2_sum_256 = _mm256_set1_ps(0); + __m256 x2_sum_256_8 = _mm256_set1_ps(0); + __m128 e_128; + float x2_sum = 0.0f; + float s_acum = 0; + const int limit_by_16 = h_size >> 4; + for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16, a_p += 4) { + // Load the data into 256 bit vectors. + __m256 x_k = _mm256_loadu_ps(x_p); + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k_8 = _mm256_loadu_ps(x_p + 8); + __m256 h_k_8 = _mm256_loadu_ps(h_p + 8); + // Compute and accumulate x * x and h * x. + x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8); + s_inst_256 = _mm256_mul_ps(h_k, x_k); + s_inst_256_8 = _mm256_mul_ps(h_k_8, x_k_8); + s_inst_hadd_256 = _mm256_hadd_ps(s_inst_256, s_inst_256_8); + s_inst_hadd_256 = _mm256_hadd_ps(s_inst_hadd_256, s_inst_hadd_256); + s_acum += s_inst_hadd_256[0]; + e_128[0] = s_acum - y[i]; + s_acum += s_inst_hadd_256[4]; + e_128[1] = s_acum - y[i]; + s_acum += s_inst_hadd_256[1]; + e_128[2] = s_acum - y[i]; + s_acum += s_inst_hadd_256[5]; + e_128[3] = s_acum - y[i]; + + __m128 accumulated_error = _mm_load_ps(a_p); + accumulated_error = _mm_fmadd_ps(e_128, e_128, accumulated_error); + _mm_storeu_ps(a_p, accumulated_error); + } + // Sum components together. + x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); + __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), + _mm256_extractf128_ps(x2_sum_256, 1)); + // Combine the accumulated vector and scalar values. + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + + // Compute the matched filter error. + float e = y[i] - s_acum; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m256 alpha_256 = _mm256_set1_ps(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 256 bit vector operations. + const int limit_by_8 = h_size >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 256 bit vectors. + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k = _mm256_loadu_ps(x_p); + // Compute h = h + alpha * x. + h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k); + + // Store the result. + _mm256_storeu_ps(h_p, h_k); + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + void MatchedFilterCore_AVX2(size_t x_start_index, float x2_sum_threshold, float smoothing, @@ -24,7 +143,15 @@ void MatchedFilterCore_AVX2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum) { + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + if (compute_accumulated_error) { + return MatchedFilterCore_AccumulatedError_AVX2( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } const int h_size = static_cast<int>(h.size()); const int x_size = static_cast<int>(x.size()); RTC_DCHECK_EQ(0, h_size % 8); @@ -81,15 +208,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index, // Sum components together. x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); s_256 = _mm256_add_ps(s_256, s_256_8); - __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), - _mm256_extractf128_ps(x2_sum_256, 1)); - __m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0), - _mm256_extractf128_ps(s_256, 1)); - // Combine the accumulated vector and scalar values. - float* v = reinterpret_cast<float*>(&x2_sum_128); - x2_sum += v[0] + v[1] + v[2] + v[3]; - v = reinterpret_cast<float*>(&s_128); - s += v[0] + v[1] + v[2] + v[3]; + __m128 sum = hsum_ab(x2_sum_256, s_256); + x2_sum += sum[0]; + s += sum[1]; // Compute the matched filter error. float e = y[i] - s; diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc index 603a864b34..17f517a001 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -14,84 +14,148 @@ #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" namespace webrtc { +namespace { +int GetDownSamplingBlockSizeLog2(int down_sampling_factor) { + int down_sampling_factor_log2 = 0; + down_sampling_factor >>= 1; + while (down_sampling_factor > 0) { + down_sampling_factor_log2++; + down_sampling_factor >>= 1; + } + return static_cast<int>(kBlockSizeLog2) > down_sampling_factor_log2 + ? static_cast<int>(kBlockSizeLog2) - down_sampling_factor_log2 + : 0; +} +} // namespace MatchedFilterLagAggregator::MatchedFilterLagAggregator( ApmDataDumper* data_dumper, size_t max_filter_lag, - const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds) + const EchoCanceller3Config::Delay& delay_config) : data_dumper_(data_dumper), - histogram_(max_filter_lag + 1, 0), - thresholds_(thresholds) { + thresholds_(delay_config.delay_selection_thresholds), + headroom_(static_cast<int>(delay_config.delay_headroom_samples / + delay_config.down_sampling_factor)), + highest_peak_aggregator_(max_filter_lag) { + if (delay_config.detect_pre_echo) { + pre_echo_lag_aggregator_ = std::make_unique<PreEchoLagAggregator>( + max_filter_lag, delay_config.down_sampling_factor); + } RTC_DCHECK(data_dumper); RTC_DCHECK_LE(thresholds_.initial, thresholds_.converged); - histogram_data_.fill(0); } MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default; void MatchedFilterLagAggregator::Reset(bool hard_reset) { - std::fill(histogram_.begin(), histogram_.end(), 0); - histogram_data_.fill(0); - histogram_data_index_ = 0; + highest_peak_aggregator_.Reset(); + if (pre_echo_lag_aggregator_ != nullptr) { + pre_echo_lag_aggregator_->Reset(); + } if (hard_reset) { significant_candidate_found_ = false; } } absl::optional<DelayEstimate> MatchedFilterLagAggregator::Aggregate( - rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates) { - // Choose the strongest lag estimate as the best one. - float best_accuracy = 0.f; - int best_lag_estimate_index = -1; - for (size_t k = 0; k < lag_estimates.size(); ++k) { - if (lag_estimates[k].updated && lag_estimates[k].reliable) { - if (lag_estimates[k].accuracy > best_accuracy) { - best_accuracy = lag_estimates[k].accuracy; - best_lag_estimate_index = static_cast<int>(k); - } + const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate) { + if (lag_estimate && pre_echo_lag_aggregator_) { + pre_echo_lag_aggregator_->Dump(data_dumper_); + pre_echo_lag_aggregator_->Aggregate( + std::max(0, static_cast<int>(lag_estimate->pre_echo_lag) - headroom_)); + } + + if (lag_estimate) { + highest_peak_aggregator_.Aggregate( + std::max(0, static_cast<int>(lag_estimate->lag) - headroom_)); + rtc::ArrayView<const int> histogram = highest_peak_aggregator_.histogram(); + int candidate = highest_peak_aggregator_.candidate(); + significant_candidate_found_ = significant_candidate_found_ || + histogram[candidate] > thresholds_.converged; + if (histogram[candidate] > thresholds_.converged || + (histogram[candidate] > thresholds_.initial && + !significant_candidate_found_)) { + DelayEstimate::Quality quality = significant_candidate_found_ + ? DelayEstimate::Quality::kRefined + : DelayEstimate::Quality::kCoarse; + int reported_delay = pre_echo_lag_aggregator_ != nullptr + ? pre_echo_lag_aggregator_->pre_echo_candidate() + : candidate; + return DelayEstimate(quality, reported_delay); } } - // TODO(peah): Remove this logging once all development is done. - data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index", - best_lag_estimate_index); - data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_histogram", histogram_); + return absl::nullopt; +} - if (best_lag_estimate_index != -1) { - RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); - RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); - --histogram_[histogram_data_[histogram_data_index_]]; +MatchedFilterLagAggregator::HighestPeakAggregator::HighestPeakAggregator( + size_t max_filter_lag) + : histogram_(max_filter_lag + 1, 0) { + histogram_data_.fill(0); +} - histogram_data_[histogram_data_index_] = - lag_estimates[best_lag_estimate_index].lag; +void MatchedFilterLagAggregator::HighestPeakAggregator::Reset() { + std::fill(histogram_.begin(), histogram_.end(), 0); + histogram_data_.fill(0); + histogram_data_index_ = 0; +} - RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); - RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); - ++histogram_[histogram_data_[histogram_data_index_]]; +void MatchedFilterLagAggregator::HighestPeakAggregator::Aggregate(int lag) { + RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); + RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); + --histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_[histogram_data_index_] = lag; + RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); + RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); + ++histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size(); + candidate_ = + std::distance(histogram_.begin(), + std::max_element(histogram_.begin(), histogram_.end())); +} - histogram_data_index_ = - (histogram_data_index_ + 1) % histogram_data_.size(); +MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator( + size_t max_filter_lag, + size_t down_sampling_factor) + : block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)), + histogram_( + ((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2, + 0) { + Reset(); +} - const int candidate = - std::distance(histogram_.begin(), - std::max_element(histogram_.begin(), histogram_.end())); +void MatchedFilterLagAggregator::PreEchoLagAggregator::Reset() { + std::fill(histogram_.begin(), histogram_.end(), 0); + histogram_data_.fill(0); + histogram_data_index_ = 0; + pre_echo_candidate_ = 0; +} - significant_candidate_found_ = - significant_candidate_found_ || - histogram_[candidate] > thresholds_.converged; - if (histogram_[candidate] > thresholds_.converged || - (histogram_[candidate] > thresholds_.initial && - !significant_candidate_found_)) { - DelayEstimate::Quality quality = significant_candidate_found_ - ? DelayEstimate::Quality::kRefined - : DelayEstimate::Quality::kCoarse; - return DelayEstimate(quality, candidate); - } +void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate( + int pre_echo_lag) { + int pre_echo_block_size = pre_echo_lag >> block_size_log2_; + RTC_DCHECK(pre_echo_block_size >= 0 && + pre_echo_block_size < static_cast<int>(histogram_.size())); + pre_echo_block_size = + rtc::SafeClamp(pre_echo_block_size, 0, histogram_.size() - 1); + if (histogram_[histogram_data_[histogram_data_index_]] > 0) { + --histogram_[histogram_data_[histogram_data_index_]]; } + histogram_data_[histogram_data_index_] = pre_echo_block_size; + ++histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size(); + int pre_echo_candidate_block_size = + std::distance(histogram_.begin(), + std::max_element(histogram_.begin(), histogram_.end())); + pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_); +} - return absl::nullopt; +void MatchedFilterLagAggregator::PreEchoLagAggregator::Dump( + ApmDataDumper* const data_dumper) { + data_dumper->DumpRaw("aec3_pre_echo_delay_candidate", pre_echo_candidate_); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h index 612bd5d942..c0598bf226 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -26,10 +26,9 @@ class ApmDataDumper; // reliable combined lag estimate. class MatchedFilterLagAggregator { public: - MatchedFilterLagAggregator( - ApmDataDumper* data_dumper, - size_t max_filter_lag, - const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds); + MatchedFilterLagAggregator(ApmDataDumper* data_dumper, + size_t max_filter_lag, + const EchoCanceller3Config::Delay& delay_config); MatchedFilterLagAggregator() = delete; MatchedFilterLagAggregator(const MatchedFilterLagAggregator&) = delete; @@ -43,18 +42,55 @@ class MatchedFilterLagAggregator { // Aggregates the provided lag estimates. absl::optional<DelayEstimate> Aggregate( - rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates); + const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate); // Returns whether a reliable delay estimate has been found. bool ReliableDelayFound() const { return significant_candidate_found_; } + // Returns the delay candidate that is computed by looking at the highest peak + // on the matched filters. + int GetDelayAtHighestPeak() const { + return highest_peak_aggregator_.candidate(); + } + private: + class PreEchoLagAggregator { + public: + PreEchoLagAggregator(size_t max_filter_lag, size_t down_sampling_factor); + void Reset(); + void Aggregate(int pre_echo_lag); + int pre_echo_candidate() const { return pre_echo_candidate_; } + void Dump(ApmDataDumper* const data_dumper); + + private: + const int block_size_log2_; + std::array<int, 250> histogram_data_; + std::vector<int> histogram_; + int histogram_data_index_ = 0; + int pre_echo_candidate_ = 0; + }; + + class HighestPeakAggregator { + public: + explicit HighestPeakAggregator(size_t max_filter_lag); + void Reset(); + void Aggregate(int lag); + int candidate() const { return candidate_; } + rtc::ArrayView<const int> histogram() const { return histogram_; } + + private: + std::vector<int> histogram_; + std::array<int, 250> histogram_data_; + int histogram_data_index_ = 0; + int candidate_ = -1; + }; + ApmDataDumper* const data_dumper_; - std::vector<int> histogram_; - std::array<int, 250> histogram_data_; - int histogram_data_index_ = 0; bool significant_candidate_found_ = false; const EchoCanceller3Config::Delay::DelaySelectionThresholds thresholds_; + const int headroom_; + HighestPeakAggregator highest_peak_aggregator_; + std::unique_ptr<PreEchoLagAggregator> pre_echo_lag_aggregator_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc index 8e2a12e6c5..6804102584 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc @@ -27,69 +27,31 @@ constexpr size_t kNumLagsBeforeDetection = 26; } // namespace -// Verifies that the most accurate lag estimate is chosen. -TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) { - constexpr size_t kLag1 = 5; - constexpr size_t kLag2 = 10; - ApmDataDumper data_dumper(0); - EchoCanceller3Config config; - std::vector<MatchedFilter::LagEstimate> lag_estimates(2); - MatchedFilterLagAggregator aggregator( - &data_dumper, std::max(kLag1, kLag2), - config.delay.delay_selection_thresholds); - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true); - lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true); - - for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { - aggregator.Aggregate(lag_estimates); - } - - absl::optional<DelayEstimate> aggregated_lag = - aggregator.Aggregate(lag_estimates); - EXPECT_TRUE(aggregated_lag); - EXPECT_EQ(kLag1, aggregated_lag->delay); - - lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true); - lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true); - - for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { - aggregated_lag = aggregator.Aggregate(lag_estimates); - EXPECT_TRUE(aggregated_lag); - EXPECT_EQ(kLag1, aggregated_lag->delay); - } - - aggregated_lag = aggregator.Aggregate(lag_estimates); - aggregated_lag = aggregator.Aggregate(lag_estimates); - EXPECT_TRUE(aggregated_lag); - EXPECT_EQ(kLag2, aggregated_lag->delay); -} - // Verifies that varying lag estimates causes lag estimates to not be deemed // reliable. TEST(MatchedFilterLagAggregator, LagEstimateInvarianceRequiredForAggregatedLag) { ApmDataDumper data_dumper(0); EchoCanceller3Config config; - std::vector<MatchedFilter::LagEstimate> lag_estimates(1); - MatchedFilterLagAggregator aggregator( - &data_dumper, 100, config.delay.delay_selection_thresholds); + MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/100, + config.delay); absl::optional<DelayEstimate> aggregated_lag; for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, 10, true); - aggregated_lag = aggregator.Aggregate(lag_estimates); + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/10, /*pre_echo_lag=*/10)); } EXPECT_TRUE(aggregated_lag); for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true); - aggregated_lag = aggregator.Aggregate(lag_estimates); + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100)); } EXPECT_FALSE(aggregated_lag); for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true); - aggregated_lag = aggregator.Aggregate(lag_estimates); + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100)); EXPECT_FALSE(aggregated_lag); } } @@ -101,13 +63,11 @@ TEST(MatchedFilterLagAggregator, constexpr size_t kLag = 5; ApmDataDumper data_dumper(0); EchoCanceller3Config config; - std::vector<MatchedFilter::LagEstimate> lag_estimates(1); - MatchedFilterLagAggregator aggregator( - &data_dumper, kLag, config.delay.delay_selection_thresholds); + MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/kLag, + config.delay); for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false); - absl::optional<DelayEstimate> aggregated_lag = - aggregator.Aggregate(lag_estimates); + absl::optional<DelayEstimate> aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag, /*pre_echo_lag=*/kLag)); EXPECT_FALSE(aggregated_lag); EXPECT_EQ(kLag, aggregated_lag->delay); } @@ -122,20 +82,19 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) { ApmDataDumper data_dumper(0); EchoCanceller3Config config; std::vector<MatchedFilter::LagEstimate> lag_estimates(1); - MatchedFilterLagAggregator aggregator( - &data_dumper, std::max(kLag1, kLag2), - config.delay.delay_selection_thresholds); + MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2), + config.delay); absl::optional<DelayEstimate> aggregated_lag; for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true); - aggregated_lag = aggregator.Aggregate(lag_estimates); + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag1, /*pre_echo_lag=*/kLag1)); } EXPECT_TRUE(aggregated_lag); EXPECT_EQ(kLag1, aggregated_lag->delay); for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) { - lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true); - aggregated_lag = aggregator.Aggregate(lag_estimates); + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag2, /*pre_echo_lag=*/kLag2)); EXPECT_TRUE(aggregated_lag); EXPECT_EQ(kLag1, aggregated_lag->delay); } @@ -146,9 +105,7 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) { // Verifies the check for non-null data dumper. TEST(MatchedFilterLagAggregatorDeathTest, NullDataDumper) { EchoCanceller3Config config; - EXPECT_DEATH(MatchedFilterLagAggregator( - nullptr, 10, config.delay.delay_selection_thresholds), - ""); + EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10, config.delay), ""); } #endif diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc index 9924256f0c..b080308191 100644 --- a/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -47,12 +47,15 @@ constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4; } // namespace +class MatchedFilterTest : public ::testing::TestWithParam<bool> {}; + #if defined(WEBRTC_HAS_NEON) // Verifies that the optimized methods for NEON are similar to their reference // counterparts. -TEST(MatchedFilter, TestNeonOptimizations) { +TEST_P(MatchedFilterTest, TestNeonOptimizations) { Random random_generator(42U); constexpr float kSmoothing = 0.7f; + const bool kComputeAccumulatederror = GetParam(); for (auto down_sampling_factor : kDownSamplingFactors) { const size_t sub_block_size = kBlockSize / down_sampling_factor; @@ -61,6 +64,10 @@ TEST(MatchedFilter, TestNeonOptimizations) { std::vector<float> y(sub_block_size); std::vector<float> h_NEON(512); std::vector<float> h(512); + std::vector<float> accumulated_error(512); + std::vector<float> accumulated_error_NEON(512); + std::vector<float> scratch_memory(512); + int x_index = 0; for (int k = 0; k < 1000; ++k) { RandomizeSampleVector(&random_generator, y); @@ -71,10 +78,13 @@ TEST(MatchedFilter, TestNeonOptimizations) { float error_sum_NEON = 0.f; MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x, - y, h_NEON, &filters_updated_NEON, &error_sum_NEON); + y, h_NEON, &filters_updated_NEON, &error_sum_NEON, + kComputeAccumulatederror, accumulated_error_NEON, + scratch_memory); MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h, - &filters_updated, &error_sum); + &filters_updated, &error_sum, kComputeAccumulatederror, + accumulated_error); EXPECT_EQ(filters_updated, filters_updated_NEON); EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f); @@ -83,6 +93,17 @@ TEST(MatchedFilter, TestNeonOptimizations) { EXPECT_NEAR(h[j], h_NEON[j], 0.00001f); } + if (kComputeAccumulatederror) { + for (size_t j = 0; j < accumulated_error.size(); ++j) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_NEON[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.02f); + } + } + x_index = (x_index + sub_block_size) % x.size(); } } @@ -92,7 +113,8 @@ TEST(MatchedFilter, TestNeonOptimizations) { #if defined(WEBRTC_ARCH_X86_FAMILY) // Verifies that the optimized methods for SSE2 are bitexact to their reference // counterparts. -TEST(MatchedFilter, TestSse2Optimizations) { +TEST_P(MatchedFilterTest, TestSse2Optimizations) { + const bool kComputeAccumulatederror = GetParam(); bool use_sse2 = (GetCPUInfo(kSSE2) != 0); if (use_sse2) { Random random_generator(42U); @@ -104,6 +126,9 @@ TEST(MatchedFilter, TestSse2Optimizations) { std::vector<float> y(sub_block_size); std::vector<float> h_SSE2(512); std::vector<float> h(512); + std::vector<float> accumulated_error(512 / 4); + std::vector<float> accumulated_error_SSE2(512 / 4); + std::vector<float> scratch_memory(512); int x_index = 0; for (int k = 0; k < 1000; ++k) { RandomizeSampleVector(&random_generator, y); @@ -115,10 +140,12 @@ TEST(MatchedFilter, TestSse2Optimizations) { MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h_SSE2, &filters_updated_SSE2, - &error_sum_SSE2); + &error_sum_SSE2, kComputeAccumulatederror, + accumulated_error_SSE2, scratch_memory); MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, - h, &filters_updated, &error_sum); + h, &filters_updated, &error_sum, + kComputeAccumulatederror, accumulated_error); EXPECT_EQ(filters_updated, filters_updated_SSE2); EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f); @@ -127,14 +154,24 @@ TEST(MatchedFilter, TestSse2Optimizations) { EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f); } + for (size_t j = 0; j < accumulated_error.size(); ++j) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_SSE2[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.00001f); + } + x_index = (x_index + sub_block_size) % x.size(); } } } } -TEST(MatchedFilter, TestAvx2Optimizations) { +TEST_P(MatchedFilterTest, TestAvx2Optimizations) { bool use_avx2 = (GetCPUInfo(kAVX2) != 0); + const bool kComputeAccumulatederror = GetParam(); if (use_avx2) { Random random_generator(42U); constexpr float kSmoothing = 0.7f; @@ -145,29 +182,36 @@ TEST(MatchedFilter, TestAvx2Optimizations) { std::vector<float> y(sub_block_size); std::vector<float> h_AVX2(512); std::vector<float> h(512); + std::vector<float> accumulated_error(512 / 4); + std::vector<float> accumulated_error_AVX2(512 / 4); + std::vector<float> scratch_memory(512); int x_index = 0; for (int k = 0; k < 1000; ++k) { RandomizeSampleVector(&random_generator, y); - bool filters_updated = false; float error_sum = 0.f; bool filters_updated_AVX2 = false; float error_sum_AVX2 = 0.f; - MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h_AVX2, &filters_updated_AVX2, - &error_sum_AVX2); - + &error_sum_AVX2, kComputeAccumulatederror, + accumulated_error_AVX2, scratch_memory); MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, - h, &filters_updated, &error_sum); - + h, &filters_updated, &error_sum, + kComputeAccumulatederror, accumulated_error); EXPECT_EQ(filters_updated, filters_updated_AVX2); EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f); - for (size_t j = 0; j < h.size(); ++j) { EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f); } - + for (size_t j = 0; j < accumulated_error.size(); j += 4) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_AVX2[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.00001f); + } x_index = (x_index + sub_block_size) % x.size(); } } @@ -199,9 +243,9 @@ TEST(MatchedFilter, MaxSquarePeakIndex) { } // Verifies that the matched filter produces proper lag estimates for -// artificially -// delayed signals. -TEST(MatchedFilter, LagEstimation) { +// artificially delayed signals. +TEST_P(MatchedFilterTest, LagEstimation) { + const bool kDetectPreEcho = GetParam(); Random random_generator(42U); constexpr size_t kNumChannels = 1; constexpr int kSampleRateHz = 48000; @@ -222,12 +266,12 @@ TEST(MatchedFilter, LagEstimation) { Decimator capture_decimator(down_sampling_factor); DelayBuffer<float> signal_delay_buffer(down_sampling_factor * delay_samples); - MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size, - kWindowSizeSubBlocks, kNumMatchedFilters, - kAlignmentShiftSubBlocks, 150, - config.delay.delay_estimate_smoothing, - config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, + 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); std::unique_ptr<RenderDelayBuffer> render_delay_buffer( RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); @@ -254,62 +298,97 @@ TEST(MatchedFilter, LagEstimation) { downsampled_capture_data.data(), sub_block_size); capture_decimator.Decimate(capture[0], downsampled_capture); filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), - downsampled_capture, false); + downsampled_capture, /*use_slow_smoothing=*/false); } // Obtain the lag estimates. - auto lag_estimates = filter.GetLagEstimates(); - - // Find which lag estimate should be the most accurate. - absl::optional<size_t> expected_most_accurate_lag_estimate; - size_t alignment_shift_sub_blocks = 0; - for (size_t k = 0; k < config.delay.num_filters; ++k) { - if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) * - sub_block_size > - delay_samples) { - expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0; - break; - } - alignment_shift_sub_blocks += kAlignmentShiftSubBlocks; - } - ASSERT_TRUE(expected_most_accurate_lag_estimate); - - // Verify that the expected most accurate lag estimate is the most - // accurate estimate. - for (size_t k = 0; k < kNumMatchedFilters; ++k) { - if (k != *expected_most_accurate_lag_estimate && - k != (*expected_most_accurate_lag_estimate + 1)) { - EXPECT_TRUE( - lag_estimates[*expected_most_accurate_lag_estimate].accuracy > - lag_estimates[k].accuracy || - !lag_estimates[k].reliable || - !lag_estimates[*expected_most_accurate_lag_estimate].reliable); - } - } + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_TRUE(lag_estimate.has_value()); - // Verify that all lag estimates are updated as expected for signals - // containing strong noise. - for (auto& le : lag_estimates) { - EXPECT_TRUE(le.updated); + // Verify that the expected most accurate lag estimate is correct. + if (lag_estimate.has_value()) { + EXPECT_EQ(delay_samples, lag_estimate->lag); + EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag); } + } + } +} - // Verify that the expected most accurate lag estimate is reliable. - EXPECT_TRUE( - lag_estimates[*expected_most_accurate_lag_estimate].reliable || - lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1, - lag_estimates.size() - 1)] - .reliable); +// Test the pre echo estimation. +TEST_P(MatchedFilterTest, PreEchoEstimation) { + const bool kDetectPreEcho = GetParam(); + Random random_generator(42U); + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); - // Verify that the expected most accurate lag estimate is correct. - if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) { - EXPECT_TRUE(delay_samples == - lag_estimates[*expected_most_accurate_lag_estimate].lag); + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + Block render(kNumBands, kNumChannels); + std::vector<std::vector<float>> capture( + 1, std::vector<float>(kBlockSize, 0.f)); + std::vector<float> capture_with_pre_echo(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + // data_dumper.SetActivated(true); + size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor; + size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor; + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = kNumMatchedFilters; + Decimator capture_decimator(down_sampling_factor); + DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor * + echo_delay_samples); + DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor * + pre_echo_delay_samples); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); + // Analyze the correlation between render and capture. + for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) { + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t channel = 0; channel < kNumChannels; ++channel) { + RandomizeSampleVector(&random_generator, render.View(band, channel)); + } + } + signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]); + signal_pre_echo_delay_buffer.Delay(render.View(0, 0), + capture_with_pre_echo); + for (size_t k = 0; k < capture[0].size(); ++k) { + constexpr float gain_pre_echo = 0.8f; + capture[0][k] += gain_pre_echo * capture_with_pre_echo[k]; + } + render_delay_buffer->Insert(render); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + std::array<float, kBlockSize> downsampled_capture_data; + rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(), + sub_block_size); + capture_decimator.Decimate(capture[0], downsampled_capture); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), + downsampled_capture, /*use_slow_smoothing=*/false); + } + // Obtain the lag estimates. + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_TRUE(lag_estimate.has_value()); + // Verify that the expected most accurate lag estimate is correct. + if (lag_estimate.has_value()) { + EXPECT_EQ(echo_delay_samples, lag_estimate->lag); + if (kDetectPreEcho) { + // The pre echo delay is estimated in a subsampled domain and a larger + // error is allowed. + EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4); } else { - EXPECT_TRUE( - delay_samples == - lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1, - lag_estimates.size() - 1)] - .lag); + // The pre echo delay fallback to the highest mached filter peak when + // its detection is disabled. + EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag); } } } @@ -317,7 +396,8 @@ TEST(MatchedFilter, LagEstimation) { // Verifies that the matched filter does not produce reliable and accurate // estimates for uncorrelated render and capture signals. -TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { +TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) { + const bool kDetectPreEcho = GetParam(); constexpr size_t kNumChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); @@ -335,12 +415,12 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { ApmDataDumper data_dumper(0); std::unique_ptr<RenderDelayBuffer> render_delay_buffer( RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); - MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size, - kWindowSizeSubBlocks, kNumMatchedFilters, - kAlignmentShiftSubBlocks, 150, - config.delay.delay_estimate_smoothing, - config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); // Analyze the correlation between render and capture. for (size_t k = 0; k < 100; ++k) { @@ -352,20 +432,17 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { false); } - // Obtain the lag estimates. - auto lag_estimates = filter.GetLagEstimates(); - EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); - - // Verify that no lag estimates are reliable. - for (auto& le : lag_estimates) { - EXPECT_FALSE(le.reliable); - } + // Obtain the best lag estimate and Verify that no lag estimates are + // reliable. + auto best_lag_estimates = filter.GetBestLagEstimate(); + EXPECT_FALSE(best_lag_estimates.has_value()); } } // Verifies that the matched filter does not produce updated lag estimates for // render signals of low level. -TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { +TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) { + const bool kDetectPreEcho = GetParam(); Random random_generator(42U); constexpr size_t kNumChannels = 1; constexpr int kSampleRateHz = 48000; @@ -374,19 +451,17 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { for (auto down_sampling_factor : kDownSamplingFactors) { const size_t sub_block_size = kBlockSize / down_sampling_factor; - std::vector<std::vector<std::vector<float>>> render( - kNumBands, std::vector<std::vector<float>>( - kNumChannels, std::vector<float>(kBlockSize, 0.f))); + Block render(kNumBands, kNumChannels); std::vector<std::vector<float>> capture( 1, std::vector<float>(kBlockSize, 0.f)); ApmDataDumper data_dumper(0); EchoCanceller3Config config; - MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size, - kWindowSizeSubBlocks, kNumMatchedFilters, - kAlignmentShiftSubBlocks, 150, - config.delay.delay_estimate_smoothing, - config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); std::unique_ptr<RenderDelayBuffer> render_delay_buffer( RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, kNumChannels)); @@ -394,11 +469,11 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { // Analyze the correlation between render and capture. for (size_t k = 0; k < 100; ++k) { - RandomizeSampleVector(&random_generator, render[0][0]); - for (auto& render_k : render[0][0]) { + RandomizeSampleVector(&random_generator, render.View(0, 0)); + for (auto& render_k : render.View(0, 0)) { render_k *= 149.f / 32767.f; } - std::copy(render[0][0].begin(), render[0][0].end(), capture[0].begin()); + std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin()); std::array<float, kBlockSize> downsampled_capture_data; rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(), sub_block_size); @@ -407,86 +482,76 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { downsampled_capture, false); } - // Obtain the lag estimates. - auto lag_estimates = filter.GetLagEstimates(); - EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); - - // Verify that no lag estimates are updated and that no lag estimates are - // reliable. - for (auto& le : lag_estimates) { - EXPECT_FALSE(le.updated); - EXPECT_FALSE(le.reliable); - } + // Verify that no lag estimate has been produced. + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_FALSE(lag_estimate.has_value()); } } -// Verifies that the correct number of lag estimates are produced for a certain -// number of alignment shifts. -TEST(MatchedFilter, NumberOfLagEstimates) { - ApmDataDumper data_dumper(0); - EchoCanceller3Config config; - for (auto down_sampling_factor : kDownSamplingFactors) { - const size_t sub_block_size = kBlockSize / down_sampling_factor; - for (size_t num_matched_filters = 0; num_matched_filters < 10; - ++num_matched_filters) { - MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size, - 32, num_matched_filters, 1, 150, - config.delay.delay_estimate_smoothing, - config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold); - EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size()); - } - } -} +INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false)); #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {}; + // Verifies the check for non-zero windows size. -TEST(MatchedFilterDeathTest, ZeroWindowSize) { +TEST_P(MatchedFilterDeathTest, ZeroWindowSize) { + const bool kDetectPreEcho = GetParam(); ApmDataDumper data_dumper(0); EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1, 150, config.delay.delay_estimate_smoothing, config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold), + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), ""); } // Verifies the check for non-null data dumper. -TEST(MatchedFilterDeathTest, NullDataDumper) { +TEST_P(MatchedFilterDeathTest, NullDataDumper) { + const bool kDetectPreEcho = GetParam(); EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold), + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), ""); } // Verifies the check for that the sub block size is a multiple of 4. // TODO(peah): Activate the unittest once the required code has been landed. -TEST(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) { +TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) { + const bool kDetectPreEcho = GetParam(); ApmDataDumper data_dumper(0); EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold), + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), ""); } // Verifies the check for that there is an integer number of sub blocks that add // up to a block size. // TODO(peah): Activate the unittest once the required code has been landed. -TEST(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) { +TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) { + const bool kDetectPreEcho = GetParam(); ApmDataDumper data_dumper(0); EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, config.delay.delay_estimate_smoothing_delay_found, - config.delay.delay_candidate_detection_threshold), + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), ""); } +INSTANTIATE_TEST_SUITE_P(_, + MatchedFilterDeathTest, + testing::Values(true, false)); + #endif } // namespace aec3 diff --git a/modules/audio_processing/aec3/render_delay_controller.cc b/modules/audio_processing/aec3/render_delay_controller.cc index aa3d440c33..826c38a1e9 100644 --- a/modules/audio_processing/aec3/render_delay_controller.cc +++ b/modules/audio_processing/aec3/render_delay_controller.cc @@ -53,7 +53,6 @@ class RenderDelayControllerImpl final : public RenderDelayController { static std::atomic<int> instance_count_; std::unique_ptr<ApmDataDumper> data_dumper_; const int hysteresis_limit_blocks_; - const int delay_headroom_samples_; absl::optional<DelayEstimate> delay_; EchoPathDelayEstimator delay_estimator_; RenderDelayControllerMetrics metrics_; @@ -66,15 +65,9 @@ class RenderDelayControllerImpl final : public RenderDelayController { DelayEstimate ComputeBufferDelay( const absl::optional<DelayEstimate>& current_delay, int hysteresis_limit_blocks, - int delay_headroom_samples, DelayEstimate estimated_delay) { - // Subtract delay headroom. - const int delay_with_headroom_samples = std::max( - static_cast<int>(estimated_delay.delay) - delay_headroom_samples, 0); - // Compute the buffer delay increase required to achieve the desired latency. - size_t new_delay_blocks = delay_with_headroom_samples >> kBlockSizeLog2; - + size_t new_delay_blocks = estimated_delay.delay >> kBlockSizeLog2; // Add hysteresis. if (current_delay) { size_t current_delay_blocks = current_delay->delay; @@ -83,7 +76,6 @@ DelayEstimate ComputeBufferDelay( new_delay_blocks = current_delay_blocks; } } - DelayEstimate new_delay = estimated_delay; new_delay.delay = new_delay_blocks; return new_delay; @@ -98,7 +90,6 @@ RenderDelayControllerImpl::RenderDelayControllerImpl( : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), hysteresis_limit_blocks_( static_cast<int>(config.delay.hysteresis_limit_blocks)), - delay_headroom_samples_(config.delay.delay_headroom_samples), delay_estimator_(data_dumper_.get(), config, num_capture_channels), last_delay_estimate_quality_(DelayEstimate::Quality::kCoarse) { RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); @@ -158,9 +149,8 @@ absl::optional<DelayEstimate> RenderDelayControllerImpl::GetDelay( const bool use_hysteresis = last_delay_estimate_quality_ == DelayEstimate::Quality::kRefined && delay_samples_->quality == DelayEstimate::Quality::kRefined; - delay_ = ComputeBufferDelay(delay_, - use_hysteresis ? hysteresis_limit_blocks_ : 0, - delay_headroom_samples_, *delay_samples_); + delay_ = ComputeBufferDelay( + delay_, use_hysteresis ? hysteresis_limit_blocks_ : 0, *delay_samples_); last_delay_estimate_quality_ = delay_samples_->quality; } |