tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tls_agent.h (21250B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #ifndef tls_agent_h_
      8 #define tls_agent_h_
      9 
     10 #include "prio.h"
     11 #include "ssl.h"
     12 #include "sslproto.h"
     13 
     14 #include <functional>
     15 #include <iostream>
     16 
     17 #include "nss_policy.h"
     18 #include "test_io.h"
     19 
     20 #define GTEST_HAS_RTTI 0
     21 #include "gtest/gtest.h"
     22 #include "nss_scoped_ptrs.h"
     23 #include "scoped_ptrs_ssl.h"
     24 
     25 extern bool g_ssl_gtest_verbose;
     26 
     27 namespace nss_test {
     28 
     29 #define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
     30 #define LOGV(msg)                      \
     31  do {                                 \
     32    if (g_ssl_gtest_verbose) LOG(msg); \
     33  } while (false)
     34 
     35 enum SessionResumptionMode {
     36  RESUME_NONE = 0,
     37  RESUME_SESSIONID = 1,
     38  RESUME_TICKET = 2,
     39  RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
     40 };
     41 
     42 enum class ClientAuthCallbackType {
     43  kAsyncImmediate,
     44  kAsyncDelay,
     45  kSync,
     46  kNone,
     47 };
     48 
     49 class PacketFilter;
     50 class TlsAgent;
     51 class TlsCipherSpec;
     52 struct TlsRecord;
     53 
     54 const extern std::vector<SSLNamedGroup> kAllDHEGroups;
     55 const extern std::vector<SSLNamedGroup> kNonPQDHEGroups;
     56 const extern std::vector<SSLNamedGroup> kECDHEGroups;
     57 const extern std::vector<SSLNamedGroup> kFFDHEGroups;
     58 const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
     59 const extern std::vector<SSLNamedGroup> kEcdhHybridGroups;
     60 
     61 // These functions are called from callbacks.  They use bare pointers because
     62 // TlsAgent sets up the callback and it doesn't know who owns it.
     63 typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
     64    AuthCertificateCallbackFunction;
     65 
     66 typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;
     67 
     68 typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
     69                              PRUint32 srvNameArrSize)>
     70    SniCallbackFunction;
     71 
     72 class TlsAgent : public PollTarget {
     73 public:
     74  enum Role { CLIENT, SERVER };
     75  enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };
     76 
     77  static const std::string kClient;     // the client key is sign only
     78  static const std::string kRsa2048;    // bigger sign and encrypt for either
     79  static const std::string kRsa8192;    // biggest sign and encrypt for either
     80  static const std::string kServerRsa;  // both sign and encrypt
     81  static const std::string kServerRsaSign;
     82  static const std::string kServerRsaPss;
     83  static const std::string kServerRsaDecrypt;
     84  static const std::string kServerEcdsa256;
     85  static const std::string kServerEcdsa384;
     86  static const std::string kServerEcdsa521;
     87  static const std::string kServerEcdhEcdsa;
     88  static const std::string kServerEcdhRsa;
     89  static const std::string kServerDsa;
     90  static const std::string kDelegatorEcdsa256;    // draft-ietf-tls-subcerts
     91  static const std::string kDelegatorRsae2048;    // draft-ietf-tls-subcerts
     92  static const std::string kDelegatorRsaPss2048;  // draft-ietf-tls-subcerts
     93 
     94  TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
     95  virtual ~TlsAgent();
     96 
     97  void SetPeer(std::shared_ptr<TlsAgent>& peer) {
     98    adapter_->SetPeer(peer->adapter_);
     99  }
    100 
    101  void SetFilter(std::shared_ptr<PacketFilter> filter) {
    102    adapter_->SetPacketFilter(filter);
    103  }
    104  void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
    105 
    106  void StartConnect(PRFileDesc* model = nullptr);
    107  void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
    108                size_t kea_size = 0) const;
    109  void CheckOriginalKEA(SSLNamedGroup kea_group) const;
    110  void CheckAuthType(SSLAuthType auth_type,
    111                     SSLSignatureScheme sig_scheme) const;
    112 
    113  void DisableAllCiphers();
    114  void EnableCiphersByAuthType(SSLAuthType authType);
    115  void EnableCiphersByKeyExchange(SSLKEAType kea);
    116  void EnableGroupsByKeyExchange(SSLKEAType kea);
    117  void EnableGroupsByAuthType(SSLAuthType authType);
    118  void EnableSingleCipher(uint16_t cipher);
    119 
    120  void Handshake();
    121  // Marks the internal state as CONNECTING in anticipation of renegotiation.
    122  void PrepareForRenegotiate();
    123  // Prepares for renegotiation, then actually triggers it.
    124  void StartRenegotiate();
    125  void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx);
    126 
    127  static bool LoadCertificate(const std::string& name,
    128                              ScopedCERTCertificate* cert,
    129                              ScopedSECKEYPrivateKey* priv);
    130  static bool LoadKeyPairFromCert(const std::string& name,
    131                                  ScopedSECKEYPublicKey* pub,
    132                                  ScopedSECKEYPrivateKey* priv);
    133 
    134  // Delegated credentials.
    135  //
    136  // Generate a delegated credential and sign it using the certificate
    137  // associated with |name|.
    138  static void DelegateCredential(const std::string& name,
    139                                 const ScopedSECKEYPublicKey& dcPub,
    140                                 SSLSignatureScheme dcCertVerifyAlg,
    141                                 PRUint32 dcValidFor, PRTime now, SECItem* dc);
    142  // Indicate support for the delegated credentials extension.
    143  void EnableDelegatedCredentials();
    144  // Generate and configure a delegated credential to use in the handshake with
    145  // clients that support this extension..
    146  void AddDelegatedCredential(const std::string& dc_name,
    147                              SSLSignatureScheme dcCertVerifyAlg,
    148                              PRUint32 dcValidFor, PRTime now);
    149  void UpdatePreliminaryChannelInfo();
    150 
    151  bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
    152                        const SSLExtraServerCertData* serverCertData = nullptr);
    153  bool ConfigServerCertWithChain(const std::string& name);
    154  bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);
    155 
    156  void SetupClientAuth(
    157      ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync,
    158      bool callbackSuccess = true);
    159  void RequestClientAuth(bool requireAuth);
    160  void ClientAuthCallbackComplete();
    161  bool CheckClientAuthCallbacksCompleted(uint8_t expected);
    162  void CheckClientAuthCompleted(uint8_t handshakes = 1);
    163  void SetOption(int32_t option, int value);
    164  void ConfigureSessionCache(SessionResumptionMode mode);
    165  void Set0RttEnabled(bool en);
    166  void SetFallbackSCSVEnabled(bool en);
    167  void SetVersionRange(uint16_t minver, uint16_t maxver);
    168  void GetVersionRange(uint16_t* minver, uint16_t* maxver);
    169  void CheckPreliminaryInfo();
    170  void ResetPreliminaryInfo();
    171  void SetExpectedVersion(uint16_t version);
    172  void SetServerKeyBits(uint16_t bits);
    173  void ExpectReadWriteError();
    174  void EnableFalseStart();
    175  void ExpectEch(bool expected = true);
    176  bool GetEchExpected() const { return expect_ech_; }
    177  void ExpectPsk(SSLPskType psk = ssl_psk_external);
    178  void ExpectResumption();
    179  void SkipVersionChecks();
    180  void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
    181  void EnableAlpn(const uint8_t* val, size_t len);
    182  void CheckAlpn(SSLNextProtoState expected_state,
    183                 const std::string& expected = "") const;
    184  void EnableSrtp();
    185  void CheckSrtp() const;
    186  void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const;
    187  void CheckErrorCode(int32_t expected) const;
    188  void WaitForErrorCode(int32_t expected, uint32_t delay) const;
    189  // Send data on the socket, encrypting it.
    190  void SendData(size_t bytes, size_t blocksize = 1024);
    191  void SendBuffer(const DataBuffer& buf);
    192  bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
    193                           uint64_t seq, uint8_t ct, const DataBuffer& buf);
    194  // Send data directly to the underlying socket, skipping the TLS layer.
    195  void SendDirect(const DataBuffer& buf);
    196  void SendRecordDirect(const TlsRecord& record);
    197  void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
    198              uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
    199  void RemovePsk(std::string label);
    200  void ReadBytes(size_t max = 16384U);
    201  void ResetSentBytes(size_t bytes = 0);  // Hack to test drops.
    202  void EnableExtendedMasterSecret();
    203  void CheckExtendedMasterSecret(bool expected);
    204  void CheckEarlyDataAccepted(bool expected);
    205  void CheckEchAccepted(bool expected);
    206  void SetDowngradeCheckVersion(uint16_t version);
    207  void CheckSecretsDestroyed();
    208  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
    209  void EnableECDHEServerKeyReuse();
    210  bool GetPeerChainLength(size_t* count);
    211  void CheckPeerChainFunctionConsistency();
    212  void CheckCipherSuite(uint16_t cipher_suite);
    213  void SetResumptionTokenCallback();
    214  bool MaybeSetResumptionToken();
    215  void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
    216    resumption_token_ = resumption_token;
    217  }
    218  const std::vector<uint8_t>& GetResumptionToken() const {
    219    return resumption_token_;
    220  }
    221  void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
    222    SECStatus rv = SSL_GetResumptionTokenInfo(
    223        resumption_token_.data(), resumption_token_.size(), token.get(),
    224        sizeof(SSLResumptionTokenInfo));
    225    ASSERT_EQ(SECSuccess, rv);
    226  }
    227  void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
    228  bool resumption_callback_called() const {
    229    return resumption_callback_called_;
    230  }
    231 
    232  const std::string& name() const { return name_; }
    233 
    234  Role role() const { return role_; }
    235  std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
    236 
    237  SSLProtocolVariant variant() const { return variant_; }
    238 
    239  State state() const { return state_; }
    240 
    241  const CERTCertificate* peer_cert() const {
    242    return SSL_PeerCertificate(ssl_fd_.get());
    243  }
    244 
    245  const char* state_str() const { return state_str(state()); }
    246 
    247  static const char* state_str(State state) { return states[state]; }
    248 
    249  NssManagedFileDesc ssl_fd() const {
    250    return NssManagedFileDesc(ssl_fd_.get(), policy_, option_);
    251  }
    252  std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
    253 
    254  const SSLChannelInfo& info() const {
    255    EXPECT_EQ(STATE_CONNECTED, state_);
    256    return info_;
    257  }
    258 
    259  const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; }
    260 
    261  bool is_compressed() const {
    262    return info().compressionMethod != ssl_compression_null;
    263  }
    264  uint16_t server_key_bits() const { return server_key_bits_; }
    265  uint16_t min_version() const { return vrange_.min; }
    266  uint16_t max_version() const { return vrange_.max; }
    267  uint16_t version() const { return info().protocolVersion; }
    268 
    269  bool cipher_suite(uint16_t* suite) const {
    270    if (state_ != STATE_CONNECTED) return false;
    271 
    272    *suite = info_.cipherSuite;
    273    return true;
    274  }
    275 
    276  void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; }
    277 
    278  std::string cipher_suite_name() const {
    279    if (state_ != STATE_CONNECTED) return "UNKNOWN";
    280 
    281    return csinfo_.cipherSuiteName;
    282  }
    283 
    284  std::vector<uint8_t> session_id() const {
    285    return std::vector<uint8_t>(info_.sessionID,
    286                                info_.sessionID + info_.sessionIDLength);
    287  }
    288 
    289  bool auth_type(SSLAuthType* a) const {
    290    if (state_ != STATE_CONNECTED) return false;
    291 
    292    *a = info_.authType;
    293    return true;
    294  }
    295 
    296  bool kea_type(SSLKEAType* k) const {
    297    if (state_ != STATE_CONNECTED) return false;
    298 
    299    *k = info_.keaType;
    300    return true;
    301  }
    302 
    303  size_t received_bytes() const { return recv_ctr_; }
    304  PRErrorCode error_code() const { return error_code_; }
    305 
    306  bool can_falsestart_hook_called() const {
    307    return can_falsestart_hook_called_;
    308  }
    309 
    310  void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
    311    handshake_callback_ = handshake_callback;
    312  }
    313 
    314  void SetAuthCertificateCallback(
    315      AuthCertificateCallbackFunction auth_certificate_callback) {
    316    auth_certificate_callback_ = auth_certificate_callback;
    317  }
    318 
    319  void SetSniCallback(SniCallbackFunction sni_callback) {
    320    sni_callback_ = sni_callback;
    321  }
    322 
    323  void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
    324  void ExpectSendAlert(uint8_t alert, uint8_t level = 0);
    325 
    326  std::string alpn_value_to_use_ = "";
    327  // set the given policy before this agent runs
    328  void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) {
    329    policy_ = NssPolicy(oid, set, clear);
    330  }
    331  void SetNssOption(PRInt32 id, PRInt32 value) {
    332    option_ = NssOption(id, value);
    333  }
    334 
    335 private:
    336  const static char* states[];
    337 
    338  void SetState(State state);
    339  void ValidateCipherSpecs();
    340 
    341  // Dummy auth certificate hook.
    342  static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
    343                                       PRBool checksig, PRBool isServer) {
    344    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    345    agent->CheckPreliminaryInfo();
    346    agent->auth_certificate_hook_called_ = true;
    347    if (agent->auth_certificate_callback_) {
    348      return agent->auth_certificate_callback_(agent, checksig ? true : false,
    349                                               isServer ? true : false);
    350    }
    351    return SECSuccess;
    352  }
    353 
    354  // Client auth certificate hook.
    355  static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
    356                                       PRBool checksig, PRBool isServer) {
    357    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    358    EXPECT_TRUE(agent->expect_client_auth_);
    359    EXPECT_EQ(PR_TRUE, isServer);
    360    if (agent->auth_certificate_callback_) {
    361      return agent->auth_certificate_callback_(agent, checksig ? true : false,
    362                                               isServer ? true : false);
    363    }
    364    return SECSuccess;
    365  }
    366 
    367  static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
    368                                         CERTDistNames* caNames,
    369                                         CERTCertificate** cert,
    370                                         SECKEYPrivateKey** privKey);
    371 
    372  static void ReadableCallback(PollTarget* self, Event event) {
    373    TlsAgent* agent = static_cast<TlsAgent*>(self);
    374    if (event == TIMER_EVENT) {
    375      agent->timer_handle_ = nullptr;
    376    }
    377    agent->ReadableCallback_int();
    378  }
    379 
    380  void ReadableCallback_int() {
    381    LOGV("Readable");
    382    switch (state_) {
    383      case STATE_CONNECTING:
    384        Handshake();
    385        break;
    386      case STATE_CONNECTED:
    387        ReadBytes();
    388        break;
    389      default:
    390        break;
    391    }
    392  }
    393 
    394  static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
    395                         PRUint32 srvNameArrSize, void* arg) {
    396    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    397    agent->CheckPreliminaryInfo();
    398    agent->sni_hook_called_ = true;
    399    EXPECT_EQ(1UL, srvNameArrSize);
    400    if (agent->sni_callback_) {
    401      return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
    402    }
    403    return 0;  // First configuration.
    404  }
    405 
    406  static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
    407                                         PRBool* canFalseStart) {
    408    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    409    agent->CheckPreliminaryInfo();
    410    EXPECT_TRUE(agent->falsestart_enabled_);
    411    EXPECT_FALSE(agent->can_falsestart_hook_called_);
    412    agent->can_falsestart_hook_called_ = true;
    413    *canFalseStart = true;
    414    return SECSuccess;
    415  }
    416 
    417  void CheckAlert(bool sent, const SSLAlert* alert);
    418 
    419  static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
    420                                    const SSLAlert* alert) {
    421    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
    422  }
    423 
    424  static void AlertSentCallback(const PRFileDesc* fd, void* arg,
    425                                const SSLAlert* alert) {
    426    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
    427  }
    428 
    429  static void HandshakeCallback(PRFileDesc* fd, void* arg) {
    430    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    431    agent->handshake_callback_called_ = true;
    432    agent->Connected();
    433    if (agent->handshake_callback_) {
    434      agent->handshake_callback_(agent);
    435    }
    436  }
    437 
    438  void DisableLameGroups();
    439  void ConfigStrongECGroups(bool en);
    440  void ConfigAllDHGroups(bool en);
    441  void CheckCallbacks() const;
    442  void Connected();
    443 
    444  const std::string name_;
    445  SSLProtocolVariant variant_;
    446  Role role_;
    447  uint16_t server_key_bits_;
    448  std::shared_ptr<DummyPrSocket> adapter_;
    449  ScopedPRFileDesc ssl_fd_;
    450  State state_;
    451  std::shared_ptr<Poller::Timer> timer_handle_;
    452  bool falsestart_enabled_;
    453  uint16_t expected_version_;
    454  uint16_t expected_cipher_suite_;
    455  bool expect_client_auth_;
    456  bool expect_ech_;
    457  SSLPskType expect_psk_;
    458  bool can_falsestart_hook_called_;
    459  bool sni_hook_called_;
    460  bool auth_certificate_hook_called_;
    461  uint8_t expected_received_alert_;
    462  uint8_t expected_received_alert_level_;
    463  uint8_t expected_sent_alert_;
    464  uint8_t expected_sent_alert_level_;
    465  bool handshake_callback_called_;
    466  bool resumption_callback_called_;
    467  SSLChannelInfo info_;
    468  SSLPreliminaryChannelInfo pre_info_;
    469  SSLCipherSuiteInfo csinfo_;
    470  SSLVersionRange vrange_;
    471  PRErrorCode error_code_;
    472  size_t send_ctr_;
    473  size_t recv_ctr_;
    474  bool expect_readwrite_error_;
    475  HandshakeCallbackFunction handshake_callback_;
    476  AuthCertificateCallbackFunction auth_certificate_callback_;
    477  SniCallbackFunction sni_callback_;
    478  bool skip_version_checks_;
    479  std::vector<uint8_t> resumption_token_;
    480  NssPolicy policy_;
    481  NssOption option_;
    482  ClientAuthCallbackType client_auth_callback_type_ =
    483      ClientAuthCallbackType::kNone;
    484  bool client_auth_callback_success_ = false;
    485  uint8_t client_auth_callback_fired_ = 0;
    486  bool client_auth_callback_awaiting_ = false;
    487 };
    488 
    489 inline std::ostream& operator<<(std::ostream& stream,
    490                                const TlsAgent::State& state) {
    491  return stream << TlsAgent::state_str(state);
    492 }
    493 
    494 class TlsAgentTestBase : public ::testing::Test {
    495 public:
    496  static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
    497 
    498  TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
    499                   uint16_t version = 0)
    500      : agent_(nullptr),
    501        role_(role),
    502        variant_(variant),
    503        version_(version),
    504        sink_adapter_(new DummyPrSocket("sink", variant)) {}
    505  virtual ~TlsAgentTestBase() {}
    506 
    507  void SetUp();
    508  void TearDown();
    509 
    510  void ExpectAlert(uint8_t alert);
    511 
    512  static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
    513                         uint16_t version, const uint8_t* buf, size_t len,
    514                         DataBuffer* out, uint64_t seq_num = 0);
    515  void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
    516                  size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
    517  void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
    518                            DataBuffer* out, uint64_t seq_num = 0) const;
    519  void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
    520                                    size_t hs_len, DataBuffer* out,
    521                                    uint64_t seq_num, uint32_t fragment_offset,
    522                                    uint32_t fragment_length) const;
    523  DataBuffer MakeCannedTls13ServerHello();
    524  static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
    525                                         DataBuffer* out);
    526  static inline TlsAgent::Role ToRole(const std::string& str) {
    527    return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
    528  }
    529 
    530  void Init(const std::string& server_name = TlsAgent::kServerRsa);
    531  void Reset(const std::string& server_name = TlsAgent::kServerRsa);
    532 
    533 protected:
    534  void EnsureInit();
    535  void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
    536                      int32_t error_code = 0);
    537 
    538  std::shared_ptr<TlsAgent> agent_;
    539  TlsAgent::Role role_;
    540  SSLProtocolVariant variant_;
    541  uint16_t version_;
    542  // This adapter is here just to accept packets from this agent.
    543  std::shared_ptr<DummyPrSocket> sink_adapter_;
    544 };
    545 
    546 class TlsAgentTest
    547    : public TlsAgentTestBase,
    548      public ::testing::WithParamInterface<
    549          std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
    550 public:
    551  TlsAgentTest()
    552      : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
    553                         std::get<1>(GetParam()), std::get<2>(GetParam())) {}
    554 };
    555 
    556 class TlsAgentTestClient : public TlsAgentTestBase,
    557                           public ::testing::WithParamInterface<
    558                               std::tuple<SSLProtocolVariant, uint16_t>> {
    559 public:
    560  TlsAgentTestClient()
    561      : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
    562                         std::get<1>(GetParam())) {}
    563 };
    564 
    565 class TlsAgentTestClient13 : public TlsAgentTestClient {};
    566 
    567 class TlsAgentStreamTestClient13 : public TlsAgentTestClient {
    568 public:
    569  TlsAgentStreamTestClient13() { variant_ = ssl_variant_stream; }
    570 };
    571 
    572 class TlsAgentStreamTestClient : public TlsAgentTestBase {
    573 public:
    574  TlsAgentStreamTestClient()
    575      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
    576 };
    577 
    578 class TlsAgentStreamTestServer : public TlsAgentTestBase {
    579 public:
    580  TlsAgentStreamTestServer()
    581      : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
    582 };
    583 
    584 class TlsAgentDgramTestClient : public TlsAgentTestBase {
    585 public:
    586  TlsAgentDgramTestClient()
    587      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
    588 };
    589 
    590 inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
    591  return vr1.min == vr2.min && vr1.max == vr2.max;
    592 }
    593 
    594 }  // namespace nss_test
    595 
    596 #endif