tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tls_connect.h (14636B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #ifndef tls_connect_h_
      8 #define tls_connect_h_
      9 
     10 #include <tuple>
     11 
     12 #include "sslproto.h"
     13 #include "sslt.h"
     14 #include "nss.h"
     15 
     16 #include "tls_agent.h"
     17 #include "tls_filter.h"
     18 
     19 #define GTEST_HAS_RTTI 0
     20 #include "gtest/gtest.h"
     21 
     22 namespace nss_test {
     23 
     24 extern std::string VersionString(uint16_t version);
     25 
     26 // A generic TLS connection test base.
     27 class TlsConnectTestBase : public ::testing::Test {
     28 public:
     29  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
     30      kTlsVariantsStream;
     31  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
     32      kTlsVariantsDatagram;
     33  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
     34      kTlsVariantsAll;
     35  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
     36  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
     37  static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
     38  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
     39  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
     40  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
     41  static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
     42  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
     43  static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
     44  static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
     45 
     46  TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
     47  virtual ~TlsConnectTestBase();
     48 
     49  virtual void SetUp();
     50  virtual void TearDown();
     51 
     52  PRTime now() const { return now_; }
     53 
     54  // Initialize client and server.
     55  void Init();
     56  // Clear the statistics.
     57  void ClearStats();
     58  // Clear the server session cache.
     59  void ClearServerCache();
     60  // Make sure TLS is configured for a connection.
     61  virtual void EnsureTlsSetup();
     62  // Reset and keep the same certificate names
     63  void Reset();
     64  // Reset, and update the certificate names on both peers
     65  void Reset(const std::string& server_name,
     66             const std::string& client_name = "client");
     67  // Replace the server.
     68  void MakeNewServer();
     69 
     70  // Set up
     71  void StartConnect();
     72  // Run the handshake.
     73  void Handshake();
     74  // Connect and check that it works.
     75  void Connect();
     76  // Check that the connection was successfully established.
     77  void CheckConnected();
     78  // Connect and expect it to fail.
     79  void ConnectExpectFail();
     80  void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
     81  void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
     82  void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
     83  void ConnectWithCipherSuite(uint16_t cipher_suite);
     84  void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
     85                           size_t expected_size);
     86  // Get the default KEA for our tls version
     87  SSLKEAType GetDefaultKEA(void) const;
     88  // Get the default auth for our tls version
     89  SSLAuthType GetDefaultAuth(void) const;
     90  // Find the default group for a given KEA
     91  SSLNamedGroup GetDefaultGroupFromKEA(SSLKEAType kea_type) const;
     92  // Find the default scheam for a given auth
     93  SSLSignatureScheme GetDefaultSchemeFromAuth(SSLAuthType auth_type) const;
     94 
     95  // Check that the keys used in the handshake match expectations.
     96  void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
     97                 SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
     98  // These version guesses some of the values based on defaults
     99  void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group) const;
    100  void CheckKeys(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
    101  void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
    102  void CheckKeys(SSLKEAType kea_type) const;
    103  void CheckKeys(SSLAuthType auth_type) const;
    104  void CheckKeys() const;
    105  // Check that keys on resumed sessions.
    106  void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
    107                           SSLNamedGroup original_kea_group,
    108                           SSLAuthType auth_type,
    109                           SSLSignatureScheme sig_scheme);
    110  void CheckGroups(const DataBuffer& groups,
    111                   std::function<void(SSLNamedGroup)> check_group);
    112  void CheckShares(const DataBuffer& shares,
    113                   std::function<void(SSLNamedGroup)> check_group);
    114  void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
    115 
    116  void ConfigureVersion(uint16_t version);
    117  void SetExpectedVersion(uint16_t version);
    118  uint16_t GetVersion(void) const { return version_; };
    119  // Expect resumption of a particular type.
    120  void ExpectResumption(SessionResumptionMode expected,
    121                        uint8_t num_resumed = 1);
    122  void DisableAllCiphers();
    123  void EnableOnlyStaticRsaCiphers();
    124  void EnableOnlyDheCiphers();
    125  void EnableSomeEcdhCiphers();
    126  void EnableExtendedMasterSecret();
    127  void ConfigureSelfEncrypt();
    128  void ConfigureSessionCache(SessionResumptionMode client,
    129                             SessionResumptionMode server);
    130  void EnableAlpn();
    131  void EnableAlpnWithCallback(const std::vector<uint8_t>& client,
    132                              std::string server_choice);
    133  void EnableAlpn(const std::vector<uint8_t>& vals);
    134  void EnsureModelSockets();
    135  void CheckAlpn(const std::string& val);
    136  void EnableSrtp();
    137  void CheckSrtp() const;
    138  void SendReceive(size_t total = 50);
    139  void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
    140              uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
    141  void RemovePsk(std::string label);
    142  void SetupForZeroRtt();
    143  void SetupForResume();
    144  void ZeroRttSendReceive(
    145      bool expect_writable, bool expect_readable,
    146      std::function<bool()> post_clienthello_check = nullptr);
    147  void Receive(size_t amount);
    148  void ExpectExtendedMasterSecret(bool expected);
    149  void ExpectEarlyDataAccepted(bool expected);
    150  void EnableECDHEServerKeyReuse();
    151  void SkipVersionChecks();
    152 
    153  // Move the DTLS timers for both endpoints to pop the next timer.
    154  void ShiftDtlsTimers();
    155  void AdvanceTime(PRTime time_shift);
    156 
    157  void ResetAntiReplay(PRTime window);
    158  void RolloverAntiReplay();
    159 
    160  void SaveAlgorithmPolicy();
    161  void RestoreAlgorithmPolicy();
    162 
    163  static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group);
    164  static void GenerateEchConfig(
    165      HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
    166      const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
    167      ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey);
    168  void SetupEch(std::shared_ptr<TlsAgent>& client,
    169                std::shared_ptr<TlsAgent>& server,
    170                HpkeKemId kem_id = HpkeDhKemX25519Sha256,
    171                bool expect_ech = true, bool set_client_config = true,
    172                bool set_server_config = true, int maxConfigSize = 100);
    173 
    174 protected:
    175  SSLProtocolVariant variant_;
    176  std::shared_ptr<TlsAgent> client_;
    177  std::shared_ptr<TlsAgent> server_;
    178  std::unique_ptr<TlsAgent> client_model_;
    179  std::unique_ptr<TlsAgent> server_model_;
    180  uint16_t version_;
    181  SessionResumptionMode expected_resumption_mode_;
    182  uint8_t expected_resumptions_;
    183  std::vector<std::vector<uint8_t>> session_ids_;
    184  ScopedSSLAntiReplayContext anti_replay_;
    185 
    186  // A simple value of "a", "b".  Note that the preferred value of "a" is placed
    187  // at the end, because the NSS API follows the now defunct NPN specification,
    188  // which places the preferred (and default) entry at the end of the list.
    189  // NSS will move this final entry to the front when used with ALPN.
    190  const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
    191 
    192  // A list of algorithm IDs whose policies need to be preserved
    193  // around test cases.  In particular, DSA is checked in
    194  // ssl_extension_unittest.cc.
    195  const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY,
    196                                              SEC_OID_ANSIX9_DSA_SIGNATURE,
    197                                              SEC_OID_CURVE25519, SEC_OID_SHA1};
    198  std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_;
    199  const std::vector<PRInt32> options_ = {
    200      NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE,
    201      NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY};
    202  std::vector<std::tuple<PRInt32, uint32_t>> saved_options_;
    203 
    204 private:
    205  void CheckResumption(SessionResumptionMode expected);
    206  void CheckExtendedMasterSecret();
    207  void CheckEarlyDataAccepted();
    208  static PRTime TimeFunc(void* arg);
    209 
    210  bool expect_extended_master_secret_;
    211  bool expect_early_data_accepted_;
    212  bool skip_version_checks_;
    213  PRTime now_;
    214 
    215  // Track groups and make sure that there are no duplicates.
    216  class DuplicateGroupChecker {
    217   public:
    218    void AddAndCheckGroup(SSLNamedGroup group) {
    219      EXPECT_EQ(groups_.end(), groups_.find(group))
    220          << "Group " << group << " should not be duplicated";
    221      groups_.insert(group);
    222    }
    223 
    224   private:
    225    std::set<SSLNamedGroup> groups_;
    226  };
    227 };
    228 
    229 // A non-parametrized TLS test base.
    230 class TlsConnectTest : public TlsConnectTestBase {
    231 public:
    232  TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
    233 };
    234 
    235 // A non-parametrized DTLS-only test base.
    236 class DtlsConnectTest : public TlsConnectTestBase {
    237 public:
    238  DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
    239 };
    240 
    241 // A TLS-only test base.
    242 class TlsConnectStream : public TlsConnectTestBase,
    243                         public ::testing::WithParamInterface<uint16_t> {
    244 public:
    245  TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
    246 };
    247 
    248 // A TLS-only test base for tests before 1.3
    249 class TlsConnectStreamPre13 : public TlsConnectStream {};
    250 
    251 // A DTLS-only test base.
    252 class TlsConnectDatagram : public TlsConnectTestBase,
    253                           public ::testing::WithParamInterface<uint16_t> {
    254 public:
    255  TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
    256 };
    257 
    258 // A generic test class that can be either stream or datagram and a single
    259 // version of TLS.  This is configured in ssl_loopback_unittest.cc.
    260 class TlsConnectGeneric : public TlsConnectTestBase,
    261                          public ::testing::WithParamInterface<
    262                              std::tuple<SSLProtocolVariant, uint16_t>> {
    263 public:
    264  TlsConnectGeneric();
    265 };
    266 
    267 class TlsConnectGenericResumption
    268    : public TlsConnectTestBase,
    269      public ::testing::WithParamInterface<
    270          std::tuple<SSLProtocolVariant, uint16_t, bool>> {
    271 private:
    272  bool external_cache_;
    273 
    274 public:
    275  TlsConnectGenericResumption();
    276 
    277  virtual void EnsureTlsSetup() {
    278    TlsConnectTestBase::EnsureTlsSetup();
    279    // Enable external resumption token cache.
    280    if (external_cache_) {
    281      client_->SetResumptionTokenCallback();
    282    }
    283  }
    284 
    285  bool use_external_cache() const { return external_cache_; }
    286 };
    287 
    288 class TlsConnectTls13ResumptionToken
    289    : public TlsConnectTestBase,
    290      public ::testing::WithParamInterface<SSLProtocolVariant> {
    291 public:
    292  TlsConnectTls13ResumptionToken();
    293 
    294  virtual void EnsureTlsSetup() {
    295    TlsConnectTestBase::EnsureTlsSetup();
    296    client_->SetResumptionTokenCallback();
    297  }
    298 };
    299 
    300 class TlsConnectGenericResumptionToken
    301    : public TlsConnectTestBase,
    302      public ::testing::WithParamInterface<
    303          std::tuple<SSLProtocolVariant, uint16_t>> {
    304 public:
    305  TlsConnectGenericResumptionToken();
    306 
    307  virtual void EnsureTlsSetup() {
    308    TlsConnectTestBase::EnsureTlsSetup();
    309    client_->SetResumptionTokenCallback();
    310  }
    311 };
    312 
    313 // A Pre TLS 1.2 generic test.
    314 class TlsConnectPre12 : public TlsConnectTestBase,
    315                        public ::testing::WithParamInterface<
    316                            std::tuple<SSLProtocolVariant, uint16_t>> {
    317 public:
    318  TlsConnectPre12();
    319 };
    320 
    321 // A TLS 1.2 only generic test.
    322 class TlsConnectTls12
    323    : public TlsConnectTestBase,
    324      public ::testing::WithParamInterface<SSLProtocolVariant> {
    325 public:
    326  TlsConnectTls12();
    327 };
    328 
    329 // A TLS 1.2 only stream test.
    330 class TlsConnectStreamTls12 : public TlsConnectTestBase {
    331 public:
    332  TlsConnectStreamTls12()
    333      : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
    334 };
    335 
    336 // A TLS 1.2+ generic test.
    337 class TlsConnectTls12Plus : public TlsConnectTestBase,
    338                            public ::testing::WithParamInterface<
    339                                std::tuple<SSLProtocolVariant, uint16_t>> {
    340 public:
    341  TlsConnectTls12Plus();
    342 };
    343 
    344 // A TLS 1.3 only generic test.
    345 class TlsConnectTls13
    346    : public TlsConnectTestBase,
    347      public ::testing::WithParamInterface<SSLProtocolVariant> {
    348 public:
    349  TlsConnectTls13();
    350 };
    351 
    352 // A TLS 1.3 only stream test.
    353 class TlsConnectStreamTls13 : public TlsConnectTestBase {
    354 public:
    355  TlsConnectStreamTls13()
    356      : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
    357 };
    358 
    359 class TlsConnectDatagram13 : public TlsConnectTestBase {
    360 public:
    361  TlsConnectDatagram13()
    362      : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
    363 };
    364 
    365 class TlsConnectDatagramPre13 : public TlsConnectDatagram {
    366 public:
    367  TlsConnectDatagramPre13() {}
    368 };
    369 
    370 // A variant that is used only with Pre13.
    371 class TlsConnectGenericPre13 : public TlsConnectGeneric {};
    372 
    373 class TlsKeyExchangeTest : public TlsConnectGeneric {
    374 protected:
    375  std::shared_ptr<TlsExtensionCapture> groups_capture_;
    376  std::shared_ptr<TlsExtensionCapture> shares_capture_;
    377  std::shared_ptr<TlsExtensionCapture> shares_capture2_;
    378  std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
    379 
    380  void EnsureKeyShareSetup();
    381  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
    382  std::vector<SSLNamedGroup> GetGroupDetails(
    383      const std::shared_ptr<TlsExtensionCapture>& capture);
    384  std::vector<SSLNamedGroup> GetShareDetails(
    385      const std::shared_ptr<TlsExtensionCapture>& capture);
    386  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
    387                       const std::vector<SSLNamedGroup>& expectedShares);
    388  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
    389                       const std::vector<SSLNamedGroup>& expectedShares,
    390                       SSLNamedGroup expectedShare2);
    391 
    392 private:
    393  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
    394                       const std::vector<SSLNamedGroup>& expectedShares,
    395                       bool expect_hrr);
    396 };
    397 
    398 class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
    399 class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};
    400 
    401 }  // namespace nss_test
    402 
    403 #endif