aboutsummaryrefslogtreecommitdiff
path: root/modules/audio_processing/aec3
diff options
context:
space:
mode:
authorLionel Koenig <lionelk@webrtc.org>2022-06-28 15:37:13 +0200
committerWebRTC LUCI CQ <webrtc-scoped@luci-project-accounts.iam.gserviceaccount.com>2022-06-28 15:16:03 +0000
commit8783c678a5cb74dda890e76092e1d767bb179d8c (patch)
tree2c4764a6f0b3d725bd3b5461215059c8a33d1494 /modules/audio_processing/aec3
parent7534ebd2bf59212cce5e010dd6ed9b3bc944818e (diff)
downloadwebrtc-8783c678a5cb74dda890e76092e1d767bb179d8c.tar.gz
delay estimator: Look for early reverberation
Look for first echo (and not only the strongest one) on the same matched filter. This change is bit exact with previous version when `pre_echo` is false. Author: Jesús de Vicente Peña <devicentepena@webrtc.org> Bug: webrtc:14205 Change-Id: I6782eaa1d690b0df78d00f6d425a85c951b2ca9d Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/266321 Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Commit-Queue: Lionel Koenig <lionelk@webrtc.org> Cr-Commit-Position: refs/heads/main@{#37360}
Diffstat (limited to 'modules/audio_processing/aec3')
-rw-r--r--modules/audio_processing/aec3/BUILD.gn1
-rw-r--r--modules/audio_processing/aec3/echo_canceller3.cc8
-rw-r--r--modules/audio_processing/aec3/echo_path_delay_estimator.cc10
-rw-r--r--modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc8
-rw-r--r--modules/audio_processing/aec3/matched_filter.cc480
-rw-r--r--modules/audio_processing/aec3/matched_filter.h46
-rw-r--r--modules/audio_processing/aec3/matched_filter_avx2.cc145
-rw-r--r--modules/audio_processing/aec3/matched_filter_lag_aggregator.cc158
-rw-r--r--modules/audio_processing/aec3/matched_filter_lag_aggregator.h52
-rw-r--r--modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc81
-rw-r--r--modules/audio_processing/aec3/matched_filter_unittest.cc337
-rw-r--r--modules/audio_processing/aec3/render_delay_controller.cc16
12 files changed, 950 insertions, 392 deletions
diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn
index 70d049549a..4de2d00f3e 100644
--- a/modules/audio_processing/aec3/BUILD.gn
+++ b/modules/audio_processing/aec3/BUILD.gn
@@ -226,6 +226,7 @@ rtc_source_set("matched_filter") {
"../../../api:array_view",
"../../../rtc_base/system:arch",
]
+ absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_source_set("vector_math") {
diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc
index 8e306aca56..1404a9987d 100644
--- a/modules/audio_processing/aec3/echo_canceller3.cc
+++ b/modules/audio_processing/aec3/echo_canceller3.cc
@@ -377,6 +377,14 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) {
false;
}
+ if (field_trial::IsEnabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
+ adjusted_cfg.delay.detect_pre_echo = true;
+ }
+
+ if (field_trial::IsDisabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
+ adjusted_cfg.delay.detect_pre_echo = false;
+ }
+
if (field_trial::IsEnabled("WebRTC-Aec3SensitiveDominantNearendActivation")) {
adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.5f;
} else if (field_trial::IsEnabled(
diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
index e64c4493f6..fc83ca2f89 100644
--- a/modules/audio_processing/aec3/echo_path_delay_estimator.cc
+++ b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
@@ -43,10 +43,11 @@ EchoPathDelayEstimator::EchoPathDelayEstimator(
: config.render_levels.poor_excitation_render_limit,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold),
+ config.delay.delay_candidate_detection_threshold,
+ config.delay.detect_pre_echo),
matched_filter_lag_aggregator_(data_dumper_,
matched_filter_.GetMaxFilterLag(),
- config.delay.delay_selection_thresholds) {
+ config.delay) {
RTC_DCHECK(data_dumper);
RTC_DCHECK(down_sampling_factor_ > 0);
}
@@ -75,13 +76,14 @@ absl::optional<DelayEstimate> EchoPathDelayEstimator::EstimateDelay(
absl::optional<DelayEstimate> aggregated_matched_filter_lag =
matched_filter_lag_aggregator_.Aggregate(
- matched_filter_.GetLagEstimates());
+ matched_filter_.GetBestLagEstimate());
// Run clockdrift detection.
if (aggregated_matched_filter_lag &&
(*aggregated_matched_filter_lag).quality ==
DelayEstimate::Quality::kRefined)
- clockdrift_detector_.Update((*aggregated_matched_filter_lag).delay);
+ clockdrift_detector_.Update(
+ matched_filter_lag_aggregator_.GetDelayAtHighestPeak());
// TODO(peah): Move this logging outside of this class once EchoCanceller3
// development is done.
diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc
index 13c9c1122e..810b0ae185 100644
--- a/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc
@@ -78,6 +78,7 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
for (auto down_sampling_factor : kDownSamplingFactors) {
EchoCanceller3Config config;
+ config.delay.delay_headroom_samples = 0;
config.delay.down_sampling_factor = down_sampling_factor;
config.delay.num_filters = 10;
for (size_t delay_samples : {30, 64, 150, 200, 800, 4000}) {
@@ -111,12 +112,13 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
}
if (estimated_delay_samples) {
- // Allow estimated delay to be off by one sample in the down-sampled
- // domain.
+ // Allow estimated delay to be off by a block as internally the delay is
+ // quantized with an error up to a block.
size_t delay_ds = delay_samples / down_sampling_factor;
size_t estimated_delay_ds =
estimated_delay_samples->delay / down_sampling_factor;
- EXPECT_NEAR(delay_ds, estimated_delay_ds, 1);
+ EXPECT_NEAR(delay_ds, estimated_delay_ds,
+ kBlockSize / down_sampling_factor);
} else {
ADD_FAILURE();
}
diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc
index faca933856..c5e394ad2f 100644
--- a/modules/audio_processing/aec3/matched_filter.cc
+++ b/modules/audio_processing/aec3/matched_filter.cc
@@ -24,16 +24,147 @@
#include <iterator>
#include <numeric>
+#include "absl/types/optional.h"
+#include "api/array_view.h"
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
+namespace {
+
+// Subsample rate used for computing the accumulated error.
+// The implementation of some core functions depends on this constant being
+// equal to 4.
+constexpr int kAccumulatedErrorSubSampleRate = 4;
+
+void UpdateAccumulatedError(
+ const rtc::ArrayView<const float> instantaneous_accumulated_error,
+ const rtc::ArrayView<float> accumulated_error,
+ float one_over_error_sum_anchor) {
+ for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
+ float error_norm =
+ instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
+ if (error_norm < accumulated_error[k]) {
+ accumulated_error[k] = error_norm;
+ } else {
+ accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
+ }
+ }
+}
+
+size_t ComputePreEchoLag(const rtc::ArrayView<float> accumulated_error,
+ size_t lag,
+ size_t alignment_shift_winner) {
+ size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
+ size_t maximum_pre_echo_lag =
+ std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
+ accumulated_error.size());
+ for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
+ if (accumulated_error[k] < 0.5f * accumulated_error[k - 1] &&
+ accumulated_error[k] < 0.5f) {
+ pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
+ break;
+ }
+ }
+ return pre_echo_lag_estimate + alignment_shift_winner;
+}
+
+} // namespace
+
namespace webrtc {
namespace aec3 {
#if defined(WEBRTC_HAS_NEON)
+inline float SumAllElements(float32x4_t elements) {
+ float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
+ sum = vpadd_f32(sum, sum);
+ return vget_lane_f32(sum, 0);
+}
+
+void MatchedFilterCoreWithAccumulatedError_NEON(
+ size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 4);
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ // Compute loop chunk sizes until, and after, the wraparound of the circular
+ // buffer for x.
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ if (chunk1 != h_size) {
+ const int chunk2 = h_size - chunk1;
+ std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
+ std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
+ }
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ const float* h_p = &h[0];
+ float* accumulated_error_p = &accumulated_error[0];
+ // Initialize values for the accumulation.
+ float32x4_t x2_sum_128 = vdupq_n_f32(0);
+ float x2_sum = 0.f;
+ float s = 0;
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0;
+ --k, h_p += 4, x_p += 4, accumulated_error_p++) {
+ // Load the data into 128 bit vectors.
+ const float32x4_t x_k = vld1q_f32(x_p);
+ const float32x4_t h_k = vld1q_f32(h_p);
+ // Compute and accumulate x * x.
+ x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
+ // Compute x * h
+ float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
+ s += SumAllElements(hk_xk_128);
+ const float e = s - y[i];
+ accumulated_error_p[0] += e * e;
+ }
+ // Combine the accumulated vector and scalar values.
+ x2_sum += SumAllElements(x2_sum_128);
+ // Compute the matched filter error.
+ float e = y[i] - s;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const float32x4_t alpha_128 = vmovq_n_f32(alpha);
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ float32x4_t h_k = vld1q_f32(h_p);
+ const float32x4_t x_k = vld1q_f32(x_p);
+ // Compute h = h + alpha * x.
+ h_k = vmlaq_f32(h_k, alpha_128, x_k);
+ // Store the result.
+ vst1q_f32(h_p, h_k);
+ }
+ *filters_updated = true;
+ }
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
void MatchedFilterCore_NEON(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@@ -41,11 +172,20 @@ void MatchedFilterCore_NEON(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum) {
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
+ if (compute_accumulated_error) {
+ return MatchedFilterCoreWithAccumulatedError_NEON(
+ x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
+ error_sum, accumulated_error, scratch_memory);
+ }
+
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
@@ -90,10 +230,8 @@ void MatchedFilterCore_NEON(size_t x_start_index,
}
// Combine the accumulated vector and scalar values.
- float* v = reinterpret_cast<float*>(&x2_sum_128);
- x2_sum += v[0] + v[1] + v[2] + v[3];
- v = reinterpret_cast<float*>(&s_128);
- s += v[0] + v[1] + v[2] + v[3];
+ s += SumAllElements(s_128);
+ x2_sum += SumAllElements(x2_sum_128);
// Compute the matched filter error.
float e = y[i] - s;
@@ -144,6 +282,103 @@ void MatchedFilterCore_NEON(size_t x_start_index,
#if defined(WEBRTC_ARCH_X86_FAMILY)
+void MatchedFilterCore_AccumulatedError_SSE2(
+ size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 8);
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ if (chunk1 != h_size) {
+ const int chunk2 = h_size - chunk1;
+ std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
+ std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
+ }
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ const float* h_p = &h[0];
+ float* a_p = &accumulated_error[0];
+ __m128 s_inst_128;
+ __m128 s_inst_128_4;
+ __m128 x2_sum_128 = _mm_set1_ps(0);
+ __m128 x2_sum_128_4 = _mm_set1_ps(0);
+ __m128 e_128;
+ float* const s_p = reinterpret_cast<float*>(&s_inst_128);
+ float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
+ float* const e_p = reinterpret_cast<float*>(&e_128);
+ float x2_sum = 0.0f;
+ float s_acum = 0;
+ // Perform 128 bit vector operations.
+ const int limit_by_8 = h_size >> 3;
+ for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
+ // Load the data into 128 bit vectors.
+ const __m128 x_k = _mm_loadu_ps(x_p);
+ const __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
+ const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
+ const __m128 xx = _mm_mul_ps(x_k, x_k);
+ const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
+ // Compute and accumulate x * x and h * x.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
+ x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
+ s_inst_128 = _mm_mul_ps(h_k, x_k);
+ s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
+ s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
+ e_p[0] = s_acum - y[i];
+ s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
+ e_p[1] = s_acum - y[i];
+ a_p[0] += e_p[0] * e_p[0];
+ a_p[1] += e_p[1] * e_p[1];
+ }
+ // Combine the accumulated vector and scalar values.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
+ float* v = reinterpret_cast<float*>(&x2_sum_128);
+ x2_sum += v[0] + v[1] + v[2] + v[3];
+ // Compute the matched filter error.
+ float e = y[i] - s_acum;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const __m128 alpha_128 = _mm_set1_ps(alpha);
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k = _mm_loadu_ps(x_p);
+ // Compute h = h + alpha * x.
+ const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
+ h_k = _mm_add_ps(h_k, alpha_x);
+ // Store the result.
+ _mm_storeu_ps(h_p, h_k);
+ }
+ *filters_updated = true;
+ }
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
void MatchedFilterCore_SSE2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@@ -151,19 +386,24 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum) {
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ if (compute_accumulated_error) {
+ return MatchedFilterCore_AccumulatedError_SSE2(
+ x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
+ error_sum, accumulated_error, scratch_memory);
+ }
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
-
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
-
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
-
// Initialize values for the accumulation.
__m128 s_128 = _mm_set1_ps(0);
__m128 s_128_4 = _mm_set1_ps(0);
@@ -171,12 +411,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
__m128 x2_sum_128_4 = _mm_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
-
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
-
// Perform the loop in two chunks.
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
@@ -198,17 +436,14 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
s_128 = _mm_add_ps(s_128, hx);
s_128_4 = _mm_add_ps(s_128_4, hx_4);
}
-
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
-
x_p = &x[0];
}
-
// Combine the accumulated vector and scalar values.
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
float* v = reinterpret_cast<float*>(&x2_sum_128);
@@ -216,22 +451,18 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
s_128 = _mm_add_ps(s_128, s_128_4);
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
-
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
-
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m128 alpha_128 = _mm_set1_ps(alpha);
-
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = &x[x_start_index];
-
// Perform the loop in two chunks.
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
@@ -244,22 +475,17 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
// Compute h = h + alpha * x.
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
h_k = _mm_add_ps(h_k, alpha_x);
-
// Store the result.
_mm_storeu_ps(h_p, h_k);
}
-
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
-
x_p = &x[0];
}
-
*filters_updated = true;
}
-
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
@@ -272,17 +498,35 @@ void MatchedFilterCore(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum) {
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error) {
+ if (compute_accumulated_error) {
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ }
+
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
float x2_sum = 0.f;
float s = 0;
size_t x_index = x_start_index;
- for (size_t k = 0; k < h.size(); ++k) {
- x2_sum += x[x_index] * x[x_index];
- s += h[k] * x[x_index];
- x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ if (compute_accumulated_error) {
+ for (size_t k = 0; k < h.size(); ++k) {
+ x2_sum += x[x_index] * x[x_index];
+ s += h[k] * x[x_index];
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ if ((k + 1 & 0b11) == 0) {
+ int idx = k >> 2;
+ accumulated_error[idx] += (y[i] - s) * (y[i] - s);
+ }
+ }
+ } else {
+ for (size_t k = 0; k < h.size(); ++k) {
+ x2_sum += x[x_index] * x[x_index];
+ s += h[k] * x[x_index];
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ }
}
// Compute the matched filter error.
@@ -354,7 +598,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
- float matching_filter_threshold)
+ float matching_filter_threshold,
+ bool detect_pre_echo)
: data_dumper_(data_dumper),
optimization_(optimization),
sub_block_size_(sub_block_size),
@@ -362,16 +607,31 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
filters_(
num_matched_filters,
std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
- lag_estimates_(num_matched_filters),
filters_offsets_(num_matched_filters, 0),
excitation_limit_(excitation_limit),
smoothing_fast_(smoothing_fast),
smoothing_slow_(smoothing_slow),
- matching_filter_threshold_(matching_filter_threshold) {
+ matching_filter_threshold_(matching_filter_threshold),
+ detect_pre_echo_(detect_pre_echo) {
RTC_DCHECK(data_dumper);
RTC_DCHECK_LT(0, window_size_sub_blocks);
RTC_DCHECK((kBlockSize % sub_block_size) == 0);
RTC_DCHECK((sub_block_size % 4) == 0);
+ static_assert(kAccumulatedErrorSubSampleRate == 4);
+ if (detect_pre_echo_) {
+ accumulated_error_ = std::vector<std::vector<float>>(
+ num_matched_filters,
+ std::vector<float>(window_size_sub_blocks * sub_block_size_ /
+ kAccumulatedErrorSubSampleRate,
+ 1.0f));
+
+ instantaneous_accumulated_error_ =
+ std::vector<float>(window_size_sub_blocks * sub_block_size_ /
+ kAccumulatedErrorSubSampleRate,
+ 0.0f);
+ scratch_memory_ =
+ std::vector<float>(window_size_sub_blocks * sub_block_size_);
+ }
}
MatchedFilter::~MatchedFilter() = default;
@@ -381,9 +641,12 @@ void MatchedFilter::Reset() {
std::fill(f.begin(), f.end(), 0.f);
}
- for (auto& l : lag_estimates_) {
- l = MatchedFilter::LagEstimate();
+ for (auto& e : accumulated_error_) {
+ std::fill(e.begin(), e.end(), 1.0f);
}
+
+ winner_lag_ = absl::nullopt;
+ reported_lag_estimate_ = absl::nullopt;
}
void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
@@ -398,11 +661,25 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
const float x2_sum_threshold =
filters_[0].size() * excitation_limit_ * excitation_limit_;
+ // Compute anchor for the matched filter error.
+ float error_sum_anchor = 0.0f;
+ for (size_t k = 0; k < y.size(); ++k) {
+ error_sum_anchor += y[k] * y[k];
+ }
+
// Apply all matched filters.
+ float winner_error_sum = error_sum_anchor;
+ winner_lag_ = absl::nullopt;
+ reported_lag_estimate_ = absl::nullopt;
size_t alignment_shift = 0;
- for (size_t n = 0; n < filters_.size(); ++n) {
+ absl::optional<size_t> previous_lag_estimate;
+ const int num_filters = static_cast<int>(filters_.size());
+ int winner_index = -1;
+ for (int n = 0; n < num_filters; ++n) {
float error_sum = 0.f;
bool filters_updated = false;
+ const bool compute_pre_echo =
+ detect_pre_echo_ && n == last_detected_best_lag_filter_;
size_t x_start_index =
(render_buffer.read + alignment_shift + sub_block_size_ - 1) %
@@ -411,85 +688,79 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
- aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing,
- render_buffer.buffer, y, filters_[n],
- &filters_updated, &error_sum);
+ aec3::MatchedFilterCore_SSE2(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
break;
case Aec3Optimization::kAvx2:
- aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing,
- render_buffer.buffer, y, filters_[n],
- &filters_updated, &error_sum);
+ aec3::MatchedFilterCore_AVX2(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
- aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing,
- render_buffer.buffer, y, filters_[n],
- &filters_updated, &error_sum);
+ aec3::MatchedFilterCore_NEON(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
break;
#endif
default:
aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
- &filters_updated, &error_sum);
- }
-
- // Compute anchor for the matched filter error.
- float error_sum_anchor = 0.0f;
- for (size_t k = 0; k < y.size(); ++k) {
- error_sum_anchor += y[k] * y[k];
+ &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_);
}
// Estimate the lag in the matched filter as the distance to the portion in
// the filter that contributes the most to the matched filter output. This
// is detected as the peak of the matched filter.
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
-
- // Update the lag estimates for the matched filter.
- lag_estimates_[n] = LagEstimate(
- error_sum_anchor - error_sum,
- (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
- error_sum < matching_filter_threshold_ * error_sum_anchor),
- lag_estimate + alignment_shift, filters_updated);
-
- RTC_DCHECK_GE(10, filters_.size());
- switch (n) {
- case 0:
- data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
- break;
- case 1:
- data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
- break;
- case 2:
- data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
- break;
- case 3:
- data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
- break;
- case 4:
- data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
- break;
- case 5:
- data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
- break;
- case 6:
- data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
- break;
- case 7:
- data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
- break;
- case 8:
- data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
- break;
- case 9:
- data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
- break;
- default:
- RTC_DCHECK_NOTREACHED();
+ const bool reliable =
+ lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
+ error_sum < matching_filter_threshold_ * error_sum_anchor;
+
+ // Find the best estimate
+ const size_t lag = lag_estimate + alignment_shift;
+ if (filters_updated && reliable && error_sum < winner_error_sum) {
+ winner_error_sum = error_sum;
+ winner_index = n;
+ // In case that 2 matched filters return the same winner candidate
+ // (overlap region), the one with the smaller index is chosen in order
+ // to search for pre-echoes.
+ if (previous_lag_estimate && previous_lag_estimate == lag) {
+ winner_lag_ = previous_lag_estimate;
+ winner_index = n - 1;
+ } else {
+ winner_lag_ = lag;
+ }
}
-
+ previous_lag_estimate = lag;
alignment_shift += filter_intra_lag_shift_;
}
+
+ if (winner_index != -1) {
+ RTC_DCHECK(winner_lag_.has_value());
+ reported_lag_estimate_ =
+ LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
+ if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
+ if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
+ UpdateAccumulatedError(instantaneous_accumulated_error_,
+ accumulated_error_[winner_index],
+ 1.0f / error_sum_anchor);
+ }
+ reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
+ accumulated_error_[winner_index], winner_lag_.value(),
+ winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
+ }
+ last_detected_best_lag_filter_ = winner_index;
+ }
+ if (ApmDataDumper::IsAvailable()) {
+ Dump();
+ }
}
void MatchedFilter::LogFilterProperties(int sample_rate_hz,
@@ -510,4 +781,27 @@ void MatchedFilter::LogFilterProperties(int sample_rate_hz,
}
}
+void MatchedFilter::Dump() {
+ for (size_t n = 0; n < filters_.size(); ++n) {
+ const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
+ std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
+ data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
+ std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
+ data_dumper_->DumpRaw(dumper_lag.c_str(),
+ lag_estimate + n * filter_intra_lag_shift_);
+ if (detect_pre_echo_) {
+ std::string dumper_error =
+ "aec3_correlator_error_" + std::to_string(n) + "_h";
+ data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
+
+ size_t pre_echo_lag = ComputePreEchoLag(
+ accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_,
+ n * filter_intra_lag_shift_);
+ std::string dumper_pre_lag =
+ "aec3_correlator_pre_echo_lag_" + std::to_string(n);
+ data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
+ }
+ }
+}
+
} // namespace webrtc
diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h
index dd4a678394..760d5e39fd 100644
--- a/modules/audio_processing/aec3/matched_filter.h
+++ b/modules/audio_processing/aec3/matched_filter.h
@@ -15,6 +15,7 @@
#include <vector>
+#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "rtc_base/system/arch.h"
@@ -36,7 +37,10 @@ void MatchedFilterCore_NEON(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum);
+ float* error_sum,
+ bool compute_accumulation_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory);
#endif
@@ -50,7 +54,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum);
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory);
// Filter core for the matched filter that is optimized for AVX2.
void MatchedFilterCore_AVX2(size_t x_start_index,
@@ -60,7 +67,10 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum);
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory);
#endif
@@ -72,7 +82,9 @@ void MatchedFilterCore(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum);
+ float* error_sum,
+ bool compute_accumulation_error,
+ rtc::ArrayView<float> accumulated_error);
// Find largest peak of squared values in array.
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
@@ -87,13 +99,10 @@ class MatchedFilter {
// shift.
struct LagEstimate {
LagEstimate() = default;
- LagEstimate(float accuracy, bool reliable, size_t lag, bool updated)
- : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {}
-
- float accuracy = 0.f;
- bool reliable = false;
+ LagEstimate(size_t lag, size_t pre_echo_lag)
+ : lag(lag), pre_echo_lag(pre_echo_lag) {}
size_t lag = 0;
- bool updated = false;
+ size_t pre_echo_lag = 0;
};
MatchedFilter(ApmDataDumper* data_dumper,
@@ -105,7 +114,8 @@ class MatchedFilter {
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
- float matching_filter_threshold);
+ float matching_filter_threshold,
+ bool detect_pre_echo);
MatchedFilter() = delete;
MatchedFilter(const MatchedFilter&) = delete;
@@ -122,8 +132,8 @@ class MatchedFilter {
void Reset();
// Returns the current lag estimates.
- rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const {
- return lag_estimates_;
+ absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const {
+ return reported_lag_estimate_;
}
// Returns the maximum filter lag.
@@ -137,17 +147,25 @@ class MatchedFilter {
size_t downsampling_factor) const;
private:
+ void Dump();
+
ApmDataDumper* const data_dumper_;
const Aec3Optimization optimization_;
const size_t sub_block_size_;
const size_t filter_intra_lag_shift_;
std::vector<std::vector<float>> filters_;
- std::vector<LagEstimate> lag_estimates_;
+ std::vector<std::vector<float>> accumulated_error_;
+ std::vector<float> instantaneous_accumulated_error_;
+ std::vector<float> scratch_memory_;
+ absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_;
+ absl::optional<size_t> winner_lag_;
+ int last_detected_best_lag_filter_ = -1;
std::vector<size_t> filters_offsets_;
const float excitation_limit_;
const float smoothing_fast_;
const float smoothing_slow_;
const float matching_filter_threshold_;
+ const bool detect_pre_echo_;
};
} // namespace webrtc
diff --git a/modules/audio_processing/aec3/matched_filter_avx2.cc b/modules/audio_processing/aec3/matched_filter_avx2.cc
index 8b7010f1dc..8c2ffcbd1e 100644
--- a/modules/audio_processing/aec3/matched_filter_avx2.cc
+++ b/modules/audio_processing/aec3/matched_filter_avx2.cc
@@ -8,15 +8,134 @@
* be found in the AUTHORS file in the root of the source tree.
*/
-#include "modules/audio_processing/aec3/matched_filter.h"
-
#include <immintrin.h>
+#include "modules/audio_processing/aec3/matched_filter.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace aec3 {
+// Let ha denote the horizontal of a, and hb the horizontal sum of b
+// returns [ha, hb, ha, hb]
+inline __m128 hsum_ab(__m256 a, __m256 b) {
+ __m256 s_256 = _mm256_hadd_ps(a, b);
+ const __m256i mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
+ s_256 = _mm256_permutevar8x32_ps(s_256, mask);
+ __m128 s = _mm_hadd_ps(_mm256_extractf128_ps(s_256, 0),
+ _mm256_extractf128_ps(s_256, 1));
+ s = _mm_hadd_ps(s, s);
+ return s;
+}
+
+void MatchedFilterCore_AccumulatedError_AVX2(
+ size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 16);
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ if (chunk1 != h_size) {
+ const int chunk2 = h_size - chunk1;
+ std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
+ std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
+ }
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ const float* h_p = &h[0];
+ float* a_p = &accumulated_error[0];
+ __m256 s_inst_hadd_256;
+ __m256 s_inst_256;
+ __m256 s_inst_256_8;
+ __m256 x2_sum_256 = _mm256_set1_ps(0);
+ __m256 x2_sum_256_8 = _mm256_set1_ps(0);
+ __m128 e_128;
+ float x2_sum = 0.0f;
+ float s_acum = 0;
+ const int limit_by_16 = h_size >> 4;
+ for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16, a_p += 4) {
+ // Load the data into 256 bit vectors.
+ __m256 x_k = _mm256_loadu_ps(x_p);
+ __m256 h_k = _mm256_loadu_ps(h_p);
+ __m256 x_k_8 = _mm256_loadu_ps(x_p + 8);
+ __m256 h_k_8 = _mm256_loadu_ps(h_p + 8);
+ // Compute and accumulate x * x and h * x.
+ x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
+ x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8);
+ s_inst_256 = _mm256_mul_ps(h_k, x_k);
+ s_inst_256_8 = _mm256_mul_ps(h_k_8, x_k_8);
+ s_inst_hadd_256 = _mm256_hadd_ps(s_inst_256, s_inst_256_8);
+ s_inst_hadd_256 = _mm256_hadd_ps(s_inst_hadd_256, s_inst_hadd_256);
+ s_acum += s_inst_hadd_256[0];
+ e_128[0] = s_acum - y[i];
+ s_acum += s_inst_hadd_256[4];
+ e_128[1] = s_acum - y[i];
+ s_acum += s_inst_hadd_256[1];
+ e_128[2] = s_acum - y[i];
+ s_acum += s_inst_hadd_256[5];
+ e_128[3] = s_acum - y[i];
+
+ __m128 accumulated_error = _mm_load_ps(a_p);
+ accumulated_error = _mm_fmadd_ps(e_128, e_128, accumulated_error);
+ _mm_storeu_ps(a_p, accumulated_error);
+ }
+ // Sum components together.
+ x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
+ __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
+ _mm256_extractf128_ps(x2_sum_256, 1));
+ // Combine the accumulated vector and scalar values.
+ float* v = reinterpret_cast<float*>(&x2_sum_128);
+ x2_sum += v[0] + v[1] + v[2] + v[3];
+
+ // Compute the matched filter error.
+ float e = y[i] - s_acum;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const __m256 alpha_256 = _mm256_set1_ps(alpha);
+
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ // Perform 256 bit vector operations.
+ const int limit_by_8 = h_size >> 3;
+ for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
+ // Load the data into 256 bit vectors.
+ __m256 h_k = _mm256_loadu_ps(h_p);
+ __m256 x_k = _mm256_loadu_ps(x_p);
+ // Compute h = h + alpha * x.
+ h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k);
+
+ // Store the result.
+ _mm256_storeu_ps(h_p, h_k);
+ }
+ *filters_updated = true;
+ }
+
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
void MatchedFilterCore_AVX2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@@ -24,7 +143,15 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
- float* error_sum) {
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ if (compute_accumulated_error) {
+ return MatchedFilterCore_AccumulatedError_AVX2(
+ x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
+ error_sum, accumulated_error, scratch_memory);
+ }
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 8);
@@ -81,15 +208,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
// Sum components together.
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
s_256 = _mm256_add_ps(s_256, s_256_8);
- __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
- _mm256_extractf128_ps(x2_sum_256, 1));
- __m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
- _mm256_extractf128_ps(s_256, 1));
- // Combine the accumulated vector and scalar values.
- float* v = reinterpret_cast<float*>(&x2_sum_128);
- x2_sum += v[0] + v[1] + v[2] + v[3];
- v = reinterpret_cast<float*>(&s_128);
- s += v[0] + v[1] + v[2] + v[3];
+ __m128 sum = hsum_ab(x2_sum_256, s_256);
+ x2_sum += sum[0];
+ s += sum[1];
// Compute the matched filter error.
float e = y[i] - s;
diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc
index 603a864b34..17f517a001 100644
--- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc
+++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc
@@ -14,84 +14,148 @@
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
+namespace {
+int GetDownSamplingBlockSizeLog2(int down_sampling_factor) {
+ int down_sampling_factor_log2 = 0;
+ down_sampling_factor >>= 1;
+ while (down_sampling_factor > 0) {
+ down_sampling_factor_log2++;
+ down_sampling_factor >>= 1;
+ }
+ return static_cast<int>(kBlockSizeLog2) > down_sampling_factor_log2
+ ? static_cast<int>(kBlockSizeLog2) - down_sampling_factor_log2
+ : 0;
+}
+} // namespace
MatchedFilterLagAggregator::MatchedFilterLagAggregator(
ApmDataDumper* data_dumper,
size_t max_filter_lag,
- const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds)
+ const EchoCanceller3Config::Delay& delay_config)
: data_dumper_(data_dumper),
- histogram_(max_filter_lag + 1, 0),
- thresholds_(thresholds) {
+ thresholds_(delay_config.delay_selection_thresholds),
+ headroom_(static_cast<int>(delay_config.delay_headroom_samples /
+ delay_config.down_sampling_factor)),
+ highest_peak_aggregator_(max_filter_lag) {
+ if (delay_config.detect_pre_echo) {
+ pre_echo_lag_aggregator_ = std::make_unique<PreEchoLagAggregator>(
+ max_filter_lag, delay_config.down_sampling_factor);
+ }
RTC_DCHECK(data_dumper);
RTC_DCHECK_LE(thresholds_.initial, thresholds_.converged);
- histogram_data_.fill(0);
}
MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default;
void MatchedFilterLagAggregator::Reset(bool hard_reset) {
- std::fill(histogram_.begin(), histogram_.end(), 0);
- histogram_data_.fill(0);
- histogram_data_index_ = 0;
+ highest_peak_aggregator_.Reset();
+ if (pre_echo_lag_aggregator_ != nullptr) {
+ pre_echo_lag_aggregator_->Reset();
+ }
if (hard_reset) {
significant_candidate_found_ = false;
}
}
absl::optional<DelayEstimate> MatchedFilterLagAggregator::Aggregate(
- rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates) {
- // Choose the strongest lag estimate as the best one.
- float best_accuracy = 0.f;
- int best_lag_estimate_index = -1;
- for (size_t k = 0; k < lag_estimates.size(); ++k) {
- if (lag_estimates[k].updated && lag_estimates[k].reliable) {
- if (lag_estimates[k].accuracy > best_accuracy) {
- best_accuracy = lag_estimates[k].accuracy;
- best_lag_estimate_index = static_cast<int>(k);
- }
+ const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate) {
+ if (lag_estimate && pre_echo_lag_aggregator_) {
+ pre_echo_lag_aggregator_->Dump(data_dumper_);
+ pre_echo_lag_aggregator_->Aggregate(
+ std::max(0, static_cast<int>(lag_estimate->pre_echo_lag) - headroom_));
+ }
+
+ if (lag_estimate) {
+ highest_peak_aggregator_.Aggregate(
+ std::max(0, static_cast<int>(lag_estimate->lag) - headroom_));
+ rtc::ArrayView<const int> histogram = highest_peak_aggregator_.histogram();
+ int candidate = highest_peak_aggregator_.candidate();
+ significant_candidate_found_ = significant_candidate_found_ ||
+ histogram[candidate] > thresholds_.converged;
+ if (histogram[candidate] > thresholds_.converged ||
+ (histogram[candidate] > thresholds_.initial &&
+ !significant_candidate_found_)) {
+ DelayEstimate::Quality quality = significant_candidate_found_
+ ? DelayEstimate::Quality::kRefined
+ : DelayEstimate::Quality::kCoarse;
+ int reported_delay = pre_echo_lag_aggregator_ != nullptr
+ ? pre_echo_lag_aggregator_->pre_echo_candidate()
+ : candidate;
+ return DelayEstimate(quality, reported_delay);
}
}
- // TODO(peah): Remove this logging once all development is done.
- data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index",
- best_lag_estimate_index);
- data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_histogram", histogram_);
+ return absl::nullopt;
+}
- if (best_lag_estimate_index != -1) {
- RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
- RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
- --histogram_[histogram_data_[histogram_data_index_]];
+MatchedFilterLagAggregator::HighestPeakAggregator::HighestPeakAggregator(
+ size_t max_filter_lag)
+ : histogram_(max_filter_lag + 1, 0) {
+ histogram_data_.fill(0);
+}
- histogram_data_[histogram_data_index_] =
- lag_estimates[best_lag_estimate_index].lag;
+void MatchedFilterLagAggregator::HighestPeakAggregator::Reset() {
+ std::fill(histogram_.begin(), histogram_.end(), 0);
+ histogram_data_.fill(0);
+ histogram_data_index_ = 0;
+}
- RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
- RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
- ++histogram_[histogram_data_[histogram_data_index_]];
+void MatchedFilterLagAggregator::HighestPeakAggregator::Aggregate(int lag) {
+ RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
+ RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
+ --histogram_[histogram_data_[histogram_data_index_]];
+ histogram_data_[histogram_data_index_] = lag;
+ RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
+ RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
+ ++histogram_[histogram_data_[histogram_data_index_]];
+ histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
+ candidate_ =
+ std::distance(histogram_.begin(),
+ std::max_element(histogram_.begin(), histogram_.end()));
+}
- histogram_data_index_ =
- (histogram_data_index_ + 1) % histogram_data_.size();
+MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator(
+ size_t max_filter_lag,
+ size_t down_sampling_factor)
+ : block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)),
+ histogram_(
+ ((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2,
+ 0) {
+ Reset();
+}
- const int candidate =
- std::distance(histogram_.begin(),
- std::max_element(histogram_.begin(), histogram_.end()));
+void MatchedFilterLagAggregator::PreEchoLagAggregator::Reset() {
+ std::fill(histogram_.begin(), histogram_.end(), 0);
+ histogram_data_.fill(0);
+ histogram_data_index_ = 0;
+ pre_echo_candidate_ = 0;
+}
- significant_candidate_found_ =
- significant_candidate_found_ ||
- histogram_[candidate] > thresholds_.converged;
- if (histogram_[candidate] > thresholds_.converged ||
- (histogram_[candidate] > thresholds_.initial &&
- !significant_candidate_found_)) {
- DelayEstimate::Quality quality = significant_candidate_found_
- ? DelayEstimate::Quality::kRefined
- : DelayEstimate::Quality::kCoarse;
- return DelayEstimate(quality, candidate);
- }
+void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate(
+ int pre_echo_lag) {
+ int pre_echo_block_size = pre_echo_lag >> block_size_log2_;
+ RTC_DCHECK(pre_echo_block_size >= 0 &&
+ pre_echo_block_size < static_cast<int>(histogram_.size()));
+ pre_echo_block_size =
+ rtc::SafeClamp(pre_echo_block_size, 0, histogram_.size() - 1);
+ if (histogram_[histogram_data_[histogram_data_index_]] > 0) {
+ --histogram_[histogram_data_[histogram_data_index_]];
}
+ histogram_data_[histogram_data_index_] = pre_echo_block_size;
+ ++histogram_[histogram_data_[histogram_data_index_]];
+ histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
+ int pre_echo_candidate_block_size =
+ std::distance(histogram_.begin(),
+ std::max_element(histogram_.begin(), histogram_.end()));
+ pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_);
+}
- return absl::nullopt;
+void MatchedFilterLagAggregator::PreEchoLagAggregator::Dump(
+ ApmDataDumper* const data_dumper) {
+ data_dumper->DumpRaw("aec3_pre_echo_delay_candidate", pre_echo_candidate_);
}
} // namespace webrtc
diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
index 612bd5d942..c0598bf226 100644
--- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
+++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
@@ -26,10 +26,9 @@ class ApmDataDumper;
// reliable combined lag estimate.
class MatchedFilterLagAggregator {
public:
- MatchedFilterLagAggregator(
- ApmDataDumper* data_dumper,
- size_t max_filter_lag,
- const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds);
+ MatchedFilterLagAggregator(ApmDataDumper* data_dumper,
+ size_t max_filter_lag,
+ const EchoCanceller3Config::Delay& delay_config);
MatchedFilterLagAggregator() = delete;
MatchedFilterLagAggregator(const MatchedFilterLagAggregator&) = delete;
@@ -43,18 +42,55 @@ class MatchedFilterLagAggregator {
// Aggregates the provided lag estimates.
absl::optional<DelayEstimate> Aggregate(
- rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates);
+ const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate);
// Returns whether a reliable delay estimate has been found.
bool ReliableDelayFound() const { return significant_candidate_found_; }
+ // Returns the delay candidate that is computed by looking at the highest peak
+ // on the matched filters.
+ int GetDelayAtHighestPeak() const {
+ return highest_peak_aggregator_.candidate();
+ }
+
private:
+ class PreEchoLagAggregator {
+ public:
+ PreEchoLagAggregator(size_t max_filter_lag, size_t down_sampling_factor);
+ void Reset();
+ void Aggregate(int pre_echo_lag);
+ int pre_echo_candidate() const { return pre_echo_candidate_; }
+ void Dump(ApmDataDumper* const data_dumper);
+
+ private:
+ const int block_size_log2_;
+ std::array<int, 250> histogram_data_;
+ std::vector<int> histogram_;
+ int histogram_data_index_ = 0;
+ int pre_echo_candidate_ = 0;
+ };
+
+ class HighestPeakAggregator {
+ public:
+ explicit HighestPeakAggregator(size_t max_filter_lag);
+ void Reset();
+ void Aggregate(int lag);
+ int candidate() const { return candidate_; }
+ rtc::ArrayView<const int> histogram() const { return histogram_; }
+
+ private:
+ std::vector<int> histogram_;
+ std::array<int, 250> histogram_data_;
+ int histogram_data_index_ = 0;
+ int candidate_ = -1;
+ };
+
ApmDataDumper* const data_dumper_;
- std::vector<int> histogram_;
- std::array<int, 250> histogram_data_;
- int histogram_data_index_ = 0;
bool significant_candidate_found_ = false;
const EchoCanceller3Config::Delay::DelaySelectionThresholds thresholds_;
+ const int headroom_;
+ HighestPeakAggregator highest_peak_aggregator_;
+ std::unique_ptr<PreEchoLagAggregator> pre_echo_lag_aggregator_;
};
} // namespace webrtc
diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc
index 8e2a12e6c5..6804102584 100644
--- a/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc
+++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc
@@ -27,69 +27,31 @@ constexpr size_t kNumLagsBeforeDetection = 26;
} // namespace
-// Verifies that the most accurate lag estimate is chosen.
-TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
- constexpr size_t kLag1 = 5;
- constexpr size_t kLag2 = 10;
- ApmDataDumper data_dumper(0);
- EchoCanceller3Config config;
- std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
- MatchedFilterLagAggregator aggregator(
- &data_dumper, std::max(kLag1, kLag2),
- config.delay.delay_selection_thresholds);
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
- lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true);
-
- for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
- aggregator.Aggregate(lag_estimates);
- }
-
- absl::optional<DelayEstimate> aggregated_lag =
- aggregator.Aggregate(lag_estimates);
- EXPECT_TRUE(aggregated_lag);
- EXPECT_EQ(kLag1, aggregated_lag->delay);
-
- lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true);
- lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true);
-
- for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
- aggregated_lag = aggregator.Aggregate(lag_estimates);
- EXPECT_TRUE(aggregated_lag);
- EXPECT_EQ(kLag1, aggregated_lag->delay);
- }
-
- aggregated_lag = aggregator.Aggregate(lag_estimates);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
- EXPECT_TRUE(aggregated_lag);
- EXPECT_EQ(kLag2, aggregated_lag->delay);
-}
-
// Verifies that varying lag estimates causes lag estimates to not be deemed
// reliable.
TEST(MatchedFilterLagAggregator,
LagEstimateInvarianceRequiredForAggregatedLag) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
- std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
- MatchedFilterLagAggregator aggregator(
- &data_dumper, 100, config.delay.delay_selection_thresholds);
+ MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/100,
+ config.delay);
absl::optional<DelayEstimate> aggregated_lag;
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, 10, true);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
+ aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/10, /*pre_echo_lag=*/10));
}
EXPECT_TRUE(aggregated_lag);
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
+ aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
}
EXPECT_FALSE(aggregated_lag);
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
+ aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
EXPECT_FALSE(aggregated_lag);
}
}
@@ -101,13 +63,11 @@ TEST(MatchedFilterLagAggregator,
constexpr size_t kLag = 5;
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
- std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
- MatchedFilterLagAggregator aggregator(
- &data_dumper, kLag, config.delay.delay_selection_thresholds);
+ MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/kLag,
+ config.delay);
for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false);
- absl::optional<DelayEstimate> aggregated_lag =
- aggregator.Aggregate(lag_estimates);
+ absl::optional<DelayEstimate> aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/kLag, /*pre_echo_lag=*/kLag));
EXPECT_FALSE(aggregated_lag);
EXPECT_EQ(kLag, aggregated_lag->delay);
}
@@ -122,20 +82,19 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
- MatchedFilterLagAggregator aggregator(
- &data_dumper, std::max(kLag1, kLag2),
- config.delay.delay_selection_thresholds);
+ MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2),
+ config.delay);
absl::optional<DelayEstimate> aggregated_lag;
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
+ aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/kLag1, /*pre_echo_lag=*/kLag1));
}
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) {
- lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true);
- aggregated_lag = aggregator.Aggregate(lag_estimates);
+ aggregated_lag = aggregator.Aggregate(
+ MatchedFilter::LagEstimate(/*lag=*/kLag2, /*pre_echo_lag=*/kLag2));
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
}
@@ -146,9 +105,7 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
// Verifies the check for non-null data dumper.
TEST(MatchedFilterLagAggregatorDeathTest, NullDataDumper) {
EchoCanceller3Config config;
- EXPECT_DEATH(MatchedFilterLagAggregator(
- nullptr, 10, config.delay.delay_selection_thresholds),
- "");
+ EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10, config.delay), "");
}
#endif
diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc
index 9924256f0c..b080308191 100644
--- a/modules/audio_processing/aec3/matched_filter_unittest.cc
+++ b/modules/audio_processing/aec3/matched_filter_unittest.cc
@@ -47,12 +47,15 @@ constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
} // namespace
+class MatchedFilterTest : public ::testing::TestWithParam<bool> {};
+
#if defined(WEBRTC_HAS_NEON)
// Verifies that the optimized methods for NEON are similar to their reference
// counterparts.
-TEST(MatchedFilter, TestNeonOptimizations) {
+TEST_P(MatchedFilterTest, TestNeonOptimizations) {
Random random_generator(42U);
constexpr float kSmoothing = 0.7f;
+ const bool kComputeAccumulatederror = GetParam();
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
@@ -61,6 +64,10 @@ TEST(MatchedFilter, TestNeonOptimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_NEON(512);
std::vector<float> h(512);
+ std::vector<float> accumulated_error(512);
+ std::vector<float> accumulated_error_NEON(512);
+ std::vector<float> scratch_memory(512);
+
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
@@ -71,10 +78,13 @@ TEST(MatchedFilter, TestNeonOptimizations) {
float error_sum_NEON = 0.f;
MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
- y, h_NEON, &filters_updated_NEON, &error_sum_NEON);
+ y, h_NEON, &filters_updated_NEON, &error_sum_NEON,
+ kComputeAccumulatederror, accumulated_error_NEON,
+ scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h,
- &filters_updated, &error_sum);
+ &filters_updated, &error_sum, kComputeAccumulatederror,
+ accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_NEON);
EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
@@ -83,6 +93,17 @@ TEST(MatchedFilter, TestNeonOptimizations) {
EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
}
+ if (kComputeAccumulatederror) {
+ for (size_t j = 0; j < accumulated_error.size(); ++j) {
+ float difference =
+ std::abs(accumulated_error[j] - accumulated_error_NEON[j]);
+ float relative_difference = accumulated_error[j] > 0
+ ? difference / accumulated_error[j]
+ : difference;
+ EXPECT_NEAR(relative_difference, 0.0f, 0.02f);
+ }
+ }
+
x_index = (x_index + sub_block_size) % x.size();
}
}
@@ -92,7 +113,8 @@ TEST(MatchedFilter, TestNeonOptimizations) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized methods for SSE2 are bitexact to their reference
// counterparts.
-TEST(MatchedFilter, TestSse2Optimizations) {
+TEST_P(MatchedFilterTest, TestSse2Optimizations) {
+ const bool kComputeAccumulatederror = GetParam();
bool use_sse2 = (GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
Random random_generator(42U);
@@ -104,6 +126,9 @@ TEST(MatchedFilter, TestSse2Optimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_SSE2(512);
std::vector<float> h(512);
+ std::vector<float> accumulated_error(512 / 4);
+ std::vector<float> accumulated_error_SSE2(512 / 4);
+ std::vector<float> scratch_memory(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
@@ -115,10 +140,12 @@ TEST(MatchedFilter, TestSse2Optimizations) {
MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_SSE2, &filters_updated_SSE2,
- &error_sum_SSE2);
+ &error_sum_SSE2, kComputeAccumulatederror,
+ accumulated_error_SSE2, scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
- h, &filters_updated, &error_sum);
+ h, &filters_updated, &error_sum,
+ kComputeAccumulatederror, accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_SSE2);
EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
@@ -127,14 +154,24 @@ TEST(MatchedFilter, TestSse2Optimizations) {
EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
}
+ for (size_t j = 0; j < accumulated_error.size(); ++j) {
+ float difference =
+ std::abs(accumulated_error[j] - accumulated_error_SSE2[j]);
+ float relative_difference = accumulated_error[j] > 0
+ ? difference / accumulated_error[j]
+ : difference;
+ EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
+ }
+
x_index = (x_index + sub_block_size) % x.size();
}
}
}
}
-TEST(MatchedFilter, TestAvx2Optimizations) {
+TEST_P(MatchedFilterTest, TestAvx2Optimizations) {
bool use_avx2 = (GetCPUInfo(kAVX2) != 0);
+ const bool kComputeAccumulatederror = GetParam();
if (use_avx2) {
Random random_generator(42U);
constexpr float kSmoothing = 0.7f;
@@ -145,29 +182,36 @@ TEST(MatchedFilter, TestAvx2Optimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_AVX2(512);
std::vector<float> h(512);
+ std::vector<float> accumulated_error(512 / 4);
+ std::vector<float> accumulated_error_AVX2(512 / 4);
+ std::vector<float> scratch_memory(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
-
bool filters_updated = false;
float error_sum = 0.f;
bool filters_updated_AVX2 = false;
float error_sum_AVX2 = 0.f;
-
MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_AVX2, &filters_updated_AVX2,
- &error_sum_AVX2);
-
+ &error_sum_AVX2, kComputeAccumulatederror,
+ accumulated_error_AVX2, scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
- h, &filters_updated, &error_sum);
-
+ h, &filters_updated, &error_sum,
+ kComputeAccumulatederror, accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_AVX2);
EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
-
for (size_t j = 0; j < h.size(); ++j) {
EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
}
-
+ for (size_t j = 0; j < accumulated_error.size(); j += 4) {
+ float difference =
+ std::abs(accumulated_error[j] - accumulated_error_AVX2[j]);
+ float relative_difference = accumulated_error[j] > 0
+ ? difference / accumulated_error[j]
+ : difference;
+ EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
+ }
x_index = (x_index + sub_block_size) % x.size();
}
}
@@ -199,9 +243,9 @@ TEST(MatchedFilter, MaxSquarePeakIndex) {
}
// Verifies that the matched filter produces proper lag estimates for
-// artificially
-// delayed signals.
-TEST(MatchedFilter, LagEstimation) {
+// artificially delayed signals.
+TEST_P(MatchedFilterTest, LagEstimation) {
+ const bool kDetectPreEcho = GetParam();
Random random_generator(42U);
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
@@ -222,12 +266,12 @@ TEST(MatchedFilter, LagEstimation) {
Decimator capture_decimator(down_sampling_factor);
DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
delay_samples);
- MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
- kWindowSizeSubBlocks, kNumMatchedFilters,
- kAlignmentShiftSubBlocks, 150,
- config.delay.delay_estimate_smoothing,
- config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold);
+ MatchedFilter filter(
+ &data_dumper, DetectOptimization(), sub_block_size,
+ kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks,
+ 150, config.delay.delay_estimate_smoothing,
+ config.delay.delay_estimate_smoothing_delay_found,
+ config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
@@ -254,62 +298,97 @@ TEST(MatchedFilter, LagEstimation) {
downsampled_capture_data.data(), sub_block_size);
capture_decimator.Decimate(capture[0], downsampled_capture);
filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
- downsampled_capture, false);
+ downsampled_capture, /*use_slow_smoothing=*/false);
}
// Obtain the lag estimates.
- auto lag_estimates = filter.GetLagEstimates();
-
- // Find which lag estimate should be the most accurate.
- absl::optional<size_t> expected_most_accurate_lag_estimate;
- size_t alignment_shift_sub_blocks = 0;
- for (size_t k = 0; k < config.delay.num_filters; ++k) {
- if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) *
- sub_block_size >
- delay_samples) {
- expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0;
- break;
- }
- alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
- }
- ASSERT_TRUE(expected_most_accurate_lag_estimate);
-
- // Verify that the expected most accurate lag estimate is the most
- // accurate estimate.
- for (size_t k = 0; k < kNumMatchedFilters; ++k) {
- if (k != *expected_most_accurate_lag_estimate &&
- k != (*expected_most_accurate_lag_estimate + 1)) {
- EXPECT_TRUE(
- lag_estimates[*expected_most_accurate_lag_estimate].accuracy >
- lag_estimates[k].accuracy ||
- !lag_estimates[k].reliable ||
- !lag_estimates[*expected_most_accurate_lag_estimate].reliable);
- }
- }
+ auto lag_estimate = filter.GetBestLagEstimate();
+ EXPECT_TRUE(lag_estimate.has_value());
- // Verify that all lag estimates are updated as expected for signals
- // containing strong noise.
- for (auto& le : lag_estimates) {
- EXPECT_TRUE(le.updated);
+ // Verify that the expected most accurate lag estimate is correct.
+ if (lag_estimate.has_value()) {
+ EXPECT_EQ(delay_samples, lag_estimate->lag);
+ EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag);
}
+ }
+ }
+}
- // Verify that the expected most accurate lag estimate is reliable.
- EXPECT_TRUE(
- lag_estimates[*expected_most_accurate_lag_estimate].reliable ||
- lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
- lag_estimates.size() - 1)]
- .reliable);
+// Test the pre echo estimation.
+TEST_P(MatchedFilterTest, PreEchoEstimation) {
+ const bool kDetectPreEcho = GetParam();
+ Random random_generator(42U);
+ constexpr size_t kNumChannels = 1;
+ constexpr int kSampleRateHz = 48000;
+ constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
- // Verify that the expected most accurate lag estimate is correct.
- if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) {
- EXPECT_TRUE(delay_samples ==
- lag_estimates[*expected_most_accurate_lag_estimate].lag);
+ for (auto down_sampling_factor : kDownSamplingFactors) {
+ const size_t sub_block_size = kBlockSize / down_sampling_factor;
+
+ Block render(kNumBands, kNumChannels);
+ std::vector<std::vector<float>> capture(
+ 1, std::vector<float>(kBlockSize, 0.f));
+ std::vector<float> capture_with_pre_echo(kBlockSize, 0.f);
+ ApmDataDumper data_dumper(0);
+ // data_dumper.SetActivated(true);
+ size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor;
+ size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor;
+ EchoCanceller3Config config;
+ config.delay.down_sampling_factor = down_sampling_factor;
+ config.delay.num_filters = kNumMatchedFilters;
+ Decimator capture_decimator(down_sampling_factor);
+ DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor *
+ echo_delay_samples);
+ DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor *
+ pre_echo_delay_samples);
+ MatchedFilter filter(
+ &data_dumper, DetectOptimization(), sub_block_size,
+ kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
+ config.delay.delay_estimate_smoothing,
+ config.delay.delay_estimate_smoothing_delay_found,
+ config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
+ std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
+ RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+ // Analyze the correlation between render and capture.
+ for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) {
+ for (size_t band = 0; band < kNumBands; ++band) {
+ for (size_t channel = 0; channel < kNumChannels; ++channel) {
+ RandomizeSampleVector(&random_generator, render.View(band, channel));
+ }
+ }
+ signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]);
+ signal_pre_echo_delay_buffer.Delay(render.View(0, 0),
+ capture_with_pre_echo);
+ for (size_t k = 0; k < capture[0].size(); ++k) {
+ constexpr float gain_pre_echo = 0.8f;
+ capture[0][k] += gain_pre_echo * capture_with_pre_echo[k];
+ }
+ render_delay_buffer->Insert(render);
+ if (k == 0) {
+ render_delay_buffer->Reset();
+ }
+ render_delay_buffer->PrepareCaptureProcessing();
+ std::array<float, kBlockSize> downsampled_capture_data;
+ rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
+ sub_block_size);
+ capture_decimator.Decimate(capture[0], downsampled_capture);
+ filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
+ downsampled_capture, /*use_slow_smoothing=*/false);
+ }
+ // Obtain the lag estimates.
+ auto lag_estimate = filter.GetBestLagEstimate();
+ EXPECT_TRUE(lag_estimate.has_value());
+ // Verify that the expected most accurate lag estimate is correct.
+ if (lag_estimate.has_value()) {
+ EXPECT_EQ(echo_delay_samples, lag_estimate->lag);
+ if (kDetectPreEcho) {
+ // The pre echo delay is estimated in a subsampled domain and a larger
+ // error is allowed.
+ EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4);
} else {
- EXPECT_TRUE(
- delay_samples ==
- lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
- lag_estimates.size() - 1)]
- .lag);
+ // The pre echo delay fallback to the highest mached filter peak when
+ // its detection is disabled.
+ EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag);
}
}
}
@@ -317,7 +396,8 @@ TEST(MatchedFilter, LagEstimation) {
// Verifies that the matched filter does not produce reliable and accurate
// estimates for uncorrelated render and capture signals.
-TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
+TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) {
+ const bool kDetectPreEcho = GetParam();
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@@ -335,12 +415,12 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
ApmDataDumper data_dumper(0);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
- MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
- kWindowSizeSubBlocks, kNumMatchedFilters,
- kAlignmentShiftSubBlocks, 150,
- config.delay.delay_estimate_smoothing,
- config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold);
+ MatchedFilter filter(
+ &data_dumper, DetectOptimization(), sub_block_size,
+ kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
+ config.delay.delay_estimate_smoothing,
+ config.delay.delay_estimate_smoothing_delay_found,
+ config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
@@ -352,20 +432,17 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
false);
}
- // Obtain the lag estimates.
- auto lag_estimates = filter.GetLagEstimates();
- EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
-
- // Verify that no lag estimates are reliable.
- for (auto& le : lag_estimates) {
- EXPECT_FALSE(le.reliable);
- }
+ // Obtain the best lag estimate and Verify that no lag estimates are
+ // reliable.
+ auto best_lag_estimates = filter.GetBestLagEstimate();
+ EXPECT_FALSE(best_lag_estimates.has_value());
}
}
// Verifies that the matched filter does not produce updated lag estimates for
// render signals of low level.
-TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
+TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) {
+ const bool kDetectPreEcho = GetParam();
Random random_generator(42U);
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
@@ -374,19 +451,17 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
- std::vector<std::vector<std::vector<float>>> render(
- kNumBands, std::vector<std::vector<float>>(
- kNumChannels, std::vector<float>(kBlockSize, 0.f)));
+ Block render(kNumBands, kNumChannels);
std::vector<std::vector<float>> capture(
1, std::vector<float>(kBlockSize, 0.f));
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
- MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
- kWindowSizeSubBlocks, kNumMatchedFilters,
- kAlignmentShiftSubBlocks, 150,
- config.delay.delay_estimate_smoothing,
- config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold);
+ MatchedFilter filter(
+ &data_dumper, DetectOptimization(), sub_block_size,
+ kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
+ config.delay.delay_estimate_smoothing,
+ config.delay.delay_estimate_smoothing_delay_found,
+ config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
kNumChannels));
@@ -394,11 +469,11 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
- RandomizeSampleVector(&random_generator, render[0][0]);
- for (auto& render_k : render[0][0]) {
+ RandomizeSampleVector(&random_generator, render.View(0, 0));
+ for (auto& render_k : render.View(0, 0)) {
render_k *= 149.f / 32767.f;
}
- std::copy(render[0][0].begin(), render[0][0].end(), capture[0].begin());
+ std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin());
std::array<float, kBlockSize> downsampled_capture_data;
rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
sub_block_size);
@@ -407,86 +482,76 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
downsampled_capture, false);
}
- // Obtain the lag estimates.
- auto lag_estimates = filter.GetLagEstimates();
- EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
-
- // Verify that no lag estimates are updated and that no lag estimates are
- // reliable.
- for (auto& le : lag_estimates) {
- EXPECT_FALSE(le.updated);
- EXPECT_FALSE(le.reliable);
- }
+ // Verify that no lag estimate has been produced.
+ auto lag_estimate = filter.GetBestLagEstimate();
+ EXPECT_FALSE(lag_estimate.has_value());
}
}
-// Verifies that the correct number of lag estimates are produced for a certain
-// number of alignment shifts.
-TEST(MatchedFilter, NumberOfLagEstimates) {
- ApmDataDumper data_dumper(0);
- EchoCanceller3Config config;
- for (auto down_sampling_factor : kDownSamplingFactors) {
- const size_t sub_block_size = kBlockSize / down_sampling_factor;
- for (size_t num_matched_filters = 0; num_matched_filters < 10;
- ++num_matched_filters) {
- MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
- 32, num_matched_filters, 1, 150,
- config.delay.delay_estimate_smoothing,
- config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold);
- EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
- }
- }
-}
+INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false));
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
+class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {};
+
// Verifies the check for non-zero windows size.
-TEST(MatchedFilterDeathTest, ZeroWindowSize) {
+TEST_P(MatchedFilterDeathTest, ZeroWindowSize) {
+ const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold),
+ config.delay.delay_candidate_detection_threshold,
+ kDetectPreEcho),
"");
}
// Verifies the check for non-null data dumper.
-TEST(MatchedFilterDeathTest, NullDataDumper) {
+TEST_P(MatchedFilterDeathTest, NullDataDumper) {
+ const bool kDetectPreEcho = GetParam();
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold),
+ config.delay.delay_candidate_detection_threshold,
+ kDetectPreEcho),
"");
}
// Verifies the check for that the sub block size is a multiple of 4.
// TODO(peah): Activate the unittest once the required code has been landed.
-TEST(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
+TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
+ const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold),
+ config.delay.delay_candidate_detection_threshold,
+ kDetectPreEcho),
"");
}
// Verifies the check for that there is an integer number of sub blocks that add
// up to a block size.
// TODO(peah): Activate the unittest once the required code has been landed.
-TEST(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
+TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
+ const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
- config.delay.delay_candidate_detection_threshold),
+ config.delay.delay_candidate_detection_threshold,
+ kDetectPreEcho),
"");
}
+INSTANTIATE_TEST_SUITE_P(_,
+ MatchedFilterDeathTest,
+ testing::Values(true, false));
+
#endif
} // namespace aec3
diff --git a/modules/audio_processing/aec3/render_delay_controller.cc b/modules/audio_processing/aec3/render_delay_controller.cc
index aa3d440c33..826c38a1e9 100644
--- a/modules/audio_processing/aec3/render_delay_controller.cc
+++ b/modules/audio_processing/aec3/render_delay_controller.cc
@@ -53,7 +53,6 @@ class RenderDelayControllerImpl final : public RenderDelayController {
static std::atomic<int> instance_count_;
std::unique_ptr<ApmDataDumper> data_dumper_;
const int hysteresis_limit_blocks_;
- const int delay_headroom_samples_;
absl::optional<DelayEstimate> delay_;
EchoPathDelayEstimator delay_estimator_;
RenderDelayControllerMetrics metrics_;
@@ -66,15 +65,9 @@ class RenderDelayControllerImpl final : public RenderDelayController {
DelayEstimate ComputeBufferDelay(
const absl::optional<DelayEstimate>& current_delay,
int hysteresis_limit_blocks,
- int delay_headroom_samples,
DelayEstimate estimated_delay) {
- // Subtract delay headroom.
- const int delay_with_headroom_samples = std::max(
- static_cast<int>(estimated_delay.delay) - delay_headroom_samples, 0);
-
// Compute the buffer delay increase required to achieve the desired latency.
- size_t new_delay_blocks = delay_with_headroom_samples >> kBlockSizeLog2;
-
+ size_t new_delay_blocks = estimated_delay.delay >> kBlockSizeLog2;
// Add hysteresis.
if (current_delay) {
size_t current_delay_blocks = current_delay->delay;
@@ -83,7 +76,6 @@ DelayEstimate ComputeBufferDelay(
new_delay_blocks = current_delay_blocks;
}
}
-
DelayEstimate new_delay = estimated_delay;
new_delay.delay = new_delay_blocks;
return new_delay;
@@ -98,7 +90,6 @@ RenderDelayControllerImpl::RenderDelayControllerImpl(
: data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
hysteresis_limit_blocks_(
static_cast<int>(config.delay.hysteresis_limit_blocks)),
- delay_headroom_samples_(config.delay.delay_headroom_samples),
delay_estimator_(data_dumper_.get(), config, num_capture_channels),
last_delay_estimate_quality_(DelayEstimate::Quality::kCoarse) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz));
@@ -158,9 +149,8 @@ absl::optional<DelayEstimate> RenderDelayControllerImpl::GetDelay(
const bool use_hysteresis =
last_delay_estimate_quality_ == DelayEstimate::Quality::kRefined &&
delay_samples_->quality == DelayEstimate::Quality::kRefined;
- delay_ = ComputeBufferDelay(delay_,
- use_hysteresis ? hysteresis_limit_blocks_ : 0,
- delay_headroom_samples_, *delay_samples_);
+ delay_ = ComputeBufferDelay(
+ delay_, use_hysteresis ? hysteresis_limit_blocks_ : 0, *delay_samples_);
last_delay_estimate_quality_ = delay_samples_->quality;
}