tls_connect.h (14636B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #ifndef tls_connect_h_ 8 #define tls_connect_h_ 9 10 #include <tuple> 11 12 #include "sslproto.h" 13 #include "sslt.h" 14 #include "nss.h" 15 16 #include "tls_agent.h" 17 #include "tls_filter.h" 18 19 #define GTEST_HAS_RTTI 0 20 #include "gtest/gtest.h" 21 22 namespace nss_test { 23 24 extern std::string VersionString(uint16_t version); 25 26 // A generic TLS connection test base. 27 class TlsConnectTestBase : public ::testing::Test { 28 public: 29 static ::testing::internal::ParamGenerator<SSLProtocolVariant> 30 kTlsVariantsStream; 31 static ::testing::internal::ParamGenerator<SSLProtocolVariant> 32 kTlsVariantsDatagram; 33 static ::testing::internal::ParamGenerator<SSLProtocolVariant> 34 kTlsVariantsAll; 35 static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; 36 static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; 37 static ::testing::internal::ParamGenerator<uint16_t> kTlsV12; 38 static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11; 39 static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; 40 static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12; 41 static ::testing::internal::ParamGenerator<uint16_t> kTlsV13; 42 static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus; 43 static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; 44 static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll; 45 46 TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); 47 virtual ~TlsConnectTestBase(); 48 49 virtual void SetUp(); 50 virtual void TearDown(); 51 52 PRTime now() const { return now_; } 53 54 // Initialize client and server. 55 void Init(); 56 // Clear the statistics. 57 void ClearStats(); 58 // Clear the server session cache. 59 void ClearServerCache(); 60 // Make sure TLS is configured for a connection. 61 virtual void EnsureTlsSetup(); 62 // Reset and keep the same certificate names 63 void Reset(); 64 // Reset, and update the certificate names on both peers 65 void Reset(const std::string& server_name, 66 const std::string& client_name = "client"); 67 // Replace the server. 68 void MakeNewServer(); 69 70 // Set up 71 void StartConnect(); 72 // Run the handshake. 73 void Handshake(); 74 // Connect and check that it works. 75 void Connect(); 76 // Check that the connection was successfully established. 77 void CheckConnected(); 78 // Connect and expect it to fail. 79 void ConnectExpectFail(); 80 void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); 81 void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); 82 void ConnectExpectFailOneSide(TlsAgent::Role failingSide); 83 void ConnectWithCipherSuite(uint16_t cipher_suite); 84 void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent, 85 size_t expected_size); 86 // Get the default KEA for our tls version 87 SSLKEAType GetDefaultKEA(void) const; 88 // Get the default auth for our tls version 89 SSLAuthType GetDefaultAuth(void) const; 90 // Find the default group for a given KEA 91 SSLNamedGroup GetDefaultGroupFromKEA(SSLKEAType kea_type) const; 92 // Find the default scheam for a given auth 93 SSLSignatureScheme GetDefaultSchemeFromAuth(SSLAuthType auth_type) const; 94 95 // Check that the keys used in the handshake match expectations. 96 void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, 97 SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; 98 // These version guesses some of the values based on defaults 99 void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group) const; 100 void CheckKeys(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; 101 void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; 102 void CheckKeys(SSLKEAType kea_type) const; 103 void CheckKeys(SSLAuthType auth_type) const; 104 void CheckKeys() const; 105 // Check that keys on resumed sessions. 106 void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, 107 SSLNamedGroup original_kea_group, 108 SSLAuthType auth_type, 109 SSLSignatureScheme sig_scheme); 110 void CheckGroups(const DataBuffer& groups, 111 std::function<void(SSLNamedGroup)> check_group); 112 void CheckShares(const DataBuffer& shares, 113 std::function<void(SSLNamedGroup)> check_group); 114 void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; 115 116 void ConfigureVersion(uint16_t version); 117 void SetExpectedVersion(uint16_t version); 118 uint16_t GetVersion(void) const { return version_; }; 119 // Expect resumption of a particular type. 120 void ExpectResumption(SessionResumptionMode expected, 121 uint8_t num_resumed = 1); 122 void DisableAllCiphers(); 123 void EnableOnlyStaticRsaCiphers(); 124 void EnableOnlyDheCiphers(); 125 void EnableSomeEcdhCiphers(); 126 void EnableExtendedMasterSecret(); 127 void ConfigureSelfEncrypt(); 128 void ConfigureSessionCache(SessionResumptionMode client, 129 SessionResumptionMode server); 130 void EnableAlpn(); 131 void EnableAlpnWithCallback(const std::vector<uint8_t>& client, 132 std::string server_choice); 133 void EnableAlpn(const std::vector<uint8_t>& vals); 134 void EnsureModelSockets(); 135 void CheckAlpn(const std::string& val); 136 void EnableSrtp(); 137 void CheckSrtp() const; 138 void SendReceive(size_t total = 50); 139 void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, 140 uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); 141 void RemovePsk(std::string label); 142 void SetupForZeroRtt(); 143 void SetupForResume(); 144 void ZeroRttSendReceive( 145 bool expect_writable, bool expect_readable, 146 std::function<bool()> post_clienthello_check = nullptr); 147 void Receive(size_t amount); 148 void ExpectExtendedMasterSecret(bool expected); 149 void ExpectEarlyDataAccepted(bool expected); 150 void EnableECDHEServerKeyReuse(); 151 void SkipVersionChecks(); 152 153 // Move the DTLS timers for both endpoints to pop the next timer. 154 void ShiftDtlsTimers(); 155 void AdvanceTime(PRTime time_shift); 156 157 void ResetAntiReplay(PRTime window); 158 void RolloverAntiReplay(); 159 160 void SaveAlgorithmPolicy(); 161 void RestoreAlgorithmPolicy(); 162 163 static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group); 164 static void GenerateEchConfig( 165 HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites, 166 const std::string& public_name, uint16_t max_name_len, DataBuffer& record, 167 ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey); 168 void SetupEch(std::shared_ptr<TlsAgent>& client, 169 std::shared_ptr<TlsAgent>& server, 170 HpkeKemId kem_id = HpkeDhKemX25519Sha256, 171 bool expect_ech = true, bool set_client_config = true, 172 bool set_server_config = true, int maxConfigSize = 100); 173 174 protected: 175 SSLProtocolVariant variant_; 176 std::shared_ptr<TlsAgent> client_; 177 std::shared_ptr<TlsAgent> server_; 178 std::unique_ptr<TlsAgent> client_model_; 179 std::unique_ptr<TlsAgent> server_model_; 180 uint16_t version_; 181 SessionResumptionMode expected_resumption_mode_; 182 uint8_t expected_resumptions_; 183 std::vector<std::vector<uint8_t>> session_ids_; 184 ScopedSSLAntiReplayContext anti_replay_; 185 186 // A simple value of "a", "b". Note that the preferred value of "a" is placed 187 // at the end, because the NSS API follows the now defunct NPN specification, 188 // which places the preferred (and default) entry at the end of the list. 189 // NSS will move this final entry to the front when used with ALPN. 190 const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; 191 192 // A list of algorithm IDs whose policies need to be preserved 193 // around test cases. In particular, DSA is checked in 194 // ssl_extension_unittest.cc. 195 const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY, 196 SEC_OID_ANSIX9_DSA_SIGNATURE, 197 SEC_OID_CURVE25519, SEC_OID_SHA1}; 198 std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_; 199 const std::vector<PRInt32> options_ = { 200 NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE, 201 NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY}; 202 std::vector<std::tuple<PRInt32, uint32_t>> saved_options_; 203 204 private: 205 void CheckResumption(SessionResumptionMode expected); 206 void CheckExtendedMasterSecret(); 207 void CheckEarlyDataAccepted(); 208 static PRTime TimeFunc(void* arg); 209 210 bool expect_extended_master_secret_; 211 bool expect_early_data_accepted_; 212 bool skip_version_checks_; 213 PRTime now_; 214 215 // Track groups and make sure that there are no duplicates. 216 class DuplicateGroupChecker { 217 public: 218 void AddAndCheckGroup(SSLNamedGroup group) { 219 EXPECT_EQ(groups_.end(), groups_.find(group)) 220 << "Group " << group << " should not be duplicated"; 221 groups_.insert(group); 222 } 223 224 private: 225 std::set<SSLNamedGroup> groups_; 226 }; 227 }; 228 229 // A non-parametrized TLS test base. 230 class TlsConnectTest : public TlsConnectTestBase { 231 public: 232 TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} 233 }; 234 235 // A non-parametrized DTLS-only test base. 236 class DtlsConnectTest : public TlsConnectTestBase { 237 public: 238 DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} 239 }; 240 241 // A TLS-only test base. 242 class TlsConnectStream : public TlsConnectTestBase, 243 public ::testing::WithParamInterface<uint16_t> { 244 public: 245 TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} 246 }; 247 248 // A TLS-only test base for tests before 1.3 249 class TlsConnectStreamPre13 : public TlsConnectStream {}; 250 251 // A DTLS-only test base. 252 class TlsConnectDatagram : public TlsConnectTestBase, 253 public ::testing::WithParamInterface<uint16_t> { 254 public: 255 TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} 256 }; 257 258 // A generic test class that can be either stream or datagram and a single 259 // version of TLS. This is configured in ssl_loopback_unittest.cc. 260 class TlsConnectGeneric : public TlsConnectTestBase, 261 public ::testing::WithParamInterface< 262 std::tuple<SSLProtocolVariant, uint16_t>> { 263 public: 264 TlsConnectGeneric(); 265 }; 266 267 class TlsConnectGenericResumption 268 : public TlsConnectTestBase, 269 public ::testing::WithParamInterface< 270 std::tuple<SSLProtocolVariant, uint16_t, bool>> { 271 private: 272 bool external_cache_; 273 274 public: 275 TlsConnectGenericResumption(); 276 277 virtual void EnsureTlsSetup() { 278 TlsConnectTestBase::EnsureTlsSetup(); 279 // Enable external resumption token cache. 280 if (external_cache_) { 281 client_->SetResumptionTokenCallback(); 282 } 283 } 284 285 bool use_external_cache() const { return external_cache_; } 286 }; 287 288 class TlsConnectTls13ResumptionToken 289 : public TlsConnectTestBase, 290 public ::testing::WithParamInterface<SSLProtocolVariant> { 291 public: 292 TlsConnectTls13ResumptionToken(); 293 294 virtual void EnsureTlsSetup() { 295 TlsConnectTestBase::EnsureTlsSetup(); 296 client_->SetResumptionTokenCallback(); 297 } 298 }; 299 300 class TlsConnectGenericResumptionToken 301 : public TlsConnectTestBase, 302 public ::testing::WithParamInterface< 303 std::tuple<SSLProtocolVariant, uint16_t>> { 304 public: 305 TlsConnectGenericResumptionToken(); 306 307 virtual void EnsureTlsSetup() { 308 TlsConnectTestBase::EnsureTlsSetup(); 309 client_->SetResumptionTokenCallback(); 310 } 311 }; 312 313 // A Pre TLS 1.2 generic test. 314 class TlsConnectPre12 : public TlsConnectTestBase, 315 public ::testing::WithParamInterface< 316 std::tuple<SSLProtocolVariant, uint16_t>> { 317 public: 318 TlsConnectPre12(); 319 }; 320 321 // A TLS 1.2 only generic test. 322 class TlsConnectTls12 323 : public TlsConnectTestBase, 324 public ::testing::WithParamInterface<SSLProtocolVariant> { 325 public: 326 TlsConnectTls12(); 327 }; 328 329 // A TLS 1.2 only stream test. 330 class TlsConnectStreamTls12 : public TlsConnectTestBase { 331 public: 332 TlsConnectStreamTls12() 333 : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} 334 }; 335 336 // A TLS 1.2+ generic test. 337 class TlsConnectTls12Plus : public TlsConnectTestBase, 338 public ::testing::WithParamInterface< 339 std::tuple<SSLProtocolVariant, uint16_t>> { 340 public: 341 TlsConnectTls12Plus(); 342 }; 343 344 // A TLS 1.3 only generic test. 345 class TlsConnectTls13 346 : public TlsConnectTestBase, 347 public ::testing::WithParamInterface<SSLProtocolVariant> { 348 public: 349 TlsConnectTls13(); 350 }; 351 352 // A TLS 1.3 only stream test. 353 class TlsConnectStreamTls13 : public TlsConnectTestBase { 354 public: 355 TlsConnectStreamTls13() 356 : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} 357 }; 358 359 class TlsConnectDatagram13 : public TlsConnectTestBase { 360 public: 361 TlsConnectDatagram13() 362 : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} 363 }; 364 365 class TlsConnectDatagramPre13 : public TlsConnectDatagram { 366 public: 367 TlsConnectDatagramPre13() {} 368 }; 369 370 // A variant that is used only with Pre13. 371 class TlsConnectGenericPre13 : public TlsConnectGeneric {}; 372 373 class TlsKeyExchangeTest : public TlsConnectGeneric { 374 protected: 375 std::shared_ptr<TlsExtensionCapture> groups_capture_; 376 std::shared_ptr<TlsExtensionCapture> shares_capture_; 377 std::shared_ptr<TlsExtensionCapture> shares_capture2_; 378 std::shared_ptr<TlsHandshakeRecorder> capture_hrr_; 379 380 void EnsureKeyShareSetup(); 381 void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); 382 std::vector<SSLNamedGroup> GetGroupDetails( 383 const std::shared_ptr<TlsExtensionCapture>& capture); 384 std::vector<SSLNamedGroup> GetShareDetails( 385 const std::shared_ptr<TlsExtensionCapture>& capture); 386 void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, 387 const std::vector<SSLNamedGroup>& expectedShares); 388 void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, 389 const std::vector<SSLNamedGroup>& expectedShares, 390 SSLNamedGroup expectedShare2); 391 392 private: 393 void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, 394 const std::vector<SSLNamedGroup>& expectedShares, 395 bool expect_hrr); 396 }; 397 398 class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {}; 399 class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {}; 400 401 } // namespace nss_test 402 403 #endif