tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tls_agent.cc (50779B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "tls_agent.h"
      8 #include "databuffer.h"
      9 #include "keyhi.h"
     10 #include "pk11func.h"
     11 #include "ssl.h"
     12 #include "sslerr.h"
     13 #include "sslexp.h"
     14 #include "sslproto.h"
     15 #include "tls_filter.h"
     16 #include "tls_parser.h"
     17 
     18 extern "C" {
     19 // This is not something that should make you happy.
     20 #include "libssl_internals.h"
     21 }
     22 
     23 #define GTEST_HAS_RTTI 0
     24 #include "gtest/gtest.h"
     25 #include "gtest_utils.h"
     26 #include "nss_scoped_ptrs.h"
     27 
     28 extern std::string g_working_dir_path;
     29 
     30 namespace nss_test {
     31 
     32 const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
     33 
     34 const std::string TlsAgent::kClient = "client";    // both sign and encrypt
     35 const std::string TlsAgent::kRsa2048 = "rsa2048";  // bigger
     36 const std::string TlsAgent::kRsa8192 = "rsa8192";  // biggest allowed
     37 const std::string TlsAgent::kServerRsa = "rsa";    // both sign and encrypt
     38 const std::string TlsAgent::kServerRsaSign = "rsa_sign";
     39 const std::string TlsAgent::kServerRsaPss = "rsa_pss";
     40 const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
     41 const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
     42 const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
     43 const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
     44 const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
     45 const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
     46 const std::string TlsAgent::kServerDsa = "dsa";
     47 const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256";
     48 const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048";
     49 const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048";
     50 
     51 static const uint8_t kCannedTls13ServerHello[] = {
     52    0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
     53    0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
     54    0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
     55    0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
     56    0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
     57    0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
     58    0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
     59    0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};
     60 
     61 TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
     62    : name_(nm),
     63      variant_(var),
     64      role_(rl),
     65      server_key_bits_(0),
     66      adapter_(new DummyPrSocket(role_str(), var)),
     67      ssl_fd_(nullptr),
     68      state_(STATE_INIT),
     69      timer_handle_(nullptr),
     70      falsestart_enabled_(false),
     71      expected_version_(0),
     72      expected_cipher_suite_(0),
     73      expect_client_auth_(false),
     74      expect_ech_(false),
     75      expect_psk_(ssl_psk_none),
     76      can_falsestart_hook_called_(false),
     77      sni_hook_called_(false),
     78      auth_certificate_hook_called_(false),
     79      expected_received_alert_(kTlsAlertCloseNotify),
     80      expected_received_alert_level_(kTlsAlertWarning),
     81      expected_sent_alert_(kTlsAlertCloseNotify),
     82      expected_sent_alert_level_(kTlsAlertWarning),
     83      handshake_callback_called_(false),
     84      resumption_callback_called_(false),
     85      error_code_(0),
     86      send_ctr_(0),
     87      recv_ctr_(0),
     88      expect_readwrite_error_(false),
     89      handshake_callback_(),
     90      auth_certificate_callback_(),
     91      sni_callback_(),
     92      skip_version_checks_(false),
     93      resumption_token_(),
     94      policy_() {
     95  memset(&info_, 0, sizeof(info_));
     96  memset(&csinfo_, 0, sizeof(csinfo_));
     97  SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
     98  EXPECT_EQ(SECSuccess, rv);
     99 }
    100 
    101 TlsAgent::~TlsAgent() {
    102  if (timer_handle_) {
    103    timer_handle_->Cancel();
    104  }
    105 
    106  if (adapter_) {
    107    Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
    108  }
    109 
    110  // Add failures manually, if any, so we don't throw in a destructor.
    111  if (expected_received_alert_ != kTlsAlertCloseNotify ||
    112      expected_received_alert_level_ != kTlsAlertWarning) {
    113    ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
    114  }
    115  if (expected_sent_alert_ != kTlsAlertCloseNotify ||
    116      expected_sent_alert_level_ != kTlsAlertWarning) {
    117    ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
    118  }
    119 }
    120 
    121 void TlsAgent::SetState(State s) {
    122  if (state_ == s) return;
    123 
    124  LOG("Changing state from " << state_ << " to " << s);
    125  state_ = s;
    126 }
    127 
    128 /*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
    129                                          ScopedCERTCertificate* cert,
    130                                          ScopedSECKEYPrivateKey* priv) {
    131  cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
    132  EXPECT_NE(nullptr, cert);
    133  if (!cert) return false;
    134  EXPECT_NE(nullptr, cert->get());
    135  if (!cert->get()) return false;
    136 
    137  priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
    138  EXPECT_NE(nullptr, priv);
    139  if (!priv) return false;
    140  EXPECT_NE(nullptr, priv->get());
    141  if (!priv->get()) return false;
    142 
    143  return true;
    144 }
    145 
    146 // Loads a key pair from the certificate identified by |id|.
    147 /*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name,
    148                                              ScopedSECKEYPublicKey* pub,
    149                                              ScopedSECKEYPrivateKey* priv) {
    150  ScopedCERTCertificate cert;
    151  if (!TlsAgent::LoadCertificate(name, &cert, priv)) {
    152    return false;
    153  }
    154 
    155  pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo));
    156  if (!pub->get()) {
    157    return false;
    158  }
    159 
    160  return true;
    161 }
    162 
    163 void TlsAgent::DelegateCredential(const std::string& name,
    164                                  const ScopedSECKEYPublicKey& dc_pub,
    165                                  SSLSignatureScheme dc_cert_verify_alg,
    166                                  PRUint32 dc_valid_for, PRTime now,
    167                                  SECItem* dc) {
    168  ScopedCERTCertificate cert;
    169  ScopedSECKEYPrivateKey cert_priv;
    170  EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv))
    171      << "Could not load delegate certificate: " << name
    172      << "; test db corrupt?";
    173 
    174  EXPECT_EQ(SECSuccess,
    175            SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(),
    176                                   dc_cert_verify_alg, dc_valid_for, now, dc));
    177 }
    178 
    179 void TlsAgent::EnableDelegatedCredentials() {
    180  ASSERT_TRUE(EnsureTlsSetup());
    181  SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE);
    182 }
    183 
    184 void TlsAgent::AddDelegatedCredential(const std::string& dc_name,
    185                                      SSLSignatureScheme dc_cert_verify_alg,
    186                                      PRUint32 dc_valid_for, PRTime now) {
    187  ASSERT_TRUE(EnsureTlsSetup());
    188 
    189  ScopedSECKEYPublicKey pub;
    190  ScopedSECKEYPrivateKey priv;
    191  EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv));
    192 
    193  StackSECItem dc;
    194  TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for,
    195                               now, &dc);
    196 
    197  SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
    198                                       nullptr,       &dc,     priv.get()};
    199  EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data));
    200 }
    201 
    202 bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
    203                                const SSLExtraServerCertData* serverCertData) {
    204  ScopedCERTCertificate cert;
    205  ScopedSECKEYPrivateKey priv;
    206  if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
    207    return false;
    208  }
    209 
    210  if (updateKeyBits) {
    211    ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
    212    EXPECT_NE(nullptr, pub.get());
    213    if (!pub.get()) return false;
    214    server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
    215  }
    216 
    217  SECStatus rv =
    218      SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
    219  EXPECT_EQ(SECFailure, rv);
    220  rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
    221                            serverCertData ? sizeof(*serverCertData) : 0);
    222  return rv == SECSuccess;
    223 }
    224 
    225 bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
    226  // Don't set up twice
    227  if (ssl_fd_) return true;
    228  NssManagePolicy policyManage(policy_, option_);
    229 
    230  ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
    231  EXPECT_NE(nullptr, dummy_fd);
    232  if (!dummy_fd) {
    233    return false;
    234  }
    235  if (adapter_->variant() == ssl_variant_stream) {
    236    ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
    237  } else {
    238    ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
    239  }
    240 
    241  EXPECT_NE(nullptr, ssl_fd_);
    242  if (!ssl_fd_) {
    243    return false;
    244  }
    245  dummy_fd.release();  // Now subsumed by ssl_fd_.
    246 
    247  SECStatus rv;
    248  if (!skip_version_checks_) {
    249    rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
    250    EXPECT_EQ(SECSuccess, rv);
    251    if (rv != SECSuccess) return false;
    252  }
    253 
    254  ScopedCERTCertList anchors(CERT_NewCertList());
    255  rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
    256  if (rv != SECSuccess) return false;
    257 
    258  if (role_ == SERVER) {
    259    EXPECT_TRUE(ConfigServerCert(name_, true));
    260 
    261    rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
    262    EXPECT_EQ(SECSuccess, rv);
    263    if (rv != SECSuccess) return false;
    264 
    265    rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
    266    EXPECT_EQ(SECSuccess, rv);
    267    if (rv != SECSuccess) return false;
    268  } else {
    269    rv = SSL_SetURL(ssl_fd(), "server");
    270    EXPECT_EQ(SECSuccess, rv);
    271    if (rv != SECSuccess) return false;
    272  }
    273 
    274  rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
    275  EXPECT_EQ(SECSuccess, rv);
    276  if (rv != SECSuccess) return false;
    277 
    278  rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
    279  EXPECT_EQ(SECSuccess, rv);
    280  if (rv != SECSuccess) return false;
    281 
    282  rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
    283  EXPECT_EQ(SECSuccess, rv);
    284  if (rv != SECSuccess) return false;
    285 
    286  rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
    287  EXPECT_EQ(SECSuccess, rv);
    288  if (rv != SECSuccess) return false;
    289 
    290  // All these tests depend on having this disabled to start with.
    291  SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE);
    292 
    293  return true;
    294 }
    295 
    296 bool TlsAgent::MaybeSetResumptionToken() {
    297  if (!resumption_token_.empty()) {
    298    LOG("setting external resumption token");
    299    SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
    300                                          resumption_token_.size());
    301 
    302    // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
    303    // if the resumption token was bad (expired/malformed/etc.).
    304    if (expect_psk_ == ssl_psk_resume) {
    305      // Only in case we expect resumption this has to be successful. We might
    306      // not expect resumption due to some reason but the token is totally fine.
    307      EXPECT_EQ(SECSuccess, rv);
    308    }
    309    if (rv != SECSuccess) {
    310      EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
    311      resumption_token_.clear();
    312      EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
    313      if (expect_psk_ == ssl_psk_resume) return false;
    314    }
    315  }
    316 
    317  return true;
    318 }
    319 
    320 void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
    321  EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
    322 }
    323 
    324 // Defaults to a Sync callback returning success
    325 void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
    326                               bool callbackSuccess) {
    327  EXPECT_TRUE(EnsureTlsSetup());
    328  ASSERT_EQ(CLIENT, role_);
    329 
    330  client_auth_callback_type_ = callbackType;
    331  client_auth_callback_success_ = callbackSuccess;
    332 
    333  if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
    334    // Don't set a callback for this case.
    335    return;
    336  }
    337  EXPECT_EQ(SECSuccess,
    338            SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
    339                                      reinterpret_cast<void*>(this)));
    340 }
    341 
    342 void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
    343  ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));
    344 
    345  ASSERT_EQ(expected->nnames, caNames->nnames);
    346 
    347  for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
    348    EXPECT_EQ(SECEqual,
    349              SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
    350  }
    351 }
    352 
    353 // Complete processing of Client Certificate Selection
    354 // A No-op if the agent is using synchronous client cert selection.
    355 // Otherwise, calls SSL_ClientCertCallbackComplete.
    356 // kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
    357 // ensure that the socket is correctly blocked.
    358 void TlsAgent::ClientAuthCallbackComplete() {
    359  ASSERT_EQ(CLIENT, role_);
    360 
    361  if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
    362      client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
    363    return;
    364  }
    365  client_auth_callback_fired_++;
    366  EXPECT_TRUE(client_auth_callback_awaiting_);
    367 
    368  std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
    369            << (client_auth_callback_success_ ? "success" : "failed")
    370            << std::endl;
    371 
    372  client_auth_callback_awaiting_ = false;
    373 
    374  if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
    375    std::cerr
    376        << "Running Handshake prior to running SSL_ClientCertCallbackComplete"
    377        << std::endl;
    378    SECStatus rv = SSL_ForceHandshake(ssl_fd());
    379    EXPECT_EQ(rv, SECFailure);
    380    EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
    381  }
    382 
    383  ScopedCERTCertificate cert;
    384  ScopedSECKEYPrivateKey priv;
    385  if (client_auth_callback_success_) {
    386    ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
    387    EXPECT_EQ(SECSuccess,
    388              SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
    389                                             priv.release(), cert.release()));
    390  } else {
    391    EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
    392                                                         nullptr, nullptr));
    393  }
    394 }
    395 
    396 SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
    397                                          CERTDistNames* caNames,
    398                                          CERTCertificate** clientCert,
    399                                          SECKEYPrivateKey** clientKey) {
    400  TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
    401  EXPECT_EQ(CLIENT, agent->role_);
    402  agent->client_auth_callback_fired_++;
    403 
    404  switch (agent->client_auth_callback_type_) {
    405    case ClientAuthCallbackType::kAsyncDelay:
    406    case ClientAuthCallbackType::kAsyncImmediate:
    407      std::cerr << "Waiting for complete call" << std::endl;
    408      agent->client_auth_callback_awaiting_ = true;
    409      return SECWouldBlock;
    410    case ClientAuthCallbackType::kSync:
    411    case ClientAuthCallbackType::kNone:
    412      // Handle the sync case. None && Success is treated as Sync and Success.
    413      if (!agent->client_auth_callback_success_) {
    414        return SECFailure;
    415      }
    416      ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
    417      EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
    418 
    419      // See bug 1573945
    420      // CheckCertReqAgainstDefaultCAs(caNames);
    421 
    422      ScopedCERTCertificate cert;
    423      ScopedSECKEYPrivateKey priv;
    424      if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
    425        return SECFailure;
    426      }
    427 
    428      *clientCert = cert.release();
    429      *clientKey = priv.release();
    430      return SECSuccess;
    431  }
    432  /* This is unreachable, but some old compilers can't tell that. */
    433  PORT_Assert(0);
    434  PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
    435  return SECFailure;
    436 }
    437 
    438 // Increments by 1 for each callback
    439 bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
    440  EXPECT_EQ(CLIENT, role_);
    441  return expected == client_auth_callback_fired_;
    442 }
    443 
    444 bool TlsAgent::GetPeerChainLength(size_t* count) {
    445  SECItemArray* chain = nullptr;
    446  SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &chain);
    447  if (rv != SECSuccess) return false;
    448 
    449  *count = chain->len;
    450 
    451  SECITEM_FreeArray(chain, true);
    452 
    453  return true;
    454 }
    455 
    456 void TlsAgent::CheckPeerChainFunctionConsistency() {
    457  SECItemArray* derChain = nullptr;
    458  SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &derChain);
    459  PRErrorCode err1 = PR_GetError();
    460  CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
    461  PRErrorCode err2 = PR_GetError();
    462  if (rv != SECSuccess) {
    463    ASSERT_EQ(nullptr, chain);
    464    ASSERT_EQ(nullptr, derChain);
    465    ASSERT_EQ(err1, SSL_ERROR_NO_CERTIFICATE);
    466    ASSERT_EQ(err2, SSL_ERROR_NO_CERTIFICATE);
    467    return;
    468  }
    469  ASSERT_NE(nullptr, chain);
    470  ASSERT_NE(nullptr, derChain);
    471 
    472  unsigned int count = 0;
    473  for (PRCList* cursor = PR_NEXT_LINK(&chain->list);
    474       count < derChain->len && cursor != &chain->list;
    475       cursor = PR_NEXT_LINK(cursor)) {
    476    CERTCertListNode* node = (CERTCertListNode*)cursor;
    477    EXPECT_TRUE(
    478        SECITEM_ItemsAreEqual(&node->cert->derCert, &derChain->items[count]));
    479    ++count;
    480  }
    481  ASSERT_EQ(count, derChain->len);
    482 
    483  SECITEM_FreeArray(derChain, true);
    484  CERT_DestroyCertList(chain);
    485 }
    486 
    487 void TlsAgent::CheckCipherSuite(uint16_t suite) {
    488  EXPECT_EQ(csinfo_.cipherSuite, suite);
    489 }
    490 
    491 void TlsAgent::RequestClientAuth(bool requireAuth) {
    492  ASSERT_EQ(SERVER, role_);
    493 
    494  SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
    495  SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
    496 
    497  EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
    498                            ssl_fd(), &TlsAgent::ClientAuthenticated, this));
    499  expect_client_auth_ = true;
    500 }
    501 
    502 void TlsAgent::StartConnect(PRFileDesc* model) {
    503  EXPECT_TRUE(EnsureTlsSetup(model));
    504 
    505  SECStatus rv;
    506  rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
    507  EXPECT_EQ(SECSuccess, rv);
    508  SetState(STATE_CONNECTING);
    509 }
    510 
    511 void TlsAgent::DisableAllCiphers() {
    512  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    513    SECStatus rv =
    514        SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
    515    EXPECT_EQ(SECSuccess, rv);
    516  }
    517 }
    518 
    519 // Not actually all groups, just the ones that we are actually willing
    520 // to use.
    521 const std::vector<SSLNamedGroup> kAllDHEGroups = {
    522    ssl_grp_ec_curve25519,
    523    ssl_grp_ec_secp256r1,
    524    ssl_grp_ec_secp384r1,
    525    ssl_grp_ec_secp521r1,
    526    ssl_grp_ffdhe_2048,
    527    ssl_grp_ffdhe_3072,
    528    ssl_grp_ffdhe_4096,
    529    ssl_grp_ffdhe_6144,
    530    ssl_grp_ffdhe_8192,
    531 #ifndef NSS_DISABLE_KYBER
    532    ssl_grp_kem_xyber768d00,
    533 #endif
    534    ssl_grp_kem_mlkem768x25519,
    535    ssl_grp_kem_secp256r1mlkem768,
    536    ssl_grp_kem_secp384r1mlkem1024,
    537 };
    538 
    539 const std::vector<SSLNamedGroup> kNonPQDHEGroups = {
    540    ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
    541    ssl_grp_ec_secp521r1,  ssl_grp_ffdhe_2048,   ssl_grp_ffdhe_3072,
    542    ssl_grp_ffdhe_4096,    ssl_grp_ffdhe_6144,   ssl_grp_ffdhe_8192,
    543 };
    544 
    545 const std::vector<SSLNamedGroup> kECDHEGroups = {
    546    ssl_grp_ec_curve25519,          ssl_grp_ec_secp256r1,
    547    ssl_grp_ec_secp384r1,           ssl_grp_ec_secp521r1,
    548 #ifndef NSS_DISABLE_KYBER
    549    ssl_grp_kem_xyber768d00,
    550 #endif
    551    ssl_grp_kem_mlkem768x25519,     ssl_grp_kem_secp256r1mlkem768,
    552    ssl_grp_kem_secp384r1mlkem1024,
    553 };
    554 
    555 const std::vector<SSLNamedGroup> kFFDHEGroups = {
    556    ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
    557    ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
    558 
    559 // Defined because the big DHE groups are ridiculously slow.
    560 const std::vector<SSLNamedGroup> kFasterDHEGroups = {
    561    ssl_grp_ec_curve25519,
    562    ssl_grp_ec_secp256r1,
    563    ssl_grp_ec_secp384r1,
    564    ssl_grp_ffdhe_2048,
    565    ssl_grp_ffdhe_3072,
    566 #ifndef NSS_DISABLE_KYBER
    567    ssl_grp_kem_xyber768d00,
    568 #endif
    569    ssl_grp_kem_mlkem768x25519,
    570    ssl_grp_kem_secp256r1mlkem768,
    571    ssl_grp_kem_secp384r1mlkem1024,
    572 };
    573 
    574 const std::vector<SSLNamedGroup> kEcdhHybridGroups = {
    575 #ifndef NSS_DISABLE_KYBER
    576    ssl_grp_kem_xyber768d00,
    577 #endif
    578    ssl_grp_kem_mlkem768x25519,
    579    ssl_grp_kem_secp256r1mlkem768,
    580    ssl_grp_kem_secp384r1mlkem1024,
    581 };
    582 
    583 void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
    584  EXPECT_TRUE(EnsureTlsSetup());
    585 
    586  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    587    SSLCipherSuiteInfo csinfo;
    588 
    589    SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
    590                                          sizeof(csinfo));
    591    ASSERT_EQ(SECSuccess, rv);
    592    EXPECT_EQ(sizeof(csinfo), csinfo.length);
    593 
    594    if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
    595      rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
    596      EXPECT_EQ(SECSuccess, rv);
    597    }
    598  }
    599 }
    600 
    601 void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
    602  switch (kea) {
    603    case ssl_kea_dh:
    604      ConfigNamedGroups(kFFDHEGroups);
    605      break;
    606    case ssl_kea_ecdh:
    607      ConfigNamedGroups(kECDHEGroups);
    608      break;
    609    case ssl_kea_ecdh_hybrid:
    610      ConfigNamedGroups(kEcdhHybridGroups);
    611      break;
    612    default:
    613      break;
    614  }
    615 }
    616 
    617 void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
    618  if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
    619      authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
    620    ConfigNamedGroups(kECDHEGroups);
    621  }
    622 }
    623 
    624 void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
    625  EXPECT_TRUE(EnsureTlsSetup());
    626 
    627  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    628    SSLCipherSuiteInfo csinfo;
    629 
    630    SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
    631                                          sizeof(csinfo));
    632    ASSERT_EQ(SECSuccess, rv);
    633 
    634    if ((csinfo.authType == authType) ||
    635        (csinfo.keaType == ssl_kea_tls13_any)) {
    636      rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
    637      EXPECT_EQ(SECSuccess, rv);
    638    }
    639  }
    640 }
    641 
    642 void TlsAgent::EnableSingleCipher(uint16_t cipher) {
    643  DisableAllCiphers();
    644  SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
    645  EXPECT_EQ(SECSuccess, rv);
    646 }
    647 
    648 void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
    649  EXPECT_TRUE(EnsureTlsSetup());
    650  SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
    651  EXPECT_EQ(SECSuccess, rv);
    652 }
    653 
    654 void TlsAgent::Set0RttEnabled(bool en) {
    655  SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
    656 }
    657 
    658 void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
    659  vrange_.min = minver;
    660  vrange_.max = maxver;
    661 
    662  if (ssl_fd()) {
    663    SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
    664    EXPECT_EQ(SECSuccess, rv);
    665  }
    666 }
    667 
    668 SECStatus ResumptionTokenCallback(PRFileDesc* fd,
    669                                  const PRUint8* resumptionToken,
    670                                  unsigned int len, void* ctx) {
    671  EXPECT_NE(nullptr, resumptionToken);
    672  if (!resumptionToken) {
    673    return SECFailure;
    674  }
    675 
    676  std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
    677  reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
    678  reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
    679  return SECSuccess;
    680 }
    681 
    682 void TlsAgent::SetResumptionTokenCallback() {
    683  EXPECT_TRUE(EnsureTlsSetup());
    684  SECStatus rv =
    685      SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
    686  EXPECT_EQ(SECSuccess, rv);
    687 }
    688 
    689 void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
    690  *minver = vrange_.min;
    691  *maxver = vrange_.max;
    692 }
    693 
    694 void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }
    695 
    696 void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
    697 
    698 void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
    699 
    700 void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
    701 
    702 void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
    703                                   size_t count) {
    704  EXPECT_TRUE(EnsureTlsSetup());
    705  EXPECT_LE(count, SSL_SignatureMaxCount());
    706  EXPECT_EQ(SECSuccess,
    707            SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
    708                                       static_cast<unsigned int>(count)));
    709  EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
    710      << "setting no schemes should fail and do nothing";
    711 
    712  std::vector<SSLSignatureScheme> configuredSchemes(count);
    713  unsigned int configuredCount;
    714  EXPECT_EQ(SECFailure,
    715            SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
    716      << "get schemes, schemes is nullptr";
    717  EXPECT_EQ(SECFailure,
    718            SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
    719                                       &configuredCount, 0))
    720      << "get schemes, too little space";
    721  EXPECT_EQ(SECFailure,
    722            SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
    723                                       configuredSchemes.size()))
    724      << "get schemes, countOut is nullptr";
    725 
    726  EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
    727                            ssl_fd(), &configuredSchemes[0], &configuredCount,
    728                            configuredSchemes.size()));
    729  // SignatureSchemePrefSet drops unsupported algorithms silently, so the
    730  // number that are configured might be fewer.
    731  EXPECT_LE(configuredCount, count);
    732  unsigned int i = 0;
    733  for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
    734    if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
    735      ++i;
    736    }
    737  }
    738  EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
    739 }
    740 
    741 void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
    742                        size_t kea_size) const {
    743  EXPECT_EQ(STATE_CONNECTED, state_);
    744  EXPECT_EQ(kea, info_.keaType);
    745  if (kea_size == 0) {
    746    switch (kea_group) {
    747      case ssl_grp_ec_curve25519:
    748      case ssl_grp_kem_xyber768d00:
    749      case ssl_grp_kem_mlkem768x25519:
    750        kea_size = 255;
    751        break;
    752      case ssl_grp_kem_secp256r1mlkem768:
    753      case ssl_grp_ec_secp256r1:
    754        kea_size = 256;
    755        break;
    756      case ssl_grp_kem_secp384r1mlkem1024:
    757      case ssl_grp_ec_secp384r1:
    758        kea_size = 384;
    759        break;
    760      case ssl_grp_ffdhe_2048:
    761        kea_size = 2048;
    762        break;
    763      case ssl_grp_ffdhe_3072:
    764        kea_size = 3072;
    765        break;
    766      case ssl_grp_ffdhe_custom:
    767        break;
    768      default:
    769        if (kea == ssl_kea_rsa) {
    770          kea_size = server_key_bits_;
    771        } else {
    772          EXPECT_TRUE(false) << "need to update group sizes";
    773        }
    774    }
    775  }
    776  if (kea_group != ssl_grp_ffdhe_custom) {
    777    EXPECT_EQ(kea_size, info_.keaKeyBits);
    778    EXPECT_EQ(kea_group, info_.keaGroup);
    779  }
    780 }
    781 
    782 void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
    783  if (kea_group != ssl_grp_ffdhe_custom) {
    784    EXPECT_EQ(kea_group, info_.originalKeaGroup);
    785  }
    786 }
    787 
    788 void TlsAgent::CheckAuthType(SSLAuthType auth,
    789                             SSLSignatureScheme sig_scheme) const {
    790  EXPECT_EQ(STATE_CONNECTED, state_);
    791  EXPECT_EQ(auth, info_.authType);
    792  if (auth != ssl_auth_psk) {
    793    EXPECT_EQ(server_key_bits_, info_.authKeyBits);
    794  }
    795  if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
    796    switch (auth) {
    797      case ssl_auth_rsa_sign:
    798        sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
    799        break;
    800      case ssl_auth_ecdsa:
    801        sig_scheme = ssl_sig_ecdsa_sha1;
    802        break;
    803      default:
    804        break;
    805    }
    806  }
    807  EXPECT_EQ(sig_scheme, info_.signatureScheme);
    808 
    809  if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
    810    return;
    811  }
    812 
    813  // Check authAlgorithm, which is the old value for authType.  This is a second
    814  // switch statement because default label is different.
    815  switch (auth) {
    816    case ssl_auth_rsa_sign:
    817    case ssl_auth_rsa_pss:
    818      EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
    819          << "authAlgorithm for RSA is always decrypt";
    820      break;
    821    case ssl_auth_ecdh_rsa:
    822      EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
    823          << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
    824      break;
    825    case ssl_auth_ecdh_ecdsa:
    826      EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
    827          << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
    828      break;
    829    default:
    830      EXPECT_EQ(auth, csinfo_.authAlgorithm)
    831          << "authAlgorithm is (usually) the same as authType";
    832      break;
    833  }
    834 }
    835 
    836 void TlsAgent::EnableFalseStart() {
    837  EXPECT_TRUE(EnsureTlsSetup());
    838 
    839  falsestart_enabled_ = true;
    840  EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
    841                            ssl_fd(), CanFalseStartCallback, this));
    842  SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
    843 }
    844 
    845 void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; }
    846 
    847 void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; }
    848 
    849 void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }
    850 
    851 void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
    852  EXPECT_TRUE(EnsureTlsSetup());
    853  EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
    854 }
    855 
    856 void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
    857                      SSLHashType hash, uint16_t zeroRttSuite) {
    858  EXPECT_TRUE(EnsureTlsSetup());
    859  EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
    860                            ssl_fd(), psk.get(),
    861                            reinterpret_cast<const uint8_t*>(label.data()),
    862                            label.length(), hash, zeroRttSuite, 1000));
    863 }
    864 
    865 void TlsAgent::RemovePsk(std::string label) {
    866  EXPECT_EQ(SECSuccess,
    867            SSL_RemoveExternalPsk(
    868                ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
    869                label.length()));
    870 }
    871 
    872 void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
    873                         const std::string& expected) const {
    874  SSLNextProtoState alpn_state;
    875  char chosen[10];
    876  unsigned int chosen_len;
    877  SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
    878                                  reinterpret_cast<unsigned char*>(chosen),
    879                                  &chosen_len, sizeof(chosen));
    880  EXPECT_EQ(SECSuccess, rv);
    881  EXPECT_EQ(expected_state, alpn_state);
    882  if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
    883    EXPECT_EQ("", expected);
    884  } else {
    885    EXPECT_NE("", expected);
    886    EXPECT_EQ(expected, std::string(chosen, chosen_len));
    887  }
    888 }
    889 
    890 void TlsAgent::CheckEpochs(uint16_t expected_read,
    891                           uint16_t expected_write) const {
    892  uint16_t read_epoch = 0;
    893  uint16_t write_epoch = 0;
    894  EXPECT_EQ(SECSuccess,
    895            SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch));
    896  EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch";
    897  EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch";
    898 }
    899 
    900 void TlsAgent::EnableSrtp() {
    901  EXPECT_TRUE(EnsureTlsSetup());
    902  const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
    903                              SRTP_AES128_CM_HMAC_SHA1_32};
    904  EXPECT_EQ(SECSuccess,
    905            SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
    906 }
    907 
    908 void TlsAgent::CheckSrtp() const {
    909  uint16_t actual;
    910  EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
    911  EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
    912 }
    913 
    914 void TlsAgent::CheckErrorCode(int32_t expected) const {
    915  EXPECT_EQ(STATE_ERROR, state_);
    916  EXPECT_EQ(expected, error_code_)
    917      << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
    918      << PORT_ErrorToName(expected) << std::endl;
    919 }
    920 
    921 static uint8_t GetExpectedAlertLevel(uint8_t alert) {
    922  if (alert == kTlsAlertCloseNotify) {
    923    return kTlsAlertWarning;
    924  }
    925  return kTlsAlertFatal;
    926 }
    927 
    928 void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
    929  expected_received_alert_ = alert;
    930  if (level == 0) {
    931    expected_received_alert_level_ = GetExpectedAlertLevel(alert);
    932  } else {
    933    expected_received_alert_level_ = level;
    934  }
    935 }
    936 
    937 void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
    938  expected_sent_alert_ = alert;
    939  if (level == 0) {
    940    expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
    941  } else {
    942    expected_sent_alert_level_ = level;
    943  }
    944 }
    945 
    946 void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
    947  LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
    948      << " alert " << (sent ? "sent" : "received") << ": "
    949      << static_cast<int>(alert->description));
    950 
    951  auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
    952  auto& expected_level =
    953      sent ? expected_sent_alert_level_ : expected_received_alert_level_;
    954  /* Silently pass close_notify in case the test has already ended. */
    955  if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
    956      alert->description == expected && alert->level == expected_level) {
    957    return;
    958  }
    959 
    960  EXPECT_EQ(expected, alert->description);
    961  EXPECT_EQ(expected_level, alert->level);
    962  expected = kTlsAlertCloseNotify;
    963  expected_level = kTlsAlertWarning;
    964 }
    965 
    966 void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
    967  ASSERT_EQ(0, error_code_);
    968  WAIT_(error_code_ != 0, delay);
    969  EXPECT_EQ(expected, error_code_)
    970      << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
    971      << PORT_ErrorToName(expected) << std::endl;
    972 }
    973 
    974 void TlsAgent::CheckPreliminaryInfo() {
    975  SSLPreliminaryChannelInfo preinfo;
    976  EXPECT_EQ(SECSuccess,
    977            SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo)));
    978  EXPECT_EQ(sizeof(preinfo), preinfo.length);
    979  EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
    980 
    981  // A version of 0 is invalid and indicates no expectation.  This value is
    982  // initialized to 0 so that tests that don't explicitly set an expected
    983  // version can negotiate a version.
    984  if (!expected_version_) {
    985    expected_version_ = preinfo.protocolVersion;
    986  }
    987  EXPECT_EQ(expected_version_, preinfo.protocolVersion);
    988 
    989  // As with the version; 0 is the null cipher suite (and also invalid).
    990  if (!expected_cipher_suite_) {
    991    expected_cipher_suite_ = preinfo.cipherSuite;
    992  }
    993  EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite);
    994 }
    995 
    996 // Check that all the expected callbacks have been called.
    997 void TlsAgent::CheckCallbacks() const {
    998  // If false start happens, the handshake is reported as being complete at the
    999  // point that false start happens.
   1000  if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
   1001    EXPECT_TRUE(handshake_callback_called_);
   1002  }
   1003 
   1004  // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
   1005  if (role_ == SERVER) {
   1006    PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
   1007    EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
   1008               expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
   1009              sni_hook_called_);
   1010  } else {
   1011    EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
   1012    // Note that this isn't unconditionally called, even with false start on.
   1013    // But the callback is only skipped if a cipher that is ridiculously weak
   1014    // (80 bits) is chosen.  Don't test that: plan to remove bad ciphers.
   1015    EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
   1016              can_falsestart_hook_called_);
   1017  }
   1018 }
   1019 
   1020 void TlsAgent::ResetPreliminaryInfo() {
   1021  expected_version_ = 0;
   1022  expected_cipher_suite_ = 0;
   1023 }
   1024 
   1025 void TlsAgent::UpdatePreliminaryChannelInfo() {
   1026  SECStatus rv =
   1027      SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
   1028  EXPECT_EQ(SECSuccess, rv);
   1029  EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
   1030 }
   1031 
   1032 void TlsAgent::ValidateCipherSpecs() {
   1033  PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
   1034  // We use one ciphersuite in each direction.
   1035  PRInt32 expected = 2;
   1036  if (variant_ == ssl_variant_datagram) {
   1037    // For DTLS 1.3, the client retains the cipher spec for early data and the
   1038    // handshake so that it can retransmit EndOfEarlyData and its final flight.
   1039    // It also retains the handshake read cipher spec so that it can read ACKs
   1040    // from the server. The server retains the handshake read cipher spec so it
   1041    // can read the client's retransmitted Finished.
   1042    if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
   1043      if (role_ == CLIENT) {
   1044        expected = info_.earlyDataAccepted ? 5 : 4;
   1045      } else {
   1046        expected = 3;
   1047      }
   1048    } else {
   1049      // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
   1050      // until the holddown timer runs down.
   1051      if (expect_psk_ == ssl_psk_resume) {
   1052        if (role_ == CLIENT) {
   1053          expected = 3;
   1054        }
   1055      } else {
   1056        if (role_ == SERVER) {
   1057          expected = 3;
   1058        }
   1059      }
   1060    }
   1061  }
   1062  // This function will be run before the handshake completes if false start is
   1063  // enabled.  In that case, the client will still be reading cleartext, but
   1064  // will have a spec prepared for reading ciphertext.  With DTLS, the client
   1065  // will also have a spec retained for retransmission of handshake messages.
   1066  if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
   1067    EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
   1068    expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
   1069  }
   1070  EXPECT_EQ(expected, cipherSpecs);
   1071  if (expected != cipherSpecs) {
   1072    SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
   1073  }
   1074 }
   1075 
   1076 void TlsAgent::Connected() {
   1077  if (state_ == STATE_CONNECTED) {
   1078    return;
   1079  }
   1080 
   1081  LOG("Handshake success");
   1082  CheckPreliminaryInfo();
   1083  CheckCallbacks();
   1084 
   1085  SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
   1086  EXPECT_EQ(SECSuccess, rv);
   1087  EXPECT_EQ(sizeof(info_), info_.length);
   1088 
   1089  EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
   1090  EXPECT_EQ(expect_psk_, info_.pskType);
   1091  EXPECT_EQ(expect_ech_, info_.echAccepted);
   1092 
   1093  // Preliminary values are exposed through callbacks during the handshake.
   1094  // If either expected values were set or the callbacks were called, check
   1095  // that the final values are correct.
   1096  UpdatePreliminaryChannelInfo();
   1097  EXPECT_EQ(expected_version_, info_.protocolVersion);
   1098  EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
   1099 
   1100  rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
   1101  EXPECT_EQ(SECSuccess, rv);
   1102  EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
   1103 
   1104  ValidateCipherSpecs();
   1105 
   1106  SetState(STATE_CONNECTED);
   1107 }
   1108 
   1109 void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
   1110  EXPECT_FALSE(client_auth_callback_awaiting_);
   1111  switch (client_auth_callback_type_) {
   1112    case ClientAuthCallbackType::kNone:
   1113      if (!client_auth_callback_success_) {
   1114        EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
   1115        break;
   1116      }
   1117    case ClientAuthCallbackType::kSync:
   1118      EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
   1119      break;
   1120    case ClientAuthCallbackType::kAsyncDelay:
   1121    case ClientAuthCallbackType::kAsyncImmediate:
   1122      EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
   1123      break;
   1124  }
   1125 }
   1126 
   1127 void TlsAgent::EnableExtendedMasterSecret() {
   1128  SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
   1129 }
   1130 
   1131 void TlsAgent::CheckExtendedMasterSecret(bool expected) {
   1132  if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
   1133    expected = PR_TRUE;
   1134  }
   1135  ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
   1136      << "unexpected extended master secret state for " << name_;
   1137 }
   1138 
   1139 void TlsAgent::CheckEarlyDataAccepted(bool expected) {
   1140  if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
   1141    expected = false;
   1142  }
   1143  ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
   1144      << "unexpected early data state for " << name_;
   1145 }
   1146 
   1147 void TlsAgent::CheckSecretsDestroyed() {
   1148  ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
   1149 }
   1150 
   1151 void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
   1152  ASSERT_TRUE(EnsureTlsSetup());
   1153 
   1154  SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
   1155  ASSERT_EQ(SECSuccess, rv);
   1156 }
   1157 
   1158 void TlsAgent::Handshake() {
   1159  LOGV("Handshake");
   1160  SECStatus rv = SSL_ForceHandshake(ssl_fd());
   1161  if (client_auth_callback_awaiting_) {
   1162    ClientAuthCallbackComplete();
   1163    rv = SSL_ForceHandshake(ssl_fd());
   1164  }
   1165  if (rv == SECSuccess) {
   1166    Connected();
   1167    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
   1168                             &TlsAgent::ReadableCallback);
   1169    return;
   1170  }
   1171 
   1172  int32_t err = PR_GetError();
   1173  if (err == PR_WOULD_BLOCK_ERROR) {
   1174    LOGV("Would have blocked");
   1175    if (variant_ == ssl_variant_datagram) {
   1176      if (timer_handle_) {
   1177        timer_handle_->Cancel();
   1178        timer_handle_ = nullptr;
   1179      }
   1180 
   1181      PRIntervalTime timeout;
   1182      rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
   1183      if (rv == SECSuccess) {
   1184        Poller::Instance()->SetTimer(
   1185            timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
   1186      }
   1187    }
   1188    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
   1189                             &TlsAgent::ReadableCallback);
   1190    return;
   1191  }
   1192 
   1193  if (err != 0) {
   1194    LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
   1195                                       << PORT_ErrorToString(err));
   1196  }
   1197 
   1198  error_code_ = err;
   1199  SetState(STATE_ERROR);
   1200 }
   1201 
   1202 void TlsAgent::PrepareForRenegotiate() {
   1203  EXPECT_EQ(STATE_CONNECTED, state_);
   1204 
   1205  SetState(STATE_CONNECTING);
   1206 }
   1207 
   1208 void TlsAgent::StartRenegotiate() {
   1209  PrepareForRenegotiate();
   1210 
   1211  SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
   1212  EXPECT_EQ(SECSuccess, rv);
   1213 }
   1214 
   1215 void TlsAgent::SendDirect(const DataBuffer& buf) {
   1216  LOG("Send Direct " << buf);
   1217  auto peer = adapter_->peer().lock();
   1218  if (peer) {
   1219    peer->PacketReceived(buf);
   1220  } else {
   1221    LOG("Send Direct peer absent");
   1222  }
   1223 }
   1224 
   1225 void TlsAgent::SendRecordDirect(const TlsRecord& record) {
   1226  DataBuffer buf;
   1227 
   1228  auto rv = record.header.Write(&buf, 0, record.buffer);
   1229  EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
   1230  SendDirect(buf);
   1231 }
   1232 
   1233 static bool ErrorIsFatal(PRErrorCode code) {
   1234  return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
   1235 }
   1236 
   1237 void TlsAgent::SendData(size_t bytes, size_t blocksize) {
   1238  uint8_t block[16385];  // One larger than the maximum record size.
   1239 
   1240  ASSERT_LE(blocksize, sizeof(block));
   1241 
   1242  while (bytes) {
   1243    size_t tosend = std::min(blocksize, bytes);
   1244 
   1245    for (size_t i = 0; i < tosend; ++i) {
   1246      block[i] = 0xff & send_ctr_;
   1247      ++send_ctr_;
   1248    }
   1249 
   1250    SendBuffer(DataBuffer(block, tosend));
   1251    bytes -= tosend;
   1252  }
   1253 }
   1254 
   1255 void TlsAgent::SendBuffer(const DataBuffer& buf) {
   1256  LOGV("Writing " << buf.len() << " bytes");
   1257  int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
   1258  if (expect_readwrite_error_) {
   1259    EXPECT_GT(0, rv);
   1260    EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
   1261    error_code_ = PR_GetError();
   1262    expect_readwrite_error_ = false;
   1263  } else {
   1264    ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
   1265  }
   1266 }
   1267 
   1268 bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
   1269                                   uint64_t seq, uint8_t ct,
   1270                                   const DataBuffer& buf) {
   1271  // Ensure that we are doing TLS 1.3.
   1272  EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
   1273  if (variant_ != ssl_variant_datagram) {
   1274    ADD_FAILURE();
   1275    return false;
   1276  }
   1277 
   1278  LOGV("Encrypting " << buf.len() << " bytes");
   1279  uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
   1280                      kCtDtlsCiphertextLengthPresent;
   1281  TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
   1282  TlsRecordHeader out_header(header);
   1283  DataBuffer padded = buf;
   1284  padded.Write(padded.len(), ct, 1);
   1285  DataBuffer ciphertext;
   1286  if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
   1287    return false;
   1288  }
   1289 
   1290  DataBuffer record;
   1291  auto rv = out_header.Write(&record, 0, ciphertext);
   1292  EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
   1293  SendDirect(record);
   1294  return true;
   1295 }
   1296 
   1297 void TlsAgent::ReadBytes(size_t amount) {
   1298  uint8_t block[16384];
   1299 
   1300  size_t remaining = amount;
   1301  while (remaining > 0) {
   1302    int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
   1303    LOGV("ReadBytes " << rv);
   1304 
   1305    if (rv > 0) {
   1306      size_t count = static_cast<size_t>(rv);
   1307      for (size_t i = 0; i < count; ++i) {
   1308        ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
   1309        recv_ctr_++;
   1310      }
   1311      remaining -= rv;
   1312    } else {
   1313      PRErrorCode err = 0;
   1314      if (rv < 0) {
   1315        err = PR_GetError();
   1316        if (err != 0) {
   1317          LOG("Read error " << PORT_ErrorToName(err) << ": "
   1318                            << PORT_ErrorToString(err));
   1319        }
   1320        if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
   1321          if (ErrorIsFatal(err)) {
   1322            SetState(STATE_ERROR);
   1323          }
   1324          error_code_ = err;
   1325          expect_readwrite_error_ = false;
   1326        }
   1327      }
   1328      if (err != 0 && ErrorIsFatal(err)) {
   1329        // If we hit a fatal error, we're done.
   1330        remaining = 0;
   1331      }
   1332      break;
   1333    }
   1334  }
   1335 
   1336  // If closed, then don't bother waiting around.
   1337  if (remaining) {
   1338    LOGV("Re-arming");
   1339    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
   1340                             &TlsAgent::ReadableCallback);
   1341  }
   1342 }
   1343 
   1344 void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; }
   1345 
   1346 void TlsAgent::SetOption(int32_t option, int value) {
   1347  ASSERT_TRUE(EnsureTlsSetup());
   1348  EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
   1349 }
   1350 
   1351 void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
   1352  SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
   1353  SetOption(SSL_ENABLE_SESSION_TICKETS,
   1354            mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
   1355 }
   1356 
   1357 void TlsAgent::EnableECDHEServerKeyReuse() {
   1358  ASSERT_EQ(TlsAgent::SERVER, role_);
   1359  SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
   1360 }
   1361 
   1362 static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
   1363 ::testing::internal::ParamGenerator<std::string>
   1364    TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
   1365 
   1366 void TlsAgentTestBase::SetUp() {
   1367  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
   1368 }
   1369 
   1370 void TlsAgentTestBase::TearDown() {
   1371  agent_ = nullptr;
   1372  SSL_ClearSessionCache();
   1373  SSL_ShutdownServerSessionIDCache();
   1374 }
   1375 
   1376 void TlsAgentTestBase::Reset(const std::string& server_name) {
   1377  agent_.reset(
   1378      new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
   1379                   role_, variant_));
   1380  if (version_) {
   1381    agent_->SetVersionRange(version_, version_);
   1382  }
   1383  agent_->adapter()->SetPeer(sink_adapter_);
   1384  agent_->StartConnect();
   1385 }
   1386 
   1387 void TlsAgentTestBase::EnsureInit() {
   1388  if (!agent_) {
   1389    Reset();
   1390  }
   1391  const std::vector<SSLNamedGroup> groups = {
   1392      ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
   1393      ssl_grp_ffdhe_2048};
   1394  agent_->ConfigNamedGroups(groups);
   1395 }
   1396 
   1397 void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
   1398  EnsureInit();
   1399  agent_->ExpectSendAlert(alert);
   1400 }
   1401 
   1402 void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
   1403                                      TlsAgent::State expected_state,
   1404                                      int32_t error_code) {
   1405  std::cerr << "Process message: " << buffer << std::endl;
   1406  EnsureInit();
   1407  agent_->adapter()->PacketReceived(buffer);
   1408  agent_->Handshake();
   1409 
   1410  ASSERT_EQ(expected_state, agent_->state());
   1411 
   1412  if (expected_state == TlsAgent::STATE_ERROR) {
   1413    ASSERT_EQ(error_code, agent_->error_code());
   1414  }
   1415 }
   1416 
   1417 void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
   1418                                  uint16_t version, const uint8_t* buf,
   1419                                  size_t len, DataBuffer* out,
   1420                                  uint64_t sequence_number) {
   1421  // Fixup the content type for DTLSCiphertext
   1422  if (variant == ssl_variant_datagram &&
   1423      version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
   1424      type == ssl_ct_application_data) {
   1425    type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
   1426           kCtDtlsCiphertextLengthPresent;
   1427  }
   1428 
   1429  size_t index = 0;
   1430  if (variant == ssl_variant_stream) {
   1431    index = out->Write(index, type, 1);
   1432    index = out->Write(index, version, 2);
   1433  } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
   1434             (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
   1435    uint32_t epoch = (sequence_number >> 48) & 0x3;
   1436    index = out->Write(index, type | epoch, 1);
   1437    uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
   1438    index = out->Write(index, seqno, 2);
   1439  } else {
   1440    index = out->Write(index, type, 1);
   1441    index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
   1442    index = out->Write(index, sequence_number >> 32, 4);
   1443    index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
   1444  }
   1445  index = out->Write(index, len, 2);
   1446  out->Write(index, buf, len);
   1447 }
   1448 
   1449 void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
   1450                                  const uint8_t* buf, size_t len,
   1451                                  DataBuffer* out, uint64_t seq_num) const {
   1452  MakeRecord(variant_, type, version, buf, len, out, seq_num);
   1453 }
   1454 
   1455 void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
   1456                                            const uint8_t* data, size_t hs_len,
   1457                                            DataBuffer* out,
   1458                                            uint64_t seq_num) const {
   1459  return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
   1460                                      0);
   1461 }
   1462 
   1463 void TlsAgentTestBase::MakeHandshakeMessageFragment(
   1464    uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
   1465    uint64_t seq_num, uint32_t fragment_offset,
   1466    uint32_t fragment_length) const {
   1467  size_t index = 0;
   1468  if (!fragment_length) fragment_length = hs_len;
   1469  index = out->Write(index, hs_type, 1);  // Handshake record type.
   1470  index = out->Write(index, hs_len, 3);   // Handshake length
   1471  if (variant_ == ssl_variant_datagram) {
   1472    index = out->Write(index, seq_num, 2);
   1473    index = out->Write(index, fragment_offset, 3);
   1474    index = out->Write(index, fragment_length, 3);
   1475  }
   1476  if (data) {
   1477    index = out->Write(index, data, fragment_length);
   1478  } else {
   1479    for (size_t i = 0; i < fragment_length; ++i) {
   1480      index = out->Write(index, 1, 1);
   1481    }
   1482  }
   1483 }
   1484 
   1485 void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
   1486                                                  size_t hs_len,
   1487                                                  DataBuffer* out) {
   1488  size_t index = 0;
   1489  index = out->Write(index, ssl_ct_handshake, 1);  // Content Type
   1490  index = out->Write(index, 3, 1);                 // Version high
   1491  index = out->Write(index, 1, 1);                 // Version low
   1492  index = out->Write(index, 4 + hs_len, 2);        // Length
   1493 
   1494  index = out->Write(index, hs_type, 1);  // Handshake record type.
   1495  index = out->Write(index, hs_len, 3);   // Handshake length
   1496  for (size_t i = 0; i < hs_len; ++i) {
   1497    index = out->Write(index, 1, 1);
   1498  }
   1499 }
   1500 
   1501 DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
   1502  DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
   1503  if (variant_ == ssl_variant_datagram) {
   1504    sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
   1505    // The version should be at the end.
   1506    uint32_t v;
   1507    EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
   1508    EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
   1509    sh.Write(sh.len() - 2, SSL_LIBRARY_VERSION_DTLS_1_3_WIRE, 2);
   1510  }
   1511  return sh;
   1512 }
   1513 
   1514 }  // namespace nss_test