test_utils.cc (5425B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" 12 13 #include <algorithm> 14 #include <cstdint> 15 #include <fstream> 16 #include <ios> 17 #include <memory> 18 #include <string> 19 #include <type_traits> 20 #include <utility> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "api/array_view.h" 25 #include "modules/audio_processing/agc2/rnn_vad/common.h" 26 #include "rtc_base/checks.h" 27 #include "rtc_base/numerics/safe_compare.h" 28 #include "test/gtest.h" 29 #include "test/testsupport/file_utils.h" 30 31 namespace webrtc { 32 namespace rnn_vad { 33 namespace { 34 35 // File reader for binary files that contain a sequence of values with 36 // arithmetic type `T`. The values of type `T` that are read are cast to float. 37 template <typename T> 38 class FloatFileReader : public FileReader { 39 public: 40 static_assert(std::is_arithmetic<T>::value, ""); 41 explicit FloatFileReader(absl::string_view filename) 42 : is_(std::string(filename), std::ios::binary | std::ios::ate), 43 size_(is_.tellg() / sizeof(T)) { 44 RTC_CHECK(is_); 45 SeekBeginning(); 46 } 47 FloatFileReader(const FloatFileReader&) = delete; 48 FloatFileReader& operator=(const FloatFileReader&) = delete; 49 ~FloatFileReader() override = default; 50 51 int size() const override { return size_; } 52 bool ReadChunk(ArrayView<float> dst) override { 53 const std::streamsize bytes_to_read = dst.size() * sizeof(T); 54 if (std::is_same<T, float>::value) { 55 is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read); 56 } else { 57 buffer_.resize(dst.size()); 58 is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read); 59 std::transform(buffer_.begin(), buffer_.end(), dst.begin(), 60 [](const T& v) -> float { return static_cast<float>(v); }); 61 } 62 return is_.gcount() == bytes_to_read; 63 } 64 bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); } 65 void SeekForward(int hop) override { 66 is_.seekg(hop * sizeof(T), std::ifstream::cur); 67 } 68 void SeekBeginning() override { is_.seekg(0, std::ifstream::beg); } 69 70 private: 71 std::ifstream is_; 72 const int size_; 73 std::vector<T> buffer_; 74 }; 75 76 } // namespace 77 78 using test::ResourcePath; 79 80 void ExpectEqualFloatArray(ArrayView<const float> expected, 81 ArrayView<const float> computed) { 82 ASSERT_EQ(expected.size(), computed.size()); 83 for (int i = 0; SafeLt(i, expected.size()); ++i) { 84 SCOPED_TRACE(i); 85 EXPECT_FLOAT_EQ(expected[i], computed[i]); 86 } 87 } 88 89 void ExpectNearAbsolute(ArrayView<const float> expected, 90 ArrayView<const float> computed, 91 float tolerance) { 92 ASSERT_EQ(expected.size(), computed.size()); 93 for (int i = 0; SafeLt(i, expected.size()); ++i) { 94 SCOPED_TRACE(i); 95 EXPECT_NEAR(expected[i], computed[i], tolerance); 96 } 97 } 98 99 std::unique_ptr<FileReader> CreatePcmSamplesReader() { 100 return std::make_unique<FloatFileReader<int16_t>>( 101 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples", 102 "pcm")); 103 } 104 105 ChunksFileReader CreatePitchBuffer24kHzReader() { 106 auto reader = std::make_unique<FloatFileReader<float>>( 107 /*filename=*/test::ResourcePath( 108 "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat")); 109 const int num_chunks = CheckedDivExact(reader->size(), kBufSize24kHz); 110 return {.chunk_size = kBufSize24kHz, 111 .num_chunks = num_chunks, 112 .reader = std::move(reader)}; 113 } 114 115 ChunksFileReader CreateLpResidualAndPitchInfoReader() { 116 constexpr int kPitchInfoSize = 2; // Pitch period and strength. 117 constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize; 118 auto reader = std::make_unique<FloatFileReader<float>>( 119 /*filename=*/test::ResourcePath( 120 "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat")); 121 const int num_chunks = CheckedDivExact(reader->size(), kChunkSize); 122 return {.chunk_size = kChunkSize, 123 .num_chunks = num_chunks, 124 .reader = std::move(reader)}; 125 } 126 127 std::unique_ptr<FileReader> CreateGruInputReader() { 128 return std::make_unique<FloatFileReader<float>>( 129 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in", 130 "dat")); 131 } 132 133 std::unique_ptr<FileReader> CreateVadProbsReader() { 134 return std::make_unique<FloatFileReader<float>>( 135 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", 136 "dat")); 137 } 138 139 PitchTestData::PitchTestData() { 140 FloatFileReader<float> reader( 141 /*filename=*/ResourcePath( 142 "audio_processing/agc2/rnn_vad/pitch_search_int", "dat")); 143 reader.ReadChunk(pitch_buffer_24k_); 144 reader.ReadChunk(square_energies_24k_); 145 reader.ReadChunk(auto_correlation_12k_); 146 // Reverse the order of the squared energy values. 147 // Required after the WebRTC CL 191703 which switched to forward computation. 148 std::reverse(square_energies_24k_.begin(), square_energies_24k_.end()); 149 } 150 151 PitchTestData::~PitchTestData() = default; 152 153 } // namespace rnn_vad 154 } // namespace webrtc