diff options
Diffstat (limited to 'modules/audio_processing/aec3/matched_filter.h')
-rw-r--r-- | modules/audio_processing/aec3/matched_filter.h | 46 |
1 files changed, 32 insertions, 14 deletions
diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index dd4a678394..760d5e39fd 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -15,6 +15,7 @@ #include <vector> +#include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" #include "rtc_base/system/arch.h" @@ -36,7 +37,10 @@ void MatchedFilterCore_NEON(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); #endif @@ -50,7 +54,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); // Filter core for the matched filter that is optimized for AVX2. void MatchedFilterCore_AVX2(size_t x_start_index, @@ -60,7 +67,10 @@ void MatchedFilterCore_AVX2(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); #endif @@ -72,7 +82,9 @@ void MatchedFilterCore(size_t x_start_index, rtc::ArrayView<const float> y, rtc::ArrayView<float> h, bool* filters_updated, - float* error_sum); + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error); // Find largest peak of squared values in array. size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h); @@ -87,13 +99,10 @@ class MatchedFilter { // shift. struct LagEstimate { LagEstimate() = default; - LagEstimate(float accuracy, bool reliable, size_t lag, bool updated) - : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {} - - float accuracy = 0.f; - bool reliable = false; + LagEstimate(size_t lag, size_t pre_echo_lag) + : lag(lag), pre_echo_lag(pre_echo_lag) {} size_t lag = 0; - bool updated = false; + size_t pre_echo_lag = 0; }; MatchedFilter(ApmDataDumper* data_dumper, @@ -105,7 +114,8 @@ class MatchedFilter { float excitation_limit, float smoothing_fast, float smoothing_slow, - float matching_filter_threshold); + float matching_filter_threshold, + bool detect_pre_echo); MatchedFilter() = delete; MatchedFilter(const MatchedFilter&) = delete; @@ -122,8 +132,8 @@ class MatchedFilter { void Reset(); // Returns the current lag estimates. - rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const { - return lag_estimates_; + absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const { + return reported_lag_estimate_; } // Returns the maximum filter lag. @@ -137,17 +147,25 @@ class MatchedFilter { size_t downsampling_factor) const; private: + void Dump(); + ApmDataDumper* const data_dumper_; const Aec3Optimization optimization_; const size_t sub_block_size_; const size_t filter_intra_lag_shift_; std::vector<std::vector<float>> filters_; - std::vector<LagEstimate> lag_estimates_; + std::vector<std::vector<float>> accumulated_error_; + std::vector<float> instantaneous_accumulated_error_; + std::vector<float> scratch_memory_; + absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_; + absl::optional<size_t> winner_lag_; + int last_detected_best_lag_filter_ = -1; std::vector<size_t> filters_offsets_; const float excitation_limit_; const float smoothing_fast_; const float smoothing_slow_; const float matching_filter_threshold_; + const bool detect_pre_echo_; }; } // namespace webrtc |