tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

transport_unittests.cpp (38747B)


      1 /* vim: set ts=2 et sw=2 tw=80: */
      2 /* This Source Code Form is subject to the terms of the Mozilla Public
      3 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      4 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      5 
      6 // Original author: ekr@rtfm.com
      7 
      8 #include <algorithm>
      9 #include <functional>
     10 #include <iostream>
     11 #include <string>
     12 
     13 #ifdef XP_MACOSX
     14 // ensure that Apple Security kit enum goes before "sslproto.h"
     15 #  include <CoreFoundation/CFAvailability.h>
     16 #  include <Security/CipherSuite.h>
     17 #endif
     18 
     19 #include "dtlsidentity.h"
     20 #include "logging.h"
     21 #include "mediapacket.h"
     22 #include "mozilla/UniquePtr.h"
     23 #include "nricectx.h"
     24 #include "nricemediastream.h"
     25 #include "nsThreadUtils.h"
     26 #include "runnable_utils.h"
     27 #include "sigslot.h"
     28 #include "ssl.h"
     29 #include "sslexp.h"
     30 #include "sslproto.h"
     31 #include "transportflow.h"
     32 #include "transportlayer.h"
     33 #include "transportlayerdtls.h"
     34 #include "transportlayerice.h"
     35 #include "transportlayerlog.h"
     36 #include "transportlayerloopback.h"
     37 
     38 #define GTEST_HAS_RTTI 0
     39 #include "gtest/gtest.h"
     40 #include "gtest_utils.h"
     41 
     42 using namespace mozilla;
     43 MOZ_MTLOG_MODULE("mtransport")
     44 
     45 const uint8_t kTlsChangeCipherSpecType = 0x14;
     46 const uint8_t kTlsHandshakeType = 0x16;
     47 
     48 const uint8_t kTlsHandshakeCertificate = 0x0b;
     49 const uint8_t kTlsHandshakeServerKeyExchange = 0x0c;
     50 
     51 const uint8_t kTlsFakeChangeCipherSpec[] = {
     52    kTlsChangeCipherSpecType,  // Type
     53    0xfe,
     54    0xff,  // Version
     55    0x00,
     56    0x00,
     57    0x00,
     58    0x00,
     59    0x00,
     60    0x00,
     61    0x00,
     62    0x10,  // Fictitious sequence #
     63    0x00,
     64    0x01,  // Length
     65    0x01   // Value
     66 };
     67 
     68 // Layer class which can't be initialized.
     69 class TransportLayerDummy : public TransportLayer {
     70 public:
     71  TransportLayerDummy(bool allow_init, bool* destroyed)
     72      : allow_init_(allow_init), destroyed_(destroyed) {
     73    *destroyed_ = false;
     74  }
     75 
     76  virtual ~TransportLayerDummy() { *destroyed_ = true; }
     77 
     78  nsresult InitInternal() override {
     79    return allow_init_ ? NS_OK : NS_ERROR_FAILURE;
     80  }
     81 
     82  TransportResult SendPacket(MediaPacket& packet) override {
     83    MOZ_CRASH();  // Should never be called.
     84    return 0;
     85  }
     86 
     87  TRANSPORT_LAYER_ID("lossy")
     88 
     89 private:
     90  bool allow_init_;
     91  bool* destroyed_;
     92 };
     93 
     94 class Inspector {
     95 public:
     96  virtual ~Inspector() = default;
     97 
     98  virtual void Inspect(TransportLayer* layer, const unsigned char* data,
     99                       size_t len) = 0;
    100 };
    101 
    102 // Class to simulate various kinds of network lossage
    103 class TransportLayerLossy : public TransportLayer {
    104 public:
    105  TransportLayerLossy() : loss_mask_(0), packet_(0), inspector_(nullptr) {}
    106  ~TransportLayerLossy() = default;
    107 
    108  TransportResult SendPacket(MediaPacket& packet) override {
    109    MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "SendPacket(" << packet.len() << ")");
    110 
    111    if (loss_mask_ & (1 << (packet_ % 32))) {
    112      MOZ_MTLOG(ML_NOTICE, "Dropping packet");
    113      ++packet_;
    114      return packet.len();
    115    }
    116    if (inspector_) {
    117      inspector_->Inspect(this, packet.data(), packet.len());
    118    }
    119 
    120    ++packet_;
    121 
    122    return downward_->SendPacket(packet);
    123  }
    124 
    125  void SetLoss(uint32_t packet) { loss_mask_ |= (1 << (packet & 32)); }
    126 
    127  void SetInspector(UniquePtr<Inspector> inspector) {
    128    inspector_ = std::move(inspector);
    129  }
    130 
    131  void StateChange(TransportLayer* layer, State state) { TL_SET_STATE(state); }
    132 
    133  void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
    134    SignalPacketReceived(this, packet);
    135  }
    136 
    137  TRANSPORT_LAYER_ID("lossy")
    138 
    139 protected:
    140  void WasInserted() override {
    141    downward_->SignalPacketReceived.connect(
    142        this, &TransportLayerLossy::PacketReceived);
    143    downward_->SignalStateChange.connect(this,
    144                                         &TransportLayerLossy::StateChange);
    145 
    146    TL_SET_STATE(downward_->state());
    147  }
    148 
    149 private:
    150  uint32_t loss_mask_;
    151  uint32_t packet_;
    152  UniquePtr<Inspector> inspector_;
    153 };
    154 
    155 // Process DTLS Records
    156 #define CHECK_LENGTH(expected)                \
    157  do {                                        \
    158    EXPECT_GE(remaining(), expected);         \
    159    if (remaining() < expected) return false; \
    160  } while (0)
    161 
    162 class TlsParser {
    163 public:
    164  TlsParser(const unsigned char* data, size_t len) : offset_(0) {
    165    buffer_.Copy(data, len);
    166  }
    167 
    168  bool Read(unsigned char* val) {
    169    if (remaining() < 1) {
    170      return false;
    171    }
    172    *val = *ptr();
    173    consume(1);
    174    return true;
    175  }
    176 
    177  // Read an integral type of specified width.
    178  bool Read(uint32_t* val, size_t len) {
    179    if (len > sizeof(uint32_t)) return false;
    180 
    181    *val = 0;
    182 
    183    for (size_t i = 0; i < len; ++i) {
    184      unsigned char tmp;
    185 
    186      if (!Read(&tmp)) return false;
    187 
    188      (*val) = ((*val) << 8) + tmp;
    189    }
    190 
    191    return true;
    192  }
    193 
    194  bool Read(unsigned char* val, size_t len) {
    195    if (remaining() < len) {
    196      return false;
    197    }
    198 
    199    if (val) {
    200      memcpy(val, ptr(), len);
    201    }
    202    consume(len);
    203 
    204    return true;
    205  }
    206 
    207 private:
    208  size_t remaining() const { return buffer_.len() - offset_; }
    209  const uint8_t* ptr() const { return buffer_.data() + offset_; }
    210  void consume(size_t len) { offset_ += len; }
    211 
    212  MediaPacket buffer_;
    213  size_t offset_;
    214 };
    215 
    216 class DtlsRecordParser {
    217 public:
    218  DtlsRecordParser(const unsigned char* data, size_t len) : offset_(0) {
    219    buffer_.Copy(data, len);
    220  }
    221 
    222  bool NextRecord(uint8_t* ct, UniquePtr<MediaPacket>* buffer) {
    223    if (!remaining()) return false;
    224 
    225    CHECK_LENGTH(13U);
    226    const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr());
    227    consume(11);  // ct + version + length
    228 
    229    const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr());
    230    size_t length = ntohs(*tmp);
    231    consume(2);
    232 
    233    CHECK_LENGTH(length);
    234    auto db = MakeUnique<MediaPacket>();
    235    db->Copy(ptr(), length);
    236    consume(length);
    237 
    238    *ct = *ctp;
    239    *buffer = std::move(db);
    240 
    241    return true;
    242  }
    243 
    244 private:
    245  size_t remaining() const { return buffer_.len() - offset_; }
    246  const uint8_t* ptr() const { return buffer_.data() + offset_; }
    247  void consume(size_t len) { offset_ += len; }
    248 
    249  MediaPacket buffer_;
    250  size_t offset_;
    251 };
    252 
    253 // Inspector that parses out DTLS records and passes
    254 // them on.
    255 class DtlsRecordInspector : public Inspector {
    256 public:
    257  virtual void Inspect(TransportLayer* layer, const unsigned char* data,
    258                       size_t len) {
    259    DtlsRecordParser parser(data, len);
    260 
    261    uint8_t ct;
    262    UniquePtr<MediaPacket> buf;
    263    while (parser.NextRecord(&ct, &buf)) {
    264      OnRecord(layer, ct, buf->data(), buf->len());
    265    }
    266  }
    267 
    268  virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
    269                        const unsigned char* record, size_t len) = 0;
    270 };
    271 
    272 // Inspector that injects arbitrary packets based on
    273 // DTLS records of various types.
    274 class DtlsInspectorInjector : public DtlsRecordInspector {
    275 public:
    276  DtlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type,
    277                        const unsigned char* data, size_t len)
    278      : packet_type_(packet_type), handshake_type_(handshake_type) {
    279    packet_.Copy(data, len);
    280  }
    281 
    282  virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
    283                        const unsigned char* data, size_t len) {
    284    // Only inject once.
    285    if (!packet_.data()) {
    286      return;
    287    }
    288 
    289    // Check that the first byte is as requested.
    290    if (content_type != packet_type_) {
    291      return;
    292    }
    293 
    294    if (handshake_type_ != 0xff) {
    295      // Check that the packet is plausibly long enough.
    296      if (len < 1) {
    297        return;
    298      }
    299 
    300      // Check that the handshake type is as requested.
    301      if (data[0] != handshake_type_) {
    302        return;
    303      }
    304    }
    305 
    306    layer->SendPacket(packet_);
    307    packet_.Reset();
    308  }
    309 
    310 private:
    311  uint8_t packet_type_;
    312  uint8_t handshake_type_;
    313  MediaPacket packet_;
    314 };
    315 
    316 // Make a copy of the first instance of a message.
    317 class DtlsInspectorRecordHandshakeMessage : public DtlsRecordInspector {
    318 public:
    319  explicit DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
    320      : handshake_type_(handshake_type) {}
    321 
    322  virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
    323                        const unsigned char* data, size_t len) {
    324    // Only do this once.
    325    if (buffer_.len()) {
    326      return;
    327    }
    328 
    329    // Check that the first byte is as requested.
    330    if (content_type != kTlsHandshakeType) {
    331      return;
    332    }
    333 
    334    TlsParser parser(data, len);
    335    unsigned char message_type;
    336    // Read the handshake message type.
    337    if (!parser.Read(&message_type)) {
    338      return;
    339    }
    340    if (message_type != handshake_type_) {
    341      return;
    342    }
    343 
    344    uint32_t length;
    345    if (!parser.Read(&length, 3)) {
    346      return;
    347    }
    348 
    349    uint32_t message_seq;
    350    if (!parser.Read(&message_seq, 2)) {
    351      return;
    352    }
    353 
    354    uint32_t fragment_offset;
    355    if (!parser.Read(&fragment_offset, 3)) {
    356      return;
    357    }
    358 
    359    uint32_t fragment_length;
    360    if (!parser.Read(&fragment_length, 3)) {
    361      return;
    362    }
    363 
    364    if ((fragment_offset != 0) || (fragment_length != length)) {
    365      // This shouldn't happen because all current tests where we
    366      // are using this code don't fragment.
    367      return;
    368    }
    369 
    370    UniquePtr<uint8_t[]> buffer(new uint8_t[length]);
    371    if (!parser.Read(buffer.get(), length)) {
    372      return;
    373    }
    374    buffer_.Take(std::move(buffer), length);
    375  }
    376 
    377  const MediaPacket& buffer() { return buffer_; }
    378 
    379 private:
    380  uint8_t handshake_type_;
    381  MediaPacket buffer_;
    382 };
    383 
    384 class TlsServerKeyExchangeECDHE {
    385 public:
    386  bool Parse(const unsigned char* data, size_t len) {
    387    TlsParser parser(data, len);
    388 
    389    uint8_t curve_type;
    390    if (!parser.Read(&curve_type)) {
    391      return false;
    392    }
    393 
    394    if (curve_type != 3) {  // named_curve
    395      return false;
    396    }
    397 
    398    uint32_t named_curve;
    399    if (!parser.Read(&named_curve, 2)) {
    400      return false;
    401    }
    402 
    403    uint32_t point_length;
    404    if (!parser.Read(&point_length, 1)) {
    405      return false;
    406    }
    407 
    408    UniquePtr<uint8_t[]> key(new uint8_t[point_length]);
    409    if (!parser.Read(key.get(), point_length)) {
    410      return false;
    411    }
    412    public_key_.Take(std::move(key), point_length);
    413 
    414    return true;
    415  }
    416 
    417  MediaPacket public_key_;
    418 };
    419 
    420 namespace {
    421 class TransportTestPeer : public sigslot::has_slots<> {
    422 public:
    423  TransportTestPeer(nsCOMPtr<nsIEventTarget> target, std::string name,
    424                    MtransportTestUtils* utils)
    425      : name_(name),
    426        offerer_(name == "P1"),
    427        target_(target),
    428        received_packets_(0),
    429        received_bytes_(0),
    430        flow_(new TransportFlow(name)),
    431        loopback_(new TransportLayerLoopback()),
    432        logging_(new TransportLayerLogging()),
    433        lossy_(new TransportLayerLossy()),
    434        dtls_(new TransportLayerDtls()),
    435        identity_(DtlsIdentity::Generate()),
    436        peer_(nullptr),
    437        gathering_complete_(false),
    438        digest_("sha-1"_ns),
    439        test_utils_(utils) {
    440    NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig());
    441    ice_ctx_ = NrIceCtx::Create(name);
    442    std::vector<NrIceStunServer> stun_servers;
    443    UniquePtr<NrIceStunServer> server(NrIceStunServer::Create(
    444        std::string((char*)"stun.services.mozilla.com"), 3478));
    445    stun_servers.push_back(*server);
    446    EXPECT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(stun_servers)));
    447 
    448    dtls_->SetIdentity(identity_);
    449    dtls_->SetRole(offerer_ ? TransportLayerDtls::SERVER
    450                            : TransportLayerDtls::CLIENT);
    451 
    452    nsresult res = identity_->ComputeFingerprint(&digest_);
    453    EXPECT_TRUE(NS_SUCCEEDED(res));
    454    EXPECT_EQ(20u, digest_.value_.size());
    455  }
    456 
    457  ~TransportTestPeer() {
    458    test_utils_->SyncDispatchToSTS(
    459        WrapRunnable(this, &TransportTestPeer::DestroyFlow));
    460  }
    461 
    462  void DestroyFlow() {
    463    disconnect_all();
    464    if (flow_) {
    465      loopback_->Disconnect();
    466      flow_ = nullptr;
    467    }
    468    ice_ctx_->Destroy();
    469    ice_ctx_ = nullptr;
    470    streams_.clear();
    471  }
    472 
    473  void DisconnectDestroyFlow() {
    474    test_utils_->SyncDispatchToSTS(NS_NewRunnableFunction(__func__, [this] {
    475      loopback_->Disconnect();
    476      disconnect_all();  // Disconnect from the signals;
    477      flow_ = nullptr;
    478    }));
    479  }
    480 
    481  void SetDtlsAllowAll() {
    482    nsresult res = dtls_->SetVerificationAllowAll();
    483    ASSERT_TRUE(NS_SUCCEEDED(res));
    484  }
    485 
    486  void SetAlpn(std::string str, bool withDefault, std::string extra = "") {
    487    std::set<std::string> alpn;
    488    alpn.insert(str);  // the one we want to select
    489    if (!extra.empty()) {
    490      alpn.insert(extra);
    491    }
    492    nsresult res = dtls_->SetAlpn(alpn, withDefault ? str : "");
    493    ASSERT_EQ(NS_OK, res);
    494  }
    495 
    496  const std::string& GetAlpn() const { return dtls_->GetNegotiatedAlpn(); }
    497 
    498  void SetDtlsPeer(TransportTestPeer* peer, int digests, unsigned int damage) {
    499    unsigned int mask = 1;
    500 
    501    for (int i = 0; i < digests; i++) {
    502      DtlsDigest digest_to_set(peer->digest_);
    503 
    504      if (damage & mask) digest_to_set.value_.data()[0]++;
    505 
    506      nsresult res = dtls_->SetVerificationDigest(digest_to_set);
    507 
    508      ASSERT_TRUE(NS_SUCCEEDED(res));
    509 
    510      mask <<= 1;
    511    }
    512  }
    513 
    514  void SetupSrtp() {
    515    std::vector<uint16_t> srtp_ciphers =
    516        TransportLayerDtls::GetDefaultSrtpCiphers();
    517    SetSrtpCiphers(srtp_ciphers);
    518  }
    519 
    520  void SetSrtpCiphers(std::vector<uint16_t>& srtp_ciphers) {
    521    ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers)));
    522  }
    523 
    524  void ConnectSocket_s(TransportTestPeer* peer) {
    525    nsresult res;
    526    res = loopback_->Init();
    527    ASSERT_EQ((nsresult)NS_OK, res);
    528 
    529    loopback_->Connect(peer->loopback_);
    530    ASSERT_EQ((nsresult)NS_OK, loopback_->Init());
    531    ASSERT_EQ((nsresult)NS_OK, logging_->Init());
    532    ASSERT_EQ((nsresult)NS_OK, lossy_->Init());
    533    ASSERT_EQ((nsresult)NS_OK, dtls_->Init());
    534    dtls_->Chain(lossy_);
    535    lossy_->Chain(logging_);
    536    logging_->Chain(loopback_);
    537 
    538    flow_->PushLayer(loopback_);
    539    flow_->PushLayer(logging_);
    540    flow_->PushLayer(lossy_);
    541    flow_->PushLayer(dtls_);
    542 
    543    if (dtls_->state() != TransportLayer::TS_ERROR) {
    544      // Don't execute these blocks if DTLS didn't initialize.
    545      TweakCiphers(dtls_->internal_fd());
    546      if (post_setup_) {
    547        post_setup_(dtls_->internal_fd());
    548      }
    549    }
    550 
    551    dtls_->SignalPacketReceived.connect(this,
    552                                        &TransportTestPeer::PacketReceived);
    553  }
    554 
    555  void TweakCiphers(PRFileDesc* fd) {
    556    for (unsigned short& enabled_cipersuite : enabled_cipersuites_) {
    557      SSL_CipherPrefSet(fd, enabled_cipersuite, PR_TRUE);
    558    }
    559    for (unsigned short& disabled_cipersuite : disabled_cipersuites_) {
    560      SSL_CipherPrefSet(fd, disabled_cipersuite, PR_FALSE);
    561    }
    562  }
    563 
    564  void ConnectSocket(TransportTestPeer* peer) {
    565    test_utils_->SyncDispatchToSTS(
    566        WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer));
    567  }
    568 
    569  nsresult InitIce_s() {
    570    nsresult rv = ice_->Init();
    571    NS_ENSURE_SUCCESS(rv, rv);
    572    rv = dtls_->Init();
    573    NS_ENSURE_SUCCESS(rv, rv);
    574    dtls_->Chain(ice_);
    575    flow_->PushLayer(ice_);
    576    flow_->PushLayer(dtls_);
    577    return NS_OK;
    578  }
    579 
    580  void InitIce() {
    581    nsresult res;
    582 
    583    char name[100];
    584    snprintf(name, sizeof(name), "%s:stream%d", name_.c_str(),
    585             (int)streams_.size());
    586 
    587    // Create the media stream
    588    RefPtr<NrIceMediaStream> stream = ice_ctx_->CreateStream(name, name, 1);
    589    // Attach our slots
    590    stream->SignalGatheringStateChange.connect(
    591        this, &TransportTestPeer::GatheringStateChange);
    592 
    593    ASSERT_TRUE(stream != nullptr);
    594    stream->SetIceCredentials("ufrag", "pass");
    595    streams_.push_back(stream);
    596 
    597    // Listen for candidates
    598    stream->SignalCandidate.connect(this, &TransportTestPeer::GotCandidate);
    599 
    600    // Create the transport layer
    601    ice_ = new TransportLayerIce();
    602    ice_->SetParameters(stream, 1);
    603 
    604    test_utils_->SyncDispatchToSTS(
    605        WrapRunnableRet(&res, this, &TransportTestPeer::InitIce_s));
    606 
    607    ASSERT_EQ((nsresult)NS_OK, res);
    608 
    609    // Listen for media events
    610    dtls_->SignalPacketReceived.connect(this,
    611                                        &TransportTestPeer::PacketReceived);
    612    dtls_->SignalStateChange.connect(this, &TransportTestPeer::StateChanged);
    613 
    614    // Start gathering
    615    test_utils_->SyncDispatchToSTS(WrapRunnableRet(
    616        &res, ice_ctx_, &NrIceCtx::StartGathering, false, false));
    617    ASSERT_TRUE(NS_SUCCEEDED(res));
    618  }
    619 
    620  void ConnectIce(TransportTestPeer* peer) {
    621    peer_ = peer;
    622 
    623    // If gathering is already complete, push the candidates over
    624    if (gathering_complete_) GatheringComplete();
    625  }
    626 
    627  // New candidate
    628  void GotCandidate(NrIceMediaStream* stream, const std::string& candidate,
    629                    const std::string& ufrag, const std::string& mdns_addr,
    630                    const std::string& actual_addr) {
    631    std::cerr << "Got candidate " << candidate << " (ufrag=" << ufrag << ")"
    632              << std::endl;
    633  }
    634 
    635  void GatheringStateChange(const std::string& aTransportId,
    636                            NrIceMediaStream::GatheringState state) {
    637    // We only use one stream, no need to check whether all streams are done
    638    // gathering.
    639    (void)aTransportId;
    640    if (state == NrIceMediaStream::ICE_STREAM_GATHER_COMPLETE) {
    641      GatheringComplete();
    642    }
    643  }
    644 
    645  // Gathering complete, so send our candidates and start
    646  // connecting on the other peer.
    647  void GatheringComplete() {
    648    // Don't send to the other side
    649    if (!peer_) {
    650      gathering_complete_ = true;
    651      return;
    652    }
    653 
    654    test_utils_->SyncDispatchToSTS(
    655        WrapRunnable(this, &TransportTestPeer::GatheringComplete_s));
    656  }
    657 
    658  void GatheringComplete_s() {
    659    // First send attributes
    660    nsresult res =
    661        peer_->ice_ctx_->ParseGlobalAttributes(ice_ctx_->GetGlobalAttributes());
    662    ASSERT_TRUE(NS_SUCCEEDED(res));
    663 
    664    for (size_t i = 0; i < streams_.size(); ++i) {
    665      res = peer_->streams_[i]->ConnectToPeer("ufrag", "pass",
    666                                              streams_[i]->GetAttributes());
    667      ASSERT_TRUE(NS_SUCCEEDED(res));
    668    }
    669 
    670    // Start checks on the other peer.
    671    res = peer_->ice_ctx_->StartChecks();
    672    ASSERT_TRUE(NS_SUCCEEDED(res));
    673  }
    674 
    675  // WrapRunnable/lambda and move semantics (MediaPacket is not copyable) don't
    676  // get along yet, so we need a wrapper. Gross.
    677  static TransportResult SendPacketWrapper(TransportLayer* layer,
    678                                           MediaPacket* packet) {
    679    return layer->SendPacket(*packet);
    680  }
    681 
    682  TransportResult SendPacket(MediaPacket& packet) {
    683    TransportResult ret;
    684 
    685    test_utils_->SyncDispatchToSTS(WrapRunnableNMRet(
    686        &ret, &TransportTestPeer::SendPacketWrapper, dtls_, &packet));
    687 
    688    return ret;
    689  }
    690 
    691  void StateChanged(TransportLayer* layer, TransportLayer::State state) {
    692    if (state == TransportLayer::TS_OPEN) {
    693      std::cerr << "Now connected" << std::endl;
    694    }
    695  }
    696 
    697  void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
    698    std::cerr << "Received " << packet.len() << " bytes" << std::endl;
    699    ++received_packets_;
    700    received_bytes_ += packet.len();
    701  }
    702 
    703  void SetLoss(uint32_t loss) { lossy_->SetLoss(loss); }
    704 
    705  void SetCombinePackets(bool combine) { loopback_->CombinePackets(combine); }
    706 
    707  void SetInspector(UniquePtr<Inspector> inspector) {
    708    lossy_->SetInspector(std::move(inspector));
    709  }
    710 
    711  void SetInspector(Inspector* in) {
    712    UniquePtr<Inspector> inspector(in);
    713 
    714    lossy_->SetInspector(std::move(inspector));
    715  }
    716 
    717  void SetCipherSuiteChanges(const std::vector<uint16_t>& enableThese,
    718                             const std::vector<uint16_t>& disableThese) {
    719    disabled_cipersuites_ = disableThese;
    720    enabled_cipersuites_ = enableThese;
    721  }
    722 
    723  void SetPostSetup(const std::function<void(PRFileDesc*)>& setup) {
    724    post_setup_ = std::move(setup);
    725  }
    726 
    727  TransportLayer::State state() {
    728    TransportLayer::State tstate;
    729 
    730    RUN_ON_THREAD(test_utils_->sts_target(),
    731                  WrapRunnableRet(&tstate, dtls_, &TransportLayer::state));
    732 
    733    return tstate;
    734  }
    735 
    736  bool connected() { return state() == TransportLayer::TS_OPEN; }
    737 
    738  bool failed() { return state() == TransportLayer::TS_ERROR; }
    739 
    740  size_t receivedPackets() { return received_packets_; }
    741 
    742  size_t receivedBytes() { return received_bytes_; }
    743 
    744  uint16_t cipherSuite() const {
    745    nsresult rv;
    746    uint16_t cipher;
    747    RUN_ON_THREAD(
    748        test_utils_->sts_target(),
    749        WrapRunnableRet(&rv, dtls_, &TransportLayerDtls::GetCipherSuite,
    750                        &cipher));
    751 
    752    if (NS_FAILED(rv)) {
    753      return TLS_NULL_WITH_NULL_NULL;  // i.e., not good
    754    }
    755    return cipher;
    756  }
    757 
    758  uint16_t srtpCipher() const {
    759    nsresult rv;
    760    uint16_t cipher;
    761    RUN_ON_THREAD(test_utils_->sts_target(),
    762                  WrapRunnableRet(&rv, dtls_,
    763                                  &TransportLayerDtls::GetSrtpCipher, &cipher));
    764    if (NS_FAILED(rv)) {
    765      return 0;  // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL
    766    }
    767    return cipher;
    768  }
    769 
    770 private:
    771  std::string name_;
    772  bool offerer_;
    773  nsCOMPtr<nsIEventTarget> target_;
    774  std::atomic<size_t> received_packets_;
    775  std::atomic<size_t> received_bytes_;
    776  RefPtr<TransportFlow> flow_;
    777  TransportLayerLoopback* loopback_;
    778  TransportLayerLogging* logging_;
    779  TransportLayerLossy* lossy_;
    780  TransportLayerDtls* dtls_;
    781  TransportLayerIce* ice_;
    782  RefPtr<DtlsIdentity> identity_;
    783  RefPtr<NrIceCtx> ice_ctx_;
    784  std::vector<RefPtr<NrIceMediaStream> > streams_;
    785  TransportTestPeer* peer_;
    786  bool gathering_complete_;
    787  DtlsDigest digest_;
    788  std::vector<uint16_t> enabled_cipersuites_;
    789  std::vector<uint16_t> disabled_cipersuites_;
    790  MtransportTestUtils* test_utils_;
    791  std::function<void(PRFileDesc* fd)> post_setup_ = nullptr;
    792 };
    793 
    794 class TransportTest : public MtransportTest {
    795 public:
    796  TransportTest() {
    797    fds_[0] = nullptr;
    798    fds_[1] = nullptr;
    799    p1_ = nullptr;
    800    p2_ = nullptr;
    801  }
    802 
    803  void TearDown() override {
    804    delete p1_;
    805    delete p2_;
    806 
    807    //    Can't detach these
    808    //    PR_Close(fds_[0]);
    809    //    PR_Close(fds_[1]);
    810    MtransportTest::TearDown();
    811  }
    812 
    813  void DestroyPeerFlows() {
    814    p1_->DisconnectDestroyFlow();
    815    p2_->DisconnectDestroyFlow();
    816  }
    817 
    818  void SetUp() override {
    819    MtransportTest::SetUp();
    820 
    821    nsresult rv;
    822    target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
    823    ASSERT_TRUE(NS_SUCCEEDED(rv));
    824 
    825    Reset();
    826  }
    827 
    828  void Reset() {
    829    if (p1_) {
    830      delete p1_;
    831    }
    832    if (p2_) {
    833      delete p2_;
    834    }
    835    p1_ = new TransportTestPeer(target_, "P1", test_utils_);
    836    p2_ = new TransportTestPeer(target_, "P2", test_utils_);
    837  }
    838 
    839  void SetupSrtp() {
    840    p1_->SetupSrtp();
    841    p2_->SetupSrtp();
    842  }
    843 
    844  void SetDtlsPeer(int digests = 1, unsigned int damage = 0) {
    845    p1_->SetDtlsPeer(p2_, digests, damage);
    846    p2_->SetDtlsPeer(p1_, digests, damage);
    847  }
    848 
    849  void SetDtlsAllowAll() {
    850    p1_->SetDtlsAllowAll();
    851    p2_->SetDtlsAllowAll();
    852  }
    853 
    854  void SetAlpn(std::string first, std::string second,
    855               bool withDefaults = true) {
    856    if (!first.empty()) {
    857      p1_->SetAlpn(first, withDefaults, "bogus");
    858    }
    859    if (!second.empty()) {
    860      p2_->SetAlpn(second, withDefaults);
    861    }
    862  }
    863 
    864  void CheckAlpn(std::string first, std::string second) {
    865    ASSERT_EQ(first, p1_->GetAlpn());
    866    ASSERT_EQ(second, p2_->GetAlpn());
    867  }
    868 
    869  void ConnectSocket() {
    870    ConnectSocketInternal();
    871    ASSERT_TRUE_WAIT(p1_->connected(), 10000);
    872    ASSERT_TRUE_WAIT(p2_->connected(), 10000);
    873 
    874    ASSERT_EQ(p1_->cipherSuite(), p2_->cipherSuite());
    875    ASSERT_EQ(p1_->srtpCipher(), p2_->srtpCipher());
    876  }
    877 
    878  void ConnectSocketExpectFail() {
    879    ConnectSocketInternal();
    880    ASSERT_TRUE_WAIT(p1_->failed(), 10000);
    881    ASSERT_TRUE_WAIT(p2_->failed(), 10000);
    882  }
    883 
    884  void ConnectSocketExpectState(TransportLayer::State s1,
    885                                TransportLayer::State s2) {
    886    ConnectSocketInternal();
    887    ASSERT_EQ_WAIT(s1, p1_->state(), 10000);
    888    ASSERT_EQ_WAIT(s2, p2_->state(), 10000);
    889  }
    890 
    891  void ConnectIce() {
    892    p1_->InitIce();
    893    p2_->InitIce();
    894    p1_->ConnectIce(p2_);
    895    p2_->ConnectIce(p1_);
    896    ASSERT_TRUE_WAIT(p1_->connected(), 10000);
    897    ASSERT_TRUE_WAIT(p2_->connected(), 10000);
    898  }
    899 
    900  void TransferTest(size_t count, size_t bytes = 1024) {
    901    unsigned char buf[bytes];
    902 
    903    for (size_t i = 0; i < count; ++i) {
    904      memset(buf, count & 0xff, sizeof(buf));
    905      MediaPacket packet;
    906      packet.Copy(buf, sizeof(buf));
    907      TransportResult rv = p1_->SendPacket(packet);
    908      ASSERT_TRUE(rv > 0);
    909    }
    910 
    911    std::cerr << "Received == " << p2_->receivedPackets() << " packets"
    912              << std::endl;
    913    ASSERT_TRUE_WAIT(count == p2_->receivedPackets(), 10000);
    914    ASSERT_TRUE((count * sizeof(buf)) == p2_->receivedBytes());
    915  }
    916 
    917 protected:
    918  void ConnectSocketInternal() {
    919    test_utils_->SyncDispatchToSTS(
    920        WrapRunnable(p1_, &TransportTestPeer::ConnectSocket, p2_));
    921    test_utils_->SyncDispatchToSTS(
    922        WrapRunnable(p2_, &TransportTestPeer::ConnectSocket, p1_));
    923  }
    924 
    925  PRFileDesc* fds_[2];
    926  TransportTestPeer* p1_;
    927  TransportTestPeer* p2_;
    928  nsCOMPtr<nsIEventTarget> target_;
    929 };
    930 
    931 TEST_F(TransportTest, TestNoDtlsVerificationSettings) {
    932  ConnectSocketExpectFail();
    933 }
    934 
    935 static void DisableChaCha(TransportTestPeer* peer) {
    936  // On ARM, ChaCha20Poly1305 might be preferred; disable it for the tests that
    937  // want to check the cipher suite.  It doesn't matter which peer disables the
    938  // suite, disabling on either side has the same effect.
    939  std::vector<uint16_t> chachaSuites;
    940  chachaSuites.push_back(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
    941  chachaSuites.push_back(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256);
    942  peer->SetCipherSuiteChanges(std::vector<uint16_t>(), chachaSuites);
    943 }
    944 
    945 TEST_F(TransportTest, TestConnect) {
    946  SetDtlsPeer();
    947  DisableChaCha(p1_);
    948  ConnectSocket();
    949 
    950  // check that we got the right suite
    951  ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
    952 
    953  // no SRTP on this one
    954  ASSERT_EQ(0, p1_->srtpCipher());
    955 }
    956 
    957 TEST_F(TransportTest, TestConnectSrtp) {
    958  SetupSrtp();
    959  SetDtlsPeer();
    960  DisableChaCha(p2_);
    961  ConnectSocket();
    962 
    963  ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
    964 
    965  // SRTP is on with default value
    966  ASSERT_EQ(kDtlsSrtpAeadAes128Gcm, p1_->srtpCipher());
    967 }
    968 
    969 TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) {
    970  SetDtlsPeer();
    971  ConnectSocket();
    972  DestroyPeerFlows();
    973 }
    974 
    975 TEST_F(TransportTest, TestConnectAllowAll) {
    976  SetDtlsAllowAll();
    977  ConnectSocket();
    978 }
    979 
    980 TEST_F(TransportTest, TestConnectAlpn) {
    981  SetDtlsPeer();
    982  SetAlpn("a", "a");
    983  ConnectSocket();
    984  CheckAlpn("a", "a");
    985 }
    986 
    987 TEST_F(TransportTest, TestConnectAlpnMismatch) {
    988  SetDtlsPeer();
    989  SetAlpn("something", "different");
    990  ConnectSocketExpectFail();
    991 }
    992 
    993 TEST_F(TransportTest, TestConnectAlpnServerDefault) {
    994  SetDtlsPeer();
    995  SetAlpn("def", "");
    996  // server allows default, client doesn't support
    997  ConnectSocket();
    998  CheckAlpn("def", "");
    999 }
   1000 
   1001 TEST_F(TransportTest, TestConnectAlpnClientDefault) {
   1002  SetDtlsPeer();
   1003  SetAlpn("", "clientdef");
   1004  // client allows default, but server will ignore the extension
   1005  ConnectSocket();
   1006  CheckAlpn("", "clientdef");
   1007 }
   1008 
   1009 TEST_F(TransportTest, TestConnectClientNoAlpn) {
   1010  SetDtlsPeer();
   1011  // Here the server has ALPN, but no default is allowed.
   1012  // Reminder: p1 == server, p2 == client
   1013  SetAlpn("server-nodefault", "", false);
   1014  // The server doesn't see the extension, so negotiates without it.
   1015  // But then the server is forced to close when it discovers that ALPN wasn't
   1016  // negotiated; the client sees a close.
   1017  ConnectSocketExpectState(TransportLayer::TS_ERROR, TransportLayer::TS_CLOSED);
   1018 }
   1019 
   1020 TEST_F(TransportTest, TestConnectServerNoAlpn) {
   1021  SetDtlsPeer();
   1022  SetAlpn("", "client-nodefault", false);
   1023  // The client aborts; the server doesn't realize this is a problem and just
   1024  // sees the close.
   1025  ConnectSocketExpectState(TransportLayer::TS_CLOSED, TransportLayer::TS_ERROR);
   1026 }
   1027 
   1028 TEST_F(TransportTest, TestConnectNoDigest) {
   1029  SetDtlsPeer(0, 0);
   1030 
   1031  ConnectSocketExpectFail();
   1032 }
   1033 
   1034 TEST_F(TransportTest, TestConnectBadDigest) {
   1035  SetDtlsPeer(1, 1);
   1036 
   1037  ConnectSocketExpectFail();
   1038 }
   1039 
   1040 TEST_F(TransportTest, TestConnectTwoDigests) {
   1041  SetDtlsPeer(2, 0);
   1042 
   1043  ConnectSocket();
   1044 }
   1045 
   1046 TEST_F(TransportTest, TestConnectTwoDigestsFirstBad) {
   1047  SetDtlsPeer(2, 1);
   1048 
   1049  ConnectSocket();
   1050 }
   1051 
   1052 TEST_F(TransportTest, TestConnectTwoDigestsSecondBad) {
   1053  SetDtlsPeer(2, 2);
   1054 
   1055  ConnectSocket();
   1056 }
   1057 
   1058 TEST_F(TransportTest, TestConnectTwoDigestsBothBad) {
   1059  SetDtlsPeer(2, 3);
   1060 
   1061  ConnectSocketExpectFail();
   1062 }
   1063 
   1064 TEST_F(TransportTest, TestConnectInjectCCS) {
   1065  SetDtlsPeer();
   1066  p2_->SetInspector(MakeUnique<DtlsInspectorInjector>(
   1067      kTlsHandshakeType, kTlsHandshakeCertificate, kTlsFakeChangeCipherSpec,
   1068      sizeof(kTlsFakeChangeCipherSpec)));
   1069 
   1070  ConnectSocket();
   1071 }
   1072 
   1073 TEST_F(TransportTest, TestConnectVerifyNewECDHE) {
   1074  SetDtlsPeer();
   1075  DtlsInspectorRecordHandshakeMessage* i1 =
   1076      new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
   1077  p1_->SetInspector(i1);
   1078  ConnectSocket();
   1079  TlsServerKeyExchangeECDHE dhe1;
   1080  ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
   1081 
   1082  Reset();
   1083  SetDtlsPeer();
   1084  DtlsInspectorRecordHandshakeMessage* i2 =
   1085      new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
   1086  p1_->SetInspector(i2);
   1087  ConnectSocket();
   1088  TlsServerKeyExchangeECDHE dhe2;
   1089  ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
   1090 
   1091  // Now compare these two to see if they are the same.
   1092  ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
   1093               (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
   1094                        dhe1.public_key_.len())));
   1095 }
   1096 
   1097 TEST_F(TransportTest, TestConnectVerifyReusedECDHE) {
   1098  auto set_reuse_ecdhe_key = [](PRFileDesc* fd) {
   1099    // TransportLayerDtls automatically sets this pref to false
   1100    // so set it back for test.
   1101    // This is pretty gross. Dig directly into the NSS FD. The problem
   1102    // is that we are testing a feature which TransaportLayerDtls doesn't
   1103    // expose.
   1104    SECStatus rv = SSL_OptionSet(fd, SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
   1105    ASSERT_EQ(SECSuccess, rv);
   1106  };
   1107 
   1108  SetDtlsPeer();
   1109  DtlsInspectorRecordHandshakeMessage* i1 =
   1110      new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
   1111  p1_->SetInspector(i1);
   1112  p1_->SetPostSetup(set_reuse_ecdhe_key);
   1113  ConnectSocket();
   1114  TlsServerKeyExchangeECDHE dhe1;
   1115  ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
   1116 
   1117  Reset();
   1118  SetDtlsPeer();
   1119  DtlsInspectorRecordHandshakeMessage* i2 =
   1120      new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
   1121 
   1122  p1_->SetInspector(i2);
   1123  p1_->SetPostSetup(set_reuse_ecdhe_key);
   1124 
   1125  ConnectSocket();
   1126  TlsServerKeyExchangeECDHE dhe2;
   1127  ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
   1128 
   1129  // Now compare these two to see if they are the same.
   1130  ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
   1131  ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
   1132                      dhe1.public_key_.len()));
   1133 }
   1134 
   1135 TEST_F(TransportTest, TestTransfer) {
   1136  SetDtlsPeer();
   1137  ConnectSocket();
   1138  TransferTest(1);
   1139 }
   1140 
   1141 TEST_F(TransportTest, TestTransferMaxSize) {
   1142  SetDtlsPeer();
   1143  ConnectSocket();
   1144  /* transportlayerdtls uses a 9216 bytes buffer - as this test uses the
   1145   * loopback implementation it does not have to take into account the extra
   1146   * bytes added by the DTLS layer below. */
   1147  TransferTest(1, 9216);
   1148 }
   1149 
   1150 TEST_F(TransportTest, TestTransferMultiple) {
   1151  SetDtlsPeer();
   1152  ConnectSocket();
   1153  TransferTest(3);
   1154 }
   1155 
   1156 TEST_F(TransportTest, TestTransferCombinedPackets) {
   1157  SetDtlsPeer();
   1158  ConnectSocket();
   1159  p2_->SetCombinePackets(true);
   1160  TransferTest(3);
   1161 }
   1162 
   1163 TEST_F(TransportTest, TestConnectLoseFirst) {
   1164  SetDtlsPeer();
   1165  p1_->SetLoss(0);
   1166  ConnectSocket();
   1167  TransferTest(1);
   1168 }
   1169 
   1170 TEST_F(TransportTest, TestConnectIce) {
   1171  SetDtlsPeer();
   1172  ConnectIce();
   1173 }
   1174 
   1175 TEST_F(TransportTest, TestTransferIceMaxSize) {
   1176  SetDtlsPeer();
   1177  ConnectIce();
   1178  /* nICEr and transportlayerdtls both use 9216 bytes buffers. But the DTLS
   1179   * layer add extra bytes to the packet, which size depends on chosen cipher
   1180   * etc. Sending more then 9216 bytes works, but on the receiving side the call
   1181   * to PR_recvfrom() will truncate any packet bigger then nICEr's buffer size
   1182   * of 9216 bytes, which then results in the DTLS layer discarding the packet.
   1183   * Therefore we leave some headroom (according to
   1184   * https://bugzilla.mozilla.org/show_bug.cgi?id=1214269#c29 256 bytes should
   1185   * be save choice) here for the DTLS bytes to make it safely into the
   1186   * receiving buffer in nICEr. */
   1187  TransferTest(1, 8960);
   1188 }
   1189 
   1190 TEST_F(TransportTest, TestTransferIceMultiple) {
   1191  SetDtlsPeer();
   1192  ConnectIce();
   1193  TransferTest(3);
   1194 }
   1195 
   1196 TEST_F(TransportTest, TestTransferIceCombinedPackets) {
   1197  SetDtlsPeer();
   1198  ConnectIce();
   1199  p2_->SetCombinePackets(true);
   1200  TransferTest(3);
   1201 }
   1202 
   1203 // test the default configuration against a peer that supports only
   1204 // one of the mandatory-to-implement suites, which should succeed
   1205 static void ConfigureOneCipher(TransportTestPeer* peer, uint16_t suite) {
   1206  std::vector<uint16_t> justOne;
   1207  justOne.push_back(suite);
   1208  std::vector<uint16_t> everythingElse(
   1209      SSL_GetImplementedCiphers(),
   1210      SSL_GetImplementedCiphers() + SSL_GetNumImplementedCiphers());
   1211  everythingElse.erase(
   1212      std::remove(everythingElse.begin(), everythingElse.end(), suite));
   1213  peer->SetCipherSuiteChanges(justOne, everythingElse);
   1214 }
   1215 
   1216 TEST_F(TransportTest, TestCipherMismatch) {
   1217  SetDtlsPeer();
   1218  ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
   1219  ConfigureOneCipher(p2_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
   1220  ConnectSocketExpectFail();
   1221 }
   1222 
   1223 TEST_F(TransportTest, TestCipherMandatoryOnlyGcm) {
   1224  SetDtlsPeer();
   1225  ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
   1226  ConnectSocket();
   1227  ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
   1228 }
   1229 
   1230 TEST_F(TransportTest, TestCipherMandatoryOnlyCbc) {
   1231  SetDtlsPeer();
   1232  ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
   1233  ConnectSocket();
   1234  ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, p1_->cipherSuite());
   1235 }
   1236 
   1237 TEST_F(TransportTest, TestSrtpMismatch) {
   1238  std::vector<uint16_t> setA;
   1239  setA.push_back(kDtlsSrtpAes128CmHmacSha1_80);
   1240  std::vector<uint16_t> setB;
   1241  setB.push_back(kDtlsSrtpAes128CmHmacSha1_32);
   1242 
   1243  p1_->SetSrtpCiphers(setA);
   1244  p2_->SetSrtpCiphers(setB);
   1245  SetDtlsPeer();
   1246  ConnectSocketExpectFail();
   1247 
   1248  ASSERT_EQ(0, p1_->srtpCipher());
   1249  ASSERT_EQ(0, p2_->srtpCipher());
   1250 }
   1251 
   1252 static SECStatus NoopXtnHandler(PRFileDesc* fd, SSLHandshakeType message,
   1253                                const uint8_t* data, unsigned int len,
   1254                                SSLAlertDescription* alert, void* arg) {
   1255  return SECSuccess;
   1256 }
   1257 
   1258 static PRBool WriteFixedXtn(PRFileDesc* fd, SSLHandshakeType message,
   1259                            uint8_t* data, unsigned int* len,
   1260                            unsigned int max_len, void* arg) {
   1261  // When we enable TLS 1.3, change ssl_hs_server_hello here to
   1262  // ssl_hs_encrypted_extensions.  At the same time, add a test that writes to
   1263  // ssl_hs_server_hello, which should fail.
   1264  if (message != ssl_hs_client_hello && message != ssl_hs_server_hello) {
   1265    return false;
   1266  }
   1267 
   1268  auto v = reinterpret_cast<std::vector<uint8_t>*>(arg);
   1269  memcpy(data, &((*v)[0]), v->size());
   1270  *len = v->size();
   1271  return true;
   1272 }
   1273 
   1274 // Note that |value| needs to be readable after this function returns.
   1275 static void InstallBadSrtpExtensionWriter(TransportTestPeer* peer,
   1276                                          std::vector<uint8_t>* value) {
   1277  peer->SetPostSetup([value](PRFileDesc* fd) {
   1278    // Override the handler that is installed by the DTLS setup.
   1279    SECStatus rv = SSL_InstallExtensionHooks(
   1280        fd, ssl_use_srtp_xtn, WriteFixedXtn, value, NoopXtnHandler, nullptr);
   1281    ASSERT_EQ(SECSuccess, rv);
   1282  });
   1283 }
   1284 
   1285 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoSrtpCiphers) {
   1286  // Server (p1_) sends an extension with two values, and empty MKI.
   1287  std::vector<uint8_t> xtn = {0x04, 0x00, 0x01, 0x00, 0x02, 0x00};
   1288  InstallBadSrtpExtensionWriter(p1_, &xtn);
   1289  SetupSrtp();
   1290  SetDtlsPeer();
   1291  ConnectSocketExpectFail();
   1292 }
   1293 
   1294 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoMki) {
   1295  // Server (p1_) sends an MKI.
   1296  std::vector<uint8_t> xtn = {0x02, 0x00, 0x01, 0x01, 0x00};
   1297  InstallBadSrtpExtensionWriter(p1_, &xtn);
   1298  SetupSrtp();
   1299  SetDtlsPeer();
   1300  ConnectSocketExpectFail();
   1301 }
   1302 
   1303 TEST_F(TransportTest, TestSrtpErrorServerSendsUnknownValue) {
   1304  std::vector<uint8_t> xtn = {0x02, 0x9a, 0xf1, 0x00};
   1305  InstallBadSrtpExtensionWriter(p1_, &xtn);
   1306  SetupSrtp();
   1307  SetDtlsPeer();
   1308  ConnectSocketExpectFail();
   1309 }
   1310 
   1311 TEST_F(TransportTest, TestSrtpErrorServerSendsOverflow) {
   1312  std::vector<uint8_t> xtn = {0x32, 0x00, 0x01, 0x00};
   1313  InstallBadSrtpExtensionWriter(p1_, &xtn);
   1314  SetupSrtp();
   1315  SetDtlsPeer();
   1316  ConnectSocketExpectFail();
   1317 }
   1318 
   1319 TEST_F(TransportTest, TestSrtpErrorServerSendsUnevenList) {
   1320  std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
   1321  InstallBadSrtpExtensionWriter(p1_, &xtn);
   1322  SetupSrtp();
   1323  SetDtlsPeer();
   1324  ConnectSocketExpectFail();
   1325 }
   1326 
   1327 TEST_F(TransportTest, TestSrtpErrorClientSendsUnevenList) {
   1328  std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
   1329  InstallBadSrtpExtensionWriter(p2_, &xtn);
   1330  SetupSrtp();
   1331  SetDtlsPeer();
   1332  ConnectSocketExpectFail();
   1333 }
   1334 
   1335 TEST_F(TransportTest, OnlyServerSendsSrtpXtn) {
   1336  p1_->SetupSrtp();
   1337  SetDtlsPeer();
   1338  // This should connect, but with no SRTP extension neogtiated.
   1339  // The client side might negotiate a data channel only.
   1340  ConnectSocket();
   1341  ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
   1342  ASSERT_EQ(0, p1_->srtpCipher());
   1343 }
   1344 
   1345 TEST_F(TransportTest, OnlyClientSendsSrtpXtn) {
   1346  p2_->SetupSrtp();
   1347  SetDtlsPeer();
   1348  // This should connect, but with no SRTP extension neogtiated.
   1349  // The server side might negotiate a data channel only.
   1350  ConnectSocket();
   1351  ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
   1352  ASSERT_EQ(0, p1_->srtpCipher());
   1353 }
   1354 
   1355 class TransportSrtpParameterTest
   1356    : public TransportTest,
   1357      public ::testing::WithParamInterface<uint16_t> {};
   1358 
   1359 INSTANTIATE_TEST_SUITE_P(
   1360    SrtpParamInit, TransportSrtpParameterTest,
   1361    ::testing::ValuesIn(TransportLayerDtls::GetDefaultSrtpCiphers()));
   1362 
   1363 TEST_P(TransportSrtpParameterTest, TestSrtpCiphersMismatchCombinations) {
   1364  uint16_t cipher = GetParam();
   1365  std::cerr << "Checking cipher: " << cipher << std::endl;
   1366 
   1367  p1_->SetupSrtp();
   1368 
   1369  std::vector<uint16_t> setB;
   1370  setB.push_back(cipher);
   1371 
   1372  p2_->SetSrtpCiphers(setB);
   1373  SetDtlsPeer();
   1374  ConnectSocket();
   1375 
   1376  ASSERT_EQ(cipher, p1_->srtpCipher());
   1377  ASSERT_EQ(cipher, p2_->srtpCipher());
   1378 }
   1379 
   1380 // NSS doesn't support DHE suites on the server end.
   1381 // This checks to see if we barf when that's the only option available.
   1382 TEST_F(TransportTest, TestDheOnlyFails) {
   1383  SetDtlsPeer();
   1384 
   1385  // p2_ is the client
   1386  // setting this on p1_ (the server) causes NSS to assert
   1387  ConfigureOneCipher(p2_, TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
   1388  ConnectSocketExpectFail();
   1389 }
   1390 
   1391 }  // end namespace