time_stretch.cc (8972B)
1 /* 2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_coding/neteq/time_stretch.h" 12 13 #include <algorithm> // min, max 14 #include <cstddef> 15 #include <cstdint> 16 #include <memory> 17 18 #include "common_audio/signal_processing/dot_product_with_scale.h" 19 #include "common_audio/signal_processing/include/signal_processing_library.h" 20 #include "common_audio/signal_processing/include/spl_inl.h" 21 #include "modules/audio_coding/neteq/background_noise.h" 22 #include "modules/audio_coding/neteq/cross_correlation.h" 23 #include "modules/audio_coding/neteq/dsp_helper.h" 24 #include "rtc_base/checks.h" 25 #include "rtc_base/numerics/safe_conversions.h" 26 27 namespace webrtc { 28 29 TimeStretch::ReturnCodes TimeStretch::Process(const int16_t* input, 30 size_t input_len, 31 bool fast_mode, 32 AudioMultiVector* output, 33 size_t* length_change_samples) { 34 // Pre-calculate common multiplication with `fs_mult_`. 35 size_t fs_mult_120 = 36 static_cast<size_t>(fs_mult_ * 120); // Corresponds to 15 ms. 37 38 const int16_t* signal; 39 std::unique_ptr<int16_t[]> signal_array; 40 size_t signal_len; 41 if (num_channels_ == 1) { 42 signal = input; 43 signal_len = input_len; 44 } else { 45 // We want `signal` to be only the first channel of `input`, which is 46 // interleaved. Thus, we take the first sample, skip forward `num_channels` 47 // samples, and continue like that. 48 signal_len = input_len / num_channels_; 49 signal_array.reset(new int16_t[signal_len]); 50 signal = signal_array.get(); 51 size_t j = kRefChannel; 52 for (size_t i = 0; i < signal_len; ++i) { 53 signal_array[i] = input[j]; 54 j += num_channels_; 55 } 56 } 57 58 // Find maximum absolute value of input signal. 59 max_input_value_ = WebRtcSpl_MaxAbsValueW16(signal, signal_len); 60 61 // Downsample to 4 kHz sample rate and calculate auto-correlation. 62 DspHelper::DownsampleTo4kHz(signal, signal_len, kDownsampledLen, 63 sample_rate_hz_, true /* compensate delay*/, 64 downsampled_input_); 65 AutoCorrelation(); 66 67 // Find the strongest correlation peak. 68 static const size_t kNumPeaks = 1; 69 size_t peak_index; 70 int16_t peak_value; 71 DspHelper::PeakDetection(auto_correlation_, kCorrelationLen, kNumPeaks, 72 fs_mult_, &peak_index, &peak_value); 73 // Assert that `peak_index` stays within boundaries. 74 RTC_DCHECK_LE(peak_index, (2 * kCorrelationLen - 1) * fs_mult_); 75 76 // Compensate peak_index for displaced starting position. The displacement 77 // happens in AutoCorrelation(). Here, `kMinLag` is in the down-sampled 4 kHz 78 // domain, while the `peak_index` is in the original sample rate; hence, the 79 // multiplication by fs_mult_ * 2. 80 peak_index += kMinLag * fs_mult_ * 2; 81 // Assert that `peak_index` stays within boundaries. 82 RTC_DCHECK_GE(peak_index, static_cast<size_t>(20 * fs_mult_)); 83 RTC_DCHECK_LE(peak_index, 84 20 * fs_mult_ + (2 * kCorrelationLen - 1) * fs_mult_); 85 86 // Calculate scaling to ensure that `peak_index` samples can be square-summed 87 // without overflowing. 88 int scaling = 31 - WebRtcSpl_NormW32(max_input_value_ * max_input_value_) - 89 WebRtcSpl_NormW32(static_cast<int32_t>(peak_index)); 90 scaling = std::max(0, scaling); 91 92 // `vec1` starts at 15 ms minus one pitch period. 93 const int16_t* vec1 = &signal[fs_mult_120 - peak_index]; 94 // `vec2` start at 15 ms. 95 const int16_t* vec2 = &signal[fs_mult_120]; 96 // Calculate energies for `vec1` and `vec2`, assuming they both contain 97 // `peak_index` samples. 98 int32_t vec1_energy = 99 WebRtcSpl_DotProductWithScale(vec1, vec1, peak_index, scaling); 100 int32_t vec2_energy = 101 WebRtcSpl_DotProductWithScale(vec2, vec2, peak_index, scaling); 102 103 // Calculate cross-correlation between `vec1` and `vec2`. 104 int32_t cross_corr = 105 WebRtcSpl_DotProductWithScale(vec1, vec2, peak_index, scaling); 106 107 // Check if the signal seems to be active speech or not (simple VAD). 108 bool active_speech = 109 SpeechDetection(vec1_energy, vec2_energy, peak_index, scaling); 110 111 int16_t best_correlation; 112 if (!active_speech) { 113 SetParametersForPassiveSpeech(signal_len, &best_correlation, &peak_index); 114 } else { 115 // Calculate correlation: 116 // cross_corr / sqrt(vec1_energy * vec2_energy). 117 118 // Start with calculating scale values. 119 int energy1_scale = std::max(0, 16 - WebRtcSpl_NormW32(vec1_energy)); 120 int energy2_scale = std::max(0, 16 - WebRtcSpl_NormW32(vec2_energy)); 121 122 // Make sure total scaling is even (to simplify scale factor after sqrt). 123 if ((energy1_scale + energy2_scale) & 1) { 124 // The sum is odd. 125 energy1_scale += 1; 126 } 127 128 // Scale energies to int16_t. 129 int16_t vec1_energy_int16 = 130 static_cast<int16_t>(vec1_energy >> energy1_scale); 131 int16_t vec2_energy_int16 = 132 static_cast<int16_t>(vec2_energy >> energy2_scale); 133 134 // Calculate square-root of energy product. 135 int16_t sqrt_energy_prod = 136 WebRtcSpl_SqrtFloor(vec1_energy_int16 * vec2_energy_int16); 137 138 // Calculate cross_corr / sqrt(en1*en2) in Q14. 139 int temp_scale = 14 - (energy1_scale + energy2_scale) / 2; 140 cross_corr = WEBRTC_SPL_SHIFT_W32(cross_corr, temp_scale); 141 cross_corr = std::max(0, cross_corr); // Don't use if negative. 142 best_correlation = WebRtcSpl_DivW32W16(cross_corr, sqrt_energy_prod); 143 // Make sure `best_correlation` is no larger than 1 in Q14. 144 best_correlation = std::min(static_cast<int16_t>(16384), best_correlation); 145 } 146 147 // Check accelerate criteria and stretch the signal. 148 ReturnCodes return_value = 149 CheckCriteriaAndStretch(input, input_len, peak_index, best_correlation, 150 active_speech, fast_mode, output); 151 switch (return_value) { 152 case kSuccess: 153 *length_change_samples = peak_index; 154 break; 155 case kSuccessLowEnergy: 156 *length_change_samples = peak_index; 157 break; 158 case kNoStretch: 159 case kError: 160 *length_change_samples = 0; 161 break; 162 } 163 return return_value; 164 } 165 166 void TimeStretch::AutoCorrelation() { 167 // Calculate correlation from lag kMinLag to lag kMaxLag in 4 kHz domain. 168 int32_t auto_corr[kCorrelationLen]; 169 CrossCorrelationWithAutoShift( 170 &downsampled_input_[kMaxLag], &downsampled_input_[kMaxLag - kMinLag], 171 kCorrelationLen, kMaxLag - kMinLag, -1, auto_corr); 172 173 // Normalize correlation to 14 bits and write to `auto_correlation_`. 174 int32_t max_corr = WebRtcSpl_MaxAbsValueW32(auto_corr, kCorrelationLen); 175 int scaling = std::max(0, 17 - WebRtcSpl_NormW32(max_corr)); 176 WebRtcSpl_VectorBitShiftW32ToW16(auto_correlation_, kCorrelationLen, 177 auto_corr, scaling); 178 } 179 180 bool TimeStretch::SpeechDetection(int32_t vec1_energy, 181 int32_t vec2_energy, 182 size_t peak_index, 183 int scaling) const { 184 // Check if the signal seems to be active speech or not (simple VAD). 185 // If (vec1_energy + vec2_energy) / (2 * peak_index) <= 186 // 8 * background_noise_energy, then we say that the signal contains no 187 // active speech. 188 // Rewrite the inequality as: 189 // (vec1_energy + vec2_energy) / 16 <= peak_index * background_noise_energy. 190 // The two sides of the inequality will be denoted `left_side` and 191 // `right_side`. 192 int32_t left_side = saturated_cast<int32_t>( 193 (static_cast<int64_t>(vec1_energy) + vec2_energy) / 16); 194 int32_t right_side; 195 if (background_noise_.initialized()) { 196 right_side = background_noise_.Energy(kRefChannel); 197 } else { 198 // If noise parameters have not been estimated, use a fixed threshold. 199 right_side = 75000; 200 } 201 int right_scale = 16 - WebRtcSpl_NormW32(right_side); 202 right_scale = std::max(0, right_scale); 203 left_side = left_side >> right_scale; 204 right_side = dchecked_cast<int32_t>(peak_index) * (right_side >> right_scale); 205 206 // Scale `left_side` properly before comparing with `right_side`. 207 // (`scaling` is the scale factor before energy calculation, thus the scale 208 // factor for the energy is 2 * scaling.) 209 if (WebRtcSpl_NormW32(left_side) < 2 * scaling) { 210 // Cannot scale only `left_side`, must scale `right_side` too. 211 int temp_scale = WebRtcSpl_NormW32(left_side); 212 left_side = left_side << temp_scale; 213 right_side = right_side >> (2 * scaling - temp_scale); 214 } else { 215 left_side = left_side << 2 * scaling; 216 } 217 return left_side > right_side; 218 } 219 220 } // namespace webrtc