tls_agent.cc (50779B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "tls_agent.h" 8 #include "databuffer.h" 9 #include "keyhi.h" 10 #include "pk11func.h" 11 #include "ssl.h" 12 #include "sslerr.h" 13 #include "sslexp.h" 14 #include "sslproto.h" 15 #include "tls_filter.h" 16 #include "tls_parser.h" 17 18 extern "C" { 19 // This is not something that should make you happy. 20 #include "libssl_internals.h" 21 } 22 23 #define GTEST_HAS_RTTI 0 24 #include "gtest/gtest.h" 25 #include "gtest_utils.h" 26 #include "nss_scoped_ptrs.h" 27 28 extern std::string g_working_dir_path; 29 30 namespace nss_test { 31 32 const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; 33 34 const std::string TlsAgent::kClient = "client"; // both sign and encrypt 35 const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger 36 const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed 37 const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt 38 const std::string TlsAgent::kServerRsaSign = "rsa_sign"; 39 const std::string TlsAgent::kServerRsaPss = "rsa_pss"; 40 const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; 41 const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; 42 const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; 43 const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; 44 const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; 45 const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; 46 const std::string TlsAgent::kServerDsa = "dsa"; 47 const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256"; 48 const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048"; 49 const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048"; 50 51 static const uint8_t kCannedTls13ServerHello[] = { 52 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 53 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 54 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 55 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 56 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, 57 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, 58 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, 59 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}; 60 61 TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) 62 : name_(nm), 63 variant_(var), 64 role_(rl), 65 server_key_bits_(0), 66 adapter_(new DummyPrSocket(role_str(), var)), 67 ssl_fd_(nullptr), 68 state_(STATE_INIT), 69 timer_handle_(nullptr), 70 falsestart_enabled_(false), 71 expected_version_(0), 72 expected_cipher_suite_(0), 73 expect_client_auth_(false), 74 expect_ech_(false), 75 expect_psk_(ssl_psk_none), 76 can_falsestart_hook_called_(false), 77 sni_hook_called_(false), 78 auth_certificate_hook_called_(false), 79 expected_received_alert_(kTlsAlertCloseNotify), 80 expected_received_alert_level_(kTlsAlertWarning), 81 expected_sent_alert_(kTlsAlertCloseNotify), 82 expected_sent_alert_level_(kTlsAlertWarning), 83 handshake_callback_called_(false), 84 resumption_callback_called_(false), 85 error_code_(0), 86 send_ctr_(0), 87 recv_ctr_(0), 88 expect_readwrite_error_(false), 89 handshake_callback_(), 90 auth_certificate_callback_(), 91 sni_callback_(), 92 skip_version_checks_(false), 93 resumption_token_(), 94 policy_() { 95 memset(&info_, 0, sizeof(info_)); 96 memset(&csinfo_, 0, sizeof(csinfo_)); 97 SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); 98 EXPECT_EQ(SECSuccess, rv); 99 } 100 101 TlsAgent::~TlsAgent() { 102 if (timer_handle_) { 103 timer_handle_->Cancel(); 104 } 105 106 if (adapter_) { 107 Poller::Instance()->Cancel(READABLE_EVENT, adapter_); 108 } 109 110 // Add failures manually, if any, so we don't throw in a destructor. 111 if (expected_received_alert_ != kTlsAlertCloseNotify || 112 expected_received_alert_level_ != kTlsAlertWarning) { 113 ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); 114 } 115 if (expected_sent_alert_ != kTlsAlertCloseNotify || 116 expected_sent_alert_level_ != kTlsAlertWarning) { 117 ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); 118 } 119 } 120 121 void TlsAgent::SetState(State s) { 122 if (state_ == s) return; 123 124 LOG("Changing state from " << state_ << " to " << s); 125 state_ = s; 126 } 127 128 /*static*/ bool TlsAgent::LoadCertificate(const std::string& name, 129 ScopedCERTCertificate* cert, 130 ScopedSECKEYPrivateKey* priv) { 131 cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); 132 EXPECT_NE(nullptr, cert); 133 if (!cert) return false; 134 EXPECT_NE(nullptr, cert->get()); 135 if (!cert->get()) return false; 136 137 priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); 138 EXPECT_NE(nullptr, priv); 139 if (!priv) return false; 140 EXPECT_NE(nullptr, priv->get()); 141 if (!priv->get()) return false; 142 143 return true; 144 } 145 146 // Loads a key pair from the certificate identified by |id|. 147 /*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name, 148 ScopedSECKEYPublicKey* pub, 149 ScopedSECKEYPrivateKey* priv) { 150 ScopedCERTCertificate cert; 151 if (!TlsAgent::LoadCertificate(name, &cert, priv)) { 152 return false; 153 } 154 155 pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo)); 156 if (!pub->get()) { 157 return false; 158 } 159 160 return true; 161 } 162 163 void TlsAgent::DelegateCredential(const std::string& name, 164 const ScopedSECKEYPublicKey& dc_pub, 165 SSLSignatureScheme dc_cert_verify_alg, 166 PRUint32 dc_valid_for, PRTime now, 167 SECItem* dc) { 168 ScopedCERTCertificate cert; 169 ScopedSECKEYPrivateKey cert_priv; 170 EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv)) 171 << "Could not load delegate certificate: " << name 172 << "; test db corrupt?"; 173 174 EXPECT_EQ(SECSuccess, 175 SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(), 176 dc_cert_verify_alg, dc_valid_for, now, dc)); 177 } 178 179 void TlsAgent::EnableDelegatedCredentials() { 180 ASSERT_TRUE(EnsureTlsSetup()); 181 SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE); 182 } 183 184 void TlsAgent::AddDelegatedCredential(const std::string& dc_name, 185 SSLSignatureScheme dc_cert_verify_alg, 186 PRUint32 dc_valid_for, PRTime now) { 187 ASSERT_TRUE(EnsureTlsSetup()); 188 189 ScopedSECKEYPublicKey pub; 190 ScopedSECKEYPrivateKey priv; 191 EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv)); 192 193 StackSECItem dc; 194 TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for, 195 now, &dc); 196 197 SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, 198 nullptr, &dc, priv.get()}; 199 EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data)); 200 } 201 202 bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, 203 const SSLExtraServerCertData* serverCertData) { 204 ScopedCERTCertificate cert; 205 ScopedSECKEYPrivateKey priv; 206 if (!TlsAgent::LoadCertificate(id, &cert, &priv)) { 207 return false; 208 } 209 210 if (updateKeyBits) { 211 ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); 212 EXPECT_NE(nullptr, pub.get()); 213 if (!pub.get()) return false; 214 server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); 215 } 216 217 SECStatus rv = 218 SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); 219 EXPECT_EQ(SECFailure, rv); 220 rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, 221 serverCertData ? sizeof(*serverCertData) : 0); 222 return rv == SECSuccess; 223 } 224 225 bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { 226 // Don't set up twice 227 if (ssl_fd_) return true; 228 NssManagePolicy policyManage(policy_, option_); 229 230 ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); 231 EXPECT_NE(nullptr, dummy_fd); 232 if (!dummy_fd) { 233 return false; 234 } 235 if (adapter_->variant() == ssl_variant_stream) { 236 ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); 237 } else { 238 ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); 239 } 240 241 EXPECT_NE(nullptr, ssl_fd_); 242 if (!ssl_fd_) { 243 return false; 244 } 245 dummy_fd.release(); // Now subsumed by ssl_fd_. 246 247 SECStatus rv; 248 if (!skip_version_checks_) { 249 rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); 250 EXPECT_EQ(SECSuccess, rv); 251 if (rv != SECSuccess) return false; 252 } 253 254 ScopedCERTCertList anchors(CERT_NewCertList()); 255 rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); 256 if (rv != SECSuccess) return false; 257 258 if (role_ == SERVER) { 259 EXPECT_TRUE(ConfigServerCert(name_, true)); 260 261 rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); 262 EXPECT_EQ(SECSuccess, rv); 263 if (rv != SECSuccess) return false; 264 265 rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); 266 EXPECT_EQ(SECSuccess, rv); 267 if (rv != SECSuccess) return false; 268 } else { 269 rv = SSL_SetURL(ssl_fd(), "server"); 270 EXPECT_EQ(SECSuccess, rv); 271 if (rv != SECSuccess) return false; 272 } 273 274 rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); 275 EXPECT_EQ(SECSuccess, rv); 276 if (rv != SECSuccess) return false; 277 278 rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); 279 EXPECT_EQ(SECSuccess, rv); 280 if (rv != SECSuccess) return false; 281 282 rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); 283 EXPECT_EQ(SECSuccess, rv); 284 if (rv != SECSuccess) return false; 285 286 rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); 287 EXPECT_EQ(SECSuccess, rv); 288 if (rv != SECSuccess) return false; 289 290 // All these tests depend on having this disabled to start with. 291 SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE); 292 293 return true; 294 } 295 296 bool TlsAgent::MaybeSetResumptionToken() { 297 if (!resumption_token_.empty()) { 298 LOG("setting external resumption token"); 299 SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), 300 resumption_token_.size()); 301 302 // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR 303 // if the resumption token was bad (expired/malformed/etc.). 304 if (expect_psk_ == ssl_psk_resume) { 305 // Only in case we expect resumption this has to be successful. We might 306 // not expect resumption due to some reason but the token is totally fine. 307 EXPECT_EQ(SECSuccess, rv); 308 } 309 if (rv != SECSuccess) { 310 EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); 311 resumption_token_.clear(); 312 EXPECT_FALSE(expect_psk_ == ssl_psk_resume); 313 if (expect_psk_ == ssl_psk_resume) return false; 314 } 315 } 316 317 return true; 318 } 319 320 void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { 321 EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get())); 322 } 323 324 // Defaults to a Sync callback returning success 325 void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType, 326 bool callbackSuccess) { 327 EXPECT_TRUE(EnsureTlsSetup()); 328 ASSERT_EQ(CLIENT, role_); 329 330 client_auth_callback_type_ = callbackType; 331 client_auth_callback_success_ = callbackSuccess; 332 333 if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) { 334 // Don't set a callback for this case. 335 return; 336 } 337 EXPECT_EQ(SECSuccess, 338 SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, 339 reinterpret_cast<void*>(this))); 340 } 341 342 void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { 343 ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr)); 344 345 ASSERT_EQ(expected->nnames, caNames->nnames); 346 347 for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) { 348 EXPECT_EQ(SECEqual, 349 SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i]))); 350 } 351 } 352 353 // Complete processing of Client Certificate Selection 354 // A No-op if the agent is using synchronous client cert selection. 355 // Otherwise, calls SSL_ClientCertCallbackComplete. 356 // kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to 357 // ensure that the socket is correctly blocked. 358 void TlsAgent::ClientAuthCallbackComplete() { 359 ASSERT_EQ(CLIENT, role_); 360 361 if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay && 362 client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) { 363 return; 364 } 365 client_auth_callback_fired_++; 366 EXPECT_TRUE(client_auth_callback_awaiting_); 367 368 std::cerr << "client: calling SSL_ClientCertCallbackComplete with status " 369 << (client_auth_callback_success_ ? "success" : "failed") 370 << std::endl; 371 372 client_auth_callback_awaiting_ = false; 373 374 if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) { 375 std::cerr 376 << "Running Handshake prior to running SSL_ClientCertCallbackComplete" 377 << std::endl; 378 SECStatus rv = SSL_ForceHandshake(ssl_fd()); 379 EXPECT_EQ(rv, SECFailure); 380 EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR); 381 } 382 383 ScopedCERTCertificate cert; 384 ScopedSECKEYPrivateKey priv; 385 if (client_auth_callback_success_) { 386 ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv)); 387 EXPECT_EQ(SECSuccess, 388 SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess, 389 priv.release(), cert.release())); 390 } else { 391 EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure, 392 nullptr, nullptr)); 393 } 394 } 395 396 SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, 397 CERTDistNames* caNames, 398 CERTCertificate** clientCert, 399 SECKEYPrivateKey** clientKey) { 400 TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); 401 EXPECT_EQ(CLIENT, agent->role_); 402 agent->client_auth_callback_fired_++; 403 404 switch (agent->client_auth_callback_type_) { 405 case ClientAuthCallbackType::kAsyncDelay: 406 case ClientAuthCallbackType::kAsyncImmediate: 407 std::cerr << "Waiting for complete call" << std::endl; 408 agent->client_auth_callback_awaiting_ = true; 409 return SECWouldBlock; 410 case ClientAuthCallbackType::kSync: 411 case ClientAuthCallbackType::kNone: 412 // Handle the sync case. None && Success is treated as Sync and Success. 413 if (!agent->client_auth_callback_success_) { 414 return SECFailure; 415 } 416 ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); 417 EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; 418 419 // See bug 1573945 420 // CheckCertReqAgainstDefaultCAs(caNames); 421 422 ScopedCERTCertificate cert; 423 ScopedSECKEYPrivateKey priv; 424 if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { 425 return SECFailure; 426 } 427 428 *clientCert = cert.release(); 429 *clientKey = priv.release(); 430 return SECSuccess; 431 } 432 /* This is unreachable, but some old compilers can't tell that. */ 433 PORT_Assert(0); 434 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); 435 return SECFailure; 436 } 437 438 // Increments by 1 for each callback 439 bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) { 440 EXPECT_EQ(CLIENT, role_); 441 return expected == client_auth_callback_fired_; 442 } 443 444 bool TlsAgent::GetPeerChainLength(size_t* count) { 445 SECItemArray* chain = nullptr; 446 SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &chain); 447 if (rv != SECSuccess) return false; 448 449 *count = chain->len; 450 451 SECITEM_FreeArray(chain, true); 452 453 return true; 454 } 455 456 void TlsAgent::CheckPeerChainFunctionConsistency() { 457 SECItemArray* derChain = nullptr; 458 SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &derChain); 459 PRErrorCode err1 = PR_GetError(); 460 CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); 461 PRErrorCode err2 = PR_GetError(); 462 if (rv != SECSuccess) { 463 ASSERT_EQ(nullptr, chain); 464 ASSERT_EQ(nullptr, derChain); 465 ASSERT_EQ(err1, SSL_ERROR_NO_CERTIFICATE); 466 ASSERT_EQ(err2, SSL_ERROR_NO_CERTIFICATE); 467 return; 468 } 469 ASSERT_NE(nullptr, chain); 470 ASSERT_NE(nullptr, derChain); 471 472 unsigned int count = 0; 473 for (PRCList* cursor = PR_NEXT_LINK(&chain->list); 474 count < derChain->len && cursor != &chain->list; 475 cursor = PR_NEXT_LINK(cursor)) { 476 CERTCertListNode* node = (CERTCertListNode*)cursor; 477 EXPECT_TRUE( 478 SECITEM_ItemsAreEqual(&node->cert->derCert, &derChain->items[count])); 479 ++count; 480 } 481 ASSERT_EQ(count, derChain->len); 482 483 SECITEM_FreeArray(derChain, true); 484 CERT_DestroyCertList(chain); 485 } 486 487 void TlsAgent::CheckCipherSuite(uint16_t suite) { 488 EXPECT_EQ(csinfo_.cipherSuite, suite); 489 } 490 491 void TlsAgent::RequestClientAuth(bool requireAuth) { 492 ASSERT_EQ(SERVER, role_); 493 494 SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); 495 SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); 496 497 EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( 498 ssl_fd(), &TlsAgent::ClientAuthenticated, this)); 499 expect_client_auth_ = true; 500 } 501 502 void TlsAgent::StartConnect(PRFileDesc* model) { 503 EXPECT_TRUE(EnsureTlsSetup(model)); 504 505 SECStatus rv; 506 rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); 507 EXPECT_EQ(SECSuccess, rv); 508 SetState(STATE_CONNECTING); 509 } 510 511 void TlsAgent::DisableAllCiphers() { 512 for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { 513 SECStatus rv = 514 SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); 515 EXPECT_EQ(SECSuccess, rv); 516 } 517 } 518 519 // Not actually all groups, just the ones that we are actually willing 520 // to use. 521 const std::vector<SSLNamedGroup> kAllDHEGroups = { 522 ssl_grp_ec_curve25519, 523 ssl_grp_ec_secp256r1, 524 ssl_grp_ec_secp384r1, 525 ssl_grp_ec_secp521r1, 526 ssl_grp_ffdhe_2048, 527 ssl_grp_ffdhe_3072, 528 ssl_grp_ffdhe_4096, 529 ssl_grp_ffdhe_6144, 530 ssl_grp_ffdhe_8192, 531 #ifndef NSS_DISABLE_KYBER 532 ssl_grp_kem_xyber768d00, 533 #endif 534 ssl_grp_kem_mlkem768x25519, 535 ssl_grp_kem_secp256r1mlkem768, 536 ssl_grp_kem_secp384r1mlkem1024, 537 }; 538 539 const std::vector<SSLNamedGroup> kNonPQDHEGroups = { 540 ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, 541 ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, 542 ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192, 543 }; 544 545 const std::vector<SSLNamedGroup> kECDHEGroups = { 546 ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, 547 ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1, 548 #ifndef NSS_DISABLE_KYBER 549 ssl_grp_kem_xyber768d00, 550 #endif 551 ssl_grp_kem_mlkem768x25519, ssl_grp_kem_secp256r1mlkem768, 552 ssl_grp_kem_secp384r1mlkem1024, 553 }; 554 555 const std::vector<SSLNamedGroup> kFFDHEGroups = { 556 ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, 557 ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; 558 559 // Defined because the big DHE groups are ridiculously slow. 560 const std::vector<SSLNamedGroup> kFasterDHEGroups = { 561 ssl_grp_ec_curve25519, 562 ssl_grp_ec_secp256r1, 563 ssl_grp_ec_secp384r1, 564 ssl_grp_ffdhe_2048, 565 ssl_grp_ffdhe_3072, 566 #ifndef NSS_DISABLE_KYBER 567 ssl_grp_kem_xyber768d00, 568 #endif 569 ssl_grp_kem_mlkem768x25519, 570 ssl_grp_kem_secp256r1mlkem768, 571 ssl_grp_kem_secp384r1mlkem1024, 572 }; 573 574 const std::vector<SSLNamedGroup> kEcdhHybridGroups = { 575 #ifndef NSS_DISABLE_KYBER 576 ssl_grp_kem_xyber768d00, 577 #endif 578 ssl_grp_kem_mlkem768x25519, 579 ssl_grp_kem_secp256r1mlkem768, 580 ssl_grp_kem_secp384r1mlkem1024, 581 }; 582 583 void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { 584 EXPECT_TRUE(EnsureTlsSetup()); 585 586 for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { 587 SSLCipherSuiteInfo csinfo; 588 589 SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, 590 sizeof(csinfo)); 591 ASSERT_EQ(SECSuccess, rv); 592 EXPECT_EQ(sizeof(csinfo), csinfo.length); 593 594 if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { 595 rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); 596 EXPECT_EQ(SECSuccess, rv); 597 } 598 } 599 } 600 601 void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { 602 switch (kea) { 603 case ssl_kea_dh: 604 ConfigNamedGroups(kFFDHEGroups); 605 break; 606 case ssl_kea_ecdh: 607 ConfigNamedGroups(kECDHEGroups); 608 break; 609 case ssl_kea_ecdh_hybrid: 610 ConfigNamedGroups(kEcdhHybridGroups); 611 break; 612 default: 613 break; 614 } 615 } 616 617 void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { 618 if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || 619 authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { 620 ConfigNamedGroups(kECDHEGroups); 621 } 622 } 623 624 void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { 625 EXPECT_TRUE(EnsureTlsSetup()); 626 627 for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { 628 SSLCipherSuiteInfo csinfo; 629 630 SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, 631 sizeof(csinfo)); 632 ASSERT_EQ(SECSuccess, rv); 633 634 if ((csinfo.authType == authType) || 635 (csinfo.keaType == ssl_kea_tls13_any)) { 636 rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); 637 EXPECT_EQ(SECSuccess, rv); 638 } 639 } 640 } 641 642 void TlsAgent::EnableSingleCipher(uint16_t cipher) { 643 DisableAllCiphers(); 644 SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); 645 EXPECT_EQ(SECSuccess, rv); 646 } 647 648 void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { 649 EXPECT_TRUE(EnsureTlsSetup()); 650 SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); 651 EXPECT_EQ(SECSuccess, rv); 652 } 653 654 void TlsAgent::Set0RttEnabled(bool en) { 655 SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); 656 } 657 658 void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { 659 vrange_.min = minver; 660 vrange_.max = maxver; 661 662 if (ssl_fd()) { 663 SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); 664 EXPECT_EQ(SECSuccess, rv); 665 } 666 } 667 668 SECStatus ResumptionTokenCallback(PRFileDesc* fd, 669 const PRUint8* resumptionToken, 670 unsigned int len, void* ctx) { 671 EXPECT_NE(nullptr, resumptionToken); 672 if (!resumptionToken) { 673 return SECFailure; 674 } 675 676 std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len); 677 reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token); 678 reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled(); 679 return SECSuccess; 680 } 681 682 void TlsAgent::SetResumptionTokenCallback() { 683 EXPECT_TRUE(EnsureTlsSetup()); 684 SECStatus rv = 685 SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); 686 EXPECT_EQ(SECSuccess, rv); 687 } 688 689 void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { 690 *minver = vrange_.min; 691 *maxver = vrange_.max; 692 } 693 694 void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; } 695 696 void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } 697 698 void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } 699 700 void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } 701 702 void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, 703 size_t count) { 704 EXPECT_TRUE(EnsureTlsSetup()); 705 EXPECT_LE(count, SSL_SignatureMaxCount()); 706 EXPECT_EQ(SECSuccess, 707 SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 708 static_cast<unsigned int>(count))); 709 EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) 710 << "setting no schemes should fail and do nothing"; 711 712 std::vector<SSLSignatureScheme> configuredSchemes(count); 713 unsigned int configuredCount; 714 EXPECT_EQ(SECFailure, 715 SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) 716 << "get schemes, schemes is nullptr"; 717 EXPECT_EQ(SECFailure, 718 SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], 719 &configuredCount, 0)) 720 << "get schemes, too little space"; 721 EXPECT_EQ(SECFailure, 722 SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, 723 configuredSchemes.size())) 724 << "get schemes, countOut is nullptr"; 725 726 EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( 727 ssl_fd(), &configuredSchemes[0], &configuredCount, 728 configuredSchemes.size())); 729 // SignatureSchemePrefSet drops unsupported algorithms silently, so the 730 // number that are configured might be fewer. 731 EXPECT_LE(configuredCount, count); 732 unsigned int i = 0; 733 for (unsigned int j = 0; j < count && i < configuredCount; ++j) { 734 if (i < configuredCount && schemes[j] == configuredSchemes[i]) { 735 ++i; 736 } 737 } 738 EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; 739 } 740 741 void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group, 742 size_t kea_size) const { 743 EXPECT_EQ(STATE_CONNECTED, state_); 744 EXPECT_EQ(kea, info_.keaType); 745 if (kea_size == 0) { 746 switch (kea_group) { 747 case ssl_grp_ec_curve25519: 748 case ssl_grp_kem_xyber768d00: 749 case ssl_grp_kem_mlkem768x25519: 750 kea_size = 255; 751 break; 752 case ssl_grp_kem_secp256r1mlkem768: 753 case ssl_grp_ec_secp256r1: 754 kea_size = 256; 755 break; 756 case ssl_grp_kem_secp384r1mlkem1024: 757 case ssl_grp_ec_secp384r1: 758 kea_size = 384; 759 break; 760 case ssl_grp_ffdhe_2048: 761 kea_size = 2048; 762 break; 763 case ssl_grp_ffdhe_3072: 764 kea_size = 3072; 765 break; 766 case ssl_grp_ffdhe_custom: 767 break; 768 default: 769 if (kea == ssl_kea_rsa) { 770 kea_size = server_key_bits_; 771 } else { 772 EXPECT_TRUE(false) << "need to update group sizes"; 773 } 774 } 775 } 776 if (kea_group != ssl_grp_ffdhe_custom) { 777 EXPECT_EQ(kea_size, info_.keaKeyBits); 778 EXPECT_EQ(kea_group, info_.keaGroup); 779 } 780 } 781 782 void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { 783 if (kea_group != ssl_grp_ffdhe_custom) { 784 EXPECT_EQ(kea_group, info_.originalKeaGroup); 785 } 786 } 787 788 void TlsAgent::CheckAuthType(SSLAuthType auth, 789 SSLSignatureScheme sig_scheme) const { 790 EXPECT_EQ(STATE_CONNECTED, state_); 791 EXPECT_EQ(auth, info_.authType); 792 if (auth != ssl_auth_psk) { 793 EXPECT_EQ(server_key_bits_, info_.authKeyBits); 794 } 795 if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { 796 switch (auth) { 797 case ssl_auth_rsa_sign: 798 sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; 799 break; 800 case ssl_auth_ecdsa: 801 sig_scheme = ssl_sig_ecdsa_sha1; 802 break; 803 default: 804 break; 805 } 806 } 807 EXPECT_EQ(sig_scheme, info_.signatureScheme); 808 809 if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { 810 return; 811 } 812 813 // Check authAlgorithm, which is the old value for authType. This is a second 814 // switch statement because default label is different. 815 switch (auth) { 816 case ssl_auth_rsa_sign: 817 case ssl_auth_rsa_pss: 818 EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) 819 << "authAlgorithm for RSA is always decrypt"; 820 break; 821 case ssl_auth_ecdh_rsa: 822 EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) 823 << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; 824 break; 825 case ssl_auth_ecdh_ecdsa: 826 EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) 827 << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; 828 break; 829 default: 830 EXPECT_EQ(auth, csinfo_.authAlgorithm) 831 << "authAlgorithm is (usually) the same as authType"; 832 break; 833 } 834 } 835 836 void TlsAgent::EnableFalseStart() { 837 EXPECT_TRUE(EnsureTlsSetup()); 838 839 falsestart_enabled_ = true; 840 EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( 841 ssl_fd(), CanFalseStartCallback, this)); 842 SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); 843 } 844 845 void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; } 846 847 void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; } 848 849 void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; } 850 851 void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { 852 EXPECT_TRUE(EnsureTlsSetup()); 853 EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); 854 } 855 856 void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label, 857 SSLHashType hash, uint16_t zeroRttSuite) { 858 EXPECT_TRUE(EnsureTlsSetup()); 859 EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt( 860 ssl_fd(), psk.get(), 861 reinterpret_cast<const uint8_t*>(label.data()), 862 label.length(), hash, zeroRttSuite, 1000)); 863 } 864 865 void TlsAgent::RemovePsk(std::string label) { 866 EXPECT_EQ(SECSuccess, 867 SSL_RemoveExternalPsk( 868 ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()), 869 label.length())); 870 } 871 872 void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, 873 const std::string& expected) const { 874 SSLNextProtoState alpn_state; 875 char chosen[10]; 876 unsigned int chosen_len; 877 SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state, 878 reinterpret_cast<unsigned char*>(chosen), 879 &chosen_len, sizeof(chosen)); 880 EXPECT_EQ(SECSuccess, rv); 881 EXPECT_EQ(expected_state, alpn_state); 882 if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) { 883 EXPECT_EQ("", expected); 884 } else { 885 EXPECT_NE("", expected); 886 EXPECT_EQ(expected, std::string(chosen, chosen_len)); 887 } 888 } 889 890 void TlsAgent::CheckEpochs(uint16_t expected_read, 891 uint16_t expected_write) const { 892 uint16_t read_epoch = 0; 893 uint16_t write_epoch = 0; 894 EXPECT_EQ(SECSuccess, 895 SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch)); 896 EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; 897 EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; 898 } 899 900 void TlsAgent::EnableSrtp() { 901 EXPECT_TRUE(EnsureTlsSetup()); 902 const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, 903 SRTP_AES128_CM_HMAC_SHA1_32}; 904 EXPECT_EQ(SECSuccess, 905 SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); 906 } 907 908 void TlsAgent::CheckSrtp() const { 909 uint16_t actual; 910 EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); 911 EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); 912 } 913 914 void TlsAgent::CheckErrorCode(int32_t expected) const { 915 EXPECT_EQ(STATE_ERROR, state_); 916 EXPECT_EQ(expected, error_code_) 917 << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " 918 << PORT_ErrorToName(expected) << std::endl; 919 } 920 921 static uint8_t GetExpectedAlertLevel(uint8_t alert) { 922 if (alert == kTlsAlertCloseNotify) { 923 return kTlsAlertWarning; 924 } 925 return kTlsAlertFatal; 926 } 927 928 void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { 929 expected_received_alert_ = alert; 930 if (level == 0) { 931 expected_received_alert_level_ = GetExpectedAlertLevel(alert); 932 } else { 933 expected_received_alert_level_ = level; 934 } 935 } 936 937 void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { 938 expected_sent_alert_ = alert; 939 if (level == 0) { 940 expected_sent_alert_level_ = GetExpectedAlertLevel(alert); 941 } else { 942 expected_sent_alert_level_ = level; 943 } 944 } 945 946 void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { 947 LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") 948 << " alert " << (sent ? "sent" : "received") << ": " 949 << static_cast<int>(alert->description)); 950 951 auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; 952 auto& expected_level = 953 sent ? expected_sent_alert_level_ : expected_received_alert_level_; 954 /* Silently pass close_notify in case the test has already ended. */ 955 if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && 956 alert->description == expected && alert->level == expected_level) { 957 return; 958 } 959 960 EXPECT_EQ(expected, alert->description); 961 EXPECT_EQ(expected_level, alert->level); 962 expected = kTlsAlertCloseNotify; 963 expected_level = kTlsAlertWarning; 964 } 965 966 void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { 967 ASSERT_EQ(0, error_code_); 968 WAIT_(error_code_ != 0, delay); 969 EXPECT_EQ(expected, error_code_) 970 << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " 971 << PORT_ErrorToName(expected) << std::endl; 972 } 973 974 void TlsAgent::CheckPreliminaryInfo() { 975 SSLPreliminaryChannelInfo preinfo; 976 EXPECT_EQ(SECSuccess, 977 SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo))); 978 EXPECT_EQ(sizeof(preinfo), preinfo.length); 979 EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); 980 981 // A version of 0 is invalid and indicates no expectation. This value is 982 // initialized to 0 so that tests that don't explicitly set an expected 983 // version can negotiate a version. 984 if (!expected_version_) { 985 expected_version_ = preinfo.protocolVersion; 986 } 987 EXPECT_EQ(expected_version_, preinfo.protocolVersion); 988 989 // As with the version; 0 is the null cipher suite (and also invalid). 990 if (!expected_cipher_suite_) { 991 expected_cipher_suite_ = preinfo.cipherSuite; 992 } 993 EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite); 994 } 995 996 // Check that all the expected callbacks have been called. 997 void TlsAgent::CheckCallbacks() const { 998 // If false start happens, the handshake is reported as being complete at the 999 // point that false start happens. 1000 if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) { 1001 EXPECT_TRUE(handshake_callback_called_); 1002 } 1003 1004 // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. 1005 if (role_ == SERVER) { 1006 PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); 1007 EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) || 1008 expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), 1009 sni_hook_called_); 1010 } else { 1011 EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_); 1012 // Note that this isn't unconditionally called, even with false start on. 1013 // But the callback is only skipped if a cipher that is ridiculously weak 1014 // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. 1015 EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume, 1016 can_falsestart_hook_called_); 1017 } 1018 } 1019 1020 void TlsAgent::ResetPreliminaryInfo() { 1021 expected_version_ = 0; 1022 expected_cipher_suite_ = 0; 1023 } 1024 1025 void TlsAgent::UpdatePreliminaryChannelInfo() { 1026 SECStatus rv = 1027 SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_)); 1028 EXPECT_EQ(SECSuccess, rv); 1029 EXPECT_EQ(sizeof(pre_info_), pre_info_.length); 1030 } 1031 1032 void TlsAgent::ValidateCipherSpecs() { 1033 PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); 1034 // We use one ciphersuite in each direction. 1035 PRInt32 expected = 2; 1036 if (variant_ == ssl_variant_datagram) { 1037 // For DTLS 1.3, the client retains the cipher spec for early data and the 1038 // handshake so that it can retransmit EndOfEarlyData and its final flight. 1039 // It also retains the handshake read cipher spec so that it can read ACKs 1040 // from the server. The server retains the handshake read cipher spec so it 1041 // can read the client's retransmitted Finished. 1042 if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { 1043 if (role_ == CLIENT) { 1044 expected = info_.earlyDataAccepted ? 5 : 4; 1045 } else { 1046 expected = 3; 1047 } 1048 } else { 1049 // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec 1050 // until the holddown timer runs down. 1051 if (expect_psk_ == ssl_psk_resume) { 1052 if (role_ == CLIENT) { 1053 expected = 3; 1054 } 1055 } else { 1056 if (role_ == SERVER) { 1057 expected = 3; 1058 } 1059 } 1060 } 1061 } 1062 // This function will be run before the handshake completes if false start is 1063 // enabled. In that case, the client will still be reading cleartext, but 1064 // will have a spec prepared for reading ciphertext. With DTLS, the client 1065 // will also have a spec retained for retransmission of handshake messages. 1066 if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { 1067 EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); 1068 expected = (variant_ == ssl_variant_datagram) ? 4 : 3; 1069 } 1070 EXPECT_EQ(expected, cipherSpecs); 1071 if (expected != cipherSpecs) { 1072 SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); 1073 } 1074 } 1075 1076 void TlsAgent::Connected() { 1077 if (state_ == STATE_CONNECTED) { 1078 return; 1079 } 1080 1081 LOG("Handshake success"); 1082 CheckPreliminaryInfo(); 1083 CheckCallbacks(); 1084 1085 SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); 1086 EXPECT_EQ(SECSuccess, rv); 1087 EXPECT_EQ(sizeof(info_), info_.length); 1088 1089 EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE); 1090 EXPECT_EQ(expect_psk_, info_.pskType); 1091 EXPECT_EQ(expect_ech_, info_.echAccepted); 1092 1093 // Preliminary values are exposed through callbacks during the handshake. 1094 // If either expected values were set or the callbacks were called, check 1095 // that the final values are correct. 1096 UpdatePreliminaryChannelInfo(); 1097 EXPECT_EQ(expected_version_, info_.protocolVersion); 1098 EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); 1099 1100 rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); 1101 EXPECT_EQ(SECSuccess, rv); 1102 EXPECT_EQ(sizeof(csinfo_), csinfo_.length); 1103 1104 ValidateCipherSpecs(); 1105 1106 SetState(STATE_CONNECTED); 1107 } 1108 1109 void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) { 1110 EXPECT_FALSE(client_auth_callback_awaiting_); 1111 switch (client_auth_callback_type_) { 1112 case ClientAuthCallbackType::kNone: 1113 if (!client_auth_callback_success_) { 1114 EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0)); 1115 break; 1116 } 1117 case ClientAuthCallbackType::kSync: 1118 EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes)); 1119 break; 1120 case ClientAuthCallbackType::kAsyncDelay: 1121 case ClientAuthCallbackType::kAsyncImmediate: 1122 EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes)); 1123 break; 1124 } 1125 } 1126 1127 void TlsAgent::EnableExtendedMasterSecret() { 1128 SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); 1129 } 1130 1131 void TlsAgent::CheckExtendedMasterSecret(bool expected) { 1132 if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { 1133 expected = PR_TRUE; 1134 } 1135 ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) 1136 << "unexpected extended master secret state for " << name_; 1137 } 1138 1139 void TlsAgent::CheckEarlyDataAccepted(bool expected) { 1140 if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { 1141 expected = false; 1142 } 1143 ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) 1144 << "unexpected early data state for " << name_; 1145 } 1146 1147 void TlsAgent::CheckSecretsDestroyed() { 1148 ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); 1149 } 1150 1151 void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { 1152 ASSERT_TRUE(EnsureTlsSetup()); 1153 1154 SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver); 1155 ASSERT_EQ(SECSuccess, rv); 1156 } 1157 1158 void TlsAgent::Handshake() { 1159 LOGV("Handshake"); 1160 SECStatus rv = SSL_ForceHandshake(ssl_fd()); 1161 if (client_auth_callback_awaiting_) { 1162 ClientAuthCallbackComplete(); 1163 rv = SSL_ForceHandshake(ssl_fd()); 1164 } 1165 if (rv == SECSuccess) { 1166 Connected(); 1167 Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, 1168 &TlsAgent::ReadableCallback); 1169 return; 1170 } 1171 1172 int32_t err = PR_GetError(); 1173 if (err == PR_WOULD_BLOCK_ERROR) { 1174 LOGV("Would have blocked"); 1175 if (variant_ == ssl_variant_datagram) { 1176 if (timer_handle_) { 1177 timer_handle_->Cancel(); 1178 timer_handle_ = nullptr; 1179 } 1180 1181 PRIntervalTime timeout; 1182 rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); 1183 if (rv == SECSuccess) { 1184 Poller::Instance()->SetTimer( 1185 timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); 1186 } 1187 } 1188 Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, 1189 &TlsAgent::ReadableCallback); 1190 return; 1191 } 1192 1193 if (err != 0) { 1194 LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " 1195 << PORT_ErrorToString(err)); 1196 } 1197 1198 error_code_ = err; 1199 SetState(STATE_ERROR); 1200 } 1201 1202 void TlsAgent::PrepareForRenegotiate() { 1203 EXPECT_EQ(STATE_CONNECTED, state_); 1204 1205 SetState(STATE_CONNECTING); 1206 } 1207 1208 void TlsAgent::StartRenegotiate() { 1209 PrepareForRenegotiate(); 1210 1211 SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); 1212 EXPECT_EQ(SECSuccess, rv); 1213 } 1214 1215 void TlsAgent::SendDirect(const DataBuffer& buf) { 1216 LOG("Send Direct " << buf); 1217 auto peer = adapter_->peer().lock(); 1218 if (peer) { 1219 peer->PacketReceived(buf); 1220 } else { 1221 LOG("Send Direct peer absent"); 1222 } 1223 } 1224 1225 void TlsAgent::SendRecordDirect(const TlsRecord& record) { 1226 DataBuffer buf; 1227 1228 auto rv = record.header.Write(&buf, 0, record.buffer); 1229 EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); 1230 SendDirect(buf); 1231 } 1232 1233 static bool ErrorIsFatal(PRErrorCode code) { 1234 return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ; 1235 } 1236 1237 void TlsAgent::SendData(size_t bytes, size_t blocksize) { 1238 uint8_t block[16385]; // One larger than the maximum record size. 1239 1240 ASSERT_LE(blocksize, sizeof(block)); 1241 1242 while (bytes) { 1243 size_t tosend = std::min(blocksize, bytes); 1244 1245 for (size_t i = 0; i < tosend; ++i) { 1246 block[i] = 0xff & send_ctr_; 1247 ++send_ctr_; 1248 } 1249 1250 SendBuffer(DataBuffer(block, tosend)); 1251 bytes -= tosend; 1252 } 1253 } 1254 1255 void TlsAgent::SendBuffer(const DataBuffer& buf) { 1256 LOGV("Writing " << buf.len() << " bytes"); 1257 int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); 1258 if (expect_readwrite_error_) { 1259 EXPECT_GT(0, rv); 1260 EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); 1261 error_code_ = PR_GetError(); 1262 expect_readwrite_error_ = false; 1263 } else { 1264 ASSERT_EQ(buf.len(), static_cast<size_t>(rv)); 1265 } 1266 } 1267 1268 bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, 1269 uint64_t seq, uint8_t ct, 1270 const DataBuffer& buf) { 1271 // Ensure that we are doing TLS 1.3. 1272 EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); 1273 if (variant_ != ssl_variant_datagram) { 1274 ADD_FAILURE(); 1275 return false; 1276 } 1277 1278 LOGV("Encrypting " << buf.len() << " bytes"); 1279 uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | 1280 kCtDtlsCiphertextLengthPresent; 1281 TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq); 1282 TlsRecordHeader out_header(header); 1283 DataBuffer padded = buf; 1284 padded.Write(padded.len(), ct, 1); 1285 DataBuffer ciphertext; 1286 if (!spec->Protect(header, padded, &ciphertext, &out_header)) { 1287 return false; 1288 } 1289 1290 DataBuffer record; 1291 auto rv = out_header.Write(&record, 0, ciphertext); 1292 EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); 1293 SendDirect(record); 1294 return true; 1295 } 1296 1297 void TlsAgent::ReadBytes(size_t amount) { 1298 uint8_t block[16384]; 1299 1300 size_t remaining = amount; 1301 while (remaining > 0) { 1302 int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); 1303 LOGV("ReadBytes " << rv); 1304 1305 if (rv > 0) { 1306 size_t count = static_cast<size_t>(rv); 1307 for (size_t i = 0; i < count; ++i) { 1308 ASSERT_EQ(recv_ctr_ & 0xff, block[i]); 1309 recv_ctr_++; 1310 } 1311 remaining -= rv; 1312 } else { 1313 PRErrorCode err = 0; 1314 if (rv < 0) { 1315 err = PR_GetError(); 1316 if (err != 0) { 1317 LOG("Read error " << PORT_ErrorToName(err) << ": " 1318 << PORT_ErrorToString(err)); 1319 } 1320 if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { 1321 if (ErrorIsFatal(err)) { 1322 SetState(STATE_ERROR); 1323 } 1324 error_code_ = err; 1325 expect_readwrite_error_ = false; 1326 } 1327 } 1328 if (err != 0 && ErrorIsFatal(err)) { 1329 // If we hit a fatal error, we're done. 1330 remaining = 0; 1331 } 1332 break; 1333 } 1334 } 1335 1336 // If closed, then don't bother waiting around. 1337 if (remaining) { 1338 LOGV("Re-arming"); 1339 Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, 1340 &TlsAgent::ReadableCallback); 1341 } 1342 } 1343 1344 void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; } 1345 1346 void TlsAgent::SetOption(int32_t option, int value) { 1347 ASSERT_TRUE(EnsureTlsSetup()); 1348 EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); 1349 } 1350 1351 void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { 1352 SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); 1353 SetOption(SSL_ENABLE_SESSION_TICKETS, 1354 mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); 1355 } 1356 1357 void TlsAgent::EnableECDHEServerKeyReuse() { 1358 ASSERT_EQ(TlsAgent::SERVER, role_); 1359 SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE); 1360 } 1361 1362 static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; 1363 ::testing::internal::ParamGenerator<std::string> 1364 TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); 1365 1366 void TlsAgentTestBase::SetUp() { 1367 SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); 1368 } 1369 1370 void TlsAgentTestBase::TearDown() { 1371 agent_ = nullptr; 1372 SSL_ClearSessionCache(); 1373 SSL_ShutdownServerSessionIDCache(); 1374 } 1375 1376 void TlsAgentTestBase::Reset(const std::string& server_name) { 1377 agent_.reset( 1378 new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, 1379 role_, variant_)); 1380 if (version_) { 1381 agent_->SetVersionRange(version_, version_); 1382 } 1383 agent_->adapter()->SetPeer(sink_adapter_); 1384 agent_->StartConnect(); 1385 } 1386 1387 void TlsAgentTestBase::EnsureInit() { 1388 if (!agent_) { 1389 Reset(); 1390 } 1391 const std::vector<SSLNamedGroup> groups = { 1392 ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, 1393 ssl_grp_ffdhe_2048}; 1394 agent_->ConfigNamedGroups(groups); 1395 } 1396 1397 void TlsAgentTestBase::ExpectAlert(uint8_t alert) { 1398 EnsureInit(); 1399 agent_->ExpectSendAlert(alert); 1400 } 1401 1402 void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, 1403 TlsAgent::State expected_state, 1404 int32_t error_code) { 1405 std::cerr << "Process message: " << buffer << std::endl; 1406 EnsureInit(); 1407 agent_->adapter()->PacketReceived(buffer); 1408 agent_->Handshake(); 1409 1410 ASSERT_EQ(expected_state, agent_->state()); 1411 1412 if (expected_state == TlsAgent::STATE_ERROR) { 1413 ASSERT_EQ(error_code, agent_->error_code()); 1414 } 1415 } 1416 1417 void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, 1418 uint16_t version, const uint8_t* buf, 1419 size_t len, DataBuffer* out, 1420 uint64_t sequence_number) { 1421 // Fixup the content type for DTLSCiphertext 1422 if (variant == ssl_variant_datagram && 1423 version >= SSL_LIBRARY_VERSION_TLS_1_3 && 1424 type == ssl_ct_application_data) { 1425 type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | 1426 kCtDtlsCiphertextLengthPresent; 1427 } 1428 1429 size_t index = 0; 1430 if (variant == ssl_variant_stream) { 1431 index = out->Write(index, type, 1); 1432 index = out->Write(index, version, 2); 1433 } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && 1434 (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) { 1435 uint32_t epoch = (sequence_number >> 48) & 0x3; 1436 index = out->Write(index, type | epoch, 1); 1437 uint32_t seqno = sequence_number & ((1ULL << 16) - 1); 1438 index = out->Write(index, seqno, 2); 1439 } else { 1440 index = out->Write(index, type, 1); 1441 index = out->Write(index, TlsVersionToDtlsVersion(version), 2); 1442 index = out->Write(index, sequence_number >> 32, 4); 1443 index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); 1444 } 1445 index = out->Write(index, len, 2); 1446 out->Write(index, buf, len); 1447 } 1448 1449 void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, 1450 const uint8_t* buf, size_t len, 1451 DataBuffer* out, uint64_t seq_num) const { 1452 MakeRecord(variant_, type, version, buf, len, out, seq_num); 1453 } 1454 1455 void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, 1456 const uint8_t* data, size_t hs_len, 1457 DataBuffer* out, 1458 uint64_t seq_num) const { 1459 return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, 1460 0); 1461 } 1462 1463 void TlsAgentTestBase::MakeHandshakeMessageFragment( 1464 uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, 1465 uint64_t seq_num, uint32_t fragment_offset, 1466 uint32_t fragment_length) const { 1467 size_t index = 0; 1468 if (!fragment_length) fragment_length = hs_len; 1469 index = out->Write(index, hs_type, 1); // Handshake record type. 1470 index = out->Write(index, hs_len, 3); // Handshake length 1471 if (variant_ == ssl_variant_datagram) { 1472 index = out->Write(index, seq_num, 2); 1473 index = out->Write(index, fragment_offset, 3); 1474 index = out->Write(index, fragment_length, 3); 1475 } 1476 if (data) { 1477 index = out->Write(index, data, fragment_length); 1478 } else { 1479 for (size_t i = 0; i < fragment_length; ++i) { 1480 index = out->Write(index, 1, 1); 1481 } 1482 } 1483 } 1484 1485 void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, 1486 size_t hs_len, 1487 DataBuffer* out) { 1488 size_t index = 0; 1489 index = out->Write(index, ssl_ct_handshake, 1); // Content Type 1490 index = out->Write(index, 3, 1); // Version high 1491 index = out->Write(index, 1, 1); // Version low 1492 index = out->Write(index, 4 + hs_len, 2); // Length 1493 1494 index = out->Write(index, hs_type, 1); // Handshake record type. 1495 index = out->Write(index, hs_len, 3); // Handshake length 1496 for (size_t i = 0; i < hs_len; ++i) { 1497 index = out->Write(index, 1, 1); 1498 } 1499 } 1500 1501 DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { 1502 DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); 1503 if (variant_ == ssl_variant_datagram) { 1504 sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); 1505 // The version should be at the end. 1506 uint32_t v; 1507 EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v)); 1508 EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v); 1509 sh.Write(sh.len() - 2, SSL_LIBRARY_VERSION_DTLS_1_3_WIRE, 2); 1510 } 1511 return sh; 1512 } 1513 1514 } // namespace nss_test