tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

subtractor_unittest.cc (14115B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/subtractor.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <memory>
     17 #include <numeric>
     18 #include <optional>
     19 #include <string>
     20 #include <tuple>
     21 #include <vector>
     22 
     23 #include "api/array_view.h"
     24 #include "api/audio/echo_canceller3_config.h"
     25 #include "api/environment/environment.h"
     26 #include "api/environment/environment_factory.h"
     27 #include "modules/audio_processing/aec3/aec3_common.h"
     28 #include "modules/audio_processing/aec3/aec3_fft.h"
     29 #include "modules/audio_processing/aec3/aec_state.h"
     30 #include "modules/audio_processing/aec3/block.h"
     31 #include "modules/audio_processing/aec3/delay_estimate.h"
     32 #include "modules/audio_processing/aec3/echo_path_variability.h"
     33 #include "modules/audio_processing/aec3/render_delay_buffer.h"
     34 #include "modules/audio_processing/aec3/render_signal_analyzer.h"
     35 #include "modules/audio_processing/aec3/subtractor_output.h"
     36 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
     37 #include "modules/audio_processing/utility/cascaded_biquad_filter.h"
     38 #include "rtc_base/checks.h"
     39 #include "rtc_base/random.h"
     40 #include "rtc_base/strings/string_builder.h"
     41 #include "test/gtest.h"
     42 
     43 namespace webrtc {
     44 namespace {
     45 
     46 std::vector<float> RunSubtractorTest(
     47    const Environment& env,
     48    size_t num_render_channels,
     49    size_t num_capture_channels,
     50    int num_blocks_to_process,
     51    int delay_samples,
     52    int refined_filter_length_blocks,
     53    int coarse_filter_length_blocks,
     54    bool uncorrelated_inputs,
     55    const std::vector<int>& blocks_with_echo_path_changes) {
     56  ApmDataDumper data_dumper(42);
     57  constexpr int kSampleRateHz = 48000;
     58  constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
     59  EchoCanceller3Config config;
     60  config.filter.refined.length_blocks = refined_filter_length_blocks;
     61  config.filter.coarse.length_blocks = coarse_filter_length_blocks;
     62 
     63  Subtractor subtractor(env, config, num_render_channels, num_capture_channels,
     64                        &data_dumper, DetectOptimization());
     65  std::optional<DelayEstimate> delay_estimate;
     66  Block x(kNumBands, num_render_channels);
     67  Block y(/*num_bands=*/1, num_capture_channels);
     68  std::array<float, kBlockSize> x_old;
     69  std::vector<SubtractorOutput> output(num_capture_channels);
     70  config.delay.default_delay = 1;
     71  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
     72      RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
     73  RenderSignalAnalyzer render_signal_analyzer(config);
     74  Random random_generator(42U);
     75  Aec3Fft fft;
     76  std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
     77  std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
     78      num_capture_channels);
     79  std::array<float, kFftLengthBy2Plus1> E2_coarse;
     80  AecState aec_state(env, config, num_capture_channels);
     81  x_old.fill(0.f);
     82  for (auto& Y2_ch : Y2) {
     83    Y2_ch.fill(0.f);
     84  }
     85  for (auto& E2_refined_ch : E2_refined) {
     86    E2_refined_ch.fill(0.f);
     87  }
     88  E2_coarse.fill(0.f);
     89 
     90  std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
     91      num_capture_channels);
     92  for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
     93    delay_buffer[capture_ch].resize(num_render_channels);
     94    for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
     95      delay_buffer[capture_ch][render_ch] =
     96          std::make_unique<DelayBuffer<float>>(delay_samples);
     97    }
     98  }
     99 
    100  // [B,A] = butter(2,100/8000,'high')
    101  constexpr std::array<CascadedBiQuadFilter::BiQuadCoefficients, 1>
    102      kHighPassFilterCoefficients = {{
    103          {.b = {0.97261f, -1.94523f, 0.97261f}, .a = {-1.94448f, 0.94598f}},
    104      }};
    105  std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
    106      num_render_channels);
    107  for (size_t ch = 0; ch < num_render_channels; ++ch) {
    108    x_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>(
    109        ArrayView<const CascadedBiQuadFilter::BiQuadCoefficients>(
    110            kHighPassFilterCoefficients));
    111  }
    112  std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
    113      num_capture_channels);
    114  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    115    y_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>(
    116        ArrayView<const CascadedBiQuadFilter::BiQuadCoefficients>(
    117            kHighPassFilterCoefficients));
    118  }
    119 
    120  for (int block_num = 0; block_num < num_blocks_to_process; ++block_num) {
    121    for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
    122      RandomizeSampleVector(&random_generator, x.View(/*band=*/0, render_ch));
    123    }
    124    if (uncorrelated_inputs) {
    125      for (size_t capture_ch = 0; capture_ch < num_capture_channels;
    126           ++capture_ch) {
    127        RandomizeSampleVector(&random_generator,
    128                              y.View(/*band=*/0, capture_ch));
    129      }
    130    } else {
    131      for (size_t capture_ch = 0; capture_ch < num_capture_channels;
    132           ++capture_ch) {
    133        ArrayView<float> y_view = y.View(/*band=*/0, capture_ch);
    134        for (size_t render_ch = 0; render_ch < num_render_channels;
    135             ++render_ch) {
    136          std::array<float, kBlockSize> y_channel;
    137          delay_buffer[capture_ch][render_ch]->Delay(
    138              x.View(/*band=*/0, render_ch), y_channel);
    139          for (size_t k = 0; k < kBlockSize; ++k) {
    140            y_view[k] += y_channel[k] / num_render_channels;
    141          }
    142        }
    143      }
    144    }
    145    for (size_t ch = 0; ch < num_render_channels; ++ch) {
    146      x_hp_filter[ch]->Process(x.View(/*band=*/0, ch));
    147    }
    148    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    149      y_hp_filter[ch]->Process(y.View(/*band=*/0, ch));
    150    }
    151 
    152    render_delay_buffer->Insert(x);
    153    if (block_num == 0) {
    154      render_delay_buffer->Reset();
    155    }
    156    render_delay_buffer->PrepareCaptureProcessing();
    157    render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
    158                                  aec_state.MinDirectPathFilterDelay());
    159 
    160    // Handle echo path changes.
    161    if (std::find(blocks_with_echo_path_changes.begin(),
    162                  blocks_with_echo_path_changes.end(),
    163                  block_num) != blocks_with_echo_path_changes.end()) {
    164      subtractor.HandleEchoPathChange(EchoPathVariability(
    165          true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay,
    166          false));
    167    }
    168    subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
    169                       render_signal_analyzer, aec_state, output);
    170 
    171    aec_state.HandleEchoPathChange(EchoPathVariability(
    172        false, EchoPathVariability::DelayAdjustment::kNone, false));
    173    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
    174                     subtractor.FilterImpulseResponses(),
    175                     *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
    176                     output);
    177  }
    178 
    179  std::vector<float> results(num_capture_channels);
    180  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    181    const float output_power = std::inner_product(
    182        output[ch].e_refined.begin(), output[ch].e_refined.end(),
    183        output[ch].e_refined.begin(), 0.f);
    184    const float y_power =
    185        std::inner_product(y.begin(/*band=*/0, ch), y.end(/*band=*/0, ch),
    186                           y.begin(/*band=*/0, ch), 0.f);
    187    if (y_power == 0.f) {
    188      ADD_FAILURE();
    189      results[ch] = -1.f;
    190    }
    191    results[ch] = output_power / y_power;
    192  }
    193  return results;
    194 }
    195 
    196 std::string ProduceDebugText(size_t num_render_channels,
    197                             size_t num_capture_channels,
    198                             size_t delay,
    199                             int filter_length_blocks) {
    200  StringBuilder ss;
    201  ss << "delay: " << delay << ", ";
    202  ss << "filter_length_blocks:" << filter_length_blocks << ", ";
    203  ss << "num_render_channels:" << num_render_channels << ", ";
    204  ss << "num_capture_channels:" << num_capture_channels;
    205  return ss.Release();
    206 }
    207 
    208 }  // namespace
    209 
    210 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
    211 
    212 // Verifies that the check for non data dumper works.
    213 TEST(SubtractorDeathTest, NullDataDumper) {
    214  EXPECT_DEATH(Subtractor(CreateEnvironment(), EchoCanceller3Config(), 1, 1,
    215                          nullptr, DetectOptimization()),
    216               "");
    217 }
    218 
    219 #endif
    220 
    221 // Verifies that the subtractor is able to converge on correlated data.
    222 TEST(Subtractor, Convergence) {
    223  const Environment env = CreateEnvironment();
    224  std::vector<int> blocks_with_echo_path_changes;
    225  for (size_t filter_length_blocks : {12, 20, 30}) {
    226    for (size_t delay_samples : {0, 64, 150, 200, 301}) {
    227      SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
    228      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    229          env, 1, 1, 2500, delay_samples, filter_length_blocks,
    230          filter_length_blocks, false, blocks_with_echo_path_changes);
    231 
    232      for (float echo_to_nearend_power : echo_to_nearend_powers) {
    233        EXPECT_GT(0.1f, echo_to_nearend_power);
    234      }
    235    }
    236  }
    237 }
    238 
    239 // Verifies that the subtractor is able to handle the case when the refined
    240 // filter is longer than the coarse filter.
    241 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) {
    242  const Environment env = CreateEnvironment();
    243  std::vector<int> blocks_with_echo_path_changes;
    244  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    245      env, 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
    246  for (float echo_to_nearend_power : echo_to_nearend_powers) {
    247    EXPECT_GT(0.5f, echo_to_nearend_power);
    248  }
    249 }
    250 
    251 // Verifies that the subtractor is able to handle the case when the coarse
    252 // filter is longer than the refined filter.
    253 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) {
    254  const Environment env = CreateEnvironment();
    255  std::vector<int> blocks_with_echo_path_changes;
    256  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    257      env, 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
    258  for (float echo_to_nearend_power : echo_to_nearend_powers) {
    259    EXPECT_GT(0.5f, echo_to_nearend_power);
    260  }
    261 }
    262 
    263 // Verifies that the subtractor does not converge on uncorrelated signals.
    264 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
    265  const Environment env = CreateEnvironment();
    266  std::vector<int> blocks_with_echo_path_changes;
    267  for (size_t filter_length_blocks : {12, 20, 30}) {
    268    for (size_t delay_samples : {0, 64, 150, 200, 301}) {
    269      SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
    270 
    271      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    272          env, 1, 1, 3000, delay_samples, filter_length_blocks,
    273          filter_length_blocks, true, blocks_with_echo_path_changes);
    274      for (float echo_to_nearend_power : echo_to_nearend_powers) {
    275        EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
    276      }
    277    }
    278  }
    279 }
    280 
    281 class SubtractorMultiChannelUpToEightRender
    282    : public ::testing::Test,
    283      public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
    284 
    285 #if defined(NDEBUG)
    286 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
    287                         SubtractorMultiChannelUpToEightRender,
    288                         ::testing::Combine(::testing::Values(1, 2, 8),
    289                                            ::testing::Values(1, 2, 4)));
    290 #else
    291 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
    292                         SubtractorMultiChannelUpToEightRender,
    293                         ::testing::Combine(::testing::Values(1, 2),
    294                                            ::testing::Values(1, 2)));
    295 #endif
    296 
    297 // Verifies that the subtractor is able to converge on correlated data.
    298 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
    299  const size_t num_render_channels = std::get<0>(GetParam());
    300  const size_t num_capture_channels = std::get<1>(GetParam());
    301  const Environment env = CreateEnvironment();
    302 
    303  std::vector<int> blocks_with_echo_path_changes;
    304  size_t num_blocks_to_process = 2500 * num_render_channels;
    305  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    306      env, num_render_channels, num_capture_channels, num_blocks_to_process, 64,
    307      20, 20, false, blocks_with_echo_path_changes);
    308 
    309  for (float echo_to_nearend_power : echo_to_nearend_powers) {
    310    EXPECT_GT(0.1f, echo_to_nearend_power);
    311  }
    312 }
    313 
    314 class SubtractorMultiChannelUpToFourRender
    315    : public ::testing::Test,
    316      public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
    317 
    318 #if defined(NDEBUG)
    319 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
    320                         SubtractorMultiChannelUpToFourRender,
    321                         ::testing::Combine(::testing::Values(1, 2, 4),
    322                                            ::testing::Values(1, 2, 4)));
    323 #else
    324 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
    325                         SubtractorMultiChannelUpToFourRender,
    326                         ::testing::Combine(::testing::Values(1, 2),
    327                                            ::testing::Values(1, 2)));
    328 #endif
    329 
    330 // Verifies that the subtractor does not converge on uncorrelated signals.
    331 TEST_P(SubtractorMultiChannelUpToFourRender,
    332       NonConvergenceOnUncorrelatedSignals) {
    333  const size_t num_render_channels = std::get<0>(GetParam());
    334  const size_t num_capture_channels = std::get<1>(GetParam());
    335  const Environment env = CreateEnvironment();
    336 
    337  std::vector<int> blocks_with_echo_path_changes;
    338  size_t num_blocks_to_process = 5000 * num_render_channels;
    339  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
    340      env, num_render_channels, num_capture_channels, num_blocks_to_process, 64,
    341      20, 20, true, blocks_with_echo_path_changes);
    342  for (float echo_to_nearend_power : echo_to_nearend_powers) {
    343    EXPECT_LT(.8f, echo_to_nearend_power);
    344    EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
    345  }
    346 }
    347 }  // namespace webrtc