tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

aec_state_unittest.cc (12109B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/aec_state.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <memory>
     17 #include <optional>
     18 #include <tuple>
     19 #include <vector>
     20 
     21 #include "api/audio/echo_canceller3_config.h"
     22 #include "api/environment/environment_factory.h"
     23 #include "modules/audio_processing/aec3/aec3_common.h"
     24 #include "modules/audio_processing/aec3/aec3_fft.h"
     25 #include "modules/audio_processing/aec3/block.h"
     26 #include "modules/audio_processing/aec3/delay_estimate.h"
     27 #include "modules/audio_processing/aec3/echo_path_variability.h"
     28 #include "modules/audio_processing/aec3/render_delay_buffer.h"
     29 #include "modules/audio_processing/aec3/subtractor_output.h"
     30 #include "modules/audio_processing/logging/apm_data_dumper.h"
     31 #include "test/gtest.h"
     32 
     33 namespace webrtc {
     34 namespace {
     35 
     36 void RunNormalUsageTest(size_t num_render_channels,
     37                        size_t num_capture_channels) {
     38  // TODO(bugs.webrtc.org/10913): Test with different content in different
     39  // channels.
     40  constexpr int kSampleRateHz = 48000;
     41  constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
     42  ApmDataDumper data_dumper(42);
     43  EchoCanceller3Config config;
     44  AecState state(CreateEnvironment(), config, num_capture_channels);
     45  std::optional<DelayEstimate> delay_estimate =
     46      DelayEstimate(DelayEstimate::Quality::kRefined, 10);
     47  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
     48      RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
     49  std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
     50      num_capture_channels);
     51  std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
     52  Block x(kNumBands, num_render_channels);
     53  EchoPathVariability echo_path_variability(
     54      false, EchoPathVariability::DelayAdjustment::kNone, false);
     55  std::vector<std::array<float, kBlockSize>> y(num_capture_channels);
     56  std::vector<SubtractorOutput> subtractor_output(num_capture_channels);
     57  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
     58    subtractor_output[ch].Reset();
     59    subtractor_output[ch].s_refined.fill(100.f);
     60    subtractor_output[ch].e_refined.fill(100.f);
     61    y[ch].fill(1000.f);
     62    E2_refined[ch].fill(0.f);
     63    Y2[ch].fill(0.f);
     64  }
     65  Aec3Fft fft;
     66  std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
     67      converged_filter_frequency_response(
     68          num_capture_channels,
     69          std::vector<std::array<float, kFftLengthBy2Plus1>>(10));
     70  for (auto& v_ch : converged_filter_frequency_response) {
     71    for (auto& v : v_ch) {
     72      v.fill(0.01f);
     73    }
     74  }
     75  std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
     76      diverged_filter_frequency_response = converged_filter_frequency_response;
     77  converged_filter_frequency_response[0][2].fill(100.f);
     78  converged_filter_frequency_response[0][2][0] = 1.f;
     79  std::vector<std::vector<float>> impulse_response(
     80      num_capture_channels,
     81      std::vector<float>(
     82          GetTimeDomainLength(config.filter.refined.length_blocks), 0.f));
     83 
     84  // Verify that linear AEC usability is true when the filter is converged
     85  for (size_t band = 0; band < kNumBands; ++band) {
     86    for (size_t ch = 0; ch < num_render_channels; ++ch) {
     87      std::fill(x.begin(band, ch), x.end(band, ch), 101.f);
     88    }
     89  }
     90  for (int k = 0; k < 3000; ++k) {
     91    render_delay_buffer->Insert(x);
     92    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
     93      subtractor_output[ch].ComputeMetrics(y[ch]);
     94    }
     95    state.Update(delay_estimate, converged_filter_frequency_response,
     96                 impulse_response, *render_delay_buffer->GetRenderBuffer(),
     97                 E2_refined, Y2, subtractor_output);
     98  }
     99  EXPECT_TRUE(state.UsableLinearEstimate());
    100 
    101  // Verify that linear AEC usability becomes false after an echo path
    102  // change is reported
    103  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    104    subtractor_output[ch].ComputeMetrics(y[ch]);
    105  }
    106  state.HandleEchoPathChange(EchoPathVariability(
    107      false, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false));
    108  state.Update(delay_estimate, converged_filter_frequency_response,
    109               impulse_response, *render_delay_buffer->GetRenderBuffer(),
    110               E2_refined, Y2, subtractor_output);
    111  EXPECT_FALSE(state.UsableLinearEstimate());
    112 
    113  // Verify that the active render detection works as intended.
    114  for (size_t ch = 0; ch < num_render_channels; ++ch) {
    115    std::fill(x.begin(0, ch), x.end(0, ch), 101.f);
    116  }
    117  render_delay_buffer->Insert(x);
    118  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    119    subtractor_output[ch].ComputeMetrics(y[ch]);
    120  }
    121  state.HandleEchoPathChange(EchoPathVariability(
    122      true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false));
    123  state.Update(delay_estimate, converged_filter_frequency_response,
    124               impulse_response, *render_delay_buffer->GetRenderBuffer(),
    125               E2_refined, Y2, subtractor_output);
    126  EXPECT_FALSE(state.ActiveRender());
    127 
    128  for (int k = 0; k < 1000; ++k) {
    129    render_delay_buffer->Insert(x);
    130    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    131      subtractor_output[ch].ComputeMetrics(y[ch]);
    132    }
    133    state.Update(delay_estimate, converged_filter_frequency_response,
    134                 impulse_response, *render_delay_buffer->GetRenderBuffer(),
    135                 E2_refined, Y2, subtractor_output);
    136  }
    137  EXPECT_TRUE(state.ActiveRender());
    138 
    139  // Verify that the ERL is properly estimated
    140  for (int band = 0; band < x.NumBands(); ++band) {
    141    for (int channel = 0; channel < x.NumChannels(); ++channel) {
    142      std::fill(x.begin(band, channel), x.end(band, channel), 0.0f);
    143    }
    144  }
    145 
    146  for (size_t ch = 0; ch < num_render_channels; ++ch) {
    147    x.View(/*band=*/0, ch)[0] = 5000.f;
    148  }
    149  for (size_t k = 0;
    150       k < render_delay_buffer->GetRenderBuffer()->GetFftBuffer().size(); ++k) {
    151    render_delay_buffer->Insert(x);
    152    if (k == 0) {
    153      render_delay_buffer->Reset();
    154    }
    155    render_delay_buffer->PrepareCaptureProcessing();
    156  }
    157 
    158  for (auto& Y2_ch : Y2) {
    159    Y2_ch.fill(10.f * 10000.f * 10000.f);
    160  }
    161  for (size_t k = 0; k < 1000; ++k) {
    162    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    163      subtractor_output[ch].ComputeMetrics(y[ch]);
    164    }
    165    state.Update(delay_estimate, converged_filter_frequency_response,
    166                 impulse_response, *render_delay_buffer->GetRenderBuffer(),
    167                 E2_refined, Y2, subtractor_output);
    168  }
    169 
    170  ASSERT_TRUE(state.UsableLinearEstimate());
    171  const std::array<float, kFftLengthBy2Plus1>& erl = state.Erl();
    172  EXPECT_EQ(erl[0], erl[1]);
    173  for (size_t k = 1; k < erl.size() - 1; ++k) {
    174    EXPECT_NEAR(k % 2 == 0 ? 10.f : 1000.f, erl[k], 0.1);
    175  }
    176  EXPECT_EQ(erl[erl.size() - 2], erl[erl.size() - 1]);
    177 
    178  // Verify that the ERLE is properly estimated
    179  for (auto& E2_refined_ch : E2_refined) {
    180    E2_refined_ch.fill(1.f * 10000.f * 10000.f);
    181  }
    182  for (auto& Y2_ch : Y2) {
    183    Y2_ch.fill(10.f * E2_refined[0][0]);
    184  }
    185  for (size_t k = 0; k < 1000; ++k) {
    186    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    187      subtractor_output[ch].ComputeMetrics(y[ch]);
    188    }
    189    state.Update(delay_estimate, converged_filter_frequency_response,
    190                 impulse_response, *render_delay_buffer->GetRenderBuffer(),
    191                 E2_refined, Y2, subtractor_output);
    192  }
    193  ASSERT_TRUE(state.UsableLinearEstimate());
    194  {
    195    // Note that the render spectrum is built so it does not have energy in
    196    // the odd bands but just in the even bands.
    197    const auto& erle = state.Erle(/*onset_compensated=*/true)[0];
    198    EXPECT_EQ(erle[0], erle[1]);
    199    constexpr size_t kLowFrequencyLimit = 32;
    200    for (size_t k = 2; k < kLowFrequencyLimit; k = k + 2) {
    201      EXPECT_NEAR(4.f, erle[k], 0.1);
    202    }
    203    for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; k = k + 2) {
    204      EXPECT_NEAR(1.5f, erle[k], 0.1);
    205    }
    206    EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]);
    207  }
    208  for (auto& E2_refined_ch : E2_refined) {
    209    E2_refined_ch.fill(1.f * 10000.f * 10000.f);
    210  }
    211  for (auto& Y2_ch : Y2) {
    212    Y2_ch.fill(5.f * E2_refined[0][0]);
    213  }
    214  for (size_t k = 0; k < 1000; ++k) {
    215    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
    216      subtractor_output[ch].ComputeMetrics(y[ch]);
    217    }
    218    state.Update(delay_estimate, converged_filter_frequency_response,
    219                 impulse_response, *render_delay_buffer->GetRenderBuffer(),
    220                 E2_refined, Y2, subtractor_output);
    221  }
    222 
    223  ASSERT_TRUE(state.UsableLinearEstimate());
    224  {
    225    const auto& erle = state.Erle(/*onset_compensated=*/true)[0];
    226    EXPECT_EQ(erle[0], erle[1]);
    227    constexpr size_t kLowFrequencyLimit = 32;
    228    for (size_t k = 1; k < kLowFrequencyLimit; ++k) {
    229      EXPECT_NEAR(k % 2 == 0 ? 4.f : 1.f, erle[k], 0.1);
    230    }
    231    for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; ++k) {
    232      EXPECT_NEAR(k % 2 == 0 ? 1.5f : 1.f, erle[k], 0.1);
    233    }
    234    EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]);
    235  }
    236 }
    237 
    238 }  // namespace
    239 
    240 class AecStateMultiChannel
    241    : public ::testing::Test,
    242      public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
    243 
    244 INSTANTIATE_TEST_SUITE_P(MultiChannel,
    245                         AecStateMultiChannel,
    246                         ::testing::Combine(::testing::Values(1, 2, 8),
    247                                            ::testing::Values(1, 2, 8)));
    248 
    249 // Verify the general functionality of AecState
    250 TEST_P(AecStateMultiChannel, NormalUsage) {
    251  const size_t num_render_channels = std::get<0>(GetParam());
    252  const size_t num_capture_channels = std::get<1>(GetParam());
    253  RunNormalUsageTest(num_render_channels, num_capture_channels);
    254 }
    255 
    256 // Verifies the delay for a converged filter is correctly identified.
    257 TEST(AecState, ConvergedFilterDelay) {
    258  constexpr int kFilterLengthBlocks = 10;
    259  constexpr size_t kNumCaptureChannels = 1;
    260  EchoCanceller3Config config;
    261  AecState state(CreateEnvironment(), config, kNumCaptureChannels);
    262  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
    263      RenderDelayBuffer::Create(config, 48000, 1));
    264  std::optional<DelayEstimate> delay_estimate;
    265  std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
    266      kNumCaptureChannels);
    267  std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
    268  std::array<float, kBlockSize> x;
    269  EchoPathVariability echo_path_variability(
    270      false, EchoPathVariability::DelayAdjustment::kNone, false);
    271  std::vector<SubtractorOutput> subtractor_output(kNumCaptureChannels);
    272  for (auto& output : subtractor_output) {
    273    output.Reset();
    274    output.s_refined.fill(100.f);
    275  }
    276  std::array<float, kBlockSize> y;
    277  x.fill(0.f);
    278  y.fill(0.f);
    279 
    280  std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
    281      frequency_response(kNumCaptureChannels,
    282                         std::vector<std::array<float, kFftLengthBy2Plus1>>(
    283                             kFilterLengthBlocks));
    284  for (auto& v_ch : frequency_response) {
    285    for (auto& v : v_ch) {
    286      v.fill(0.01f);
    287    }
    288  }
    289 
    290  std::vector<std::vector<float>> impulse_response(
    291      kNumCaptureChannels,
    292      std::vector<float>(
    293          GetTimeDomainLength(config.filter.refined.length_blocks), 0.f));
    294 
    295  // Verify that the filter delay for a converged filter is properly
    296  // identified.
    297  for (int k = 0; k < kFilterLengthBlocks; ++k) {
    298    for (auto& ir : impulse_response) {
    299      std::fill(ir.begin(), ir.end(), 0.f);
    300      ir[k * kBlockSize + 1] = 1.f;
    301    }
    302 
    303    state.HandleEchoPathChange(echo_path_variability);
    304    subtractor_output[0].ComputeMetrics(y);
    305    state.Update(delay_estimate, frequency_response, impulse_response,
    306                 *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
    307                 subtractor_output);
    308  }
    309 }
    310 
    311 }  // namespace webrtc