diff options
Diffstat (limited to 'webrtc/modules/audio_processing/intelligibility')
8 files changed, 1608 insertions, 0 deletions
diff --git a/webrtc/modules/audio_processing/intelligibility/Android.mk b/webrtc/modules/audio_processing/intelligibility/Android.mk new file mode 100644 index 0000000000..f8824a6492 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/Android.mk @@ -0,0 +1,46 @@ +# Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +include $(LOCAL_PATH)/../../../../android-webrtc.mk + +LOCAL_ARM_MODE := arm +LOCAL_MODULE_CLASS := STATIC_LIBRARIES +LOCAL_MODULE := libwebrtc_intell +LOCAL_MODULE_TAGS := optional +LOCAL_CPP_EXTENSION := .cc +LOCAL_SRC_FILES := \ + intelligibility_enhancer.cc \ + intelligibility_utils.cc \ + +# Flags passed to both C and C++ files. +LOCAL_CFLAGS := \ + $(MY_WEBRTC_COMMON_DEFS) + +LOCAL_CFLAGS_arm := $(MY_WEBRTC_COMMON_DEFS_arm) +LOCAL_CFLAGS_x86 := $(MY_WEBRTC_COMMON_DEFS_x86) +LOCAL_CFLAGS_mips := $(MY_WEBRTC_COMMON_DEFS_mips) +LOCAL_CFLAGS_arm64 := $(MY_WEBRTC_COMMON_DEFS_arm64) +LOCAL_CFLAGS_x86_64 := $(MY_WEBRTC_COMMON_DEFS_x86_64) +LOCAL_CFLAGS_mips64 := $(MY_WEBRTC_COMMON_DEFS_mips64) + +# Include paths placed before CFLAGS/CPPFLAGS +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH) \ + $(LOCAL_PATH)/../../../.. \ + +ifdef WEBRTC_STL +LOCAL_NDK_STL_VARIANT := $(WEBRTC_STL) +LOCAL_SDK_VERSION := 14 +LOCAL_MODULE := $(LOCAL_MODULE)_$(WEBRTC_STL) +endif + +include $(BUILD_STATIC_LIBRARY) diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc new file mode 100644 index 0000000000..d014ce060c --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Implements core class for intelligibility enhancer. +// +// Details of the model and algorithm can be found in the original paper: +// http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788 +// + +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" + +#include <math.h> +#include <stdlib.h> +#include <algorithm> +#include <numeric> + +#include "webrtc/base/checks.h" +#include "webrtc/common_audio/include/audio_util.h" +#include "webrtc/common_audio/window_generator.h" + +namespace webrtc { + +namespace { + +const size_t kErbResolution = 2; +const int kWindowSizeMs = 2; +const int kChunkSizeMs = 10; // Size provided by APM. +const float kClipFreq = 200.0f; +const float kConfigRho = 0.02f; // Default production and interpretation SNR. +const float kKbdAlpha = 1.5f; +const float kLambdaBot = -1.0f; // Extreme values in bisection +const float kLambdaTop = -10e-18f; // search for lamda. + +} // namespace + +using std::complex; +using std::max; +using std::min; +using VarianceType = intelligibility::VarianceArray::StepType; + +IntelligibilityEnhancer::TransformCallback::TransformCallback( + IntelligibilityEnhancer* parent, + IntelligibilityEnhancer::AudioSource source) + : parent_(parent), source_(source) { +} + +void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock( + const complex<float>* const* in_block, + int in_channels, + size_t frames, + int /* out_channels */, + complex<float>* const* out_block) { + RTC_DCHECK_EQ(parent_->freqs_, frames); + for (int i = 0; i < in_channels; ++i) { + parent_->DispatchAudio(source_, in_block[i], out_block[i]); + } +} + +IntelligibilityEnhancer::IntelligibilityEnhancer() + : IntelligibilityEnhancer(IntelligibilityEnhancer::Config()) { +} + +IntelligibilityEnhancer::IntelligibilityEnhancer(const Config& config) + : freqs_(RealFourier::ComplexLength( + RealFourier::FftOrder(config.sample_rate_hz * kWindowSizeMs / 1000))), + window_size_(static_cast<size_t>(1 << RealFourier::FftOrder(freqs_))), + chunk_length_( + static_cast<size_t>(config.sample_rate_hz * kChunkSizeMs / 1000)), + bank_size_(GetBankSize(config.sample_rate_hz, kErbResolution)), + sample_rate_hz_(config.sample_rate_hz), + erb_resolution_(kErbResolution), + num_capture_channels_(config.num_capture_channels), + num_render_channels_(config.num_render_channels), + analysis_rate_(config.analysis_rate), + active_(true), + clear_variance_(freqs_, + config.var_type, + config.var_window_size, + config.var_decay_rate), + noise_variance_(freqs_, + config.var_type, + config.var_window_size, + config.var_decay_rate), + filtered_clear_var_(new float[bank_size_]), + filtered_noise_var_(new float[bank_size_]), + filter_bank_(bank_size_), + center_freqs_(new float[bank_size_]), + rho_(new float[bank_size_]), + gains_eq_(new float[bank_size_]), + gain_applier_(freqs_, config.gain_change_limit), + temp_render_out_buffer_(chunk_length_, num_render_channels_), + temp_capture_out_buffer_(chunk_length_, num_capture_channels_), + kbd_window_(new float[window_size_]), + render_callback_(this, AudioSource::kRenderStream), + capture_callback_(this, AudioSource::kCaptureStream), + block_count_(0), + analysis_step_(0) { + RTC_DCHECK_LE(config.rho, 1.0f); + + CreateErbBank(); + + // Assumes all rho equal. + for (size_t i = 0; i < bank_size_; ++i) { + rho_[i] = config.rho * config.rho; + } + + float freqs_khz = kClipFreq / 1000.0f; + size_t erb_index = static_cast<size_t>(ceilf( + 11.17f * logf((freqs_khz + 0.312f) / (freqs_khz + 14.6575f)) + 43.0f)); + start_freq_ = std::max(static_cast<size_t>(1), erb_index * erb_resolution_); + + WindowGenerator::KaiserBesselDerived(kKbdAlpha, window_size_, + kbd_window_.get()); + render_mangler_.reset(new LappedTransform( + num_render_channels_, num_render_channels_, chunk_length_, + kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_)); + capture_mangler_.reset(new LappedTransform( + num_capture_channels_, num_capture_channels_, chunk_length_, + kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_)); +} + +void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio, + int sample_rate_hz, + int num_channels) { + RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz); + RTC_CHECK_EQ(num_render_channels_, num_channels); + + if (active_) { + render_mangler_->ProcessChunk(audio, temp_render_out_buffer_.channels()); + } + + if (active_) { + for (int i = 0; i < num_render_channels_; ++i) { + memcpy(audio[i], temp_render_out_buffer_.channels()[i], + chunk_length_ * sizeof(**audio)); + } + } +} + +void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio, + int sample_rate_hz, + int num_channels) { + RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz); + RTC_CHECK_EQ(num_capture_channels_, num_channels); + + capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels()); +} + +void IntelligibilityEnhancer::DispatchAudio( + IntelligibilityEnhancer::AudioSource source, + const complex<float>* in_block, + complex<float>* out_block) { + switch (source) { + case kRenderStream: + ProcessClearBlock(in_block, out_block); + break; + case kCaptureStream: + ProcessNoiseBlock(in_block, out_block); + break; + } +} + +void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block, + complex<float>* out_block) { + if (block_count_ < 2) { + memset(out_block, 0, freqs_ * sizeof(*out_block)); + ++block_count_; + return; + } + + // TODO(ekm): Use VAD to |Step| and |AnalyzeClearBlock| only if necessary. + if (true) { + clear_variance_.Step(in_block, false); + if (block_count_ % analysis_rate_ == analysis_rate_ - 1) { + const float power_target = std::accumulate( + clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.f); + AnalyzeClearBlock(power_target); + ++analysis_step_; + } + ++block_count_; + } + + if (active_) { + gain_applier_.Apply(in_block, out_block); + } +} + +void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) { + FilterVariance(clear_variance_.variance(), filtered_clear_var_.get()); + FilterVariance(noise_variance_.variance(), filtered_noise_var_.get()); + + SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get()); + const float power_top = + DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); + SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get()); + const float power_bot = + DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); + if (power_target >= power_bot && power_target <= power_top) { + SolveForLambda(power_target, power_bot, power_top); + UpdateErbGains(); + } // Else experiencing variance underflow, so do nothing. +} + +void IntelligibilityEnhancer::SolveForLambda(float power_target, + float power_bot, + float power_top) { + const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values + const int kMaxIters = 100; // for these, based on experiments. + + const float reciprocal_power_target = 1.f / power_target; + float lambda_bot = kLambdaBot; + float lambda_top = kLambdaTop; + float power_ratio = 2.0f; // Ratio of achieved power to target power. + int iters = 0; + while (std::fabs(power_ratio - 1.0f) > kConvergeThresh && + iters <= kMaxIters) { + const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f; + SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get()); + const float power = + DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); + if (power < power_target) { + lambda_bot = lambda; + } else { + lambda_top = lambda; + } + power_ratio = std::fabs(power * reciprocal_power_target); + ++iters; + } +} + +void IntelligibilityEnhancer::UpdateErbGains() { + // (ERB gain) = filterbank' * (freq gain) + float* gains = gain_applier_.target(); + for (size_t i = 0; i < freqs_; ++i) { + gains[i] = 0.0f; + for (size_t j = 0; j < bank_size_; ++j) { + gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]); + } + } +} + +void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block, + complex<float>* /*out_block*/) { + noise_variance_.Step(in_block); +} + +size_t IntelligibilityEnhancer::GetBankSize(int sample_rate, + size_t erb_resolution) { + float freq_limit = sample_rate / 2000.0f; + size_t erb_scale = static_cast<size_t>(ceilf( + 11.17f * logf((freq_limit + 0.312f) / (freq_limit + 14.6575f)) + 43.0f)); + return erb_scale * erb_resolution; +} + +void IntelligibilityEnhancer::CreateErbBank() { + size_t lf = 1, rf = 4; + + for (size_t i = 0; i < bank_size_; ++i) { + float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_)); + center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp)); + center_freqs_[i] -= 14678.49f; + } + float last_center_freq = center_freqs_[bank_size_ - 1]; + for (size_t i = 0; i < bank_size_; ++i) { + center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq; + } + + for (size_t i = 0; i < bank_size_; ++i) { + filter_bank_[i].resize(freqs_); + } + + for (size_t i = 1; i <= bank_size_; ++i) { + size_t lll, ll, rr, rrr; + static const size_t kOne = 1; // Avoids repeated static_cast<>s below. + lll = static_cast<size_t>(round( + center_freqs_[max(kOne, i - lf) - 1] * freqs_ / + (0.5f * sample_rate_hz_))); + ll = static_cast<size_t>(round( + center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_))); + lll = min(freqs_, max(lll, kOne)) - 1; + ll = min(freqs_, max(ll, kOne)) - 1; + + rrr = static_cast<size_t>(round( + center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ / + (0.5f * sample_rate_hz_))); + rr = static_cast<size_t>(round( + center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ / + (0.5f * sample_rate_hz_))); + rrr = min(freqs_, max(rrr, kOne)) - 1; + rr = min(freqs_, max(rr, kOne)) - 1; + + float step, element; + + step = 1.0f / (ll - lll); + element = 0.0f; + for (size_t j = lll; j <= ll; ++j) { + filter_bank_[i - 1][j] = element; + element += step; + } + step = 1.0f / (rrr - rr); + element = 1.0f; + for (size_t j = rr; j <= rrr; ++j) { + filter_bank_[i - 1][j] = element; + element -= step; + } + for (size_t j = ll; j <= rr; ++j) { + filter_bank_[i - 1][j] = 1.0f; + } + } + + float sum; + for (size_t i = 0; i < freqs_; ++i) { + sum = 0.0f; + for (size_t j = 0; j < bank_size_; ++j) { + sum += filter_bank_[j][i]; + } + for (size_t j = 0; j < bank_size_; ++j) { + filter_bank_[j][i] /= sum; + } + } +} + +void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda, + size_t start_freq, + float* sols) { + bool quadratic = (kConfigRho < 1.0f); + const float* var_x0 = filtered_clear_var_.get(); + const float* var_n0 = filtered_noise_var_.get(); + + for (size_t n = 0; n < start_freq; ++n) { + sols[n] = 1.0f; + } + + // Analytic solution for optimal gains. See paper for derivation. + for (size_t n = start_freq - 1; n < bank_size_; ++n) { + float alpha0, beta0, gamma0; + gamma0 = 0.5f * rho_[n] * var_x0[n] * var_n0[n] + + lambda * var_x0[n] * var_n0[n] * var_n0[n]; + beta0 = lambda * var_x0[n] * (2 - rho_[n]) * var_x0[n] * var_n0[n]; + if (quadratic) { + alpha0 = lambda * var_x0[n] * (1 - rho_[n]) * var_x0[n] * var_x0[n]; + sols[n] = + (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0); + } else { + sols[n] = -gamma0 / beta0; + } + sols[n] = fmax(0, sols[n]); + } +} + +void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) { + RTC_DCHECK_GT(freqs_, 0u); + for (size_t i = 0; i < bank_size_; ++i) { + result[i] = DotProduct(&filter_bank_[i][0], var, freqs_); + } +} + +float IntelligibilityEnhancer::DotProduct(const float* a, + const float* b, + size_t length) { + float ret = 0.0f; + + for (size_t i = 0; i < length; ++i) { + ret = fmaf(a[i], b[i], ret); + } + return ret; +} + +bool IntelligibilityEnhancer::active() const { + return active_; +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h new file mode 100644 index 0000000000..1e9e35ac2a --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Specifies core class for intelligbility enhancement. +// + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_ENHANCER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_ENHANCER_H_ + +#include <complex> +#include <vector> + +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/common_audio/lapped_transform.h" +#include "webrtc/common_audio/channel_buffer.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" + +namespace webrtc { + +// Speech intelligibility enhancement module. Reads render and capture +// audio streams and modifies the render stream with a set of gains per +// frequency bin to enhance speech against the noise background. +// Note: assumes speech and noise streams are already separated. +class IntelligibilityEnhancer { + public: + struct Config { + // |var_*| are parameters for the VarianceArray constructor for the + // clear speech stream. + // TODO(bercic): the |var_*|, |*_rate| and |gain_limit| parameters should + // probably go away once fine tuning is done. + Config() + : sample_rate_hz(16000), + num_capture_channels(1), + num_render_channels(1), + var_type(intelligibility::VarianceArray::kStepDecaying), + var_decay_rate(0.9f), + var_window_size(10), + analysis_rate(800), + gain_change_limit(0.1f), + rho(0.02f) {} + int sample_rate_hz; + int num_capture_channels; + int num_render_channels; + intelligibility::VarianceArray::StepType var_type; + float var_decay_rate; + size_t var_window_size; + int analysis_rate; + float gain_change_limit; + float rho; + }; + + explicit IntelligibilityEnhancer(const Config& config); + IntelligibilityEnhancer(); // Initialize with default config. + + // Reads and processes chunk of noise stream in time domain. + void AnalyzeCaptureAudio(float* const* audio, + int sample_rate_hz, + int num_channels); + + // Reads chunk of speech in time domain and updates with modified signal. + void ProcessRenderAudio(float* const* audio, + int sample_rate_hz, + int num_channels); + bool active() const; + + private: + enum AudioSource { + kRenderStream = 0, // Clear speech stream. + kCaptureStream, // Noise stream. + }; + + // Provides access point to the frequency domain. + class TransformCallback : public LappedTransform::Callback { + public: + TransformCallback(IntelligibilityEnhancer* parent, AudioSource source); + + // All in frequency domain, receives input |in_block|, applies + // intelligibility enhancement, and writes result to |out_block|. + void ProcessAudioBlock(const std::complex<float>* const* in_block, + int in_channels, + size_t frames, + int out_channels, + std::complex<float>* const* out_block) override; + + private: + IntelligibilityEnhancer* parent_; + AudioSource source_; + }; + friend class TransformCallback; + FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation); + FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains); + + // Sends streams to ProcessClearBlock or ProcessNoiseBlock based on source. + void DispatchAudio(AudioSource source, + const std::complex<float>* in_block, + std::complex<float>* out_block); + + // Updates variance computation and analysis with |in_block_|, + // and writes modified speech to |out_block|. + void ProcessClearBlock(const std::complex<float>* in_block, + std::complex<float>* out_block); + + // Computes and sets modified gains. + void AnalyzeClearBlock(float power_target); + + // Bisection search for optimal |lambda|. + void SolveForLambda(float power_target, float power_bot, float power_top); + + // Transforms freq gains to ERB gains. + void UpdateErbGains(); + + // Updates variance calculation for noise input with |in_block|. + void ProcessNoiseBlock(const std::complex<float>* in_block, + std::complex<float>* out_block); + + // Returns number of ERB filters. + static size_t GetBankSize(int sample_rate, size_t erb_resolution); + + // Initializes ERB filterbank. + void CreateErbBank(); + + // Analytically solves quadratic for optimal gains given |lambda|. + // Negative gains are set to 0. Stores the results in |sols|. + void SolveForGainsGivenLambda(float lambda, size_t start_freq, float* sols); + + // Computes variance across ERB filters from freq variance |var|. + // Stores in |result|. + void FilterVariance(const float* var, float* result); + + // Returns dot product of vectors specified by size |length| arrays |a|,|b|. + static float DotProduct(const float* a, const float* b, size_t length); + + const size_t freqs_; // Num frequencies in frequency domain. + const size_t window_size_; // Window size in samples; also the block size. + const size_t chunk_length_; // Chunk size in samples. + const size_t bank_size_; // Num ERB filters. + const int sample_rate_hz_; + const int erb_resolution_; + const int num_capture_channels_; + const int num_render_channels_; + const int analysis_rate_; // Num blocks before gains recalculated. + + const bool active_; // Whether render gains are being updated. + // TODO(ekm): Add logic for updating |active_|. + + intelligibility::VarianceArray clear_variance_; + intelligibility::VarianceArray noise_variance_; + rtc::scoped_ptr<float[]> filtered_clear_var_; + rtc::scoped_ptr<float[]> filtered_noise_var_; + std::vector<std::vector<float>> filter_bank_; + rtc::scoped_ptr<float[]> center_freqs_; + size_t start_freq_; + rtc::scoped_ptr<float[]> rho_; // Production and interpretation SNR. + // for each ERB band. + rtc::scoped_ptr<float[]> gains_eq_; // Pre-filter modified gains. + intelligibility::GainApplier gain_applier_; + + // Destination buffers used to reassemble blocked chunks before overwriting + // the original input array with modifications. + ChannelBuffer<float> temp_render_out_buffer_; + ChannelBuffer<float> temp_capture_out_buffer_; + + rtc::scoped_ptr<float[]> kbd_window_; + TransformCallback render_callback_; + TransformCallback capture_callback_; + rtc::scoped_ptr<LappedTransform> render_mangler_; + rtc::scoped_ptr<LappedTransform> capture_mangler_; + int block_count_; + int analysis_step_; +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_ENHANCER_H_ diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc new file mode 100644 index 0000000000..ce146deaf5 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Unit tests for intelligibility enhancer. +// + +#include <math.h> +#include <stdlib.h> +#include <algorithm> +#include <vector> + +#include "testing/gtest/include/gtest/gtest.h" +#include "webrtc/base/arraysize.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/common_audio/signal_processing/include/signal_processing_library.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" + +namespace webrtc { + +namespace { + +// Target output for ERB create test. Generated with matlab. +const float kTestCenterFreqs[] = { + 13.169f, 26.965f, 41.423f, 56.577f, 72.461f, 89.113f, 106.57f, 124.88f, + 144.08f, 164.21f, 185.34f, 207.5f, 230.75f, 255.16f, 280.77f, 307.66f, + 335.9f, 365.56f, 396.71f, 429.44f, 463.84f, 500.f}; +const float kTestFilterBank[][2] = {{0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.2f}, + {0, 0.2f}, + {0, 0.2f}, + {0, 0.2f}, + {0, 0.2f}}; +static_assert(arraysize(kTestCenterFreqs) == arraysize(kTestFilterBank), + "Test filterbank badly initialized."); + +// Target output for gain solving test. Generated with matlab. +const size_t kTestStartFreq = 12; // Lowest integral frequency for ERBs. +const float kTestZeroVar[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; +static_assert(arraysize(kTestCenterFreqs) == arraysize(kTestZeroVar), + "Variance test data badly initialized."); +const float kTestNonZeroVarLambdaTop[] = { + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 0.f, 0.f, 0.0351f, 0.0636f, 0.0863f, + 0.1037f, 0.1162f, 0.1236f, 0.1251f, 0.1189f, 0.0993f}; +static_assert(arraysize(kTestCenterFreqs) == + arraysize(kTestNonZeroVarLambdaTop), + "Variance test data badly initialized."); +const float kMaxTestError = 0.005f; + +// Enhancer initialization parameters. +const int kSamples = 2000; +const int kSampleRate = 1000; +const int kNumChannels = 1; +const int kFragmentSize = kSampleRate / 100; + +} // namespace + +using std::vector; +using intelligibility::VarianceArray; + +class IntelligibilityEnhancerTest : public ::testing::Test { + protected: + IntelligibilityEnhancerTest() + : clear_data_(kSamples), noise_data_(kSamples), orig_data_(kSamples) { + config_.sample_rate_hz = kSampleRate; + enh_.reset(new IntelligibilityEnhancer(config_)); + } + + bool CheckUpdate(VarianceArray::StepType step_type) { + config_.sample_rate_hz = kSampleRate; + config_.var_type = step_type; + enh_.reset(new IntelligibilityEnhancer(config_)); + float* clear_cursor = &clear_data_[0]; + float* noise_cursor = &noise_data_[0]; + for (int i = 0; i < kSamples; i += kFragmentSize) { + enh_->AnalyzeCaptureAudio(&noise_cursor, kSampleRate, kNumChannels); + enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels); + clear_cursor += kFragmentSize; + noise_cursor += kFragmentSize; + } + for (int i = 0; i < kSamples; i++) { + if (std::fabs(clear_data_[i] - orig_data_[i]) > kMaxTestError) { + return true; + } + } + return false; + } + + IntelligibilityEnhancer::Config config_; + rtc::scoped_ptr<IntelligibilityEnhancer> enh_; + vector<float> clear_data_; + vector<float> noise_data_; + vector<float> orig_data_; +}; + +// For each class of generated data, tests that render stream is +// updated when it should be for each variance update method. +TEST_F(IntelligibilityEnhancerTest, TestRenderUpdate) { + vector<VarianceArray::StepType> step_types; + step_types.push_back(VarianceArray::kStepInfinite); + step_types.push_back(VarianceArray::kStepDecaying); + step_types.push_back(VarianceArray::kStepWindowed); + step_types.push_back(VarianceArray::kStepBlocked); + step_types.push_back(VarianceArray::kStepBlockBasedMovingAverage); + std::fill(noise_data_.begin(), noise_data_.end(), 0.0f); + std::fill(orig_data_.begin(), orig_data_.end(), 0.0f); + for (auto step_type : step_types) { + std::fill(clear_data_.begin(), clear_data_.end(), 0.0f); + EXPECT_FALSE(CheckUpdate(step_type)); + } + std::srand(1); + auto float_rand = []() { return std::rand() * 2.f / RAND_MAX - 1; }; + std::generate(noise_data_.begin(), noise_data_.end(), float_rand); + for (auto step_type : step_types) { + EXPECT_FALSE(CheckUpdate(step_type)); + } + for (auto step_type : step_types) { + std::generate(clear_data_.begin(), clear_data_.end(), float_rand); + orig_data_ = clear_data_; + EXPECT_TRUE(CheckUpdate(step_type)); + } +} + +// Tests ERB bank creation, comparing against matlab output. +TEST_F(IntelligibilityEnhancerTest, TestErbCreation) { + ASSERT_EQ(arraysize(kTestCenterFreqs), enh_->bank_size_); + for (size_t i = 0; i < enh_->bank_size_; ++i) { + EXPECT_NEAR(kTestCenterFreqs[i], enh_->center_freqs_[i], kMaxTestError); + ASSERT_EQ(arraysize(kTestFilterBank[0]), enh_->freqs_); + for (size_t j = 0; j < enh_->freqs_; ++j) { + EXPECT_NEAR(kTestFilterBank[i][j], enh_->filter_bank_[i][j], + kMaxTestError); + } + } +} + +// Tests analytic solution for optimal gains, comparing +// against matlab output. +TEST_F(IntelligibilityEnhancerTest, TestSolveForGains) { + ASSERT_EQ(kTestStartFreq, enh_->start_freq_); + vector<float> sols(enh_->bank_size_); + float lambda = -0.001f; + for (size_t i = 0; i < enh_->bank_size_; i++) { + enh_->filtered_clear_var_[i] = 0.0f; + enh_->filtered_noise_var_[i] = 0.0f; + enh_->rho_[i] = 0.02f; + } + enh_->SolveForGainsGivenLambda(lambda, enh_->start_freq_, &sols[0]); + for (size_t i = 0; i < enh_->bank_size_; i++) { + EXPECT_NEAR(kTestZeroVar[i], sols[i], kMaxTestError); + } + for (size_t i = 0; i < enh_->bank_size_; i++) { + enh_->filtered_clear_var_[i] = static_cast<float>(i + 1); + enh_->filtered_noise_var_[i] = static_cast<float>(enh_->bank_size_ - i); + } + enh_->SolveForGainsGivenLambda(lambda, enh_->start_freq_, &sols[0]); + for (size_t i = 0; i < enh_->bank_size_; i++) { + EXPECT_NEAR(kTestNonZeroVarLambdaTop[i], sols[i], kMaxTestError); + } + lambda = -1.0; + enh_->SolveForGainsGivenLambda(lambda, enh_->start_freq_, &sols[0]); + for (size_t i = 0; i < enh_->bank_size_; i++) { + EXPECT_NEAR(kTestZeroVar[i], sols[i], kMaxTestError); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc new file mode 100644 index 0000000000..7da9b957a4 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Implements helper functions and classes for intelligibility enhancement. +// + +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" + +#include <math.h> +#include <stdlib.h> +#include <string.h> +#include <algorithm> + +using std::complex; +using std::min; + +namespace webrtc { + +namespace intelligibility { + +float UpdateFactor(float target, float current, float limit) { + float delta = fabsf(target - current); + float sign = copysign(1.0f, target - current); + return current + sign * fminf(delta, limit); +} + +float AddDitherIfZero(float value) { + return value == 0.f ? std::rand() * 0.01f / RAND_MAX : value; +} + +complex<float> zerofudge(complex<float> c) { + return complex<float>(AddDitherIfZero(c.real()), AddDitherIfZero(c.imag())); +} + +complex<float> NewMean(complex<float> mean, complex<float> data, size_t count) { + return mean + (data - mean) / static_cast<float>(count); +} + +void AddToMean(complex<float> data, size_t count, complex<float>* mean) { + (*mean) = NewMean(*mean, data, count); +} + + +static const size_t kWindowBlockSize = 10; + +VarianceArray::VarianceArray(size_t num_freqs, + StepType type, + size_t window_size, + float decay) + : running_mean_(new complex<float>[num_freqs]()), + running_mean_sq_(new complex<float>[num_freqs]()), + sub_running_mean_(new complex<float>[num_freqs]()), + sub_running_mean_sq_(new complex<float>[num_freqs]()), + variance_(new float[num_freqs]()), + conj_sum_(new float[num_freqs]()), + num_freqs_(num_freqs), + window_size_(window_size), + decay_(decay), + history_cursor_(0), + count_(0), + array_mean_(0.0f), + buffer_full_(false) { + history_.reset(new rtc::scoped_ptr<complex<float>[]>[num_freqs_]()); + for (size_t i = 0; i < num_freqs_; ++i) { + history_[i].reset(new complex<float>[window_size_]()); + } + subhistory_.reset(new rtc::scoped_ptr<complex<float>[]>[num_freqs_]()); + for (size_t i = 0; i < num_freqs_; ++i) { + subhistory_[i].reset(new complex<float>[window_size_]()); + } + subhistory_sq_.reset(new rtc::scoped_ptr<complex<float>[]>[num_freqs_]()); + for (size_t i = 0; i < num_freqs_; ++i) { + subhistory_sq_[i].reset(new complex<float>[window_size_]()); + } + switch (type) { + case kStepInfinite: + step_func_ = &VarianceArray::InfiniteStep; + break; + case kStepDecaying: + step_func_ = &VarianceArray::DecayStep; + break; + case kStepWindowed: + step_func_ = &VarianceArray::WindowedStep; + break; + case kStepBlocked: + step_func_ = &VarianceArray::BlockedStep; + break; + case kStepBlockBasedMovingAverage: + step_func_ = &VarianceArray::BlockBasedMovingAverage; + break; + } +} + +// Compute the variance with Welford's algorithm, adding some fudge to +// the input in case of all-zeroes. +void VarianceArray::InfiniteStep(const complex<float>* data, bool skip_fudge) { + array_mean_ = 0.0f; + ++count_; + for (size_t i = 0; i < num_freqs_; ++i) { + complex<float> sample = data[i]; + if (!skip_fudge) { + sample = zerofudge(sample); + } + if (count_ == 1) { + running_mean_[i] = sample; + variance_[i] = 0.0f; + } else { + float old_sum = conj_sum_[i]; + complex<float> old_mean = running_mean_[i]; + running_mean_[i] = + old_mean + (sample - old_mean) / static_cast<float>(count_); + conj_sum_[i] = + (old_sum + std::conj(sample - old_mean) * (sample - running_mean_[i])) + .real(); + variance_[i] = + conj_sum_[i] / (count_ - 1); + } + array_mean_ += (variance_[i] - array_mean_) / (i + 1); + } +} + +// Compute the variance from the beginning, with exponential decaying of the +// series data. +void VarianceArray::DecayStep(const complex<float>* data, bool /*dummy*/) { + array_mean_ = 0.0f; + ++count_; + for (size_t i = 0; i < num_freqs_; ++i) { + complex<float> sample = data[i]; + sample = zerofudge(sample); + + if (count_ == 1) { + running_mean_[i] = sample; + running_mean_sq_[i] = sample * std::conj(sample); + variance_[i] = 0.0f; + } else { + complex<float> prev = running_mean_[i]; + complex<float> prev2 = running_mean_sq_[i]; + running_mean_[i] = decay_ * prev + (1.0f - decay_) * sample; + running_mean_sq_[i] = + decay_ * prev2 + (1.0f - decay_) * sample * std::conj(sample); + variance_[i] = (running_mean_sq_[i] - + running_mean_[i] * std::conj(running_mean_[i])).real(); + } + + array_mean_ += (variance_[i] - array_mean_) / (i + 1); + } +} + +// Windowed variance computation. On each step, the variances for the +// window are recomputed from scratch, using Welford's algorithm. +void VarianceArray::WindowedStep(const complex<float>* data, bool /*dummy*/) { + size_t num = min(count_ + 1, window_size_); + array_mean_ = 0.0f; + for (size_t i = 0; i < num_freqs_; ++i) { + complex<float> mean; + float conj_sum = 0.0f; + + history_[i][history_cursor_] = data[i]; + + mean = history_[i][history_cursor_]; + variance_[i] = 0.0f; + for (size_t j = 1; j < num; ++j) { + complex<float> sample = + zerofudge(history_[i][(history_cursor_ + j) % window_size_]); + sample = history_[i][(history_cursor_ + j) % window_size_]; + float old_sum = conj_sum; + complex<float> old_mean = mean; + + mean = old_mean + (sample - old_mean) / static_cast<float>(j + 1); + conj_sum = + (old_sum + std::conj(sample - old_mean) * (sample - mean)).real(); + variance_[i] = conj_sum / (j); + } + array_mean_ += (variance_[i] - array_mean_) / (i + 1); + } + history_cursor_ = (history_cursor_ + 1) % window_size_; + ++count_; +} + +// Variance with a window of blocks. Within each block, the variances are +// recomputed from scratch at every stp, using |Var(X) = E(X^2) - E^2(X)|. +// Once a block is filled with kWindowBlockSize samples, it is added to the +// history window and a new block is started. The variances for the window +// are recomputed from scratch at each of these transitions. +void VarianceArray::BlockedStep(const complex<float>* data, bool /*dummy*/) { + size_t blocks = min(window_size_, history_cursor_ + 1); + for (size_t i = 0; i < num_freqs_; ++i) { + AddToMean(data[i], count_ + 1, &sub_running_mean_[i]); + AddToMean(data[i] * std::conj(data[i]), count_ + 1, + &sub_running_mean_sq_[i]); + subhistory_[i][history_cursor_ % window_size_] = sub_running_mean_[i]; + subhistory_sq_[i][history_cursor_ % window_size_] = sub_running_mean_sq_[i]; + + variance_[i] = + (NewMean(running_mean_sq_[i], sub_running_mean_sq_[i], blocks) - + NewMean(running_mean_[i], sub_running_mean_[i], blocks) * + std::conj(NewMean(running_mean_[i], sub_running_mean_[i], blocks))) + .real(); + if (count_ == kWindowBlockSize - 1) { + sub_running_mean_[i] = complex<float>(0.0f, 0.0f); + sub_running_mean_sq_[i] = complex<float>(0.0f, 0.0f); + running_mean_[i] = complex<float>(0.0f, 0.0f); + running_mean_sq_[i] = complex<float>(0.0f, 0.0f); + for (size_t j = 0; j < min(window_size_, history_cursor_); ++j) { + AddToMean(subhistory_[i][j], j + 1, &running_mean_[i]); + AddToMean(subhistory_sq_[i][j], j + 1, &running_mean_sq_[i]); + } + ++history_cursor_; + } + } + ++count_; + if (count_ == kWindowBlockSize) { + count_ = 0; + } +} + +// Recomputes variances for each window from scratch based on previous window. +void VarianceArray::BlockBasedMovingAverage(const std::complex<float>* data, + bool /*dummy*/) { + // TODO(ekmeyerson) To mitigate potential divergence, add counter so that + // after every so often sums are computed scratch by summing over all + // elements instead of subtracting oldest and adding newest. + for (size_t i = 0; i < num_freqs_; ++i) { + sub_running_mean_[i] += data[i]; + sub_running_mean_sq_[i] += data[i] * std::conj(data[i]); + } + ++count_; + + // TODO(ekmeyerson) Make kWindowBlockSize nonconstant to allow + // experimentation with different block size,window size pairs. + if (count_ >= kWindowBlockSize) { + count_ = 0; + + for (size_t i = 0; i < num_freqs_; ++i) { + running_mean_[i] -= subhistory_[i][history_cursor_]; + running_mean_sq_[i] -= subhistory_sq_[i][history_cursor_]; + + float scale = 1.f / kWindowBlockSize; + subhistory_[i][history_cursor_] = sub_running_mean_[i] * scale; + subhistory_sq_[i][history_cursor_] = sub_running_mean_sq_[i] * scale; + + sub_running_mean_[i] = std::complex<float>(0.0f, 0.0f); + sub_running_mean_sq_[i] = std::complex<float>(0.0f, 0.0f); + + running_mean_[i] += subhistory_[i][history_cursor_]; + running_mean_sq_[i] += subhistory_sq_[i][history_cursor_]; + + scale = 1.f / (buffer_full_ ? window_size_ : history_cursor_ + 1); + variance_[i] = std::real(running_mean_sq_[i] * scale - + running_mean_[i] * scale * + std::conj(running_mean_[i]) * scale); + } + + ++history_cursor_; + if (history_cursor_ >= window_size_) { + buffer_full_ = true; + history_cursor_ = 0; + } + } +} + +void VarianceArray::Clear() { + memset(running_mean_.get(), 0, sizeof(*running_mean_.get()) * num_freqs_); + memset(running_mean_sq_.get(), 0, + sizeof(*running_mean_sq_.get()) * num_freqs_); + memset(variance_.get(), 0, sizeof(*variance_.get()) * num_freqs_); + memset(conj_sum_.get(), 0, sizeof(*conj_sum_.get()) * num_freqs_); + history_cursor_ = 0; + count_ = 0; + array_mean_ = 0.0f; +} + +void VarianceArray::ApplyScale(float scale) { + array_mean_ = 0.0f; + for (size_t i = 0; i < num_freqs_; ++i) { + variance_[i] *= scale * scale; + array_mean_ += (variance_[i] - array_mean_) / (i + 1); + } +} + +GainApplier::GainApplier(size_t freqs, float change_limit) + : num_freqs_(freqs), + change_limit_(change_limit), + target_(new float[freqs]()), + current_(new float[freqs]()) { + for (size_t i = 0; i < freqs; ++i) { + target_[i] = 1.0f; + current_[i] = 1.0f; + } +} + +void GainApplier::Apply(const complex<float>* in_block, + complex<float>* out_block) { + for (size_t i = 0; i < num_freqs_; ++i) { + float factor = sqrtf(fabsf(current_[i])); + if (!std::isnormal(factor)) { + factor = 1.0f; + } + out_block[i] = factor * in_block[i]; + current_[i] = UpdateFactor(target_[i], current_[i], change_limit_); + } +} + +} // namespace intelligibility + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h new file mode 100644 index 0000000000..4ac1167147 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Specifies helper classes for intelligibility enhancement. +// + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_UTILS_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_UTILS_H_ + +#include <complex> + +#include "webrtc/base/scoped_ptr.h" + +namespace webrtc { + +namespace intelligibility { + +// Return |current| changed towards |target|, with the change being at most +// |limit|. +float UpdateFactor(float target, float current, float limit); + +// Apply a small fudge to degenerate complex values. The numbers in the array +// were chosen randomly, so that even a series of all zeroes has some small +// variability. +std::complex<float> zerofudge(std::complex<float> c); + +// Incremental mean computation. Return the mean of the series with the +// mean |mean| with added |data|. +std::complex<float> NewMean(std::complex<float> mean, + std::complex<float> data, + size_t count); + +// Updates |mean| with added |data|; +void AddToMean(std::complex<float> data, + size_t count, + std::complex<float>* mean); + +// Internal helper for computing the variances of a stream of arrays. +// The result is an array of variances per position: the i-th variance +// is the variance of the stream of data on the i-th positions in the +// input arrays. +// There are four methods of computation: +// * kStepInfinite computes variances from the beginning onwards +// * kStepDecaying uses a recursive exponential decay formula with a +// settable forgetting factor +// * kStepWindowed computes variances within a moving window +// * kStepBlocked is similar to kStepWindowed, but history is kept +// as a rolling window of blocks: multiple input elements are used for +// one block and the history then consists of the variances of these blocks +// with the same effect as kStepWindowed, but less storage, so the window +// can be longer +class VarianceArray { + public: + enum StepType { + kStepInfinite = 0, + kStepDecaying, + kStepWindowed, + kStepBlocked, + kStepBlockBasedMovingAverage + }; + + // Construct an instance for the given input array length (|freqs|) and + // computation algorithm (|type|), with the appropriate parameters. + // |window_size| is the number of samples for kStepWindowed and + // the number of blocks for kStepBlocked. |decay| is the forgetting factor + // for kStepDecaying. + VarianceArray(size_t freqs, StepType type, size_t window_size, float decay); + + // Add a new data point to the series and compute the new variances. + // TODO(bercic) |skip_fudge| is a flag for kStepWindowed and kStepDecaying, + // whether they should skip adding some small dummy values to the input + // to prevent problems with all-zero inputs. Can probably be removed. + void Step(const std::complex<float>* data, bool skip_fudge = false) { + (this->*step_func_)(data, skip_fudge); + } + // Reset variances to zero and forget all history. + void Clear(); + // Scale the input data by |scale|. Effectively multiply variances + // by |scale^2|. + void ApplyScale(float scale); + + // The current set of variances. + const float* variance() const { return variance_.get(); } + + // The mean value of the current set of variances. + float array_mean() const { return array_mean_; } + + private: + void InfiniteStep(const std::complex<float>* data, bool dummy); + void DecayStep(const std::complex<float>* data, bool dummy); + void WindowedStep(const std::complex<float>* data, bool dummy); + void BlockedStep(const std::complex<float>* data, bool dummy); + void BlockBasedMovingAverage(const std::complex<float>* data, bool dummy); + + // TODO(ekmeyerson): Switch the following running means + // and histories from rtc::scoped_ptr to std::vector. + + // The current average X and X^2. + rtc::scoped_ptr<std::complex<float>[]> running_mean_; + rtc::scoped_ptr<std::complex<float>[]> running_mean_sq_; + + // Average X and X^2 for the current block in kStepBlocked. + rtc::scoped_ptr<std::complex<float>[]> sub_running_mean_; + rtc::scoped_ptr<std::complex<float>[]> sub_running_mean_sq_; + + // Sample history for the rolling window in kStepWindowed and block-wise + // histories for kStepBlocked. + rtc::scoped_ptr<rtc::scoped_ptr<std::complex<float>[]>[]> history_; + rtc::scoped_ptr<rtc::scoped_ptr<std::complex<float>[]>[]> subhistory_; + rtc::scoped_ptr<rtc::scoped_ptr<std::complex<float>[]>[]> subhistory_sq_; + + // The current set of variances and sums for Welford's algorithm. + rtc::scoped_ptr<float[]> variance_; + rtc::scoped_ptr<float[]> conj_sum_; + + const size_t num_freqs_; + const size_t window_size_; + const float decay_; + size_t history_cursor_; + size_t count_; + float array_mean_; + bool buffer_full_; + void (VarianceArray::*step_func_)(const std::complex<float>*, bool); +}; + +// Helper class for smoothing gain changes. On each applicatiion step, the +// currently used gains are changed towards a set of settable target gains, +// constrained by a limit on the magnitude of the changes. +class GainApplier { + public: + GainApplier(size_t freqs, float change_limit); + + // Copy |in_block| to |out_block|, multiplied by the current set of gains, + // and step the current set of gains towards the target set. + void Apply(const std::complex<float>* in_block, + std::complex<float>* out_block); + + // Return the current target gain set. Modify this array to set the targets. + float* target() const { return target_.get(); } + + private: + const size_t num_freqs_; + const float change_limit_; + rtc::scoped_ptr<float[]> target_; + rtc::scoped_ptr<float[]> current_; +}; + +} // namespace intelligibility + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_UTILS_H_ diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc new file mode 100644 index 0000000000..9caa2eb0a1 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Unit tests for intelligibility utils. +// + +#include <math.h> +#include <complex> +#include <iostream> +#include <vector> + +#include "testing/gtest/include/gtest/gtest.h" +#include "webrtc/base/arraysize.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" + +using std::complex; +using std::vector; + +namespace webrtc { + +namespace intelligibility { + +vector<vector<complex<float>>> GenerateTestData(int freqs, int samples) { + vector<vector<complex<float>>> data(samples); + for (int i = 0; i < samples; i++) { + for (int j = 0; j < freqs; j++) { + const float val = 0.99f / ((i + 1) * (j + 1)); + data[i].push_back(complex<float>(val, val)); + } + } + return data; +} + +// Tests UpdateFactor. +TEST(IntelligibilityUtilsTest, TestUpdateFactor) { + EXPECT_EQ(0, intelligibility::UpdateFactor(0, 0, 0)); + EXPECT_EQ(4, intelligibility::UpdateFactor(4, 2, 3)); + EXPECT_EQ(3, intelligibility::UpdateFactor(4, 2, 1)); + EXPECT_EQ(2, intelligibility::UpdateFactor(2, 4, 3)); + EXPECT_EQ(3, intelligibility::UpdateFactor(2, 4, 1)); +} + +// Tests zerofudge. +TEST(IntelligibilityUtilsTest, TestCplx) { + complex<float> t0(1.f, 0.f); + t0 = intelligibility::zerofudge(t0); + EXPECT_NE(t0.imag(), 0.f); + EXPECT_NE(t0.real(), 0.f); +} + +// Tests NewMean and AddToMean. +TEST(IntelligibilityUtilsTest, TestMeanUpdate) { + const complex<float> data[] = {{3, 8}, {7, 6}, {2, 1}, {8, 9}, {0, 6}}; + const complex<float> means[] = {{3, 8}, {5, 7}, {4, 5}, {5, 6}, {4, 6}}; + complex<float> mean(3, 8); + for (size_t i = 0; i < arraysize(data); i++) { + EXPECT_EQ(means[i], NewMean(mean, data[i], i + 1)); + AddToMean(data[i], i + 1, &mean); + EXPECT_EQ(means[i], mean); + } +} + +// Tests VarianceArray, for all variance step types. +TEST(IntelligibilityUtilsTest, TestVarianceArray) { + const int kFreqs = 10; + const int kSamples = 100; + const int kWindowSize = 10; // Should pass for all kWindowSize > 1. + const float kDecay = 0.5f; + vector<VarianceArray::StepType> step_types; + step_types.push_back(VarianceArray::kStepInfinite); + step_types.push_back(VarianceArray::kStepDecaying); + step_types.push_back(VarianceArray::kStepWindowed); + step_types.push_back(VarianceArray::kStepBlocked); + step_types.push_back(VarianceArray::kStepBlockBasedMovingAverage); + const vector<vector<complex<float>>> test_data( + GenerateTestData(kFreqs, kSamples)); + for (auto step_type : step_types) { + VarianceArray variance_array(kFreqs, step_type, kWindowSize, kDecay); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + variance_array.ApplyScale(2.0f); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + + // Makes sure Step is doing something. + variance_array.Step(&test_data[0][0]); + for (int i = 1; i < kSamples; i++) { + variance_array.Step(&test_data[i][0]); + EXPECT_GE(variance_array.array_mean(), 0.0f); + EXPECT_LE(variance_array.array_mean(), 1.0f); + for (int j = 0; j < kFreqs; j++) { + EXPECT_GE(variance_array.variance()[j], 0.0f); + EXPECT_LE(variance_array.variance()[j], 1.0f); + } + } + variance_array.Clear(); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + } +} + +// Tests exact computation on synthetic data. +TEST(IntelligibilityUtilsTest, TestMovingBlockAverage) { + // Exact, not unbiased estimates. + const float kTestVarianceBufferNotFull = 16.5f; + const float kTestVarianceBufferFull1 = 66.5f; + const float kTestVarianceBufferFull2 = 333.375f; + const int kFreqs = 2; + const int kSamples = 50; + const int kWindowSize = 2; + const float kDecay = 0.5f; + const float kMaxError = 0.0001f; + + VarianceArray variance_array( + kFreqs, VarianceArray::kStepBlockBasedMovingAverage, kWindowSize, kDecay); + + vector<vector<complex<float>>> test_data(kSamples); + for (int i = 0; i < kSamples; i++) { + for (int j = 0; j < kFreqs; j++) { + if (i < 30) { + test_data[i].push_back(complex<float>(static_cast<float>(kSamples - i), + static_cast<float>(i + 1))); + } else { + test_data[i].push_back(complex<float>(0.f, 0.f)); + } + } + } + + for (int i = 0; i < kSamples; i++) { + variance_array.Step(&test_data[i][0]); + for (int j = 0; j < kFreqs; j++) { + if (i < 9) { // In utils, kWindowBlockSize = 10. + EXPECT_EQ(0, variance_array.variance()[j]); + } else if (i < 19) { + EXPECT_NEAR(kTestVarianceBufferNotFull, variance_array.variance()[j], + kMaxError); + } else if (i < 39) { + EXPECT_NEAR(kTestVarianceBufferFull1, variance_array.variance()[j], + kMaxError); + } else if (i < 49) { + EXPECT_NEAR(kTestVarianceBufferFull2, variance_array.variance()[j], + kMaxError); + } else { + EXPECT_EQ(0, variance_array.variance()[j]); + } + } + } +} + +// Tests gain applier. +TEST(IntelligibilityUtilsTest, TestGainApplier) { + const int kFreqs = 10; + const int kSamples = 100; + const float kChangeLimit = 0.1f; + GainApplier gain_applier(kFreqs, kChangeLimit); + const vector<vector<complex<float>>> in_data( + GenerateTestData(kFreqs, kSamples)); + vector<vector<complex<float>>> out_data(GenerateTestData(kFreqs, kSamples)); + for (int i = 0; i < kSamples; i++) { + gain_applier.Apply(&in_data[i][0], &out_data[i][0]); + for (int j = 0; j < kFreqs; j++) { + EXPECT_GT(out_data[i][j].real(), 0.0f); + EXPECT_LT(out_data[i][j].real(), 1.0f); + EXPECT_GT(out_data[i][j].imag(), 0.0f); + EXPECT_LT(out_data[i][j].imag(), 1.0f); + } + } +} + +} // namespace intelligibility + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc new file mode 100644 index 0000000000..27d0ab48bb --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Command line tool for speech intelligibility enhancement. Provides for +// running and testing intelligibility_enhancer as an independent process. +// Use --help for options. +// + +#include <stdint.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <string> + +#include "gflags/gflags.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "webrtc/base/checks.h" +#include "webrtc/common_audio/real_fourier.h" +#include "webrtc/common_audio/wav_file.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" +#include "webrtc/system_wrappers/include/critical_section_wrapper.h" +#include "webrtc/test/testsupport/fileutils.h" + +using std::complex; +using webrtc::intelligibility::VarianceArray; + +namespace webrtc { +namespace { + +bool ValidateClearWindow(const char* flagname, int32_t value) { + return value > 0; +} + +DEFINE_int32(clear_type, + webrtc::intelligibility::VarianceArray::kStepDecaying, + "Variance algorithm for clear data."); +DEFINE_double(clear_alpha, 0.9, "Variance decay factor for clear data."); +DEFINE_int32(clear_window, + 475, + "Window size for windowed variance for clear data."); +const bool clear_window_dummy = + google::RegisterFlagValidator(&FLAGS_clear_window, &ValidateClearWindow); +DEFINE_int32(sample_rate, + 16000, + "Audio sample rate used in the input and output files."); +DEFINE_int32(ana_rate, + 800, + "Analysis rate; gains recalculated every N blocks."); +DEFINE_int32( + var_rate, + 2, + "Variance clear rate; history is forgotten every N gain recalculations."); +DEFINE_double(gain_limit, 1000.0, "Maximum gain change in one block."); + +DEFINE_string(clear_file, "speech.wav", "Input file with clear speech."); +DEFINE_string(noise_file, "noise.wav", "Input file with noise data."); +DEFINE_string(out_file, + "proc_enhanced.wav", + "Enhanced output. Use '-' to " + "play through aplay immediately."); + +const int kNumChannels = 1; + +// void function for gtest +void void_main(int argc, char* argv[]) { + google::SetUsageMessage( + "\n\nVariance algorithm types are:\n" + " 0 - infinite/normal,\n" + " 1 - exponentially decaying,\n" + " 2 - rolling window.\n" + "\nInput files must be little-endian 16-bit signed raw PCM.\n"); + google::ParseCommandLineFlags(&argc, &argv, true); + + size_t samples; // Number of samples in input PCM file + size_t fragment_size; // Number of samples to process at a time + // to simulate APM stream processing + + // Load settings and wav input. + + fragment_size = FLAGS_sample_rate / 100; // Mirror real time APM chunk size. + // Duplicates chunk_length_ in + // IntelligibilityEnhancer. + + struct stat in_stat, noise_stat; + ASSERT_EQ(stat(FLAGS_clear_file.c_str(), &in_stat), 0) + << "Empty speech file."; + ASSERT_EQ(stat(FLAGS_noise_file.c_str(), &noise_stat), 0) + << "Empty noise file."; + + samples = std::min(in_stat.st_size, noise_stat.st_size) / 2; + + WavReader in_file(FLAGS_clear_file); + std::vector<float> in_fpcm(samples); + in_file.ReadSamples(samples, &in_fpcm[0]); + + WavReader noise_file(FLAGS_noise_file); + std::vector<float> noise_fpcm(samples); + noise_file.ReadSamples(samples, &noise_fpcm[0]); + + // Run intelligibility enhancement. + IntelligibilityEnhancer::Config config; + config.sample_rate_hz = FLAGS_sample_rate; + config.var_type = static_cast<VarianceArray::StepType>(FLAGS_clear_type); + config.var_decay_rate = static_cast<float>(FLAGS_clear_alpha); + config.var_window_size = static_cast<size_t>(FLAGS_clear_window); + config.analysis_rate = FLAGS_ana_rate; + config.gain_change_limit = FLAGS_gain_limit; + IntelligibilityEnhancer enh(config); + + // Slice the input into smaller chunks, as the APM would do, and feed them + // through the enhancer. + float* clear_cursor = &in_fpcm[0]; + float* noise_cursor = &noise_fpcm[0]; + + for (size_t i = 0; i < samples; i += fragment_size) { + enh.AnalyzeCaptureAudio(&noise_cursor, FLAGS_sample_rate, kNumChannels); + enh.ProcessRenderAudio(&clear_cursor, FLAGS_sample_rate, kNumChannels); + clear_cursor += fragment_size; + noise_cursor += fragment_size; + } + + if (FLAGS_out_file.compare("-") == 0) { + const std::string temp_out_filename = + test::TempFilename(test::WorkingDir(), "temp_wav_file"); + { + WavWriter out_file(temp_out_filename, FLAGS_sample_rate, kNumChannels); + out_file.WriteSamples(&in_fpcm[0], samples); + } + system(("aplay " + temp_out_filename).c_str()); + system(("rm " + temp_out_filename).c_str()); + } else { + WavWriter out_file(FLAGS_out_file, FLAGS_sample_rate, kNumChannels); + out_file.WriteSamples(&in_fpcm[0], samples); + } +} + +} // namespace +} // namespace webrtc + +int main(int argc, char* argv[]) { + webrtc::void_main(argc, argv); + return 0; +} |