tls_connect.cc (38269B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "tls_connect.h" 8 #include "sslexp.h" 9 extern "C" { 10 #include "libssl_internals.h" 11 } 12 13 #include <iostream> 14 15 #include "databuffer.h" 16 #include "gtest_utils.h" 17 #include "nss_scoped_ptrs.h" 18 #include "sslproto.h" 19 20 extern std::string g_working_dir_path; 21 22 namespace nss_test { 23 24 static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; 25 ::testing::internal::ParamGenerator<SSLProtocolVariant> 26 TlsConnectTestBase::kTlsVariantsStream = 27 ::testing::ValuesIn(kTlsVariantsStreamArr); 28 static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { 29 ssl_variant_datagram}; 30 ::testing::internal::ParamGenerator<SSLProtocolVariant> 31 TlsConnectTestBase::kTlsVariantsDatagram = 32 ::testing::ValuesIn(kTlsVariantsDatagramArr); 33 static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, 34 ssl_variant_datagram}; 35 ::testing::internal::ParamGenerator<SSLProtocolVariant> 36 TlsConnectTestBase::kTlsVariantsAll = 37 ::testing::ValuesIn(kTlsVariantsAllArr); 38 39 static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; 40 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 = 41 ::testing::ValuesIn(kTlsV10Arr); 42 static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; 43 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 = 44 ::testing::ValuesIn(kTlsV11Arr); 45 static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; 46 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 = 47 ::testing::ValuesIn(kTlsV12Arr); 48 static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, 49 SSL_LIBRARY_VERSION_TLS_1_1}; 50 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 = 51 ::testing::ValuesIn(kTlsV10V11Arr); 52 static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, 53 SSL_LIBRARY_VERSION_TLS_1_1, 54 SSL_LIBRARY_VERSION_TLS_1_2}; 55 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 = 56 ::testing::ValuesIn(kTlsV10ToV12Arr); 57 static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, 58 SSL_LIBRARY_VERSION_TLS_1_2}; 59 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 = 60 ::testing::ValuesIn(kTlsV11V12Arr); 61 62 static const uint16_t kTlsV11PlusArr[] = { 63 #ifndef NSS_DISABLE_TLS_1_3 64 SSL_LIBRARY_VERSION_TLS_1_3, 65 #endif 66 SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1}; 67 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus = 68 ::testing::ValuesIn(kTlsV11PlusArr); 69 static const uint16_t kTlsV12PlusArr[] = { 70 #ifndef NSS_DISABLE_TLS_1_3 71 SSL_LIBRARY_VERSION_TLS_1_3, 72 #endif 73 SSL_LIBRARY_VERSION_TLS_1_2}; 74 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus = 75 ::testing::ValuesIn(kTlsV12PlusArr); 76 static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3}; 77 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 = 78 ::testing::ValuesIn(kTlsV13Arr); 79 static const uint16_t kTlsVAllArr[] = { 80 #ifndef NSS_DISABLE_TLS_1_3 81 SSL_LIBRARY_VERSION_TLS_1_3, 82 #endif 83 SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1, 84 SSL_LIBRARY_VERSION_TLS_1_0}; 85 ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll = 86 ::testing::ValuesIn(kTlsVAllArr); 87 88 std::string VersionString(uint16_t version) { 89 switch (version) { 90 case 0: 91 return "(no version)"; 92 case SSL_LIBRARY_VERSION_3_0: 93 return "1.0"; 94 case SSL_LIBRARY_VERSION_TLS_1_0: 95 return "1.0"; 96 case SSL_LIBRARY_VERSION_TLS_1_1: 97 return "1.1"; 98 case SSL_LIBRARY_VERSION_TLS_1_2: 99 return "1.2"; 100 case SSL_LIBRARY_VERSION_TLS_1_3: 101 return "1.3"; 102 default: 103 std::cerr << "Invalid version: " << version << std::endl; 104 EXPECT_TRUE(false); 105 return ""; 106 } 107 } 108 109 // The default anti-replay window for tests. Tests that rely on a different 110 // value call ResetAntiReplay directly. 111 static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC; 112 113 TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, 114 uint16_t version) 115 : variant_(variant), 116 client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), 117 server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), 118 client_model_(nullptr), 119 server_model_(nullptr), 120 version_(version), 121 expected_resumption_mode_(RESUME_NONE), 122 expected_resumptions_(0), 123 session_ids_(), 124 expect_extended_master_secret_(false), 125 expect_early_data_accepted_(false), 126 skip_version_checks_(false) { 127 std::string v; 128 if (variant_ == ssl_variant_datagram && 129 version_ == SSL_LIBRARY_VERSION_TLS_1_1) { 130 v = "1.0"; 131 } else { 132 v = VersionString(version_); 133 } 134 std::cerr << "Version: " << variant_ << " " << v << std::endl; 135 } 136 137 TlsConnectTestBase::~TlsConnectTestBase() {} 138 139 // Check the group of each of the supported groups 140 void TlsConnectTestBase::CheckGroups( 141 const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) { 142 DuplicateGroupChecker group_set; 143 uint32_t tmp = 0; 144 EXPECT_TRUE(groups.Read(0, 2, &tmp)); 145 EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp)); 146 for (size_t i = 2; i < groups.len(); i += 2) { 147 EXPECT_TRUE(groups.Read(i, 2, &tmp)); 148 SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); 149 group_set.AddAndCheckGroup(group); 150 check_group(group); 151 } 152 } 153 154 // Check the group of each of the shares 155 void TlsConnectTestBase::CheckShares( 156 const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) { 157 DuplicateGroupChecker group_set; 158 uint32_t tmp = 0; 159 EXPECT_TRUE(shares.Read(0, 2, &tmp)); 160 EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp)); 161 size_t i; 162 for (i = 2; i < shares.len(); i += 4 + tmp) { 163 ASSERT_TRUE(shares.Read(i, 2, &tmp)); 164 SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); 165 group_set.AddAndCheckGroup(group); 166 check_group(group); 167 ASSERT_TRUE(shares.Read(i + 2, 2, &tmp)); 168 } 169 EXPECT_EQ(shares.len(), i); 170 } 171 172 void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, 173 uint16_t server_epoch) const { 174 client_->CheckEpochs(server_epoch, client_epoch); 175 server_->CheckEpochs(client_epoch, server_epoch); 176 } 177 178 void TlsConnectTestBase::ClearStats() { 179 // Clear statistics. 180 SSL3Statistics* stats = SSL_GetStatistics(); 181 memset(stats, 0, sizeof(*stats)); 182 } 183 184 void TlsConnectTestBase::ClearServerCache() { 185 SSL_ShutdownServerSessionIDCache(); 186 SSLInt_ClearSelfEncryptKey(); 187 SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); 188 } 189 190 void TlsConnectTestBase::SaveAlgorithmPolicy() { 191 saved_policies_.clear(); 192 for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) { 193 uint32_t policy; 194 SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy); 195 ASSERT_EQ(SECSuccess, rv); 196 saved_policies_.push_back(std::make_tuple(*it, policy)); 197 } 198 saved_options_.clear(); 199 for (auto it : options_) { 200 int32_t option; 201 SECStatus rv = NSS_OptionGet(it, &option); 202 ASSERT_EQ(SECSuccess, rv); 203 saved_options_.push_back(std::make_tuple(it, option)); 204 } 205 } 206 207 void TlsConnectTestBase::RestoreAlgorithmPolicy() { 208 for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) { 209 auto algorithm = std::get<0>(*it); 210 auto policy = std::get<1>(*it); 211 SECStatus rv = NSS_SetAlgorithmPolicy( 212 algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX); 213 ASSERT_EQ(SECSuccess, rv); 214 } 215 for (auto it = saved_options_.begin(); it != saved_options_.end(); ++it) { 216 auto option_id = std::get<0>(*it); 217 auto option = std::get<1>(*it); 218 SECStatus rv = NSS_OptionSet(option_id, option); 219 ASSERT_EQ(SECSuccess, rv); 220 } 221 } 222 223 PRTime TlsConnectTestBase::TimeFunc(void* arg) { 224 return *reinterpret_cast<PRTime*>(arg); 225 } 226 227 void TlsConnectTestBase::SetUp() { 228 SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); 229 SSLInt_ClearSelfEncryptKey(); 230 now_ = PR_Now(); 231 ResetAntiReplay(kAntiReplayWindow); 232 ClearStats(); 233 SaveAlgorithmPolicy(); 234 Init(); 235 } 236 237 void TlsConnectTestBase::TearDown() { 238 client_ = nullptr; 239 server_ = nullptr; 240 241 SSL_ClearSessionCache(); 242 SSLInt_ClearSelfEncryptKey(); 243 SSL_ShutdownServerSessionIDCache(); 244 RestoreAlgorithmPolicy(); 245 } 246 247 void TlsConnectTestBase::Init() { 248 client_->SetPeer(server_); 249 server_->SetPeer(client_); 250 251 if (version_) { 252 ConfigureVersion(version_); 253 } 254 } 255 256 void TlsConnectTestBase::ResetAntiReplay(PRTime window) { 257 SSLAntiReplayContext* p_anti_replay = nullptr; 258 EXPECT_EQ(SECSuccess, 259 SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay)); 260 EXPECT_NE(nullptr, p_anti_replay); 261 anti_replay_.reset(p_anti_replay); 262 } 263 264 ScopedSECItem TlsConnectTestBase::MakeEcKeyParams(SSLNamedGroup group) { 265 auto groupDef = ssl_LookupNamedGroup(group); 266 EXPECT_NE(nullptr, groupDef); 267 268 auto oidData = SECOID_FindOIDByTag(groupDef->oidTag); 269 EXPECT_NE(nullptr, oidData); 270 ScopedSECItem params( 271 SECITEM_AllocItem(nullptr, nullptr, (2 + oidData->oid.len))); 272 EXPECT_TRUE(!!params); 273 params->data[0] = SEC_ASN1_OBJECT_ID; 274 params->data[1] = oidData->oid.len; 275 memcpy(params->data + 2, oidData->oid.data, oidData->oid.len); 276 return params; 277 } 278 279 void TlsConnectTestBase::GenerateEchConfig( 280 HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites, 281 const std::string& public_name, uint16_t max_name_len, DataBuffer& record, 282 ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey) { 283 bool gen_keys = !pubKey && !privKey; 284 285 SECKEYPublicKey* pub = nullptr; 286 SECKEYPrivateKey* priv = nullptr; 287 288 if (gen_keys) { 289 ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519); 290 priv = SECKEY_CreateECPrivateKey(ecParams.get(), &pub, nullptr); 291 } else { 292 priv = privKey.get(); 293 pub = pubKey.get(); 294 } 295 ASSERT_NE(nullptr, priv); 296 PRUint8 encoded[1024]; 297 unsigned int encoded_len = 0; 298 SECStatus rv = SSL_EncodeEchConfigId( 299 77, public_name.c_str(), max_name_len, kem_id, pub, cipher_suites.data(), 300 cipher_suites.size(), encoded, &encoded_len, sizeof(encoded)); 301 EXPECT_EQ(SECSuccess, rv); 302 EXPECT_GT(encoded_len, 0U); 303 304 if (gen_keys) { 305 pubKey.reset(pub); 306 privKey.reset(priv); 307 } 308 record.Truncate(0); 309 record.Write(0, encoded, encoded_len); 310 } 311 312 void TlsConnectTestBase::SetupEch(std::shared_ptr<TlsAgent>& client, 313 std::shared_ptr<TlsAgent>& server, 314 HpkeKemId kem_id, bool expect_ech, 315 bool set_client_config, 316 bool set_server_config, int max_name_len) { 317 EXPECT_TRUE(set_server_config || set_client_config); 318 ScopedSECKEYPublicKey pub; 319 ScopedSECKEYPrivateKey priv; 320 DataBuffer record; 321 static const std::vector<HpkeSymmetricSuite> kDefaultSuites = { 322 {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305}, 323 {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; 324 325 GenerateEchConfig(kem_id, kDefaultSuites, "public.name", max_name_len, record, 326 pub, priv); 327 ASSERT_NE(0U, record.len()); 328 SECStatus rv; 329 if (set_server_config) { 330 rv = SSL_SetServerEchConfigs(server->ssl_fd(), pub.get(), priv.get(), 331 record.data(), record.len()); 332 ASSERT_EQ(SECSuccess, rv); 333 } 334 if (set_client_config) { 335 rv = SSL_SetClientEchConfigs(client->ssl_fd(), record.data(), record.len()); 336 ASSERT_EQ(SECSuccess, rv); 337 } 338 339 /* Filter expect_ech, which typically defaults to true. Parameterized tests 340 * running DTLS or TLS < 1.3 should expect only a non-ECH result. */ 341 bool expect = expect_ech && variant_ != ssl_variant_datagram && 342 version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && set_client_config && 343 set_server_config; 344 client->ExpectEch(expect); 345 server->ExpectEch(expect); 346 } 347 348 void TlsConnectTestBase::Reset() { 349 // Take a copy of the names because they are about to disappear. 350 std::string server_name = server_->name(); 351 std::string client_name = client_->name(); 352 Reset(server_name, client_name); 353 } 354 355 void TlsConnectTestBase::Reset(const std::string& server_name, 356 const std::string& client_name) { 357 auto token = client_->GetResumptionToken(); 358 client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); 359 client_->SetResumptionToken(token); 360 server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); 361 if (skip_version_checks_) { 362 client_->SkipVersionChecks(); 363 server_->SkipVersionChecks(); 364 } 365 366 std::cerr << "Reset server:" << server_name << ", client:" << client_name 367 << std::endl; 368 Init(); 369 } 370 371 void TlsConnectTestBase::MakeNewServer() { 372 auto replacement = std::make_shared<TlsAgent>( 373 server_->name(), TlsAgent::SERVER, server_->variant()); 374 server_ = replacement; 375 if (version_) { 376 server_->SetVersionRange(version_, version_); 377 } 378 client_->SetPeer(server_); 379 server_->SetPeer(client_); 380 server_->StartConnect(); 381 } 382 383 void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, 384 uint8_t num_resumptions) { 385 expected_resumption_mode_ = expected; 386 if (expected != RESUME_NONE) { 387 client_->ExpectResumption(); 388 server_->ExpectResumption(); 389 expected_resumptions_ = num_resumptions; 390 } 391 EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); 392 } 393 394 void TlsConnectTestBase::EnsureTlsSetup() { 395 EXPECT_TRUE(server_->EnsureTlsSetup( 396 server_model_ ? server_model_->ssl_fd().get() : nullptr)); 397 EXPECT_TRUE(client_->EnsureTlsSetup( 398 client_model_ ? client_model_->ssl_fd().get() : nullptr)); 399 server_->SetAntiReplayContext(anti_replay_); 400 EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(), 401 TlsConnectTestBase::TimeFunc, &now_)); 402 EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(), 403 TlsConnectTestBase::TimeFunc, &now_)); 404 } 405 406 void TlsConnectTestBase::Handshake() { 407 client_->SetServerKeyBits(server_->server_key_bits()); 408 client_->Handshake(); 409 server_->Handshake(); 410 411 ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && 412 (server_->state() != TlsAgent::STATE_CONNECTING), 413 5000); 414 } 415 416 void TlsConnectTestBase::EnableExtendedMasterSecret() { 417 client_->EnableExtendedMasterSecret(); 418 server_->EnableExtendedMasterSecret(); 419 ExpectExtendedMasterSecret(true); 420 } 421 422 void TlsConnectTestBase::Connect() { 423 StartConnect(); 424 client_->MaybeSetResumptionToken(); 425 Handshake(); 426 CheckConnected(); 427 } 428 429 void TlsConnectTestBase::StartConnect() { 430 EnsureTlsSetup(); 431 server_->StartConnect(); 432 client_->StartConnect(); 433 } 434 435 void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { 436 EnsureTlsSetup(); 437 client_->EnableSingleCipher(cipher_suite); 438 439 Connect(); 440 SendReceive(); 441 442 // Check that we used the right cipher suite. 443 uint16_t actual; 444 EXPECT_TRUE(client_->cipher_suite(&actual)); 445 EXPECT_EQ(cipher_suite, actual); 446 EXPECT_TRUE(server_->cipher_suite(&actual)); 447 EXPECT_EQ(cipher_suite, actual); 448 } 449 450 void TlsConnectTestBase::CheckConnected() { 451 // Have the client read handshake twice to make sure we get the 452 // NST and the ACK. 453 if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && 454 variant_ == ssl_variant_datagram) { 455 client_->Handshake(); 456 client_->Handshake(); 457 auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); 458 // Verify that we dropped the client's retransmission cipher suites. 459 EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; 460 if (suites != 2) { 461 SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); 462 } 463 } 464 EXPECT_EQ(client_->version(), server_->version()); 465 if (!skip_version_checks_) { 466 // Check the version is as expected 467 EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), 468 client_->version()); 469 } 470 471 EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); 472 EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); 473 474 uint16_t cipher_suite1, cipher_suite2; 475 ASSERT_TRUE(client_->cipher_suite(&cipher_suite1)); 476 ASSERT_TRUE(server_->cipher_suite(&cipher_suite2)); 477 EXPECT_EQ(cipher_suite1, cipher_suite2); 478 479 std::cerr << "Connected with version " << client_->version() 480 << " cipher suite " << client_->cipher_suite_name() << std::endl; 481 482 if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { 483 // Check and store session ids. 484 std::vector<uint8_t> sid_c1 = client_->session_id(); 485 EXPECT_EQ(32U, sid_c1.size()); 486 std::vector<uint8_t> sid_s1 = server_->session_id(); 487 EXPECT_EQ(32U, sid_s1.size()); 488 EXPECT_EQ(sid_c1, sid_s1); 489 session_ids_.push_back(sid_c1); 490 } 491 492 CheckExtendedMasterSecret(); 493 CheckEarlyDataAccepted(); 494 CheckResumption(expected_resumption_mode_); 495 client_->CheckSecretsDestroyed(); 496 server_->CheckSecretsDestroyed(); 497 } 498 499 void TlsConnectTestBase::CheckEarlyDataLimit( 500 const std::shared_ptr<TlsAgent>& agent, size_t expected_size) { 501 SSLPreliminaryChannelInfo preinfo; 502 SECStatus rv = 503 SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo)); 504 EXPECT_EQ(SECSuccess, rv); 505 EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize)); 506 } 507 508 SSLKEAType TlsConnectTestBase::GetDefaultKEA(void) const { 509 if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { 510 return ssl_kea_ecdh_hybrid; 511 } 512 return ssl_kea_ecdh; 513 } 514 515 SSLAuthType TlsConnectTestBase::GetDefaultAuth(void) const { 516 return ssl_auth_rsa_sign; 517 } 518 519 SSLNamedGroup TlsConnectTestBase::GetDefaultGroupFromKEA( 520 SSLKEAType kea_type) const { 521 SSLNamedGroup group; 522 switch (kea_type) { 523 case ssl_kea_ecdh_hybrid: 524 group = ssl_grp_kem_mlkem768x25519; 525 break; 526 case ssl_kea_ecdh: 527 group = ssl_grp_ec_curve25519; 528 break; 529 case ssl_kea_dh: 530 group = ssl_grp_ffdhe_2048; 531 break; 532 case ssl_kea_rsa: 533 group = ssl_grp_none; 534 break; 535 default: 536 EXPECT_TRUE(false) << "unexpected KEA"; 537 group = ssl_grp_none; 538 break; 539 } 540 return group; 541 } 542 543 SSLSignatureScheme TlsConnectTestBase::GetDefaultSchemeFromAuth( 544 SSLAuthType auth_type) const { 545 SSLSignatureScheme scheme; 546 switch (auth_type) { 547 case ssl_auth_rsa_decrypt: 548 scheme = ssl_sig_none; 549 break; 550 case ssl_auth_rsa_sign: 551 if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { 552 scheme = ssl_sig_rsa_pss_rsae_sha256; 553 } else { 554 scheme = ssl_sig_rsa_pkcs1_sha256; 555 } 556 break; 557 case ssl_auth_rsa_pss: 558 scheme = ssl_sig_rsa_pss_rsae_sha256; 559 break; 560 case ssl_auth_ecdsa: 561 scheme = ssl_sig_ecdsa_secp256r1_sha256; 562 break; 563 case ssl_auth_dsa: 564 scheme = ssl_sig_dsa_sha1; 565 break; 566 default: 567 EXPECT_TRUE(false) << "unexpected auth type"; 568 scheme = static_cast<SSLSignatureScheme>(0x0100); 569 break; 570 } 571 return scheme; 572 } 573 574 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, 575 SSLAuthType auth_type, 576 SSLSignatureScheme sig_scheme) const { 577 if (kea_group != ssl_grp_none) { 578 client_->CheckKEA(kea_type, kea_group); 579 server_->CheckKEA(kea_type, kea_group); 580 } 581 server_->CheckAuthType(auth_type, sig_scheme); 582 client_->CheckAuthType(auth_type, sig_scheme); 583 } 584 585 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, 586 SSLNamedGroup kea_group) const { 587 SSLAuthType auth_type = GetDefaultAuth(); 588 SSLSignatureScheme scheme = GetDefaultSchemeFromAuth(auth_type); 589 CheckKeys(kea_type, kea_group, auth_type, scheme); 590 } 591 592 void TlsConnectTestBase::CheckKeys(SSLAuthType auth_type, 593 SSLSignatureScheme sig_scheme) const { 594 SSLKEAType kea_type = GetDefaultKEA(); 595 SSLNamedGroup group = GetDefaultGroupFromKEA(kea_type); 596 CheckKeys(kea_type, group, auth_type, sig_scheme); 597 } 598 599 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, 600 SSLAuthType auth_type) const { 601 SSLNamedGroup group = GetDefaultGroupFromKEA(kea_type); 602 SSLSignatureScheme scheme = GetDefaultSchemeFromAuth(auth_type); 603 604 CheckKeys(kea_type, group, auth_type, scheme); 605 } 606 607 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type) const { 608 SSLAuthType auth_type = GetDefaultAuth(); 609 CheckKeys(kea_type, auth_type); 610 } 611 612 void TlsConnectTestBase::CheckKeys(SSLAuthType auth_type) const { 613 SSLKEAType kea_type = GetDefaultKEA(); 614 CheckKeys(kea_type, auth_type); 615 } 616 617 void TlsConnectTestBase::CheckKeys() const { 618 SSLKEAType kea_type = GetDefaultKEA(); 619 SSLAuthType auth_type = GetDefaultAuth(); 620 CheckKeys(kea_type, auth_type); 621 } 622 623 void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, 624 SSLNamedGroup kea_group, 625 SSLNamedGroup original_kea_group, 626 SSLAuthType auth_type, 627 SSLSignatureScheme sig_scheme) { 628 CheckKeys(kea_type, kea_group, auth_type, sig_scheme); 629 EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); 630 client_->CheckOriginalKEA(original_kea_group); 631 server_->CheckOriginalKEA(original_kea_group); 632 } 633 634 void TlsConnectTestBase::ConnectExpectFail() { 635 StartConnect(); 636 Handshake(); 637 ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); 638 ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); 639 } 640 641 void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender, 642 uint8_t alert) { 643 EnsureTlsSetup(); 644 auto receiver = (sender == client_) ? server_ : client_; 645 sender->ExpectSendAlert(alert); 646 receiver->ExpectReceiveAlert(alert); 647 } 648 649 void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, 650 uint8_t alert) { 651 ExpectAlert(sender, alert); 652 ConnectExpectFail(); 653 } 654 655 void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { 656 StartConnect(); 657 client_->SetServerKeyBits(server_->server_key_bits()); 658 client_->Handshake(); 659 server_->Handshake(); 660 661 auto failing_agent = server_; 662 if (failing_side == TlsAgent::CLIENT) { 663 failing_agent = client_; 664 } 665 ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); 666 } 667 668 void TlsConnectTestBase::ConfigureVersion(uint16_t version) { 669 version_ = version; 670 client_->SetVersionRange(version, version); 671 server_->SetVersionRange(version, version); 672 } 673 674 void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { 675 client_->SetExpectedVersion(version); 676 server_->SetExpectedVersion(version); 677 } 678 679 void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label, 680 SSLHashType hash, uint16_t zeroRttSuite) { 681 client_->AddPsk(psk, label, hash, zeroRttSuite); 682 server_->AddPsk(psk, label, hash, zeroRttSuite); 683 client_->ExpectPsk(); 684 server_->ExpectPsk(); 685 } 686 687 void TlsConnectTestBase::DisableAllCiphers() { 688 EnsureTlsSetup(); 689 client_->DisableAllCiphers(); 690 server_->DisableAllCiphers(); 691 } 692 693 void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() { 694 DisableAllCiphers(); 695 696 client_->EnableCiphersByKeyExchange(ssl_kea_rsa); 697 server_->EnableCiphersByKeyExchange(ssl_kea_rsa); 698 } 699 700 void TlsConnectTestBase::EnableOnlyDheCiphers() { 701 if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { 702 DisableAllCiphers(); 703 client_->EnableCiphersByKeyExchange(ssl_kea_dh); 704 server_->EnableCiphersByKeyExchange(ssl_kea_dh); 705 } else { 706 client_->ConfigNamedGroups(kFFDHEGroups); 707 server_->ConfigNamedGroups(kFFDHEGroups); 708 } 709 } 710 711 void TlsConnectTestBase::EnableSomeEcdhCiphers() { 712 if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { 713 client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); 714 client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); 715 server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); 716 server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); 717 } else { 718 client_->ConfigNamedGroups(kECDHEGroups); 719 server_->ConfigNamedGroups(kECDHEGroups); 720 } 721 } 722 723 void TlsConnectTestBase::ConfigureSelfEncrypt() { 724 ScopedCERTCertificate cert; 725 ScopedSECKEYPrivateKey privKey; 726 ASSERT_TRUE( 727 TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); 728 729 ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); 730 ASSERT_TRUE(pubKey); 731 732 EXPECT_EQ(SECSuccess, 733 SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); 734 } 735 736 void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, 737 SessionResumptionMode server) { 738 client_->ConfigureSessionCache(client); 739 server_->ConfigureSessionCache(server); 740 if ((server & RESUME_TICKET) != 0) { 741 ConfigureSelfEncrypt(); 742 } 743 } 744 745 void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { 746 EXPECT_NE(RESUME_BOTH, expected); 747 748 int resume_count = expected ? expected_resumptions_ : 0; 749 int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; 750 751 // Note: hch == server counter; hsh == client counter. 752 SSL3Statistics* stats = SSL_GetStatistics(); 753 EXPECT_EQ(resume_count, stats->hch_sid_cache_hits); 754 EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits); 755 756 EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes); 757 EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes); 758 759 if (expected != RESUME_NONE) { 760 if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && 761 client_->GetResumptionToken().size() == 0) { 762 // Check that the last two session ids match. 763 ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); 764 EXPECT_EQ(session_ids_[session_ids_.size() - 1], 765 session_ids_[session_ids_.size() - 2]); 766 } else { 767 // We've either chosen TLS 1.3 or are using an external resumption token, 768 // both of which only use tickets. 769 EXPECT_TRUE(expected & RESUME_TICKET); 770 } 771 } 772 } 773 774 static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd, 775 const unsigned char* protos, 776 unsigned int protos_len, 777 unsigned char* protoOut, 778 unsigned int* protoOutLen, 779 unsigned int protoMaxLen) { 780 EXPECT_EQ(protoMaxLen, 255U); 781 TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); 782 // Check that agent->alpn_value_to_use_ is in protos. 783 if (protos_len < 1) { 784 return SECFailure; 785 } 786 for (size_t i = 0; i < protos_len;) { 787 size_t l = protos[i]; 788 EXPECT_LT(i + l, protos_len); 789 if (i + l >= protos_len) { 790 return SECFailure; 791 } 792 std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l); 793 if (protos_s == agent->alpn_value_to_use_) { 794 size_t s_len = agent->alpn_value_to_use_.size(); 795 EXPECT_LE(s_len, 255U); 796 memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len); 797 *protoOutLen = s_len; 798 return SECSuccess; 799 } 800 i += l + 1; 801 } 802 return SECFailure; 803 } 804 805 void TlsConnectTestBase::EnableAlpn() { 806 client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); 807 server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); 808 } 809 810 void TlsConnectTestBase::EnableAlpnWithCallback( 811 const std::vector<uint8_t>& client_vals, std::string server_choice) { 812 EnsureTlsSetup(); 813 server_->alpn_value_to_use_ = server_choice; 814 EXPECT_EQ(SECSuccess, 815 SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(), 816 client_vals.size())); 817 SECStatus rv = SSL_SetNextProtoCallback( 818 server_->ssl_fd(), NextProtoCallbackServer, server_.get()); 819 EXPECT_EQ(SECSuccess, rv); 820 } 821 822 void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) { 823 client_->EnableAlpn(vals.data(), vals.size()); 824 server_->EnableAlpn(vals.data(), vals.size()); 825 } 826 827 void TlsConnectTestBase::EnsureModelSockets() { 828 // Make sure models agents are available. 829 if (!client_model_) { 830 ASSERT_EQ(server_model_, nullptr); 831 client_model_.reset( 832 new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); 833 server_model_.reset( 834 new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); 835 if (skip_version_checks_) { 836 client_model_->SkipVersionChecks(); 837 server_model_->SkipVersionChecks(); 838 } 839 } 840 } 841 842 void TlsConnectTestBase::CheckAlpn(const std::string& val) { 843 client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val); 844 server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val); 845 } 846 847 void TlsConnectTestBase::EnableSrtp() { 848 client_->EnableSrtp(); 849 server_->EnableSrtp(); 850 } 851 852 void TlsConnectTestBase::CheckSrtp() const { 853 client_->CheckSrtp(); 854 server_->CheckSrtp(); 855 } 856 857 void TlsConnectTestBase::SendReceive(size_t total) { 858 ASSERT_GT(total, client_->received_bytes()); 859 ASSERT_GT(total, server_->received_bytes()); 860 client_->SendData(total - server_->received_bytes()); 861 server_->SendData(total - client_->received_bytes()); 862 Receive(total); // Receive() is cumulative 863 } 864 865 // Do a first connection so we can do 0-RTT on the second one. 866 void TlsConnectTestBase::SetupForZeroRtt() { 867 // Force rollover of the anti-replay window. 868 // If we don't do this, then all 0-RTT attempts will be rejected. 869 RolloverAntiReplay(); 870 871 ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); 872 ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); 873 server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. 874 Connect(); 875 SendReceive(); // Need to read so that we absorb the session ticket. 876 CheckKeys(); 877 878 Reset(); 879 StartConnect(); 880 } 881 882 // Do a first connection so we can do resumption 883 void TlsConnectTestBase::SetupForResume() { 884 EnsureTlsSetup(); 885 ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); 886 Connect(); 887 SendReceive(); // Need to read so that we absorb the session ticket. 888 CheckKeys(); 889 890 Reset(); 891 } 892 893 void TlsConnectTestBase::ZeroRttSendReceive( 894 bool expect_writable, bool expect_readable, 895 std::function<bool()> post_clienthello_check) { 896 const char* k0RttData = "ABCDEF"; 897 const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); 898 899 client_->Handshake(); // Send ClientHello. 900 if (post_clienthello_check) { 901 if (!post_clienthello_check()) return; 902 } 903 PRInt32 rv = 904 PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. 905 if (expect_writable) { 906 EXPECT_EQ(k0RttDataLen, rv); 907 } else { 908 EXPECT_EQ(SECFailure, rv); 909 } 910 server_->Handshake(); // Consume ClientHello 911 912 std::vector<uint8_t> buf(k0RttDataLen); 913 rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read 914 if (expect_readable) { 915 std::cerr << "0-RTT read " << rv << " bytes\n"; 916 EXPECT_EQ(k0RttDataLen, rv); 917 } else { 918 EXPECT_EQ(SECFailure, rv); 919 EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) 920 << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); 921 } 922 923 // Do a second read. This should fail. 924 rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); 925 EXPECT_EQ(SECFailure, rv); 926 EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); 927 } 928 929 void TlsConnectTestBase::Receive(size_t amount) { 930 WAIT_(client_->received_bytes() == amount && 931 server_->received_bytes() == amount, 932 2000); 933 ASSERT_EQ(amount, client_->received_bytes()); 934 ASSERT_EQ(amount, server_->received_bytes()); 935 } 936 937 void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { 938 expect_extended_master_secret_ = expected; 939 } 940 941 void TlsConnectTestBase::CheckExtendedMasterSecret() { 942 client_->CheckExtendedMasterSecret(expect_extended_master_secret_); 943 server_->CheckExtendedMasterSecret(expect_extended_master_secret_); 944 } 945 946 void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) { 947 expect_early_data_accepted_ = expected; 948 } 949 950 void TlsConnectTestBase::CheckEarlyDataAccepted() { 951 client_->CheckEarlyDataAccepted(expect_early_data_accepted_); 952 server_->CheckEarlyDataAccepted(expect_early_data_accepted_); 953 } 954 955 void TlsConnectTestBase::EnableECDHEServerKeyReuse() { 956 server_->EnableECDHEServerKeyReuse(); 957 } 958 959 void TlsConnectTestBase::SkipVersionChecks() { 960 skip_version_checks_ = true; 961 client_->SkipVersionChecks(); 962 server_->SkipVersionChecks(); 963 } 964 965 // Shift the DTLS timers, to the minimum time necessary to let the next timer 966 // run on either client or server. This allows tests to skip waiting without 967 // having timers run out of order. 968 void TlsConnectTestBase::ShiftDtlsTimers() { 969 PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; 970 PRIntervalTime time; 971 SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); 972 if (rv == SECSuccess) { 973 time_shift = time; 974 } 975 rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); 976 if (rv == SECSuccess && 977 (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { 978 time_shift = time; 979 } 980 981 if (time_shift != PR_INTERVAL_NO_TIMEOUT) { 982 AdvanceTime(PR_IntervalToMicroseconds(time_shift)); 983 EXPECT_EQ(SECSuccess, 984 SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); 985 EXPECT_EQ(SECSuccess, 986 SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); 987 } 988 } 989 990 void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; } 991 992 // Advance time by a full anti-replay window. 993 void TlsConnectTestBase::RolloverAntiReplay() { 994 AdvanceTime(kAntiReplayWindow); 995 } 996 997 TlsConnectGeneric::TlsConnectGeneric() 998 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} 999 1000 TlsConnectPre12::TlsConnectPre12() 1001 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} 1002 1003 TlsConnectTls12::TlsConnectTls12() 1004 : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {} 1005 1006 TlsConnectTls12Plus::TlsConnectTls12Plus() 1007 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} 1008 1009 TlsConnectTls13::TlsConnectTls13() 1010 : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} 1011 1012 TlsConnectGenericResumption::TlsConnectGenericResumption() 1013 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), 1014 external_cache_(std::get<2>(GetParam())) {} 1015 1016 TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() 1017 : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} 1018 1019 TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() 1020 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} 1021 1022 void TlsKeyExchangeTest::EnsureKeyShareSetup() { 1023 EnsureTlsSetup(); 1024 groups_capture_ = 1025 std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); 1026 shares_capture_ = 1027 std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); 1028 shares_capture2_ = std::make_shared<TlsExtensionCapture>( 1029 client_, ssl_tls13_key_share_xtn, true); 1030 std::vector<std::shared_ptr<PacketFilter>> captures = { 1031 groups_capture_, shares_capture_, shares_capture2_}; 1032 client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); 1033 capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>( 1034 server_, kTlsHandshakeHelloRetryRequest); 1035 } 1036 1037 void TlsKeyExchangeTest::ConfigNamedGroups( 1038 const std::vector<SSLNamedGroup>& groups) { 1039 client_->ConfigNamedGroups(groups); 1040 server_->ConfigNamedGroups(groups); 1041 } 1042 1043 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( 1044 const std::shared_ptr<TlsExtensionCapture>& capture) { 1045 EXPECT_TRUE(capture->captured()); 1046 const DataBuffer& ext = capture->extension(); 1047 1048 uint32_t tmp = 0; 1049 EXPECT_TRUE(ext.Read(0, 2, &tmp)); 1050 EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); 1051 EXPECT_TRUE(ext.len() % 2 == 0); 1052 1053 std::vector<SSLNamedGroup> groups; 1054 for (size_t i = 1; i < ext.len() / 2; i += 1) { 1055 EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); 1056 groups.push_back(static_cast<SSLNamedGroup>(tmp)); 1057 } 1058 return groups; 1059 } 1060 1061 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( 1062 const std::shared_ptr<TlsExtensionCapture>& capture) { 1063 EXPECT_TRUE(capture->captured()); 1064 const DataBuffer& ext = capture->extension(); 1065 1066 uint32_t tmp = 0; 1067 EXPECT_TRUE(ext.Read(0, 2, &tmp)); 1068 EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); 1069 1070 std::vector<SSLNamedGroup> shares; 1071 size_t i = 2; 1072 while (i < ext.len()) { 1073 EXPECT_TRUE(ext.Read(i, 2, &tmp)); 1074 shares.push_back(static_cast<SSLNamedGroup>(tmp)); 1075 EXPECT_TRUE(ext.Read(i + 2, 2, &tmp)); 1076 i += 4 + tmp; 1077 } 1078 EXPECT_EQ(ext.len(), i); 1079 return shares; 1080 } 1081 1082 void TlsKeyExchangeTest::CheckKEXDetails( 1083 const std::vector<SSLNamedGroup>& expected_groups, 1084 const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { 1085 std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); 1086 EXPECT_EQ(expected_groups, groups); 1087 1088 if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { 1089 ASSERT_LT(0U, expected_shares.size()); 1090 std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); 1091 EXPECT_EQ(expected_shares, shares); 1092 } else { 1093 EXPECT_FALSE(shares_capture_->captured()); 1094 } 1095 1096 EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); 1097 } 1098 1099 void TlsKeyExchangeTest::CheckKEXDetails( 1100 const std::vector<SSLNamedGroup>& expected_groups, 1101 const std::vector<SSLNamedGroup>& expected_shares) { 1102 CheckKEXDetails(expected_groups, expected_shares, false); 1103 } 1104 1105 void TlsKeyExchangeTest::CheckKEXDetails( 1106 const std::vector<SSLNamedGroup>& expected_groups, 1107 const std::vector<SSLNamedGroup>& expected_shares, 1108 SSLNamedGroup expected_share2) { 1109 CheckKEXDetails(expected_groups, expected_shares, true); 1110 1111 for (auto it : expected_shares) { 1112 EXPECT_NE(expected_share2, it); 1113 } 1114 std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; 1115 EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); 1116 } 1117 } // namespace nss_test