tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

av1_external_partition_test.cc (26231B)


      1 /*
      2 * Copyright (c) 2021, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <fstream>
     13 #include <new>
     14 #include <sstream>
     15 #include <string>
     16 
     17 #include "aom/aom_codec.h"
     18 #include "aom/aom_external_partition.h"
     19 #include "av1/common/blockd.h"
     20 #include "av1/encoder/encodeframe_utils.h"
     21 #include "gtest/gtest.h"
     22 #include "test/codec_factory.h"
     23 #include "test/encode_test_driver.h"
     24 #include "test/y4m_video_source.h"
     25 #include "test/util.h"
     26 
     27 #if CONFIG_AV1_ENCODER
     28 #if !CONFIG_REALTIME_ONLY
     29 namespace {
     30 
     31 constexpr int kFrameNum = 8;
     32 constexpr int kVersion = 1;
     33 
     34 struct TestData {
     35  int version = kVersion;
     36 };
     37 
     38 struct ToyModel {
     39  TestData *data;
     40  aom_ext_part_config_t config;
     41  aom_ext_part_funcs_t funcs;
     42  int mi_row;
     43  int mi_col;
     44  int frame_width;
     45  int frame_height;
     46  BLOCK_SIZE block_size;
     47 };
     48 
     49 // Note:
     50 // if CONFIG_PARTITION_SEARCH_ORDER = 0, we test APIs designed for the baseline
     51 // encoder's DFS partition search workflow.
     52 // if CONFIG_PARTITION_SEARCH_ORDER = 1, we test APIs designed for the new
     53 // ML model's partition search workflow.
     54 #if CONFIG_PARTITION_SEARCH_ORDER
     55 aom_ext_part_status_t ext_part_create_model(
     56    void *priv, const aom_ext_part_config_t *part_config,
     57    aom_ext_part_model_t *ext_part_model) {
     58  TestData *received_data = reinterpret_cast<TestData *>(priv);
     59  EXPECT_EQ(received_data->version, kVersion);
     60  ToyModel *toy_model = new (std::nothrow) ToyModel;
     61  if (toy_model == nullptr) {
     62    EXPECT_NE(toy_model, nullptr);
     63    return AOM_EXT_PART_ERROR;
     64  }
     65  toy_model->data = received_data;
     66  *ext_part_model = toy_model;
     67  EXPECT_EQ(part_config->superblock_size, BLOCK_64X64);
     68  return AOM_EXT_PART_OK;
     69 }
     70 
     71 aom_ext_part_status_t ext_part_send_features(
     72    aom_ext_part_model_t ext_part_model,
     73    const aom_partition_features_t *part_features) {
     74  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
     75  toy_model->mi_row = part_features->mi_row;
     76  toy_model->mi_col = part_features->mi_col;
     77  toy_model->frame_width = part_features->frame_width;
     78  toy_model->frame_height = part_features->frame_height;
     79  toy_model->block_size = static_cast<BLOCK_SIZE>(part_features->block_size);
     80  return AOM_EXT_PART_OK;
     81 }
     82 
     83 // The model provide the whole decision tree to the encoder.
     84 aom_ext_part_status_t ext_part_get_partition_decision_whole_tree(
     85    aom_ext_part_model_t ext_part_model,
     86    aom_partition_decision_t *ext_part_decision) {
     87  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
     88  // A toy model that always asks the encoder to encode with
     89  // 4x4 blocks (the smallest).
     90  ext_part_decision->is_final_decision = 1;
     91  // Note: super block size is fixed to BLOCK_64X64 for the
     92  // input video. It is determined inside the encoder, see the
     93  // check in "ext_part_create_model".
     94  const int is_last_sb_col =
     95      toy_model->mi_col * 4 + 64 > toy_model->frame_width;
     96  const int is_last_sb_row =
     97      toy_model->mi_row * 4 + 64 > toy_model->frame_height;
     98  if (is_last_sb_row && is_last_sb_col) {
     99    // 64x64: 1 node
    100    // 32x32: 4 nodes (only the first one will further split)
    101    // 16x16: 4 nodes
    102    // 8x8:   4 * 4 nodes
    103    // 4x4:   4 * 4 * 4 nodes
    104    const int num_blocks = 1 + 4 + 4 + 4 * 4 + 4 * 4 * 4;
    105    const int num_4x4_blocks = 4 * 4 * 4;
    106    ext_part_decision->num_nodes = num_blocks;
    107    // 64x64
    108    ext_part_decision->partition_decision[0] = PARTITION_SPLIT;
    109    // 32x32, only the first one will split, the other three are
    110    // out of frame boundary.
    111    ext_part_decision->partition_decision[1] = PARTITION_SPLIT;
    112    ext_part_decision->partition_decision[2] = PARTITION_NONE;
    113    ext_part_decision->partition_decision[3] = PARTITION_NONE;
    114    ext_part_decision->partition_decision[4] = PARTITION_NONE;
    115    // The rest blocks inside the top-left 32x32 block.
    116    for (int i = 5; i < num_blocks - num_4x4_blocks; ++i) {
    117      ext_part_decision->partition_decision[i] = PARTITION_SPLIT;
    118    }
    119    for (int i = num_blocks - num_4x4_blocks; i < num_blocks; ++i) {
    120      ext_part_decision->partition_decision[i] = PARTITION_NONE;
    121    }
    122  } else if (is_last_sb_row) {
    123    // 64x64: 1 node
    124    // 32x32: 4 nodes (only the first two will further split)
    125    // 16x16: 2 * 4 nodes
    126    // 8x8:   2 * 4 * 4 nodes
    127    // 4x4:   2 * 4 * 4 * 4 nodes
    128    const int num_blocks = 1 + 4 + 2 * 4 + 2 * 4 * 4 + 2 * 4 * 4 * 4;
    129    const int num_4x4_blocks = 2 * 4 * 4 * 4;
    130    ext_part_decision->num_nodes = num_blocks;
    131    // 64x64
    132    ext_part_decision->partition_decision[0] = PARTITION_SPLIT;
    133    // 32x32, only the first two will split, the other two are out
    134    // of frame boundary.
    135    ext_part_decision->partition_decision[1] = PARTITION_SPLIT;
    136    ext_part_decision->partition_decision[2] = PARTITION_SPLIT;
    137    ext_part_decision->partition_decision[3] = PARTITION_NONE;
    138    ext_part_decision->partition_decision[4] = PARTITION_NONE;
    139    // The rest blocks.
    140    for (int i = 5; i < num_blocks - num_4x4_blocks; ++i) {
    141      ext_part_decision->partition_decision[i] = PARTITION_SPLIT;
    142    }
    143    for (int i = num_blocks - num_4x4_blocks; i < num_blocks; ++i) {
    144      ext_part_decision->partition_decision[i] = PARTITION_NONE;
    145    }
    146  } else if (is_last_sb_col) {
    147    // 64x64: 1 node
    148    // 32x32: 4 nodes (only the top-left and bottom-left will further split)
    149    // 16x16: 2 * 4 nodes
    150    // 8x8:   2 * 4 * 4 nodes
    151    // 4x4:   2 * 4 * 4 * 4 nodes
    152    const int num_blocks = 1 + 4 + 2 * 4 + 2 * 4 * 4 + 2 * 4 * 4 * 4;
    153    const int num_4x4_blocks = 2 * 4 * 4 * 4;
    154    ext_part_decision->num_nodes = num_blocks;
    155    // 64x64
    156    ext_part_decision->partition_decision[0] = PARTITION_SPLIT;
    157    // 32x32, only the top-left and bottom-left will split, the other two are
    158    // out of frame boundary.
    159    ext_part_decision->partition_decision[1] = PARTITION_SPLIT;
    160    ext_part_decision->partition_decision[2] = PARTITION_NONE;
    161    ext_part_decision->partition_decision[3] = PARTITION_SPLIT;
    162    ext_part_decision->partition_decision[4] = PARTITION_NONE;
    163    // The rest blocks.
    164    for (int i = 5; i < num_blocks - num_4x4_blocks; ++i) {
    165      ext_part_decision->partition_decision[i] = PARTITION_SPLIT;
    166    }
    167    for (int i = num_blocks - num_4x4_blocks; i < num_blocks; ++i) {
    168      ext_part_decision->partition_decision[i] = PARTITION_NONE;
    169    }
    170  } else {
    171    // 64x64: 1 node
    172    // 32x32: 4 nodes
    173    // 16x16: 4 * 4 nodes
    174    // 8x8:   4 * 4 * 4 nodes
    175    // 4x4:   4 * 4 * 4 * 4 nodes
    176    const int num_blocks = 1 + 4 + 4 * 4 + 4 * 4 * 4 + 4 * 4 * 4 * 4;
    177    const int num_4x4_blocks = 4 * 4 * 4 * 4;
    178    ext_part_decision->num_nodes = num_blocks;
    179    for (int i = 0; i < num_blocks - num_4x4_blocks; ++i) {
    180      ext_part_decision->partition_decision[i] = PARTITION_SPLIT;
    181    }
    182    for (int i = num_blocks - num_4x4_blocks; i < num_blocks; ++i) {
    183      ext_part_decision->partition_decision[i] = PARTITION_NONE;
    184    }
    185  }
    186 
    187  return AOM_EXT_PART_OK;
    188 }
    189 
    190 aom_ext_part_status_t ext_part_get_partition_decision_recursive(
    191    aom_ext_part_model_t ext_part_model,
    192    aom_partition_decision_t *ext_part_decision) {
    193  ext_part_decision->current_decision = PARTITION_NONE;
    194  ext_part_decision->is_final_decision = 1;
    195  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
    196  // Note: super block size is fixed to BLOCK_64X64 for the
    197  // input video. It is determined inside the encoder, see the
    198  // check in "ext_part_create_model".
    199  const int is_last_sb_col =
    200      toy_model->mi_col * 4 + 64 > toy_model->frame_width;
    201  const int is_last_sb_row =
    202      toy_model->mi_row * 4 + 64 > toy_model->frame_height;
    203  if (is_last_sb_row && is_last_sb_col) {
    204    if (block_size_wide[toy_model->block_size] == 64) {
    205      ext_part_decision->current_decision = PARTITION_SPLIT;
    206    } else {
    207      ext_part_decision->current_decision = PARTITION_NONE;
    208    }
    209  } else if (is_last_sb_row) {
    210    if (block_size_wide[toy_model->block_size] == 64) {
    211      ext_part_decision->current_decision = PARTITION_SPLIT;
    212    } else {
    213      ext_part_decision->current_decision = PARTITION_NONE;
    214    }
    215  } else if (is_last_sb_col) {
    216    if (block_size_wide[toy_model->block_size] == 64) {
    217      ext_part_decision->current_decision = PARTITION_SPLIT;
    218    } else {
    219      ext_part_decision->current_decision = PARTITION_NONE;
    220    }
    221  } else {
    222    ext_part_decision->current_decision = PARTITION_NONE;
    223  }
    224  return AOM_EXT_PART_OK;
    225 }
    226 
    227 aom_ext_part_status_t ext_part_send_partition_stats(
    228    aom_ext_part_model_t ext_part_model,
    229    const aom_partition_stats_t *ext_part_stats) {
    230  (void)ext_part_model;
    231  (void)ext_part_stats;
    232  return AOM_EXT_PART_OK;
    233 }
    234 
    235 aom_ext_part_status_t ext_part_delete_model(
    236    aom_ext_part_model_t ext_part_model) {
    237  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
    238  EXPECT_EQ(toy_model->data->version, kVersion);
    239  delete toy_model;
    240  return AOM_EXT_PART_OK;
    241 }
    242 
    243 class ExternalPartitionTestAPI
    244    : public ::libaom_test::CodecTestWith2Params<libaom_test::TestMode, int>,
    245      public ::libaom_test::EncoderTest {
    246 protected:
    247  ExternalPartitionTestAPI()
    248      : EncoderTest(GET_PARAM(0)), encoding_mode_(GET_PARAM(1)),
    249        cpu_used_(GET_PARAM(2)), psnr_(0.0), nframes_(0) {}
    250  ~ExternalPartitionTestAPI() override {}
    251 
    252  void SetUp() override {
    253    InitializeConfig(encoding_mode_);
    254    const aom_rational timebase = { 1, 30 };
    255    cfg_.g_timebase = timebase;
    256    cfg_.rc_end_usage = AOM_VBR;
    257    cfg_.g_threads = 1;
    258    cfg_.g_lag_in_frames = 4;
    259    cfg_.rc_target_bitrate = 400;
    260    init_flags_ = AOM_CODEC_USE_PSNR;
    261  }
    262 
    263  bool DoDecode() const override { return false; }
    264 
    265  void BeginPassHook(unsigned int) override {
    266    psnr_ = 0.0;
    267    nframes_ = 0;
    268  }
    269 
    270  void PSNRPktHook(const aom_codec_cx_pkt_t *pkt) override {
    271    psnr_ += pkt->data.psnr.psnr[0];
    272    nframes_++;
    273  }
    274 
    275  double GetAveragePsnr() const {
    276    if (nframes_) return psnr_ / nframes_;
    277    return 0.0;
    278  }
    279 
    280  void SetExternalPartition(bool use_external_partition) {
    281    use_external_partition_ = use_external_partition;
    282  }
    283 
    284  void SetPartitionControlMode(int mode) { partition_control_mode_ = mode; }
    285 
    286  void SetDecisionMode(aom_ext_part_decision_mode_t mode) {
    287    decision_mode_ = mode;
    288  }
    289 
    290  void PreEncodeFrameHook(::libaom_test::VideoSource *video,
    291                          ::libaom_test::Encoder *encoder) override {
    292    if (video->frame() == 0) {
    293      if (decision_mode_ == AOM_EXT_PART_WHOLE_TREE) {
    294        aom_ext_part_funcs_t ext_part_funcs;
    295        ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
    296        ext_part_funcs.decision_mode = AOM_EXT_PART_WHOLE_TREE;
    297        ext_part_funcs.create_model = ext_part_create_model;
    298        ext_part_funcs.send_features = ext_part_send_features;
    299        ext_part_funcs.get_partition_decision =
    300            ext_part_get_partition_decision_whole_tree;
    301        ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
    302        ext_part_funcs.delete_model = ext_part_delete_model;
    303 
    304        encoder->Control(AOME_SET_CPUUSED, cpu_used_);
    305        encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
    306        if (use_external_partition_) {
    307          encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
    308        }
    309        if (partition_control_mode_ == -1) {
    310          encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 128);
    311          encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
    312        } else {
    313          switch (partition_control_mode_) {
    314            case 1:
    315              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 64);
    316              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 64);
    317              break;
    318            case 2:
    319              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 4);
    320              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
    321              break;
    322            default: assert(0 && "Invalid partition control mode."); break;
    323          }
    324        }
    325      } else if (decision_mode_ == AOM_EXT_PART_RECURSIVE) {
    326        aom_ext_part_funcs_t ext_part_funcs;
    327        ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
    328        ext_part_funcs.decision_mode = AOM_EXT_PART_RECURSIVE;
    329        ext_part_funcs.create_model = ext_part_create_model;
    330        ext_part_funcs.send_features = ext_part_send_features;
    331        ext_part_funcs.get_partition_decision =
    332            ext_part_get_partition_decision_recursive;
    333        ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
    334        ext_part_funcs.delete_model = ext_part_delete_model;
    335 
    336        encoder->Control(AOME_SET_CPUUSED, cpu_used_);
    337        encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
    338        if (use_external_partition_) {
    339          encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
    340        }
    341        if (partition_control_mode_ == -1) {
    342          encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 128);
    343          encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
    344        } else {
    345          switch (partition_control_mode_) {
    346            case 1:
    347              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 64);
    348              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 64);
    349              break;
    350            case 2:
    351              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 4);
    352              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
    353              break;
    354            default: assert(0 && "Invalid partition control mode."); break;
    355          }
    356        }
    357      } else {
    358        assert(0 && "Invalid decision mode.");
    359      }
    360    }
    361  }
    362 
    363 private:
    364  libaom_test::TestMode encoding_mode_;
    365  int cpu_used_;
    366  double psnr_;
    367  unsigned int nframes_;
    368  bool use_external_partition_ = false;
    369  TestData test_data_;
    370  int partition_control_mode_ = -1;
    371  aom_ext_part_decision_mode_t decision_mode_;
    372 };
    373 
    374 // Encode twice and expect the same psnr value.
    375 // The first run is a normal encoding run with restricted partition types,
    376 // i.e., we use control flags to force the encoder to encode with the
    377 // 4x4 block size.
    378 // The second run is to get partition decisions from a toy model that we
    379 // built, which will asks the encoder to encode with the 4x4 blocks.
    380 // We expect the encoding results are the same.
    381 TEST_P(ExternalPartitionTestAPI, WholePartitionTree4x4Block) {
    382  ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
    383  SetExternalPartition(false);
    384  SetPartitionControlMode(2);
    385  SetDecisionMode(AOM_EXT_PART_WHOLE_TREE);
    386  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    387  const double psnr = GetAveragePsnr();
    388 
    389  SetExternalPartition(true);
    390  SetPartitionControlMode(2);
    391  SetDecisionMode(AOM_EXT_PART_WHOLE_TREE);
    392  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    393  const double psnr2 = GetAveragePsnr();
    394 
    395  EXPECT_DOUBLE_EQ(psnr, psnr2);
    396 }
    397 
    398 TEST_P(ExternalPartitionTestAPI, RecursivePartition) {
    399  ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
    400  SetExternalPartition(false);
    401  SetPartitionControlMode(1);
    402  SetDecisionMode(AOM_EXT_PART_RECURSIVE);
    403  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    404  const double psnr = GetAveragePsnr();
    405 
    406  SetExternalPartition(true);
    407  SetPartitionControlMode(1);
    408  SetDecisionMode(AOM_EXT_PART_RECURSIVE);
    409  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    410  const double psnr2 = GetAveragePsnr();
    411 
    412  const double psnr_thresh = 0.02;
    413  EXPECT_NEAR(psnr, psnr2, psnr_thresh);
    414 }
    415 
    416 AV1_INSTANTIATE_TEST_SUITE(ExternalPartitionTestAPI,
    417                           ::testing::Values(::libaom_test::kTwoPassGood),
    418                           ::testing::Values(4));  // cpu_used
    419 
    420 #else   // !CONFIG_PARTITION_SEARCH_ORDER
    421 // Feature files written during encoding, as defined in partition_strategy.c.
    422 std::string feature_file_names[] = {
    423  "feature_before_partition_none",
    424  "feature_before_partition_none_prune_rect",
    425  "feature_after_partition_none_prune",
    426  "feature_after_partition_none_terminate",
    427  "feature_after_partition_split_terminate",
    428  "feature_after_partition_split_prune_rect",
    429  "feature_after_partition_rect",
    430  "feature_after_partition_ab",
    431 };
    432 
    433 // Files written here in the test, where the feature data is received
    434 // from the API.
    435 std::string test_feature_file_names[] = {
    436  "test_feature_before_partition_none",
    437  "test_feature_before_partition_none_prune_rect",
    438  "test_feature_after_partition_none_prune",
    439  "test_feature_after_partition_none_terminate",
    440  "test_feature_after_partition_split_terminate",
    441  "test_feature_after_partition_split_prune_rect",
    442  "test_feature_after_partition_rect",
    443  "test_feature_after_partition_ab",
    444 };
    445 
    446 static void write_features_to_file(const float *features,
    447                                   const int feature_size, const int id) {
    448  if (!WRITE_FEATURE_TO_FILE) return;
    449  char filename[256];
    450  snprintf(filename, sizeof(filename), "%s",
    451           test_feature_file_names[id].c_str());
    452  FILE *pfile = fopen(filename, "a");
    453  ASSERT_NE(pfile, nullptr);
    454  for (int i = 0; i < feature_size; ++i) {
    455    fprintf(pfile, "%.6f", features[i]);
    456    if (i < feature_size - 1) fprintf(pfile, ",");
    457  }
    458  fprintf(pfile, "\n");
    459  fclose(pfile);
    460 }
    461 
    462 aom_ext_part_status_t ext_part_create_model(
    463    void *priv, const aom_ext_part_config_t *part_config,
    464    aom_ext_part_model_t *ext_part_model) {
    465  TestData *received_data = reinterpret_cast<TestData *>(priv);
    466  EXPECT_EQ(received_data->version, kVersion);
    467  ToyModel *toy_model = new (std::nothrow) ToyModel;
    468  if (toy_model == nullptr) {
    469    EXPECT_NE(toy_model, nullptr);
    470    return AOM_EXT_PART_ERROR;
    471  }
    472  toy_model->data = received_data;
    473  *ext_part_model = toy_model;
    474  EXPECT_EQ(part_config->superblock_size, BLOCK_64X64);
    475  return AOM_EXT_PART_OK;
    476 }
    477 
    478 aom_ext_part_status_t ext_part_create_model_test(
    479    void *priv, const aom_ext_part_config_t *part_config,
    480    aom_ext_part_model_t *ext_part_model) {
    481  (void)priv;
    482  (void)ext_part_model;
    483  EXPECT_EQ(part_config->superblock_size, BLOCK_64X64);
    484  // Return status indicates it's a encoder test. It lets the encoder
    485  // set a flag and write partition features to text files.
    486  return AOM_EXT_PART_TEST;
    487 }
    488 
    489 aom_ext_part_status_t ext_part_send_features(
    490    aom_ext_part_model_t ext_part_model,
    491    const aom_partition_features_t *part_features) {
    492  (void)ext_part_model;
    493  (void)part_features;
    494  return AOM_EXT_PART_OK;
    495 }
    496 
    497 aom_ext_part_status_t ext_part_send_features_test(
    498    aom_ext_part_model_t ext_part_model,
    499    const aom_partition_features_t *part_features) {
    500  (void)ext_part_model;
    501  if (part_features->id == AOM_EXT_PART_FEATURE_BEFORE_NONE) {
    502    write_features_to_file(part_features->before_part_none.f,
    503                           AOM_EXT_PART_SIZE_DIRECT_SPLIT, 0);
    504  } else if (part_features->id == AOM_EXT_PART_FEATURE_BEFORE_NONE_PART2) {
    505    write_features_to_file(part_features->before_part_none.f_part2,
    506                           AOM_EXT_PART_SIZE_PRUNE_PART, 1);
    507  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_NONE) {
    508    write_features_to_file(part_features->after_part_none.f,
    509                           AOM_EXT_PART_SIZE_PRUNE_NONE, 2);
    510  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_NONE_PART2) {
    511    write_features_to_file(part_features->after_part_none.f_terminate,
    512                           AOM_EXT_PART_SIZE_TERM_NONE, 3);
    513  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_SPLIT) {
    514    write_features_to_file(part_features->after_part_split.f_terminate,
    515                           AOM_EXT_PART_SIZE_TERM_SPLIT, 4);
    516  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_SPLIT_PART2) {
    517    write_features_to_file(part_features->after_part_split.f_prune_rect,
    518                           AOM_EXT_PART_SIZE_PRUNE_RECT, 5);
    519  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_RECT) {
    520    write_features_to_file(part_features->after_part_rect.f,
    521                           AOM_EXT_PART_SIZE_PRUNE_AB, 6);
    522  } else if (part_features->id == AOM_EXT_PART_FEATURE_AFTER_AB) {
    523    write_features_to_file(part_features->after_part_ab.f,
    524                           AOM_EXT_PART_SIZE_PRUNE_4_WAY, 7);
    525  }
    526  return AOM_EXT_PART_TEST;
    527 }
    528 
    529 aom_ext_part_status_t ext_part_get_partition_decision(
    530    aom_ext_part_model_t ext_part_model,
    531    aom_partition_decision_t *ext_part_decision) {
    532  (void)ext_part_model;
    533  (void)ext_part_decision;
    534  // Return an invalid decision such that the encoder doesn't take any
    535  // partition decision from the ml model.
    536  return AOM_EXT_PART_ERROR;
    537 }
    538 
    539 aom_ext_part_status_t ext_part_send_partition_stats(
    540    aom_ext_part_model_t ext_part_model,
    541    const aom_partition_stats_t *ext_part_stats) {
    542  (void)ext_part_model;
    543  (void)ext_part_stats;
    544  return AOM_EXT_PART_OK;
    545 }
    546 
    547 aom_ext_part_status_t ext_part_delete_model(
    548    aom_ext_part_model_t ext_part_model) {
    549  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
    550  EXPECT_EQ(toy_model->data->version, kVersion);
    551  delete toy_model;
    552  return AOM_EXT_PART_OK;
    553 }
    554 
    555 class ExternalPartitionTestDfsAPI
    556    : public ::libaom_test::CodecTestWith2Params<libaom_test::TestMode, int>,
    557      public ::libaom_test::EncoderTest {
    558 protected:
    559  ExternalPartitionTestDfsAPI()
    560      : EncoderTest(GET_PARAM(0)), encoding_mode_(GET_PARAM(1)),
    561        cpu_used_(GET_PARAM(2)), psnr_(0.0), nframes_(0) {}
    562  ~ExternalPartitionTestDfsAPI() override = default;
    563 
    564  void SetUp() override {
    565    InitializeConfig(encoding_mode_);
    566    const aom_rational timebase = { 1, 30 };
    567    cfg_.g_timebase = timebase;
    568    cfg_.rc_end_usage = AOM_VBR;
    569    cfg_.g_threads = 1;
    570    cfg_.g_lag_in_frames = 4;
    571    cfg_.rc_target_bitrate = 400;
    572    init_flags_ = AOM_CODEC_USE_PSNR;
    573  }
    574 
    575  bool DoDecode() const override { return false; }
    576 
    577  void BeginPassHook(unsigned int) override {
    578    psnr_ = 0.0;
    579    nframes_ = 0;
    580  }
    581 
    582  void PSNRPktHook(const aom_codec_cx_pkt_t *pkt) override {
    583    psnr_ += pkt->data.psnr.psnr[0];
    584    nframes_++;
    585  }
    586 
    587  double GetAveragePsnr() const {
    588    if (nframes_) return psnr_ / nframes_;
    589    return 0.0;
    590  }
    591 
    592  void SetExternalPartition(bool use_external_partition) {
    593    use_external_partition_ = use_external_partition;
    594  }
    595 
    596  void SetTestSendFeatures(int test_send_features) {
    597    test_send_features_ = test_send_features;
    598  }
    599 
    600  void PreEncodeFrameHook(::libaom_test::VideoSource *video,
    601                          ::libaom_test::Encoder *encoder) override {
    602    if (video->frame() == 0) {
    603      aom_ext_part_funcs_t ext_part_funcs;
    604      ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
    605      if (use_external_partition_) {
    606        ext_part_funcs.create_model = ext_part_create_model;
    607        ext_part_funcs.send_features = ext_part_send_features;
    608      }
    609      if (test_send_features_ == 1) {
    610        ext_part_funcs.create_model = ext_part_create_model;
    611        ext_part_funcs.send_features = ext_part_send_features_test;
    612      } else if (test_send_features_ == 0) {
    613        ext_part_funcs.create_model = ext_part_create_model_test;
    614        ext_part_funcs.send_features = ext_part_send_features;
    615      }
    616      ext_part_funcs.get_partition_decision = ext_part_get_partition_decision;
    617      ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
    618      ext_part_funcs.delete_model = ext_part_delete_model;
    619 
    620      encoder->Control(AOME_SET_CPUUSED, cpu_used_);
    621      encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
    622      if (use_external_partition_) {
    623        encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
    624      }
    625    }
    626  }
    627 
    628 private:
    629  libaom_test::TestMode encoding_mode_;
    630  int cpu_used_;
    631  double psnr_;
    632  unsigned int nframes_;
    633  bool use_external_partition_ = false;
    634  int test_send_features_ = -1;
    635  TestData test_data_;
    636 };
    637 
    638 // Encode twice and expect the same psnr value.
    639 // The first run is the baseline without external partition.
    640 // The second run is to get partition decisions from the toy model we defined.
    641 // Here, we let the partition decision return invalid for all stages.
    642 // In this case, the external partition doesn't alter the original encoder
    643 // behavior. So we expect the same encoding results.
    644 TEST_P(ExternalPartitionTestDfsAPI, EncodeMatch) {
    645  ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
    646  SetExternalPartition(false);
    647  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    648  const double psnr = GetAveragePsnr();
    649 
    650  SetExternalPartition(true);
    651  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    652  const double psnr2 = GetAveragePsnr();
    653 
    654  EXPECT_DOUBLE_EQ(psnr, psnr2);
    655 }
    656 
    657 // Encode twice to compare generated feature files.
    658 // The first run let the encoder write partition features to file.
    659 // The second run calls send partition features function to send features to
    660 // the external model, and we write them to file.
    661 // The generated files should match each other.
    662 TEST_P(ExternalPartitionTestDfsAPI, SendFeatures) {
    663  ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
    664  SetExternalPartition(true);
    665  SetTestSendFeatures(0);
    666  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    667 
    668  SetExternalPartition(true);
    669  SetTestSendFeatures(1);
    670  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
    671  if (!WRITE_FEATURE_TO_FILE) return;
    672 
    673  // Compare feature files by reading them into strings.
    674  for (int i = 0; i < 8; ++i) {
    675    std::ifstream base_file(feature_file_names[i]);
    676    ASSERT_TRUE(base_file.good());
    677    std::stringstream base_stream;
    678    base_stream << base_file.rdbuf();
    679    std::string base_string = base_stream.str();
    680 
    681    std::ifstream test_file(test_feature_file_names[i]);
    682    ASSERT_TRUE(test_file.good());
    683    std::stringstream test_stream;
    684    test_stream << test_file.rdbuf();
    685    std::string test_string = test_stream.str();
    686 
    687    EXPECT_STREQ(base_string.c_str(), test_string.c_str());
    688  }
    689 
    690  // Remove files.
    691  std::string command("rm -f feature_* test_feature_*");
    692  system(command.c_str());
    693 }
    694 
    695 AV1_INSTANTIATE_TEST_SUITE(ExternalPartitionTestDfsAPI,
    696                           ::testing::Values(::libaom_test::kTwoPassGood),
    697                           ::testing::Values(4));  // cpu_used
    698 #endif  // CONFIG_PARTITION_SEARCH_ORDER
    699 
    700 }  // namespace
    701 #endif  // !CONFIG_REALTIME_ONLY
    702 #endif  // CONFIG_AV1_ENCODER