av1_softmax_test.cc (4023B)
1 /* 2 * Copyright (c) 2021, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <memory> 13 #include <new> 14 #include <tuple> 15 16 #include "aom/aom_integer.h" 17 #include "aom_ports/aom_timer.h" 18 #include "av1/encoder/ml.h" 19 #include "config/aom_config.h" 20 #include "config/aom_dsp_rtcd.h" 21 #include "config/av1_rtcd.h" 22 #include "gtest/gtest.h" 23 #include "test/acm_random.h" 24 #include "test/register_state_check.h" 25 #include "test/util.h" 26 27 namespace { 28 using FastSoftmaxFn = void (*)(const float *const input, float *output); 29 using FastSoftmaxTestParams = std::tuple<const FastSoftmaxFn, int>; 30 31 // Error thresholds for functional equivalence 32 constexpr float kRelEpsilon = 5e-2f; 33 constexpr float kAbsEpsilon = 5e-3f; 34 35 class FastSoftmaxTest : public ::testing::TestWithParam<FastSoftmaxTestParams> { 36 public: 37 FastSoftmaxTest() : target_fn_(GET_PARAM(0)), num_classes_(GET_PARAM(1)) {} 38 void SetUp() override { 39 ref_buf_.reset(new (std::nothrow) float[num_classes_]()); 40 ASSERT_NE(ref_buf_, nullptr); 41 dst_buf_.reset(new (std::nothrow) float[num_classes_]()); 42 ASSERT_NE(dst_buf_, nullptr); 43 input_.reset(new (std::nothrow) float[num_classes_]()); 44 ASSERT_NE(input_, nullptr); 45 } 46 void RunSoftmaxTest(); 47 void RunSoftmaxSpeedTest(const int run_times); 48 void FillInputBuf(); 49 50 private: 51 const FastSoftmaxFn target_fn_; 52 const int num_classes_; 53 std::unique_ptr<float[]> ref_buf_, dst_buf_, input_; 54 libaom_test::ACMRandom rng_; 55 }; 56 57 void FastSoftmaxTest::FillInputBuf() { 58 for (int idx = 0; idx < num_classes_; idx++) { 59 input_[idx] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 30); 60 } 61 } 62 63 void FastSoftmaxTest::RunSoftmaxTest() { 64 av1_nn_softmax(input_.get(), ref_buf_.get(), num_classes_); 65 target_fn_(input_.get(), dst_buf_.get()); 66 67 for (int idx = 0; idx < num_classes_; idx++) { 68 if (ref_buf_[idx] < kAbsEpsilon) { 69 ASSERT_LE(dst_buf_[idx], kAbsEpsilon) 70 << "Reference output was near-zero, test output was not" << std::endl; 71 } else { 72 const float error = dst_buf_[idx] - ref_buf_[idx]; 73 const float relative_error = fabsf(error / ref_buf_[idx]); 74 ASSERT_LE(relative_error, kRelEpsilon) 75 << "Excessive relative error between reference and test output" 76 << std::endl; 77 ASSERT_LE(error, kAbsEpsilon) 78 << "Excessive absolute error between reference and test output" 79 << std::endl; 80 } 81 } 82 } 83 84 void FastSoftmaxTest::RunSoftmaxSpeedTest(const int run_times) { 85 aom_usec_timer timer; 86 aom_usec_timer_start(&timer); 87 for (int idx = 0; idx < run_times; idx++) { 88 target_fn_(input_.get(), dst_buf_.get()); 89 } 90 aom_usec_timer_mark(&timer); 91 const int64_t time = aom_usec_timer_elapsed(&timer); 92 std::cout << "Test with " << num_classes_ << " classes took " << time 93 << " us." << std::endl; 94 } 95 96 TEST_P(FastSoftmaxTest, RandomValues) { 97 FillInputBuf(); 98 RunSoftmaxTest(); 99 } 100 101 TEST_P(FastSoftmaxTest, DISABLED_Speed) { 102 constexpr int kNumTimes = 1000000; 103 RunSoftmaxSpeedTest(kNumTimes); 104 } 105 106 void AnchorSoftmax16Fn(const float *input, float *output) { 107 av1_nn_softmax(input, output, 16); 108 } 109 110 const FastSoftmaxTestParams kArrayParams_c[] = { 111 FastSoftmaxTestParams(AnchorSoftmax16Fn, 16), 112 FastSoftmaxTestParams(av1_nn_fast_softmax_16_c, 16) 113 }; 114 INSTANTIATE_TEST_SUITE_P(C, FastSoftmaxTest, 115 ::testing::ValuesIn(kArrayParams_c)); 116 117 #if HAVE_SSE3 && !CONFIG_EXCLUDE_SIMD_MISMATCH 118 INSTANTIATE_TEST_SUITE_P( 119 SSE3, FastSoftmaxTest, 120 ::testing::Values(FastSoftmaxTestParams(av1_nn_fast_softmax_16_sse3, 16))); 121 #endif 122 } // namespace