tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tls_connect.cc (38269B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "tls_connect.h"
      8 #include "sslexp.h"
      9 extern "C" {
     10 #include "libssl_internals.h"
     11 }
     12 
     13 #include <iostream>
     14 
     15 #include "databuffer.h"
     16 #include "gtest_utils.h"
     17 #include "nss_scoped_ptrs.h"
     18 #include "sslproto.h"
     19 
     20 extern std::string g_working_dir_path;
     21 
     22 namespace nss_test {
     23 
     24 static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
     25 ::testing::internal::ParamGenerator<SSLProtocolVariant>
     26    TlsConnectTestBase::kTlsVariantsStream =
     27        ::testing::ValuesIn(kTlsVariantsStreamArr);
     28 static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
     29    ssl_variant_datagram};
     30 ::testing::internal::ParamGenerator<SSLProtocolVariant>
     31    TlsConnectTestBase::kTlsVariantsDatagram =
     32        ::testing::ValuesIn(kTlsVariantsDatagramArr);
     33 static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
     34                                                        ssl_variant_datagram};
     35 ::testing::internal::ParamGenerator<SSLProtocolVariant>
     36    TlsConnectTestBase::kTlsVariantsAll =
     37        ::testing::ValuesIn(kTlsVariantsAllArr);
     38 
     39 static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
     40 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
     41    ::testing::ValuesIn(kTlsV10Arr);
     42 static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
     43 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
     44    ::testing::ValuesIn(kTlsV11Arr);
     45 static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
     46 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
     47    ::testing::ValuesIn(kTlsV12Arr);
     48 static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
     49                                         SSL_LIBRARY_VERSION_TLS_1_1};
     50 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
     51    ::testing::ValuesIn(kTlsV10V11Arr);
     52 static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
     53                                           SSL_LIBRARY_VERSION_TLS_1_1,
     54                                           SSL_LIBRARY_VERSION_TLS_1_2};
     55 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
     56    ::testing::ValuesIn(kTlsV10ToV12Arr);
     57 static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
     58                                         SSL_LIBRARY_VERSION_TLS_1_2};
     59 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
     60    ::testing::ValuesIn(kTlsV11V12Arr);
     61 
     62 static const uint16_t kTlsV11PlusArr[] = {
     63 #ifndef NSS_DISABLE_TLS_1_3
     64    SSL_LIBRARY_VERSION_TLS_1_3,
     65 #endif
     66    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
     67 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
     68    ::testing::ValuesIn(kTlsV11PlusArr);
     69 static const uint16_t kTlsV12PlusArr[] = {
     70 #ifndef NSS_DISABLE_TLS_1_3
     71    SSL_LIBRARY_VERSION_TLS_1_3,
     72 #endif
     73    SSL_LIBRARY_VERSION_TLS_1_2};
     74 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
     75    ::testing::ValuesIn(kTlsV12PlusArr);
     76 static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
     77 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
     78    ::testing::ValuesIn(kTlsV13Arr);
     79 static const uint16_t kTlsVAllArr[] = {
     80 #ifndef NSS_DISABLE_TLS_1_3
     81    SSL_LIBRARY_VERSION_TLS_1_3,
     82 #endif
     83    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
     84    SSL_LIBRARY_VERSION_TLS_1_0};
     85 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
     86    ::testing::ValuesIn(kTlsVAllArr);
     87 
     88 std::string VersionString(uint16_t version) {
     89  switch (version) {
     90    case 0:
     91      return "(no version)";
     92    case SSL_LIBRARY_VERSION_3_0:
     93      return "1.0";
     94    case SSL_LIBRARY_VERSION_TLS_1_0:
     95      return "1.0";
     96    case SSL_LIBRARY_VERSION_TLS_1_1:
     97      return "1.1";
     98    case SSL_LIBRARY_VERSION_TLS_1_2:
     99      return "1.2";
    100    case SSL_LIBRARY_VERSION_TLS_1_3:
    101      return "1.3";
    102    default:
    103      std::cerr << "Invalid version: " << version << std::endl;
    104      EXPECT_TRUE(false);
    105      return "";
    106  }
    107 }
    108 
    109 // The default anti-replay window for tests.  Tests that rely on a different
    110 // value call ResetAntiReplay directly.
    111 static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC;
    112 
    113 TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
    114                                       uint16_t version)
    115    : variant_(variant),
    116      client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
    117      server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
    118      client_model_(nullptr),
    119      server_model_(nullptr),
    120      version_(version),
    121      expected_resumption_mode_(RESUME_NONE),
    122      expected_resumptions_(0),
    123      session_ids_(),
    124      expect_extended_master_secret_(false),
    125      expect_early_data_accepted_(false),
    126      skip_version_checks_(false) {
    127  std::string v;
    128  if (variant_ == ssl_variant_datagram &&
    129      version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
    130    v = "1.0";
    131  } else {
    132    v = VersionString(version_);
    133  }
    134  std::cerr << "Version: " << variant_ << " " << v << std::endl;
    135 }
    136 
    137 TlsConnectTestBase::~TlsConnectTestBase() {}
    138 
    139 // Check the group of each of the supported groups
    140 void TlsConnectTestBase::CheckGroups(
    141    const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
    142  DuplicateGroupChecker group_set;
    143  uint32_t tmp = 0;
    144  EXPECT_TRUE(groups.Read(0, 2, &tmp));
    145  EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
    146  for (size_t i = 2; i < groups.len(); i += 2) {
    147    EXPECT_TRUE(groups.Read(i, 2, &tmp));
    148    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    149    group_set.AddAndCheckGroup(group);
    150    check_group(group);
    151  }
    152 }
    153 
    154 // Check the group of each of the shares
    155 void TlsConnectTestBase::CheckShares(
    156    const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
    157  DuplicateGroupChecker group_set;
    158  uint32_t tmp = 0;
    159  EXPECT_TRUE(shares.Read(0, 2, &tmp));
    160  EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
    161  size_t i;
    162  for (i = 2; i < shares.len(); i += 4 + tmp) {
    163    ASSERT_TRUE(shares.Read(i, 2, &tmp));
    164    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    165    group_set.AddAndCheckGroup(group);
    166    check_group(group);
    167    ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
    168  }
    169  EXPECT_EQ(shares.len(), i);
    170 }
    171 
    172 void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
    173                                     uint16_t server_epoch) const {
    174  client_->CheckEpochs(server_epoch, client_epoch);
    175  server_->CheckEpochs(client_epoch, server_epoch);
    176 }
    177 
    178 void TlsConnectTestBase::ClearStats() {
    179  // Clear statistics.
    180  SSL3Statistics* stats = SSL_GetStatistics();
    181  memset(stats, 0, sizeof(*stats));
    182 }
    183 
    184 void TlsConnectTestBase::ClearServerCache() {
    185  SSL_ShutdownServerSessionIDCache();
    186  SSLInt_ClearSelfEncryptKey();
    187  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
    188 }
    189 
    190 void TlsConnectTestBase::SaveAlgorithmPolicy() {
    191  saved_policies_.clear();
    192  for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) {
    193    uint32_t policy;
    194    SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy);
    195    ASSERT_EQ(SECSuccess, rv);
    196    saved_policies_.push_back(std::make_tuple(*it, policy));
    197  }
    198  saved_options_.clear();
    199  for (auto it : options_) {
    200    int32_t option;
    201    SECStatus rv = NSS_OptionGet(it, &option);
    202    ASSERT_EQ(SECSuccess, rv);
    203    saved_options_.push_back(std::make_tuple(it, option));
    204  }
    205 }
    206 
    207 void TlsConnectTestBase::RestoreAlgorithmPolicy() {
    208  for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) {
    209    auto algorithm = std::get<0>(*it);
    210    auto policy = std::get<1>(*it);
    211    SECStatus rv = NSS_SetAlgorithmPolicy(
    212        algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX);
    213    ASSERT_EQ(SECSuccess, rv);
    214  }
    215  for (auto it = saved_options_.begin(); it != saved_options_.end(); ++it) {
    216    auto option_id = std::get<0>(*it);
    217    auto option = std::get<1>(*it);
    218    SECStatus rv = NSS_OptionSet(option_id, option);
    219    ASSERT_EQ(SECSuccess, rv);
    220  }
    221 }
    222 
    223 PRTime TlsConnectTestBase::TimeFunc(void* arg) {
    224  return *reinterpret_cast<PRTime*>(arg);
    225 }
    226 
    227 void TlsConnectTestBase::SetUp() {
    228  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
    229  SSLInt_ClearSelfEncryptKey();
    230  now_ = PR_Now();
    231  ResetAntiReplay(kAntiReplayWindow);
    232  ClearStats();
    233  SaveAlgorithmPolicy();
    234  Init();
    235 }
    236 
    237 void TlsConnectTestBase::TearDown() {
    238  client_ = nullptr;
    239  server_ = nullptr;
    240 
    241  SSL_ClearSessionCache();
    242  SSLInt_ClearSelfEncryptKey();
    243  SSL_ShutdownServerSessionIDCache();
    244  RestoreAlgorithmPolicy();
    245 }
    246 
    247 void TlsConnectTestBase::Init() {
    248  client_->SetPeer(server_);
    249  server_->SetPeer(client_);
    250 
    251  if (version_) {
    252    ConfigureVersion(version_);
    253  }
    254 }
    255 
    256 void TlsConnectTestBase::ResetAntiReplay(PRTime window) {
    257  SSLAntiReplayContext* p_anti_replay = nullptr;
    258  EXPECT_EQ(SECSuccess,
    259            SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay));
    260  EXPECT_NE(nullptr, p_anti_replay);
    261  anti_replay_.reset(p_anti_replay);
    262 }
    263 
    264 ScopedSECItem TlsConnectTestBase::MakeEcKeyParams(SSLNamedGroup group) {
    265  auto groupDef = ssl_LookupNamedGroup(group);
    266  EXPECT_NE(nullptr, groupDef);
    267 
    268  auto oidData = SECOID_FindOIDByTag(groupDef->oidTag);
    269  EXPECT_NE(nullptr, oidData);
    270  ScopedSECItem params(
    271      SECITEM_AllocItem(nullptr, nullptr, (2 + oidData->oid.len)));
    272  EXPECT_TRUE(!!params);
    273  params->data[0] = SEC_ASN1_OBJECT_ID;
    274  params->data[1] = oidData->oid.len;
    275  memcpy(params->data + 2, oidData->oid.data, oidData->oid.len);
    276  return params;
    277 }
    278 
    279 void TlsConnectTestBase::GenerateEchConfig(
    280    HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
    281    const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
    282    ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey) {
    283  bool gen_keys = !pubKey && !privKey;
    284 
    285  SECKEYPublicKey* pub = nullptr;
    286  SECKEYPrivateKey* priv = nullptr;
    287 
    288  if (gen_keys) {
    289    ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519);
    290    priv = SECKEY_CreateECPrivateKey(ecParams.get(), &pub, nullptr);
    291  } else {
    292    priv = privKey.get();
    293    pub = pubKey.get();
    294  }
    295  ASSERT_NE(nullptr, priv);
    296  PRUint8 encoded[1024];
    297  unsigned int encoded_len = 0;
    298  SECStatus rv = SSL_EncodeEchConfigId(
    299      77, public_name.c_str(), max_name_len, kem_id, pub, cipher_suites.data(),
    300      cipher_suites.size(), encoded, &encoded_len, sizeof(encoded));
    301  EXPECT_EQ(SECSuccess, rv);
    302  EXPECT_GT(encoded_len, 0U);
    303 
    304  if (gen_keys) {
    305    pubKey.reset(pub);
    306    privKey.reset(priv);
    307  }
    308  record.Truncate(0);
    309  record.Write(0, encoded, encoded_len);
    310 }
    311 
    312 void TlsConnectTestBase::SetupEch(std::shared_ptr<TlsAgent>& client,
    313                                  std::shared_ptr<TlsAgent>& server,
    314                                  HpkeKemId kem_id, bool expect_ech,
    315                                  bool set_client_config,
    316                                  bool set_server_config, int max_name_len) {
    317  EXPECT_TRUE(set_server_config || set_client_config);
    318  ScopedSECKEYPublicKey pub;
    319  ScopedSECKEYPrivateKey priv;
    320  DataBuffer record;
    321  static const std::vector<HpkeSymmetricSuite> kDefaultSuites = {
    322      {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305},
    323      {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
    324 
    325  GenerateEchConfig(kem_id, kDefaultSuites, "public.name", max_name_len, record,
    326                    pub, priv);
    327  ASSERT_NE(0U, record.len());
    328  SECStatus rv;
    329  if (set_server_config) {
    330    rv = SSL_SetServerEchConfigs(server->ssl_fd(), pub.get(), priv.get(),
    331                                 record.data(), record.len());
    332    ASSERT_EQ(SECSuccess, rv);
    333  }
    334  if (set_client_config) {
    335    rv = SSL_SetClientEchConfigs(client->ssl_fd(), record.data(), record.len());
    336    ASSERT_EQ(SECSuccess, rv);
    337  }
    338 
    339  /* Filter expect_ech, which typically defaults to true. Parameterized tests
    340   * running DTLS or TLS < 1.3 should expect only a non-ECH result. */
    341  bool expect = expect_ech && variant_ != ssl_variant_datagram &&
    342                version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && set_client_config &&
    343                set_server_config;
    344  client->ExpectEch(expect);
    345  server->ExpectEch(expect);
    346 }
    347 
    348 void TlsConnectTestBase::Reset() {
    349  // Take a copy of the names because they are about to disappear.
    350  std::string server_name = server_->name();
    351  std::string client_name = client_->name();
    352  Reset(server_name, client_name);
    353 }
    354 
    355 void TlsConnectTestBase::Reset(const std::string& server_name,
    356                               const std::string& client_name) {
    357  auto token = client_->GetResumptionToken();
    358  client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
    359  client_->SetResumptionToken(token);
    360  server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
    361  if (skip_version_checks_) {
    362    client_->SkipVersionChecks();
    363    server_->SkipVersionChecks();
    364  }
    365 
    366  std::cerr << "Reset server:" << server_name << ", client:" << client_name
    367            << std::endl;
    368  Init();
    369 }
    370 
    371 void TlsConnectTestBase::MakeNewServer() {
    372  auto replacement = std::make_shared<TlsAgent>(
    373      server_->name(), TlsAgent::SERVER, server_->variant());
    374  server_ = replacement;
    375  if (version_) {
    376    server_->SetVersionRange(version_, version_);
    377  }
    378  client_->SetPeer(server_);
    379  server_->SetPeer(client_);
    380  server_->StartConnect();
    381 }
    382 
    383 void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
    384                                          uint8_t num_resumptions) {
    385  expected_resumption_mode_ = expected;
    386  if (expected != RESUME_NONE) {
    387    client_->ExpectResumption();
    388    server_->ExpectResumption();
    389    expected_resumptions_ = num_resumptions;
    390  }
    391  EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
    392 }
    393 
    394 void TlsConnectTestBase::EnsureTlsSetup() {
    395  EXPECT_TRUE(server_->EnsureTlsSetup(
    396      server_model_ ? server_model_->ssl_fd().get() : nullptr));
    397  EXPECT_TRUE(client_->EnsureTlsSetup(
    398      client_model_ ? client_model_->ssl_fd().get() : nullptr));
    399  server_->SetAntiReplayContext(anti_replay_);
    400  EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(),
    401                                        TlsConnectTestBase::TimeFunc, &now_));
    402  EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(),
    403                                        TlsConnectTestBase::TimeFunc, &now_));
    404 }
    405 
    406 void TlsConnectTestBase::Handshake() {
    407  client_->SetServerKeyBits(server_->server_key_bits());
    408  client_->Handshake();
    409  server_->Handshake();
    410 
    411  ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
    412                       (server_->state() != TlsAgent::STATE_CONNECTING),
    413                   5000);
    414 }
    415 
    416 void TlsConnectTestBase::EnableExtendedMasterSecret() {
    417  client_->EnableExtendedMasterSecret();
    418  server_->EnableExtendedMasterSecret();
    419  ExpectExtendedMasterSecret(true);
    420 }
    421 
    422 void TlsConnectTestBase::Connect() {
    423  StartConnect();
    424  client_->MaybeSetResumptionToken();
    425  Handshake();
    426  CheckConnected();
    427 }
    428 
    429 void TlsConnectTestBase::StartConnect() {
    430  EnsureTlsSetup();
    431  server_->StartConnect();
    432  client_->StartConnect();
    433 }
    434 
    435 void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
    436  EnsureTlsSetup();
    437  client_->EnableSingleCipher(cipher_suite);
    438 
    439  Connect();
    440  SendReceive();
    441 
    442  // Check that we used the right cipher suite.
    443  uint16_t actual;
    444  EXPECT_TRUE(client_->cipher_suite(&actual));
    445  EXPECT_EQ(cipher_suite, actual);
    446  EXPECT_TRUE(server_->cipher_suite(&actual));
    447  EXPECT_EQ(cipher_suite, actual);
    448 }
    449 
    450 void TlsConnectTestBase::CheckConnected() {
    451  // Have the client read handshake twice to make sure we get the
    452  // NST and the ACK.
    453  if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
    454      variant_ == ssl_variant_datagram) {
    455    client_->Handshake();
    456    client_->Handshake();
    457    auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
    458    // Verify that we dropped the client's retransmission cipher suites.
    459    EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
    460    if (suites != 2) {
    461      SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
    462    }
    463  }
    464  EXPECT_EQ(client_->version(), server_->version());
    465  if (!skip_version_checks_) {
    466    // Check the version is as expected
    467    EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
    468              client_->version());
    469  }
    470 
    471  EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    472  EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
    473 
    474  uint16_t cipher_suite1, cipher_suite2;
    475  ASSERT_TRUE(client_->cipher_suite(&cipher_suite1));
    476  ASSERT_TRUE(server_->cipher_suite(&cipher_suite2));
    477  EXPECT_EQ(cipher_suite1, cipher_suite2);
    478 
    479  std::cerr << "Connected with version " << client_->version()
    480            << " cipher suite " << client_->cipher_suite_name() << std::endl;
    481 
    482  if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
    483    // Check and store session ids.
    484    std::vector<uint8_t> sid_c1 = client_->session_id();
    485    EXPECT_EQ(32U, sid_c1.size());
    486    std::vector<uint8_t> sid_s1 = server_->session_id();
    487    EXPECT_EQ(32U, sid_s1.size());
    488    EXPECT_EQ(sid_c1, sid_s1);
    489    session_ids_.push_back(sid_c1);
    490  }
    491 
    492  CheckExtendedMasterSecret();
    493  CheckEarlyDataAccepted();
    494  CheckResumption(expected_resumption_mode_);
    495  client_->CheckSecretsDestroyed();
    496  server_->CheckSecretsDestroyed();
    497 }
    498 
    499 void TlsConnectTestBase::CheckEarlyDataLimit(
    500    const std::shared_ptr<TlsAgent>& agent, size_t expected_size) {
    501  SSLPreliminaryChannelInfo preinfo;
    502  SECStatus rv =
    503      SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
    504  EXPECT_EQ(SECSuccess, rv);
    505  EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
    506 }
    507 
    508 SSLKEAType TlsConnectTestBase::GetDefaultKEA(void) const {
    509  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    510    return ssl_kea_ecdh_hybrid;
    511  }
    512  return ssl_kea_ecdh;
    513 }
    514 
    515 SSLAuthType TlsConnectTestBase::GetDefaultAuth(void) const {
    516  return ssl_auth_rsa_sign;
    517 }
    518 
    519 SSLNamedGroup TlsConnectTestBase::GetDefaultGroupFromKEA(
    520    SSLKEAType kea_type) const {
    521  SSLNamedGroup group;
    522  switch (kea_type) {
    523    case ssl_kea_ecdh_hybrid:
    524      group = ssl_grp_kem_mlkem768x25519;
    525      break;
    526    case ssl_kea_ecdh:
    527      group = ssl_grp_ec_curve25519;
    528      break;
    529    case ssl_kea_dh:
    530      group = ssl_grp_ffdhe_2048;
    531      break;
    532    case ssl_kea_rsa:
    533      group = ssl_grp_none;
    534      break;
    535    default:
    536      EXPECT_TRUE(false) << "unexpected KEA";
    537      group = ssl_grp_none;
    538      break;
    539  }
    540  return group;
    541 }
    542 
    543 SSLSignatureScheme TlsConnectTestBase::GetDefaultSchemeFromAuth(
    544    SSLAuthType auth_type) const {
    545  SSLSignatureScheme scheme;
    546  switch (auth_type) {
    547    case ssl_auth_rsa_decrypt:
    548      scheme = ssl_sig_none;
    549      break;
    550    case ssl_auth_rsa_sign:
    551      if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
    552        scheme = ssl_sig_rsa_pss_rsae_sha256;
    553      } else {
    554        scheme = ssl_sig_rsa_pkcs1_sha256;
    555      }
    556      break;
    557    case ssl_auth_rsa_pss:
    558      scheme = ssl_sig_rsa_pss_rsae_sha256;
    559      break;
    560    case ssl_auth_ecdsa:
    561      scheme = ssl_sig_ecdsa_secp256r1_sha256;
    562      break;
    563    case ssl_auth_dsa:
    564      scheme = ssl_sig_dsa_sha1;
    565      break;
    566    default:
    567      EXPECT_TRUE(false) << "unexpected auth type";
    568      scheme = static_cast<SSLSignatureScheme>(0x0100);
    569      break;
    570  }
    571  return scheme;
    572 }
    573 
    574 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
    575                                   SSLAuthType auth_type,
    576                                   SSLSignatureScheme sig_scheme) const {
    577  if (kea_group != ssl_grp_none) {
    578    client_->CheckKEA(kea_type, kea_group);
    579    server_->CheckKEA(kea_type, kea_group);
    580  }
    581  server_->CheckAuthType(auth_type, sig_scheme);
    582  client_->CheckAuthType(auth_type, sig_scheme);
    583 }
    584 
    585 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
    586                                   SSLNamedGroup kea_group) const {
    587  SSLAuthType auth_type = GetDefaultAuth();
    588  SSLSignatureScheme scheme = GetDefaultSchemeFromAuth(auth_type);
    589  CheckKeys(kea_type, kea_group, auth_type, scheme);
    590 }
    591 
    592 void TlsConnectTestBase::CheckKeys(SSLAuthType auth_type,
    593                                   SSLSignatureScheme sig_scheme) const {
    594  SSLKEAType kea_type = GetDefaultKEA();
    595  SSLNamedGroup group = GetDefaultGroupFromKEA(kea_type);
    596  CheckKeys(kea_type, group, auth_type, sig_scheme);
    597 }
    598 
    599 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
    600                                   SSLAuthType auth_type) const {
    601  SSLNamedGroup group = GetDefaultGroupFromKEA(kea_type);
    602  SSLSignatureScheme scheme = GetDefaultSchemeFromAuth(auth_type);
    603 
    604  CheckKeys(kea_type, group, auth_type, scheme);
    605 }
    606 
    607 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type) const {
    608  SSLAuthType auth_type = GetDefaultAuth();
    609  CheckKeys(kea_type, auth_type);
    610 }
    611 
    612 void TlsConnectTestBase::CheckKeys(SSLAuthType auth_type) const {
    613  SSLKEAType kea_type = GetDefaultKEA();
    614  CheckKeys(kea_type, auth_type);
    615 }
    616 
    617 void TlsConnectTestBase::CheckKeys() const {
    618  SSLKEAType kea_type = GetDefaultKEA();
    619  SSLAuthType auth_type = GetDefaultAuth();
    620  CheckKeys(kea_type, auth_type);
    621 }
    622 
    623 void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
    624                                             SSLNamedGroup kea_group,
    625                                             SSLNamedGroup original_kea_group,
    626                                             SSLAuthType auth_type,
    627                                             SSLSignatureScheme sig_scheme) {
    628  CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
    629  EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
    630  client_->CheckOriginalKEA(original_kea_group);
    631  server_->CheckOriginalKEA(original_kea_group);
    632 }
    633 
    634 void TlsConnectTestBase::ConnectExpectFail() {
    635  StartConnect();
    636  Handshake();
    637  ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
    638  ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
    639 }
    640 
    641 void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
    642                                     uint8_t alert) {
    643  EnsureTlsSetup();
    644  auto receiver = (sender == client_) ? server_ : client_;
    645  sender->ExpectSendAlert(alert);
    646  receiver->ExpectReceiveAlert(alert);
    647 }
    648 
    649 void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
    650                                            uint8_t alert) {
    651  ExpectAlert(sender, alert);
    652  ConnectExpectFail();
    653 }
    654 
    655 void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
    656  StartConnect();
    657  client_->SetServerKeyBits(server_->server_key_bits());
    658  client_->Handshake();
    659  server_->Handshake();
    660 
    661  auto failing_agent = server_;
    662  if (failing_side == TlsAgent::CLIENT) {
    663    failing_agent = client_;
    664  }
    665  ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
    666 }
    667 
    668 void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
    669  version_ = version;
    670  client_->SetVersionRange(version, version);
    671  server_->SetVersionRange(version, version);
    672 }
    673 
    674 void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
    675  client_->SetExpectedVersion(version);
    676  server_->SetExpectedVersion(version);
    677 }
    678 
    679 void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label,
    680                                SSLHashType hash, uint16_t zeroRttSuite) {
    681  client_->AddPsk(psk, label, hash, zeroRttSuite);
    682  server_->AddPsk(psk, label, hash, zeroRttSuite);
    683  client_->ExpectPsk();
    684  server_->ExpectPsk();
    685 }
    686 
    687 void TlsConnectTestBase::DisableAllCiphers() {
    688  EnsureTlsSetup();
    689  client_->DisableAllCiphers();
    690  server_->DisableAllCiphers();
    691 }
    692 
    693 void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
    694  DisableAllCiphers();
    695 
    696  client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
    697  server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
    698 }
    699 
    700 void TlsConnectTestBase::EnableOnlyDheCiphers() {
    701  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    702    DisableAllCiphers();
    703    client_->EnableCiphersByKeyExchange(ssl_kea_dh);
    704    server_->EnableCiphersByKeyExchange(ssl_kea_dh);
    705  } else {
    706    client_->ConfigNamedGroups(kFFDHEGroups);
    707    server_->ConfigNamedGroups(kFFDHEGroups);
    708  }
    709 }
    710 
    711 void TlsConnectTestBase::EnableSomeEcdhCiphers() {
    712  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    713    client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    714    client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
    715    server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    716    server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
    717  } else {
    718    client_->ConfigNamedGroups(kECDHEGroups);
    719    server_->ConfigNamedGroups(kECDHEGroups);
    720  }
    721 }
    722 
    723 void TlsConnectTestBase::ConfigureSelfEncrypt() {
    724  ScopedCERTCertificate cert;
    725  ScopedSECKEYPrivateKey privKey;
    726  ASSERT_TRUE(
    727      TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
    728 
    729  ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
    730  ASSERT_TRUE(pubKey);
    731 
    732  EXPECT_EQ(SECSuccess,
    733            SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
    734 }
    735 
    736 void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
    737                                               SessionResumptionMode server) {
    738  client_->ConfigureSessionCache(client);
    739  server_->ConfigureSessionCache(server);
    740  if ((server & RESUME_TICKET) != 0) {
    741    ConfigureSelfEncrypt();
    742  }
    743 }
    744 
    745 void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
    746  EXPECT_NE(RESUME_BOTH, expected);
    747 
    748  int resume_count = expected ? expected_resumptions_ : 0;
    749  int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
    750 
    751  // Note: hch == server counter; hsh == client counter.
    752  SSL3Statistics* stats = SSL_GetStatistics();
    753  EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
    754  EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);
    755 
    756  EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
    757  EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);
    758 
    759  if (expected != RESUME_NONE) {
    760    if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
    761        client_->GetResumptionToken().size() == 0) {
    762      // Check that the last two session ids match.
    763      ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
    764      EXPECT_EQ(session_ids_[session_ids_.size() - 1],
    765                session_ids_[session_ids_.size() - 2]);
    766    } else {
    767      // We've either chosen TLS 1.3 or are using an external resumption token,
    768      // both of which only use tickets.
    769      EXPECT_TRUE(expected & RESUME_TICKET);
    770    }
    771  }
    772 }
    773 
    774 static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd,
    775                                         const unsigned char* protos,
    776                                         unsigned int protos_len,
    777                                         unsigned char* protoOut,
    778                                         unsigned int* protoOutLen,
    779                                         unsigned int protoMaxLen) {
    780  EXPECT_EQ(protoMaxLen, 255U);
    781  TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    782  // Check that agent->alpn_value_to_use_ is in protos.
    783  if (protos_len < 1) {
    784    return SECFailure;
    785  }
    786  for (size_t i = 0; i < protos_len;) {
    787    size_t l = protos[i];
    788    EXPECT_LT(i + l, protos_len);
    789    if (i + l >= protos_len) {
    790      return SECFailure;
    791    }
    792    std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l);
    793    if (protos_s == agent->alpn_value_to_use_) {
    794      size_t s_len = agent->alpn_value_to_use_.size();
    795      EXPECT_LE(s_len, 255U);
    796      memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len);
    797      *protoOutLen = s_len;
    798      return SECSuccess;
    799    }
    800    i += l + 1;
    801  }
    802  return SECFailure;
    803 }
    804 
    805 void TlsConnectTestBase::EnableAlpn() {
    806  client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
    807  server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
    808 }
    809 
    810 void TlsConnectTestBase::EnableAlpnWithCallback(
    811    const std::vector<uint8_t>& client_vals, std::string server_choice) {
    812  EnsureTlsSetup();
    813  server_->alpn_value_to_use_ = server_choice;
    814  EXPECT_EQ(SECSuccess,
    815            SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(),
    816                                 client_vals.size()));
    817  SECStatus rv = SSL_SetNextProtoCallback(
    818      server_->ssl_fd(), NextProtoCallbackServer, server_.get());
    819  EXPECT_EQ(SECSuccess, rv);
    820 }
    821 
    822 void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) {
    823  client_->EnableAlpn(vals.data(), vals.size());
    824  server_->EnableAlpn(vals.data(), vals.size());
    825 }
    826 
    827 void TlsConnectTestBase::EnsureModelSockets() {
    828  // Make sure models agents are available.
    829  if (!client_model_) {
    830    ASSERT_EQ(server_model_, nullptr);
    831    client_model_.reset(
    832        new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
    833    server_model_.reset(
    834        new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
    835    if (skip_version_checks_) {
    836      client_model_->SkipVersionChecks();
    837      server_model_->SkipVersionChecks();
    838    }
    839  }
    840 }
    841 
    842 void TlsConnectTestBase::CheckAlpn(const std::string& val) {
    843  client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
    844  server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
    845 }
    846 
    847 void TlsConnectTestBase::EnableSrtp() {
    848  client_->EnableSrtp();
    849  server_->EnableSrtp();
    850 }
    851 
    852 void TlsConnectTestBase::CheckSrtp() const {
    853  client_->CheckSrtp();
    854  server_->CheckSrtp();
    855 }
    856 
    857 void TlsConnectTestBase::SendReceive(size_t total) {
    858  ASSERT_GT(total, client_->received_bytes());
    859  ASSERT_GT(total, server_->received_bytes());
    860  client_->SendData(total - server_->received_bytes());
    861  server_->SendData(total - client_->received_bytes());
    862  Receive(total);  // Receive() is cumulative
    863 }
    864 
    865 // Do a first connection so we can do 0-RTT on the second one.
    866 void TlsConnectTestBase::SetupForZeroRtt() {
    867  // Force rollover of the anti-replay window.
    868  // If we don't do this, then all 0-RTT attempts will be rejected.
    869  RolloverAntiReplay();
    870 
    871  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
    872  ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
    873  server_->Set0RttEnabled(true);  // So we signal that we allow 0-RTT.
    874  Connect();
    875  SendReceive();  // Need to read so that we absorb the session ticket.
    876  CheckKeys();
    877 
    878  Reset();
    879  StartConnect();
    880 }
    881 
    882 // Do a first connection so we can do resumption
    883 void TlsConnectTestBase::SetupForResume() {
    884  EnsureTlsSetup();
    885  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
    886  Connect();
    887  SendReceive();  // Need to read so that we absorb the session ticket.
    888  CheckKeys();
    889 
    890  Reset();
    891 }
    892 
    893 void TlsConnectTestBase::ZeroRttSendReceive(
    894    bool expect_writable, bool expect_readable,
    895    std::function<bool()> post_clienthello_check) {
    896  const char* k0RttData = "ABCDEF";
    897  const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
    898 
    899  client_->Handshake();  // Send ClientHello.
    900  if (post_clienthello_check) {
    901    if (!post_clienthello_check()) return;
    902  }
    903  PRInt32 rv =
    904      PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen);  // 0-RTT write.
    905  if (expect_writable) {
    906    EXPECT_EQ(k0RttDataLen, rv);
    907  } else {
    908    EXPECT_EQ(SECFailure, rv);
    909  }
    910  server_->Handshake();  // Consume ClientHello
    911 
    912  std::vector<uint8_t> buf(k0RttDataLen);
    913  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);  // 0-RTT read
    914  if (expect_readable) {
    915    std::cerr << "0-RTT read " << rv << " bytes\n";
    916    EXPECT_EQ(k0RttDataLen, rv);
    917  } else {
    918    EXPECT_EQ(SECFailure, rv);
    919    EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
    920        << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
    921  }
    922 
    923  // Do a second read. This should fail.
    924  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
    925  EXPECT_EQ(SECFailure, rv);
    926  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
    927 }
    928 
    929 void TlsConnectTestBase::Receive(size_t amount) {
    930  WAIT_(client_->received_bytes() == amount &&
    931            server_->received_bytes() == amount,
    932        2000);
    933  ASSERT_EQ(amount, client_->received_bytes());
    934  ASSERT_EQ(amount, server_->received_bytes());
    935 }
    936 
    937 void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
    938  expect_extended_master_secret_ = expected;
    939 }
    940 
    941 void TlsConnectTestBase::CheckExtendedMasterSecret() {
    942  client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
    943  server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
    944 }
    945 
    946 void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
    947  expect_early_data_accepted_ = expected;
    948 }
    949 
    950 void TlsConnectTestBase::CheckEarlyDataAccepted() {
    951  client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
    952  server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
    953 }
    954 
    955 void TlsConnectTestBase::EnableECDHEServerKeyReuse() {
    956  server_->EnableECDHEServerKeyReuse();
    957 }
    958 
    959 void TlsConnectTestBase::SkipVersionChecks() {
    960  skip_version_checks_ = true;
    961  client_->SkipVersionChecks();
    962  server_->SkipVersionChecks();
    963 }
    964 
    965 // Shift the DTLS timers, to the minimum time necessary to let the next timer
    966 // run on either client or server.  This allows tests to skip waiting without
    967 // having timers run out of order.
    968 void TlsConnectTestBase::ShiftDtlsTimers() {
    969  PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
    970  PRIntervalTime time;
    971  SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
    972  if (rv == SECSuccess) {
    973    time_shift = time;
    974  }
    975  rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
    976  if (rv == SECSuccess &&
    977      (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
    978    time_shift = time;
    979  }
    980 
    981  if (time_shift != PR_INTERVAL_NO_TIMEOUT) {
    982    AdvanceTime(PR_IntervalToMicroseconds(time_shift));
    983    EXPECT_EQ(SECSuccess,
    984              SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
    985    EXPECT_EQ(SECSuccess,
    986              SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
    987  }
    988 }
    989 
    990 void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; }
    991 
    992 // Advance time by a full anti-replay window.
    993 void TlsConnectTestBase::RolloverAntiReplay() {
    994  AdvanceTime(kAntiReplayWindow);
    995 }
    996 
    997 TlsConnectGeneric::TlsConnectGeneric()
    998    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
    999 
   1000 TlsConnectPre12::TlsConnectPre12()
   1001    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
   1002 
   1003 TlsConnectTls12::TlsConnectTls12()
   1004    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}
   1005 
   1006 TlsConnectTls12Plus::TlsConnectTls12Plus()
   1007    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
   1008 
   1009 TlsConnectTls13::TlsConnectTls13()
   1010    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
   1011 
   1012 TlsConnectGenericResumption::TlsConnectGenericResumption()
   1013    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
   1014      external_cache_(std::get<2>(GetParam())) {}
   1015 
   1016 TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken()
   1017    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
   1018 
   1019 TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
   1020    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
   1021 
   1022 void TlsKeyExchangeTest::EnsureKeyShareSetup() {
   1023  EnsureTlsSetup();
   1024  groups_capture_ =
   1025      std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
   1026  shares_capture_ =
   1027      std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
   1028  shares_capture2_ = std::make_shared<TlsExtensionCapture>(
   1029      client_, ssl_tls13_key_share_xtn, true);
   1030  std::vector<std::shared_ptr<PacketFilter>> captures = {
   1031      groups_capture_, shares_capture_, shares_capture2_};
   1032  client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
   1033  capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>(
   1034      server_, kTlsHandshakeHelloRetryRequest);
   1035 }
   1036 
   1037 void TlsKeyExchangeTest::ConfigNamedGroups(
   1038    const std::vector<SSLNamedGroup>& groups) {
   1039  client_->ConfigNamedGroups(groups);
   1040  server_->ConfigNamedGroups(groups);
   1041 }
   1042 
   1043 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
   1044    const std::shared_ptr<TlsExtensionCapture>& capture) {
   1045  EXPECT_TRUE(capture->captured());
   1046  const DataBuffer& ext = capture->extension();
   1047 
   1048  uint32_t tmp = 0;
   1049  EXPECT_TRUE(ext.Read(0, 2, &tmp));
   1050  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
   1051  EXPECT_TRUE(ext.len() % 2 == 0);
   1052 
   1053  std::vector<SSLNamedGroup> groups;
   1054  for (size_t i = 1; i < ext.len() / 2; i += 1) {
   1055    EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
   1056    groups.push_back(static_cast<SSLNamedGroup>(tmp));
   1057  }
   1058  return groups;
   1059 }
   1060 
   1061 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
   1062    const std::shared_ptr<TlsExtensionCapture>& capture) {
   1063  EXPECT_TRUE(capture->captured());
   1064  const DataBuffer& ext = capture->extension();
   1065 
   1066  uint32_t tmp = 0;
   1067  EXPECT_TRUE(ext.Read(0, 2, &tmp));
   1068  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
   1069 
   1070  std::vector<SSLNamedGroup> shares;
   1071  size_t i = 2;
   1072  while (i < ext.len()) {
   1073    EXPECT_TRUE(ext.Read(i, 2, &tmp));
   1074    shares.push_back(static_cast<SSLNamedGroup>(tmp));
   1075    EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
   1076    i += 4 + tmp;
   1077  }
   1078  EXPECT_EQ(ext.len(), i);
   1079  return shares;
   1080 }
   1081 
   1082 void TlsKeyExchangeTest::CheckKEXDetails(
   1083    const std::vector<SSLNamedGroup>& expected_groups,
   1084    const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
   1085  std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
   1086  EXPECT_EQ(expected_groups, groups);
   1087 
   1088  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
   1089    ASSERT_LT(0U, expected_shares.size());
   1090    std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
   1091    EXPECT_EQ(expected_shares, shares);
   1092  } else {
   1093    EXPECT_FALSE(shares_capture_->captured());
   1094  }
   1095 
   1096  EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
   1097 }
   1098 
   1099 void TlsKeyExchangeTest::CheckKEXDetails(
   1100    const std::vector<SSLNamedGroup>& expected_groups,
   1101    const std::vector<SSLNamedGroup>& expected_shares) {
   1102  CheckKEXDetails(expected_groups, expected_shares, false);
   1103 }
   1104 
   1105 void TlsKeyExchangeTest::CheckKEXDetails(
   1106    const std::vector<SSLNamedGroup>& expected_groups,
   1107    const std::vector<SSLNamedGroup>& expected_shares,
   1108    SSLNamedGroup expected_share2) {
   1109  CheckKEXDetails(expected_groups, expected_shares, true);
   1110 
   1111  for (auto it : expected_shares) {
   1112    EXPECT_NE(expected_share2, it);
   1113  }
   1114  std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
   1115  EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
   1116 }
   1117 }  // namespace nss_test