tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

test_utils.cc (5425B)


      1 /*
      2 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
     12 
     13 #include <algorithm>
     14 #include <cstdint>
     15 #include <fstream>
     16 #include <ios>
     17 #include <memory>
     18 #include <string>
     19 #include <type_traits>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "absl/strings/string_view.h"
     24 #include "api/array_view.h"
     25 #include "modules/audio_processing/agc2/rnn_vad/common.h"
     26 #include "rtc_base/checks.h"
     27 #include "rtc_base/numerics/safe_compare.h"
     28 #include "test/gtest.h"
     29 #include "test/testsupport/file_utils.h"
     30 
     31 namespace webrtc {
     32 namespace rnn_vad {
     33 namespace {
     34 
     35 // File reader for binary files that contain a sequence of values with
     36 // arithmetic type `T`. The values of type `T` that are read are cast to float.
     37 template <typename T>
     38 class FloatFileReader : public FileReader {
     39 public:
     40  static_assert(std::is_arithmetic<T>::value, "");
     41  explicit FloatFileReader(absl::string_view filename)
     42      : is_(std::string(filename), std::ios::binary | std::ios::ate),
     43        size_(is_.tellg() / sizeof(T)) {
     44    RTC_CHECK(is_);
     45    SeekBeginning();
     46  }
     47  FloatFileReader(const FloatFileReader&) = delete;
     48  FloatFileReader& operator=(const FloatFileReader&) = delete;
     49  ~FloatFileReader() override = default;
     50 
     51  int size() const override { return size_; }
     52  bool ReadChunk(ArrayView<float> dst) override {
     53    const std::streamsize bytes_to_read = dst.size() * sizeof(T);
     54    if (std::is_same<T, float>::value) {
     55      is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
     56    } else {
     57      buffer_.resize(dst.size());
     58      is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
     59      std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
     60                     [](const T& v) -> float { return static_cast<float>(v); });
     61    }
     62    return is_.gcount() == bytes_to_read;
     63  }
     64  bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
     65  void SeekForward(int hop) override {
     66    is_.seekg(hop * sizeof(T), std::ifstream::cur);
     67  }
     68  void SeekBeginning() override { is_.seekg(0, std::ifstream::beg); }
     69 
     70 private:
     71  std::ifstream is_;
     72  const int size_;
     73  std::vector<T> buffer_;
     74 };
     75 
     76 }  // namespace
     77 
     78 using test::ResourcePath;
     79 
     80 void ExpectEqualFloatArray(ArrayView<const float> expected,
     81                           ArrayView<const float> computed) {
     82  ASSERT_EQ(expected.size(), computed.size());
     83  for (int i = 0; SafeLt(i, expected.size()); ++i) {
     84    SCOPED_TRACE(i);
     85    EXPECT_FLOAT_EQ(expected[i], computed[i]);
     86  }
     87 }
     88 
     89 void ExpectNearAbsolute(ArrayView<const float> expected,
     90                        ArrayView<const float> computed,
     91                        float tolerance) {
     92  ASSERT_EQ(expected.size(), computed.size());
     93  for (int i = 0; SafeLt(i, expected.size()); ++i) {
     94    SCOPED_TRACE(i);
     95    EXPECT_NEAR(expected[i], computed[i], tolerance);
     96  }
     97 }
     98 
     99 std::unique_ptr<FileReader> CreatePcmSamplesReader() {
    100  return std::make_unique<FloatFileReader<int16_t>>(
    101      /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
    102                                      "pcm"));
    103 }
    104 
    105 ChunksFileReader CreatePitchBuffer24kHzReader() {
    106  auto reader = std::make_unique<FloatFileReader<float>>(
    107      /*filename=*/test::ResourcePath(
    108          "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
    109  const int num_chunks = CheckedDivExact(reader->size(), kBufSize24kHz);
    110  return {.chunk_size = kBufSize24kHz,
    111          .num_chunks = num_chunks,
    112          .reader = std::move(reader)};
    113 }
    114 
    115 ChunksFileReader CreateLpResidualAndPitchInfoReader() {
    116  constexpr int kPitchInfoSize = 2;  // Pitch period and strength.
    117  constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
    118  auto reader = std::make_unique<FloatFileReader<float>>(
    119      /*filename=*/test::ResourcePath(
    120          "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
    121  const int num_chunks = CheckedDivExact(reader->size(), kChunkSize);
    122  return {.chunk_size = kChunkSize,
    123          .num_chunks = num_chunks,
    124          .reader = std::move(reader)};
    125 }
    126 
    127 std::unique_ptr<FileReader> CreateGruInputReader() {
    128  return std::make_unique<FloatFileReader<float>>(
    129      /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
    130                                      "dat"));
    131 }
    132 
    133 std::unique_ptr<FileReader> CreateVadProbsReader() {
    134  return std::make_unique<FloatFileReader<float>>(
    135      /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
    136                                      "dat"));
    137 }
    138 
    139 PitchTestData::PitchTestData() {
    140  FloatFileReader<float> reader(
    141      /*filename=*/ResourcePath(
    142          "audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
    143  reader.ReadChunk(pitch_buffer_24k_);
    144  reader.ReadChunk(square_energies_24k_);
    145  reader.ReadChunk(auto_correlation_12k_);
    146  // Reverse the order of the squared energy values.
    147  // Required after the WebRTC CL 191703 which switched to forward computation.
    148  std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
    149 }
    150 
    151 PitchTestData::~PitchTestData() = default;
    152 
    153 }  // namespace rnn_vad
    154 }  // namespace webrtc