tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

transportlayerdtls.cpp (48943B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 // Original author: ekr@rtfm.com
      8 
      9 #include "transportlayerdtls.h"
     10 
     11 #include <algorithm>
     12 #include <iomanip>
     13 #include <queue>
     14 #include <sstream>
     15 
     16 #include "dtlsidentity.h"
     17 #include "keyhi.h"
     18 #include "logging.h"
     19 #include "mozilla/StaticPrefs_media.h"
     20 #include "mozilla/StaticPrefs_security.h"
     21 #include "mozilla/UniquePtr.h"
     22 #include "mozilla/glean/DomMediaWebrtcMetrics.h"
     23 #include "nsCOMPtr.h"
     24 #include "nsNetCID.h"
     25 #include "nsServiceManagerUtils.h"
     26 #include "sslexp.h"
     27 #include "sslproto.h"
     28 
     29 namespace mozilla {
     30 
     31 MOZ_MTLOG_MODULE("mtransport")
     32 
     33 static PRDescIdentity transport_layer_identity = PR_INVALID_IO_LAYER;
     34 
     35 // TODO: Implement a mode for this where
     36 // the channel is not ready until confirmed externally
     37 // (e.g., after cert check).
     38 
     39 #define UNIMPLEMENTED                                                     \
     40  MOZ_MTLOG(ML_ERROR, "Call to unimplemented function " << __FUNCTION__); \
     41  MOZ_ASSERT(false);                                                      \
     42  PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
     43 
     44 #define MAX_ALPN_LENGTH 255
     45 
     46 // We need to adapt the NSPR/libssl model to the TransportFlow model.
     47 // The former wants pull semantics and TransportFlow wants push.
     48 //
     49 // - A TransportLayerDtls assumes it is sitting on top of another
     50 //   TransportLayer, which means that events come in asynchronously.
     51 // - NSS (libssl) wants to sit on top of a PRFileDesc and poll.
     52 // - The TransportLayerNSPRAdapter is a PRFileDesc containing a
     53 //   FIFO.
     54 // - When TransportLayerDtls.PacketReceived() is called, we insert
     55 //   the packets in the FIFO and then do a PR_Recv() on the NSS
     56 //   PRFileDesc, which eventually reads off the FIFO.
     57 //
     58 // All of this stuff is assumed to happen solely in a single thread
     59 // (generally the SocketTransportService thread)
     60 
     61 void TransportLayerNSPRAdapter::PacketReceived(MediaPacket& packet) {
     62  if (enabled_) {
     63    input_.push(new MediaPacket(std::move(packet)));
     64  }
     65 }
     66 
     67 int32_t TransportLayerNSPRAdapter::Recv(void* buf, int32_t buflen) {
     68  if (input_.empty()) {
     69    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
     70    return -1;
     71  }
     72 
     73  MediaPacket* front = input_.front();
     74  int32_t count = static_cast<int32_t>(front->len());
     75 
     76  if (buflen < count) {
     77    MOZ_ASSERT(false, "Not enough buffer space to receive into");
     78    PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
     79    return -1;
     80  }
     81 
     82  memcpy(buf, front->data(), count);
     83 
     84  input_.pop();
     85  delete front;
     86 
     87  return count;
     88 }
     89 
     90 int32_t TransportLayerNSPRAdapter::Write(const void* buf, int32_t length) {
     91  if (!enabled_) {
     92    MOZ_MTLOG(ML_WARNING, "Writing to disabled transport layer");
     93    return -1;
     94  }
     95 
     96  MediaPacket packet;
     97  // Copies. Oh well.
     98  packet.Copy(static_cast<const uint8_t*>(buf), static_cast<size_t>(length));
     99  packet.SetType(MediaPacket::DTLS);
    100 
    101  TransportResult r = output_->SendPacket(packet);
    102  if (r >= 0) {
    103    return r;
    104  }
    105 
    106  if (r == TE_WOULDBLOCK) {
    107    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    108  } else {
    109    PR_SetError(PR_IO_ERROR, 0);
    110  }
    111 
    112  return -1;
    113 }
    114 
    115 // Implementation of NSPR methods
    116 static PRStatus TransportLayerClose(PRFileDesc* f) {
    117  f->dtor(f);
    118  return PR_SUCCESS;
    119 }
    120 
    121 static int32_t TransportLayerRead(PRFileDesc* f, void* buf, int32_t length) {
    122  UNIMPLEMENTED;
    123  return -1;
    124 }
    125 
    126 static int32_t TransportLayerWrite(PRFileDesc* f, const void* buf,
    127                                   int32_t length) {
    128  TransportLayerNSPRAdapter* io =
    129      reinterpret_cast<TransportLayerNSPRAdapter*>(f->secret);
    130  return io->Write(buf, length);
    131 }
    132 
    133 static int32_t TransportLayerAvailable(PRFileDesc* f) {
    134  UNIMPLEMENTED;
    135  return -1;
    136 }
    137 
    138 int64_t TransportLayerAvailable64(PRFileDesc* f) {
    139  UNIMPLEMENTED;
    140  return -1;
    141 }
    142 
    143 static PRStatus TransportLayerSync(PRFileDesc* f) {
    144  UNIMPLEMENTED;
    145  return PR_FAILURE;
    146 }
    147 
    148 static int32_t TransportLayerSeek(PRFileDesc* f, int32_t offset,
    149                                  PRSeekWhence how) {
    150  UNIMPLEMENTED;
    151  return -1;
    152 }
    153 
    154 static int64_t TransportLayerSeek64(PRFileDesc* f, int64_t offset,
    155                                    PRSeekWhence how) {
    156  UNIMPLEMENTED;
    157  return -1;
    158 }
    159 
    160 static PRStatus TransportLayerFileInfo(PRFileDesc* f, PRFileInfo* info) {
    161  UNIMPLEMENTED;
    162  return PR_FAILURE;
    163 }
    164 
    165 static PRStatus TransportLayerFileInfo64(PRFileDesc* f, PRFileInfo64* info) {
    166  UNIMPLEMENTED;
    167  return PR_FAILURE;
    168 }
    169 
    170 static int32_t TransportLayerWritev(PRFileDesc* f, const PRIOVec* iov,
    171                                    int32_t iov_size, PRIntervalTime to) {
    172  UNIMPLEMENTED;
    173  return -1;
    174 }
    175 
    176 static PRStatus TransportLayerConnect(PRFileDesc* f, const PRNetAddr* addr,
    177                                      PRIntervalTime to) {
    178  UNIMPLEMENTED;
    179  return PR_FAILURE;
    180 }
    181 
    182 static PRFileDesc* TransportLayerAccept(PRFileDesc* sd, PRNetAddr* addr,
    183                                        PRIntervalTime to) {
    184  UNIMPLEMENTED;
    185  return nullptr;
    186 }
    187 
    188 static PRStatus TransportLayerBind(PRFileDesc* f, const PRNetAddr* addr) {
    189  UNIMPLEMENTED;
    190  return PR_FAILURE;
    191 }
    192 
    193 static PRStatus TransportLayerListen(PRFileDesc* f, int32_t depth) {
    194  UNIMPLEMENTED;
    195  return PR_FAILURE;
    196 }
    197 
    198 static PRStatus TransportLayerShutdown(PRFileDesc* f, int32_t how) {
    199  // This is only called from NSS when we are the server and the client refuses
    200  // to provide a certificate.  In this case, the handshake is destined for
    201  // failure, so we will just let this pass.
    202  TransportLayerNSPRAdapter* io =
    203      reinterpret_cast<TransportLayerNSPRAdapter*>(f->secret);
    204  io->SetEnabled(false);
    205  return PR_SUCCESS;
    206 }
    207 
    208 // This function does not support peek, or waiting until `to`
    209 static int32_t TransportLayerRecv(PRFileDesc* f, void* buf, int32_t buflen,
    210                                  int32_t flags, PRIntervalTime to) {
    211  MOZ_ASSERT(flags == 0);
    212  if (flags != 0) {
    213    PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
    214    return -1;
    215  }
    216 
    217  TransportLayerNSPRAdapter* io =
    218      reinterpret_cast<TransportLayerNSPRAdapter*>(f->secret);
    219  return io->Recv(buf, buflen);
    220 }
    221 
    222 // Note: this is always nonblocking and assumes a zero timeout.
    223 static int32_t TransportLayerSend(PRFileDesc* f, const void* buf,
    224                                  int32_t amount, int32_t flags,
    225                                  PRIntervalTime to) {
    226  int32_t written = TransportLayerWrite(f, buf, amount);
    227  return written;
    228 }
    229 
    230 static int32_t TransportLayerRecvfrom(PRFileDesc* f, void* buf, int32_t amount,
    231                                      int32_t flags, PRNetAddr* addr,
    232                                      PRIntervalTime to) {
    233  UNIMPLEMENTED;
    234  return -1;
    235 }
    236 
    237 static int32_t TransportLayerSendto(PRFileDesc* f, const void* buf,
    238                                    int32_t amount, int32_t flags,
    239                                    const PRNetAddr* addr, PRIntervalTime to) {
    240  UNIMPLEMENTED;
    241  return -1;
    242 }
    243 
    244 static int16_t TransportLayerPoll(PRFileDesc* f, int16_t in_flags,
    245                                  int16_t* out_flags) {
    246  UNIMPLEMENTED;
    247  return -1;
    248 }
    249 
    250 static int32_t TransportLayerAcceptRead(PRFileDesc* sd, PRFileDesc** nd,
    251                                        PRNetAddr** raddr, void* buf,
    252                                        int32_t amount, PRIntervalTime t) {
    253  UNIMPLEMENTED;
    254  return -1;
    255 }
    256 
    257 static int32_t TransportLayerTransmitFile(PRFileDesc* sd, PRFileDesc* f,
    258                                          const void* headers, int32_t hlen,
    259                                          PRTransmitFileFlags flags,
    260                                          PRIntervalTime t) {
    261  UNIMPLEMENTED;
    262  return -1;
    263 }
    264 
    265 static PRStatus TransportLayerGetpeername(PRFileDesc* f, PRNetAddr* addr) {
    266  // TODO: Modify to return unique names for each channel
    267  // somehow, as opposed to always the same static address. The current
    268  // implementation messes up the session cache, which is why it's off
    269  // elsewhere
    270  addr->inet.family = PR_AF_INET;
    271  addr->inet.port = 0;
    272  addr->inet.ip = 0;
    273 
    274  return PR_SUCCESS;
    275 }
    276 
    277 static PRStatus TransportLayerGetsockname(PRFileDesc* f, PRNetAddr* addr) {
    278  UNIMPLEMENTED;
    279  return PR_FAILURE;
    280 }
    281 
    282 static PRStatus TransportLayerGetsockoption(PRFileDesc* f,
    283                                            PRSocketOptionData* opt) {
    284  switch (opt->option) {
    285    case PR_SockOpt_Nonblocking:
    286      opt->value.non_blocking = PR_TRUE;
    287      return PR_SUCCESS;
    288    default:
    289      UNIMPLEMENTED;
    290      break;
    291  }
    292 
    293  return PR_FAILURE;
    294 }
    295 
    296 // Imitate setting socket options. These are mostly noops.
    297 static PRStatus TransportLayerSetsockoption(PRFileDesc* f,
    298                                            const PRSocketOptionData* opt) {
    299  switch (opt->option) {
    300    case PR_SockOpt_Nonblocking:
    301      return PR_SUCCESS;
    302    case PR_SockOpt_NoDelay:
    303      return PR_SUCCESS;
    304    default:
    305      UNIMPLEMENTED;
    306      break;
    307  }
    308 
    309  return PR_FAILURE;
    310 }
    311 
    312 static int32_t TransportLayerSendfile(PRFileDesc* out, PRSendFileData* in,
    313                                      PRTransmitFileFlags flags,
    314                                      PRIntervalTime to) {
    315  UNIMPLEMENTED;
    316  return -1;
    317 }
    318 
    319 static PRStatus TransportLayerConnectContinue(PRFileDesc* f, int16_t flags) {
    320  UNIMPLEMENTED;
    321  return PR_FAILURE;
    322 }
    323 
    324 static int32_t TransportLayerReserved(PRFileDesc* f) {
    325  UNIMPLEMENTED;
    326  return -1;
    327 }
    328 
    329 static const struct PRIOMethods TransportLayerMethods = {
    330    PR_DESC_LAYERED,
    331    TransportLayerClose,
    332    TransportLayerRead,
    333    TransportLayerWrite,
    334    TransportLayerAvailable,
    335    TransportLayerAvailable64,
    336    TransportLayerSync,
    337    TransportLayerSeek,
    338    TransportLayerSeek64,
    339    TransportLayerFileInfo,
    340    TransportLayerFileInfo64,
    341    TransportLayerWritev,
    342    TransportLayerConnect,
    343    TransportLayerAccept,
    344    TransportLayerBind,
    345    TransportLayerListen,
    346    TransportLayerShutdown,
    347    TransportLayerRecv,
    348    TransportLayerSend,
    349    TransportLayerRecvfrom,
    350    TransportLayerSendto,
    351    TransportLayerPoll,
    352    TransportLayerAcceptRead,
    353    TransportLayerTransmitFile,
    354    TransportLayerGetsockname,
    355    TransportLayerGetpeername,
    356    TransportLayerReserved,
    357    TransportLayerReserved,
    358    TransportLayerGetsockoption,
    359    TransportLayerSetsockoption,
    360    TransportLayerSendfile,
    361    TransportLayerConnectContinue,
    362    TransportLayerReserved,
    363    TransportLayerReserved,
    364    TransportLayerReserved,
    365    TransportLayerReserved};
    366 
    367 TransportLayerDtls::~TransportLayerDtls() {
    368  // Destroy the NSS instance first so it can still send out an alert before
    369  // we disable the nspr_io_adapter_.
    370  ssl_fd_ = nullptr;
    371  nspr_io_adapter_->SetEnabled(false);
    372  if (timer_) {
    373    timer_->Cancel();
    374  }
    375 }
    376 
    377 nsresult TransportLayerDtls::InitInternal() {
    378  // Get the transport service as an event target
    379  nsresult rv;
    380  target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
    381 
    382  if (NS_FAILED(rv)) {
    383    MOZ_MTLOG(ML_ERROR, "Couldn't get socket transport service");
    384    return rv;
    385  }
    386 
    387  timer_ = NS_NewTimer();
    388  if (!timer_) {
    389    MOZ_MTLOG(ML_ERROR, "Couldn't get timer");
    390    return rv;
    391  }
    392 
    393  return NS_OK;
    394 }
    395 
    396 void TransportLayerDtls::WasInserted() {
    397  // Connect to the lower layers
    398  if (!Setup()) {
    399    TL_SET_STATE(TS_ERROR);
    400  }
    401 }
    402 
    403 // Set the permitted and default ALPN identifiers.
    404 // The default is here to allow for peers that don't want to negotiate ALPN
    405 // in that case, the default string will be reported from GetNegotiatedAlpn().
    406 // Setting the default to the empty string causes the transport layer to fail
    407 // if ALPN is not negotiated.
    408 // Note: we only support Unicode strings here, which are encoded into UTF-8,
    409 // even though ALPN ostensibly allows arbitrary octet sequences.
    410 nsresult TransportLayerDtls::SetAlpn(const std::set<std::string>& alpn_allowed,
    411                                     const std::string& alpn_default) {
    412  alpn_allowed_ = alpn_allowed;
    413  alpn_default_ = alpn_default;
    414 
    415  return NS_OK;
    416 }
    417 
    418 nsresult TransportLayerDtls::SetVerificationAllowAll() {
    419  // Defensive programming
    420  if (verification_mode_ != VERIFY_UNSET) return NS_ERROR_ALREADY_INITIALIZED;
    421 
    422  verification_mode_ = VERIFY_ALLOW_ALL;
    423 
    424  return NS_OK;
    425 }
    426 
    427 nsresult TransportLayerDtls::SetVerificationDigest(const DtlsDigest& digest) {
    428  // Defensive programming
    429  if (verification_mode_ != VERIFY_UNSET &&
    430      verification_mode_ != VERIFY_DIGEST) {
    431    return NS_ERROR_ALREADY_INITIALIZED;
    432  }
    433 
    434  digests_.push_back(digest);
    435  verification_mode_ = VERIFY_DIGEST;
    436  return NS_OK;
    437 }
    438 
    439 void TransportLayerDtls::SetMinMaxVersion(Version min_version,
    440                                          Version max_version) {
    441  if (min_version < Version::DTLS_1_0 || min_version > Version::DTLS_1_3 ||
    442      max_version < Version::DTLS_1_0 || max_version > Version::DTLS_1_3 ||
    443      min_version > max_version || max_version < min_version) {
    444    return;
    445  }
    446  minVersion_ = min_version;
    447  maxVersion_ = max_version;
    448 }
    449 
    450 // TODO: make sure this is called from STS. Otherwise
    451 // we have thread safety issues
    452 bool TransportLayerDtls::Setup() {
    453  CheckThread();
    454  SECStatus rv;
    455 
    456  if (!downward_) {
    457    MOZ_MTLOG(ML_ERROR, "DTLS layer with nothing below. This is useless");
    458    return false;
    459  }
    460  nspr_io_adapter_ = MakeUnique<TransportLayerNSPRAdapter>(downward_);
    461 
    462  if (!identity_) {
    463    MOZ_MTLOG(ML_ERROR, "Can't start DTLS without an identity");
    464    return false;
    465  }
    466 
    467  if (verification_mode_ == VERIFY_UNSET) {
    468    MOZ_MTLOG(ML_ERROR,
    469              "Can't start DTLS without specifying a verification mode");
    470    return false;
    471  }
    472 
    473  if (transport_layer_identity == PR_INVALID_IO_LAYER) {
    474    transport_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
    475  }
    476 
    477  UniquePRFileDesc pr_fd(
    478      PR_CreateIOLayerStub(transport_layer_identity, &TransportLayerMethods));
    479  MOZ_ASSERT(pr_fd != nullptr);
    480  if (!pr_fd) return false;
    481  pr_fd->secret = reinterpret_cast<PRFilePrivate*>(nspr_io_adapter_.get());
    482 
    483  UniquePRFileDesc ssl_fd(DTLS_ImportFD(nullptr, pr_fd.get()));
    484  MOZ_ASSERT(ssl_fd != nullptr);  // This should never happen
    485  if (!ssl_fd) {
    486    return false;
    487  }
    488 
    489  (void)pr_fd.release();  // ownership transfered to ssl_fd;
    490 
    491  if (role_ == CLIENT) {
    492    MOZ_MTLOG(ML_INFO, "Setting up DTLS as client");
    493    rv = SSL_GetClientAuthDataHook(ssl_fd.get(), GetClientAuthDataHook, this);
    494    if (rv != SECSuccess) {
    495      MOZ_MTLOG(ML_ERROR, "Couldn't set identity");
    496      return false;
    497    }
    498 
    499    if (maxVersion_ >= Version::DTLS_1_3) {
    500      MOZ_MTLOG(ML_INFO, "Setting DTLS1.3 supported_versions workaround");
    501      rv = SSL_SetDtls13VersionWorkaround(ssl_fd.get(), PR_TRUE);
    502      if (rv != SECSuccess) {
    503        MOZ_MTLOG(ML_ERROR, "Couldn't set DTLS1.3 workaround");
    504        return false;
    505      }
    506    }
    507  } else {
    508    MOZ_MTLOG(ML_INFO, "Setting up DTLS as server");
    509    // Server side
    510    rv = SSL_ConfigSecureServer(ssl_fd.get(), identity_->cert().get(),
    511                                identity_->privkey().get(),
    512                                identity_->auth_type());
    513    if (rv != SECSuccess) {
    514      MOZ_MTLOG(ML_ERROR, "Couldn't set identity");
    515      return false;
    516    }
    517 
    518    UniqueCERTCertList zero_certs(CERT_NewCertList());
    519    rv = SSL_SetTrustAnchors(ssl_fd.get(), zero_certs.get());
    520    if (rv != SECSuccess) {
    521      MOZ_MTLOG(ML_ERROR, "Couldn't set trust anchors");
    522      return false;
    523    }
    524 
    525    // Insist on a certificate from the client
    526    rv = SSL_OptionSet(ssl_fd.get(), SSL_REQUEST_CERTIFICATE, PR_TRUE);
    527    if (rv != SECSuccess) {
    528      MOZ_MTLOG(ML_ERROR, "Couldn't request certificate");
    529      return false;
    530    }
    531 
    532    rv = SSL_OptionSet(ssl_fd.get(), SSL_REQUIRE_CERTIFICATE, PR_TRUE);
    533    if (rv != SECSuccess) {
    534      MOZ_MTLOG(ML_ERROR, "Couldn't require certificate");
    535      return false;
    536    }
    537  }
    538 
    539  SSLVersionRange version_range = {static_cast<PRUint16>(minVersion_),
    540                                   static_cast<PRUint16>(maxVersion_)};
    541 
    542  rv = SSL_VersionRangeSet(ssl_fd.get(), &version_range);
    543  if (rv != SECSuccess) {
    544    MOZ_MTLOG(ML_ERROR, "Can't disable SSLv3");
    545    return false;
    546  }
    547 
    548  rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
    549  if (rv != SECSuccess) {
    550    MOZ_MTLOG(ML_ERROR, "Couldn't disable session tickets");
    551    return false;
    552  }
    553 
    554  rv = SSL_OptionSet(ssl_fd.get(), SSL_NO_CACHE, PR_TRUE);
    555  if (rv != SECSuccess) {
    556    MOZ_MTLOG(ML_ERROR, "Couldn't disable session caching");
    557    return false;
    558  }
    559 
    560  rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_DEFLATE, PR_FALSE);
    561  if (rv != SECSuccess) {
    562    MOZ_MTLOG(ML_ERROR, "Couldn't disable deflate");
    563    return false;
    564  }
    565 
    566  rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_RENEGOTIATION,
    567                     SSL_RENEGOTIATE_NEVER);
    568  if (rv != SECSuccess) {
    569    MOZ_MTLOG(ML_ERROR, "Couldn't disable renegotiation");
    570    return false;
    571  }
    572 
    573  rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_FALSE_START, PR_FALSE);
    574  if (rv != SECSuccess) {
    575    MOZ_MTLOG(ML_ERROR, "Couldn't disable false start");
    576    return false;
    577  }
    578 
    579  rv = SSL_OptionSet(ssl_fd.get(), SSL_NO_LOCKS, PR_TRUE);
    580  if (rv != SECSuccess) {
    581    MOZ_MTLOG(ML_ERROR, "Couldn't disable locks");
    582    return false;
    583  }
    584 
    585  rv = SSL_OptionSet(ssl_fd.get(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
    586  if (rv != SECSuccess) {
    587    MOZ_MTLOG(ML_ERROR, "Couldn't disable ECDHE key reuse");
    588    return false;
    589  }
    590 
    591  if (!SetupCipherSuites(ssl_fd)) {
    592    return false;
    593  }
    594 
    595  // If the version is DTLS1.3 and pq in enabled, we will indicate support for
    596  // ML-KEM
    597  bool enable_mlkem = maxVersion_ >= Version::DTLS_1_3 &&
    598                      StaticPrefs::media_webrtc_enable_pq_hybrid_kex();
    599 
    600  bool send_mlkem = StaticPrefs::media_webrtc_send_mlkem_keyshare();
    601 
    602  if (!enable_mlkem && send_mlkem) {
    603    MOZ_MTLOG(ML_NOTICE,
    604              "The PQ preferences are inconsistent. ML-KEM support will not be "
    605              "advertised, nor will the ML-KEM key share be sent.");
    606  }
    607 
    608  std::vector<SSLNamedGroup> namedGroups;
    609 
    610  if (enable_mlkem && send_mlkem) {
    611    // RFC 8446: client_shares:  A list of offered KeyShareEntry values in
    612    // descending order of client preference.
    613    // {ssl_grp_kem_mlkem768x25519}, {ssl_grp_ec_curve2551} key shared to be
    614    // sent ML-KEM has the highest preference, so it's sent first.
    615    namedGroups = {ssl_grp_kem_mlkem768x25519, ssl_grp_ec_curve25519,
    616                   ssl_grp_ec_secp256r1,       ssl_grp_ec_secp384r1,
    617                   ssl_grp_ffdhe_2048,         ssl_grp_ffdhe_3072};
    618 
    619    if (SECSuccess != SSL_SendAdditionalKeyShares(ssl_fd.get(), 1)) {
    620      MOZ_MTLOG(ML_ERROR, "Couldn't set up additional key shares");
    621      return false;
    622    }
    623  }
    624  // Else we don't send any additional key share
    625  else if (enable_mlkem && !send_mlkem) {
    626    // Here the order of the namedGroups is different than in the first if
    627    // {ssl_grp_ec_curve25519} is first because it's the key share
    628    // that will be sent by default
    629    namedGroups = {ssl_grp_ec_curve25519, ssl_grp_kem_mlkem768x25519,
    630                   ssl_grp_ec_secp256r1,  ssl_grp_ec_secp384r1,
    631                   ssl_grp_ffdhe_2048,    ssl_grp_ffdhe_3072};
    632  }
    633  // ml_kem is disabled
    634  // {ssl_grp_ec_curve25519} will be send as a default key_share
    635  else {
    636    namedGroups = {ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1,
    637                   ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048,
    638                   ssl_grp_ffdhe_3072};
    639  }
    640 
    641  rv = SSL_NamedGroupConfig(ssl_fd.get(), namedGroups.data(),
    642                            std::size(namedGroups));
    643 
    644  if (rv != SECSuccess) {
    645    MOZ_MTLOG(ML_ERROR, "Couldn't set named groups");
    646    return false;
    647  }
    648 
    649  // Certificate validation
    650  rv = SSL_AuthCertificateHook(ssl_fd.get(), AuthCertificateHook,
    651                               reinterpret_cast<void*>(this));
    652  if (rv != SECSuccess) {
    653    MOZ_MTLOG(ML_ERROR, "Couldn't set certificate validation hook");
    654    return false;
    655  }
    656 
    657  if (!SetupAlpn(ssl_fd)) {
    658    return false;
    659  }
    660 
    661  // Now start the handshake
    662  rv = SSL_ResetHandshake(ssl_fd.get(), role_ == SERVER ? PR_TRUE : PR_FALSE);
    663  if (rv != SECSuccess) {
    664    MOZ_MTLOG(ML_ERROR, "Couldn't reset handshake");
    665    return false;
    666  }
    667  ssl_fd_ = std::move(ssl_fd);
    668 
    669  // Finally, get ready to receive data
    670  downward_->SignalStateChange.connect(this, &TransportLayerDtls::StateChange);
    671  downward_->SignalPacketReceived.connect(this,
    672                                          &TransportLayerDtls::PacketReceived);
    673 
    674  if (downward_->state() == TS_OPEN) {
    675    TL_SET_STATE(TS_CONNECTING);
    676    Handshake();
    677  }
    678 
    679  return true;
    680 }
    681 
    682 bool TransportLayerDtls::SetupAlpn(UniquePRFileDesc& ssl_fd) const {
    683  if (alpn_allowed_.empty()) {
    684    return true;
    685  }
    686 
    687  SECStatus rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_NPN, PR_FALSE);
    688  if (rv != SECSuccess) {
    689    MOZ_MTLOG(ML_ERROR, "Couldn't disable NPN");
    690    return false;
    691  }
    692 
    693  rv = SSL_OptionSet(ssl_fd.get(), SSL_ENABLE_ALPN, PR_TRUE);
    694  if (rv != SECSuccess) {
    695    MOZ_MTLOG(ML_ERROR, "Couldn't enable ALPN");
    696    return false;
    697  }
    698 
    699  unsigned char buf[MAX_ALPN_LENGTH];
    700  size_t offset = 0;
    701  for (const auto& tag : alpn_allowed_) {
    702    if ((offset + 1 + tag.length()) >= sizeof(buf)) {
    703      MOZ_MTLOG(ML_ERROR, "ALPN too long");
    704      return false;
    705    }
    706    buf[offset++] = tag.length();
    707    memcpy(buf + offset, tag.c_str(), tag.length());
    708    offset += tag.length();
    709  }
    710  rv = SSL_SetNextProtoNego(ssl_fd.get(), buf, offset);
    711  if (rv != SECSuccess) {
    712    MOZ_MTLOG(ML_ERROR, "Couldn't set ALPN string");
    713    return false;
    714  }
    715  return true;
    716 }
    717 
    718 // Ciphers we need to enable.  These are on by default in standard firefox
    719 // builds, but can be disabled with prefs and they aren't on in our unit tests
    720 // since that uses NSS default configuration.
    721 //
    722 // Only override prefs to comply with MUST statements in the security-arch
    723 // doc. Anything outside this list is governed by the usual combination of
    724 // policy and user preferences.
    725 static const uint32_t EnabledCiphers[] = {
    726    TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
    727    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA};
    728 
    729 // Disable all NSS suites modes without PFS or with old and rusty
    730 // ciphersuites. Anything outside this list is governed by the usual
    731 // combination of policy and user preferences.
    732 static const uint32_t DisabledCiphers[] = {
    733    // Bug 1310061: disable all SHA384 ciphers until fixed
    734    TLS_AES_256_GCM_SHA384,
    735    TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
    736    TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
    737    TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
    738    TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
    739    TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
    740    TLS_DHE_DSS_WITH_AES_256_GCM_SHA384,
    741 
    742    TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
    743    TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
    744 
    745    TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
    746    TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
    747    TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
    748    TLS_ECDHE_RSA_WITH_RC4_128_SHA,
    749 
    750    TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
    751    TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA,
    752    TLS_DHE_DSS_WITH_RC4_128_SHA,
    753 
    754    TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
    755    TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
    756    TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
    757    TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
    758    TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
    759    TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
    760    TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
    761    TLS_ECDH_RSA_WITH_RC4_128_SHA,
    762 
    763    TLS_RSA_WITH_AES_128_GCM_SHA256,
    764    TLS_RSA_WITH_AES_256_GCM_SHA384,
    765    TLS_RSA_WITH_AES_128_CBC_SHA,
    766    TLS_RSA_WITH_AES_128_CBC_SHA256,
    767    TLS_RSA_WITH_CAMELLIA_128_CBC_SHA,
    768    TLS_RSA_WITH_AES_256_CBC_SHA,
    769    TLS_RSA_WITH_AES_256_CBC_SHA256,
    770    TLS_RSA_WITH_CAMELLIA_256_CBC_SHA,
    771    TLS_RSA_WITH_SEED_CBC_SHA,
    772    TLS_RSA_WITH_3DES_EDE_CBC_SHA,
    773    TLS_RSA_WITH_RC4_128_SHA,
    774    TLS_RSA_WITH_RC4_128_MD5,
    775 
    776    TLS_DHE_RSA_WITH_DES_CBC_SHA,
    777    TLS_DHE_DSS_WITH_DES_CBC_SHA,
    778    TLS_RSA_WITH_DES_CBC_SHA,
    779 
    780    TLS_ECDHE_ECDSA_WITH_NULL_SHA,
    781    TLS_ECDHE_RSA_WITH_NULL_SHA,
    782    TLS_ECDH_ECDSA_WITH_NULL_SHA,
    783    TLS_ECDH_RSA_WITH_NULL_SHA,
    784    TLS_RSA_WITH_NULL_SHA,
    785    TLS_RSA_WITH_NULL_SHA256,
    786    TLS_RSA_WITH_NULL_MD5,
    787 };
    788 
    789 bool TransportLayerDtls::SetupCipherSuites(UniquePRFileDesc& ssl_fd) {
    790  SECStatus rv;
    791 
    792  // Set the SRTP ciphers
    793  if (!enabled_srtp_ciphers_.empty()) {
    794    rv = SSL_InstallExtensionHooks(ssl_fd.get(), ssl_use_srtp_xtn,
    795                                   TransportLayerDtls::WriteSrtpXtn, this,
    796                                   TransportLayerDtls::HandleSrtpXtn, this);
    797    if (rv != SECSuccess) {
    798      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "unable to set SRTP extension handler");
    799      return false;
    800    }
    801  }
    802 
    803  for (const auto& cipher : EnabledCiphers) {
    804    MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Enabling: " << cipher);
    805    rv = SSL_CipherPrefSet(ssl_fd.get(), cipher, PR_TRUE);
    806    if (rv != SECSuccess) {
    807      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Unable to enable suite: " << cipher);
    808      return false;
    809    }
    810  }
    811 
    812  for (const auto& cipher : DisabledCiphers) {
    813    MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Disabling: " << cipher);
    814 
    815    PRBool enabled = false;
    816    rv = SSL_CipherPrefGet(ssl_fd.get(), cipher, &enabled);
    817    if (rv != SECSuccess) {
    818      MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Unable to check if suite is enabled: "
    819                                      << cipher);
    820      return false;
    821    }
    822    if (enabled) {
    823      rv = SSL_CipherPrefSet(ssl_fd.get(), cipher, PR_FALSE);
    824      if (rv != SECSuccess) {
    825        MOZ_MTLOG(ML_NOTICE,
    826                  LAYER_INFO << "Unable to disable suite: " << cipher);
    827        return false;
    828      }
    829    }
    830  }
    831 
    832  return true;
    833 }
    834 
    835 nsresult TransportLayerDtls::GetCipherSuite(uint16_t* cipherSuite) const {
    836  CheckThread();
    837  if (!cipherSuite) {
    838    MOZ_MTLOG(ML_ERROR, LAYER_INFO << "GetCipherSuite passed a nullptr");
    839    return NS_ERROR_NULL_POINTER;
    840  }
    841  if (state_ != TS_OPEN) {
    842    return NS_ERROR_NOT_AVAILABLE;
    843  }
    844  SSLChannelInfo info;
    845  SECStatus rv = SSL_GetChannelInfo(ssl_fd_.get(), &info, sizeof(info));
    846  if (rv != SECSuccess) {
    847    MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "GetCipherSuite can't get channel info");
    848    return NS_ERROR_FAILURE;
    849  }
    850  *cipherSuite = info.cipherSuite;
    851  return NS_OK;
    852 }
    853 
    854 std::vector<uint16_t> TransportLayerDtls::GetDefaultSrtpCiphers() {
    855  std::vector<uint16_t> ciphers;
    856 
    857  ciphers.push_back(kDtlsSrtpAeadAes128Gcm);
    858  // Since we don't support DTLS 1.3 or SHA384 ciphers (see bug 1312976)
    859  // we don't really enough entropy to prefer this over 128 bit
    860  ciphers.push_back(kDtlsSrtpAeadAes256Gcm);
    861  ciphers.push_back(kDtlsSrtpAes128CmHmacSha1_80);
    862 #ifndef NIGHTLY_BUILD
    863  // To support bug 1491583 lets try to find out if we get bug reports if we
    864  // no longer offer this in Nightly builds.
    865  ciphers.push_back(kDtlsSrtpAes128CmHmacSha1_32);
    866 #endif
    867 
    868  return ciphers;
    869 }
    870 
    871 void TransportLayerDtls::StateChange(TransportLayer* layer, State state) {
    872  switch (state) {
    873    case TS_NONE:
    874      MOZ_ASSERT(false);  // Can't happen
    875      break;
    876 
    877    case TS_INIT:
    878      MOZ_MTLOG(ML_ERROR,
    879                LAYER_INFO << "State change of lower layer to INIT forbidden");
    880      TL_SET_STATE(TS_ERROR);
    881      break;
    882 
    883    case TS_CONNECTING:
    884      MOZ_MTLOG(ML_INFO, LAYER_INFO << "Lower layer is connecting.");
    885      break;
    886 
    887    case TS_OPEN:
    888      if (timer_) {
    889        MOZ_MTLOG(ML_INFO,
    890                  LAYER_INFO << "Lower layer is now open; starting TLS");
    891        timer_->Cancel();
    892        timer_->SetTarget(target_);
    893        // Async, since the ICE layer might need to send a STUN response, and
    894        // we don't want the handshake to start until that is sent.
    895        timer_->InitWithNamedFuncCallback(
    896            TimerCallback, this, 0, nsITimer::TYPE_ONE_SHOT,
    897            "TransportLayerDtls::TimerCallback"_ns);
    898        TL_SET_STATE(TS_CONNECTING);
    899      } else {
    900        // We have already completed DTLS. Can happen if the ICE layer failed
    901        // due to a loss of network, and then recovered.
    902        TL_SET_STATE(TS_OPEN);
    903      }
    904      break;
    905 
    906    case TS_CLOSED:
    907      MOZ_MTLOG(ML_INFO, LAYER_INFO << "Lower layer is now closed");
    908      TL_SET_STATE(TS_CLOSED);
    909      break;
    910 
    911    case TS_ERROR:
    912      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Lower layer experienced an error");
    913      TL_SET_STATE(TS_ERROR);
    914      break;
    915  }
    916 }
    917 
    918 void TransportLayerDtls::Handshake() {
    919  if (!timer_) {
    920    // We are done with DTLS, regardless of the state changes of lower layers
    921    return;
    922  }
    923 
    924  if (!handshakeTelemetryRecorded) {
    925    RecordStartedHandshakeTelemetry();
    926    handshakeTelemetryRecorded = true;
    927  }
    928 
    929  // Clear the retransmit timer
    930  timer_->Cancel();
    931 
    932  MOZ_ASSERT(state_ == TS_CONNECTING);
    933 
    934  SECStatus rv = SSL_ForceHandshake(ssl_fd_.get());
    935 
    936  if (rv == SECSuccess) {
    937    MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "****** SSL handshake completed ******");
    938    if (!cert_ok_) {
    939      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Certificate check never occurred");
    940      RecordHandshakeCompletionTelemetry("CERT_FAILURE");
    941      TL_SET_STATE(TS_ERROR);
    942      return;
    943    }
    944    if (!CheckAlpn()) {
    945      // Despite connecting, the connection doesn't have a valid ALPN label.
    946      // Forcibly close the connection so that the peer isn't left hanging
    947      // (assuming the close_notify isn't dropped).
    948      ssl_fd_ = nullptr;
    949      RecordHandshakeCompletionTelemetry("ALPN_FAILURE");
    950      TL_SET_STATE(TS_ERROR);
    951      return;
    952    }
    953 
    954    RecordHandshakeCompletionTelemetry("SUCCESS");
    955    TL_SET_STATE(TS_OPEN);
    956 
    957    RecordTlsTelemetry();
    958    timer_ = nullptr;
    959  } else {
    960    int32_t err = PR_GetError();
    961    switch (err) {
    962      case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
    963        MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Malformed DTLS message; ignoring");
    964        // If this were TLS (and not DTLS), this would be fatal, but
    965        // here we're required to ignore bad messages, so fall through
    966        [[fallthrough]];
    967      case PR_WOULD_BLOCK_ERROR:
    968        MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Handshake would have blocked");
    969        PRIntervalTime timeout;
    970        rv = DTLS_GetHandshakeTimeout(ssl_fd_.get(), &timeout);
    971        if (rv == SECSuccess) {
    972          uint32_t timeout_ms = PR_IntervalToMilliseconds(timeout);
    973 
    974          MOZ_MTLOG(ML_DEBUG,
    975                    LAYER_INFO << "Setting DTLS timeout to " << timeout_ms);
    976          timer_->SetTarget(target_);
    977          timer_->InitWithNamedFuncCallback(
    978              TimerCallback, this, timeout_ms, nsITimer::TYPE_ONE_SHOT,
    979              "TransportLayerDtls::TimerCallback"_ns);
    980        }
    981        break;
    982      default:
    983        const char* err_msg = PR_ErrorToName(err);
    984        MOZ_MTLOG(ML_ERROR, LAYER_INFO << "DTLS handshake error " << err << " ("
    985                                       << err_msg << ")");
    986        RecordHandshakeCompletionTelemetry(err_msg);
    987        TL_SET_STATE(TS_ERROR);
    988        break;
    989    }
    990  }
    991 }
    992 
    993 // Checks if ALPN was negotiated correctly and returns false if it wasn't.
    994 // After this returns successfully, alpn_ will be set to the negotiated
    995 // protocol.
    996 bool TransportLayerDtls::CheckAlpn() {
    997  if (alpn_allowed_.empty()) {
    998    return true;
    999  }
   1000 
   1001  SSLNextProtoState alpnState;
   1002  char chosenAlpn[MAX_ALPN_LENGTH];
   1003  unsigned int chosenAlpnLen;
   1004  SECStatus rv = SSL_GetNextProto(ssl_fd_.get(), &alpnState,
   1005                                  reinterpret_cast<unsigned char*>(chosenAlpn),
   1006                                  &chosenAlpnLen, sizeof(chosenAlpn));
   1007  if (rv != SECSuccess) {
   1008    MOZ_MTLOG(ML_ERROR, LAYER_INFO << "ALPN error");
   1009    return false;
   1010  }
   1011  switch (alpnState) {
   1012    case SSL_NEXT_PROTO_SELECTED:
   1013    case SSL_NEXT_PROTO_NEGOTIATED:
   1014      break;  // OK
   1015 
   1016    case SSL_NEXT_PROTO_NO_SUPPORT:
   1017      MOZ_MTLOG(ML_NOTICE,
   1018                LAYER_INFO << "ALPN not negotiated, "
   1019                           << (alpn_default_.empty() ? "failing"
   1020                                                     : "selecting default"));
   1021      alpn_ = alpn_default_;
   1022      return !alpn_.empty();
   1023 
   1024    case SSL_NEXT_PROTO_NO_OVERLAP:
   1025      // This only happens if there is a custom NPN/ALPN callback installed
   1026      // and that callback doesn't properly handle ALPN.
   1027      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "error in ALPN selection callback");
   1028      return false;
   1029 
   1030    case SSL_NEXT_PROTO_EARLY_VALUE:
   1031      MOZ_CRASH("Unexpected 0-RTT ALPN value");
   1032      return false;
   1033  }
   1034 
   1035  // Warning: NSS won't null terminate the ALPN string for us.
   1036  std::string chosen(chosenAlpn, chosenAlpnLen);
   1037  MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Selected ALPN string: " << chosen);
   1038  if (alpn_allowed_.find(chosen) == alpn_allowed_.end()) {
   1039    // Maybe our peer chose a protocol we didn't offer (when we are client),
   1040    // or something is seriously wrong.
   1041    std::ostringstream ss;
   1042    for (auto i = alpn_allowed_.begin(); i != alpn_allowed_.end(); ++i) {
   1043      ss << (i == alpn_allowed_.begin() ? " '" : ", '") << *i << "'";
   1044    }
   1045    MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Bad ALPN string: '" << chosen
   1046                                   << "'; permitted:" << ss.str());
   1047    return false;
   1048  }
   1049  alpn_ = chosen;
   1050  return true;
   1051 }
   1052 
   1053 void TransportLayerDtls::PacketReceived(TransportLayer* layer,
   1054                                        MediaPacket& packet) {
   1055  CheckThread();
   1056  MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "PacketReceived(" << packet.len() << ")");
   1057 
   1058  if (state_ != TS_CONNECTING && state_ != TS_OPEN) {
   1059    MOZ_MTLOG(ML_DEBUG,
   1060              LAYER_INFO << "Discarding packet in inappropriate state");
   1061    return;
   1062  }
   1063 
   1064  if (!packet.data()) {
   1065    // Something ate this, probably the SRTP layer
   1066    return;
   1067  }
   1068 
   1069  if (packet.type() != MediaPacket::DTLS) {
   1070    return;
   1071  }
   1072 
   1073  nspr_io_adapter_->PacketReceived(packet);
   1074  GetDecryptedPackets();
   1075 }
   1076 
   1077 void TransportLayerDtls::GetDecryptedPackets() {
   1078  // If we're still connecting, try to handshake
   1079  if (state_ == TS_CONNECTING) {
   1080    Handshake();
   1081  }
   1082 
   1083  // Now try a recv if we're open, since there might be data left
   1084  if (state_ == TS_OPEN) {
   1085    int32_t rv;
   1086    // One packet might contain several DTLS packets
   1087    do {
   1088      // nICEr uses a 9216 bytes buffer to allow support for jumbo frames
   1089      // Can we peek to get a better idea of the actual size?
   1090      static const size_t kBufferSize = 9216;
   1091      auto buffer = MakeUnique<uint8_t[]>(kBufferSize);
   1092      rv = PR_Recv(ssl_fd_.get(), buffer.get(), kBufferSize, 0,
   1093                   PR_INTERVAL_NO_WAIT);
   1094      if (rv > 0) {
   1095        // We have data
   1096        MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Read " << rv << " bytes from NSS");
   1097        MediaPacket packet;
   1098        packet.SetType(MediaPacket::SCTP);
   1099        packet.Take(std::move(buffer), static_cast<size_t>(rv));
   1100        SignalPacketReceived(this, packet);
   1101      } else if (rv == 0) {
   1102        TL_SET_STATE(TS_CLOSED);
   1103      } else {
   1104        int32_t err = PR_GetError();
   1105 
   1106        if (err == PR_WOULD_BLOCK_ERROR) {
   1107          // This gets ignored
   1108          MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Receive would have blocked");
   1109        } else {
   1110          MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "NSS Error " << err);
   1111          TL_SET_STATE(TS_ERROR);
   1112        }
   1113      }
   1114    } while (rv > 0);
   1115  }
   1116 }
   1117 
   1118 void TransportLayerDtls::SetState(State state, const char* file,
   1119                                  unsigned line) {
   1120  if (timer_) {
   1121    switch (state) {
   1122      case TS_NONE:
   1123      case TS_INIT:
   1124        MOZ_ASSERT(false);
   1125        break;
   1126      case TS_CONNECTING:
   1127        break;
   1128      case TS_OPEN:
   1129      case TS_CLOSED:
   1130      case TS_ERROR:
   1131        timer_->Cancel();
   1132        break;
   1133    }
   1134  }
   1135 
   1136  TransportLayer::SetState(state, file, line);
   1137 }
   1138 
   1139 TransportResult TransportLayerDtls::SendPacket(MediaPacket& packet) {
   1140  CheckThread();
   1141  if (state_ != TS_OPEN) {
   1142    MOZ_MTLOG(ML_ERROR,
   1143              LAYER_INFO << "Can't call SendPacket() in state " << state_);
   1144    return TE_ERROR;
   1145  }
   1146 
   1147  int32_t rv = PR_Send(ssl_fd_.get(), packet.data(), packet.len(), 0,
   1148                       PR_INTERVAL_NO_WAIT);
   1149 
   1150  if (rv > 0) {
   1151    // We have data
   1152    MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Wrote " << rv << " bytes to SSL Layer");
   1153    return rv;
   1154  }
   1155 
   1156  if (rv == 0) {
   1157    TL_SET_STATE(TS_CLOSED);
   1158    return 0;
   1159  }
   1160 
   1161  int32_t err = PR_GetError();
   1162 
   1163  if (err == PR_WOULD_BLOCK_ERROR) {
   1164    // This gets ignored
   1165    MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Send would have blocked");
   1166    return TE_WOULDBLOCK;
   1167  }
   1168 
   1169  MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "NSS Error " << err);
   1170  TL_SET_STATE(TS_ERROR);
   1171  return TE_ERROR;
   1172 }
   1173 
   1174 SECStatus TransportLayerDtls::GetClientAuthDataHook(
   1175    void* arg, PRFileDesc* fd, CERTDistNames* caNames,
   1176    CERTCertificate** pRetCert, SECKEYPrivateKey** pRetKey) {
   1177  MOZ_MTLOG(ML_DEBUG, "Server requested client auth");
   1178 
   1179  TransportLayerDtls* stream = reinterpret_cast<TransportLayerDtls*>(arg);
   1180  stream->CheckThread();
   1181 
   1182  if (!stream->identity_) {
   1183    MOZ_MTLOG(ML_ERROR, "No identity available");
   1184    PR_SetError(SSL_ERROR_NO_CERTIFICATE, 0);
   1185    return SECFailure;
   1186  }
   1187 
   1188  *pRetCert = CERT_DupCertificate(stream->identity_->cert().get());
   1189  if (!*pRetCert) {
   1190    PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0);
   1191    return SECFailure;
   1192  }
   1193 
   1194  *pRetKey = SECKEY_CopyPrivateKey(stream->identity_->privkey().get());
   1195  if (!*pRetKey) {
   1196    CERT_DestroyCertificate(*pRetCert);
   1197    *pRetCert = nullptr;
   1198    PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0);
   1199    return SECFailure;
   1200  }
   1201 
   1202  return SECSuccess;
   1203 }
   1204 
   1205 nsresult TransportLayerDtls::SetSrtpCiphers(
   1206    const std::vector<uint16_t>& ciphers) {
   1207  enabled_srtp_ciphers_ = std::move(ciphers);
   1208  return NS_OK;
   1209 }
   1210 
   1211 nsresult TransportLayerDtls::GetSrtpCipher(uint16_t* cipher) const {
   1212  CheckThread();
   1213  if (srtp_cipher_ == 0) {
   1214    return NS_ERROR_NOT_AVAILABLE;
   1215  }
   1216  *cipher = srtp_cipher_;
   1217  return NS_OK;
   1218 }
   1219 
   1220 static uint8_t* WriteUint16(uint8_t* cursor, uint16_t v) {
   1221  *cursor++ = v >> 8;
   1222  *cursor++ = v & 0xff;
   1223  return cursor;
   1224 }
   1225 
   1226 static SSLHandshakeType SrtpXtnServerMessage(PRFileDesc* fd) {
   1227  SSLPreliminaryChannelInfo preinfo;
   1228  SECStatus rv = SSL_GetPreliminaryChannelInfo(fd, &preinfo, sizeof(preinfo));
   1229  if (rv != SECSuccess) {
   1230    MOZ_ASSERT(false, "Can't get version info");
   1231    return ssl_hs_client_hello;
   1232  }
   1233  return (preinfo.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3)
   1234             ? ssl_hs_encrypted_extensions
   1235             : ssl_hs_server_hello;
   1236 }
   1237 
   1238 /* static */
   1239 PRBool TransportLayerDtls::WriteSrtpXtn(PRFileDesc* fd,
   1240                                        SSLHandshakeType message, uint8_t* data,
   1241                                        unsigned int* len, unsigned int max_len,
   1242                                        void* arg) {
   1243  auto self = reinterpret_cast<TransportLayerDtls*>(arg);
   1244 
   1245  // ClientHello: send all supported versions.
   1246  if (message == ssl_hs_client_hello) {
   1247    MOZ_ASSERT(self->role_ == CLIENT);
   1248    MOZ_ASSERT(self->enabled_srtp_ciphers_.size(), "Haven't enabled SRTP");
   1249    // We will take 2 octets for each cipher, plus a 2 octet length and 1
   1250    // octet for the length of the empty MKI.
   1251    if (max_len < self->enabled_srtp_ciphers_.size() * 2 + 3) {
   1252      MOZ_ASSERT(false, "Not enough space to send SRTP extension");
   1253      return false;
   1254    }
   1255    uint8_t* cursor = WriteUint16(data, self->enabled_srtp_ciphers_.size() * 2);
   1256    for (auto cs : self->enabled_srtp_ciphers_) {
   1257      cursor = WriteUint16(cursor, cs);
   1258    }
   1259    *cursor++ = 0;  // MKI is empty
   1260    *len = cursor - data;
   1261    return true;
   1262  }
   1263 
   1264  if (message == SrtpXtnServerMessage(fd)) {
   1265    MOZ_ASSERT(self->role_ == SERVER);
   1266    if (!self->srtp_cipher_) {
   1267      // Not negotiated. Definitely bad, but the connection can fail later.
   1268      return false;
   1269    }
   1270    if (max_len < 5) {
   1271      MOZ_ASSERT(false, "Not enough space to send SRTP extension");
   1272      return false;
   1273    }
   1274 
   1275    uint8_t* cursor = WriteUint16(data, 2);  // Length = 2.
   1276    cursor = WriteUint16(cursor, self->srtp_cipher_);
   1277    *cursor++ = 0;  // No MKI
   1278    *len = cursor - data;
   1279    return true;
   1280  }
   1281 
   1282  return false;
   1283 }
   1284 
   1285 class TlsParser {
   1286 public:
   1287  TlsParser(const uint8_t* data, size_t len) : cursor_(data), remaining_(len) {}
   1288 
   1289  bool error() const { return error_; }
   1290  size_t remaining() const { return remaining_; }
   1291 
   1292  template <typename T,
   1293            class = typename std::enable_if<std::is_unsigned<T>::value>::type>
   1294  void Read(T* v, size_t sz = sizeof(T)) {
   1295    MOZ_ASSERT(sz <= sizeof(T),
   1296               "Type is too small to hold the value requested");
   1297    if (remaining_ < sz) {
   1298      error_ = true;
   1299      return;
   1300    }
   1301 
   1302    T result = 0;
   1303    for (size_t i = 0; i < sz; ++i) {
   1304      result = (result << 8) | *cursor_++;
   1305      remaining_--;
   1306    }
   1307    *v = result;
   1308  }
   1309 
   1310  template <typename T,
   1311            class = typename std::enable_if<std::is_unsigned<T>::value>::type>
   1312  void ReadVector(std::vector<T>* v, size_t w) {
   1313    MOZ_ASSERT(v->empty(), "vector needs to be empty");
   1314 
   1315    uint32_t len;
   1316    Read(&len, w);
   1317    if (error_ || len % sizeof(T) != 0 || len > remaining_) {
   1318      error_ = true;
   1319      return;
   1320    }
   1321 
   1322    size_t count = len / sizeof(T);
   1323    v->reserve(count);
   1324    for (T i = 0; !error_ && i < count; ++i) {
   1325      T item;
   1326      Read(&item);
   1327      if (!error_) {
   1328        v->push_back(item);
   1329      }
   1330    }
   1331  }
   1332 
   1333  void Skip(size_t n) {
   1334    if (remaining_ < n) {
   1335      error_ = true;
   1336    } else {
   1337      cursor_ += n;
   1338      remaining_ -= n;
   1339    }
   1340  }
   1341 
   1342  size_t SkipVector(size_t w) {
   1343    uint32_t len = 0;
   1344    Read(&len, w);
   1345    Skip(len);
   1346    return len;
   1347  }
   1348 
   1349 private:
   1350  const uint8_t* cursor_;
   1351  size_t remaining_;
   1352  bool error_ = false;
   1353 };
   1354 
   1355 /* static */
   1356 SECStatus TransportLayerDtls::HandleSrtpXtn(
   1357    PRFileDesc* fd, SSLHandshakeType message, const uint8_t* data,
   1358    unsigned int len, SSLAlertDescription* alert, void* arg) {
   1359  static const uint8_t kTlsAlertHandshakeFailure = 40;
   1360  static const uint8_t kTlsAlertIllegalParameter = 47;
   1361  static const uint8_t kTlsAlertDecodeError = 50;
   1362  static const uint8_t kTlsAlertUnsupportedExtension = 110;
   1363 
   1364  auto self = reinterpret_cast<TransportLayerDtls*>(arg);
   1365 
   1366  // Parse the extension.
   1367  TlsParser parser(data, len);
   1368  std::vector<uint16_t> advertised;
   1369  parser.ReadVector(&advertised, 2);
   1370  size_t mki_len = parser.SkipVector(1);
   1371  if (parser.error() || parser.remaining() > 0) {
   1372    *alert = kTlsAlertDecodeError;
   1373    return SECFailure;
   1374  }
   1375 
   1376  if (message == ssl_hs_client_hello) {
   1377    MOZ_ASSERT(self->role_ == SERVER);
   1378    if (self->enabled_srtp_ciphers_.empty()) {
   1379      // We don't have SRTP enabled, which is probably bad, but no sense in
   1380      // having the handshake fail at this point, let the client decide if
   1381      // this is a problem.
   1382      return SECSuccess;
   1383    }
   1384 
   1385    for (auto supported : self->enabled_srtp_ciphers_) {
   1386      auto it = std::find(advertised.begin(), advertised.end(), supported);
   1387      if (it != advertised.end()) {
   1388        self->srtp_cipher_ = supported;
   1389        return SECSuccess;
   1390      }
   1391    }
   1392 
   1393    // No common cipher.
   1394    *alert = kTlsAlertHandshakeFailure;
   1395    return SECFailure;
   1396  }
   1397 
   1398  if (message == SrtpXtnServerMessage(fd)) {
   1399    MOZ_ASSERT(self->role_ == CLIENT);
   1400    if (advertised.size() != 1 || mki_len > 0) {
   1401      *alert = kTlsAlertIllegalParameter;
   1402      return SECFailure;
   1403    }
   1404    self->srtp_cipher_ = advertised[0];
   1405    return SECSuccess;
   1406  }
   1407 
   1408  *alert = kTlsAlertUnsupportedExtension;
   1409  return SECFailure;
   1410 }
   1411 
   1412 nsresult TransportLayerDtls::ExportKeyingMaterial(const std::string& label,
   1413                                                  bool use_context,
   1414                                                  const std::string& context,
   1415                                                  unsigned char* out,
   1416                                                  unsigned int outlen) {
   1417  CheckThread();
   1418  if (state_ != TS_OPEN) {
   1419    MOZ_ASSERT(false, "Transport must be open for ExportKeyingMaterial");
   1420    return NS_ERROR_NOT_AVAILABLE;
   1421  }
   1422  SECStatus rv = SSL_ExportKeyingMaterial(
   1423      ssl_fd_.get(), label.c_str(), label.size(), use_context,
   1424      reinterpret_cast<const unsigned char*>(context.c_str()), context.size(),
   1425      out, outlen);
   1426  if (rv != SECSuccess) {
   1427    MOZ_MTLOG(ML_ERROR, "Couldn't export SSL keying material");
   1428    return NS_ERROR_FAILURE;
   1429  }
   1430 
   1431  return NS_OK;
   1432 }
   1433 
   1434 SECStatus TransportLayerDtls::AuthCertificateHook(void* arg, PRFileDesc* fd,
   1435                                                  PRBool checksig,
   1436                                                  PRBool isServer) {
   1437  TransportLayerDtls* stream = reinterpret_cast<TransportLayerDtls*>(arg);
   1438  stream->CheckThread();
   1439  return stream->AuthCertificateHook(fd, checksig, isServer);
   1440 }
   1441 
   1442 SECStatus TransportLayerDtls::CheckDigest(
   1443    const DtlsDigest& digest, UniqueCERTCertificate& peer_cert) const {
   1444  DtlsDigest computed_digest(digest.algorithm_);
   1445 
   1446  MOZ_MTLOG(ML_DEBUG,
   1447            LAYER_INFO << "Checking digest, algorithm=" << digest.algorithm_);
   1448  nsresult res = DtlsIdentity::ComputeFingerprint(peer_cert, &computed_digest);
   1449  if (NS_FAILED(res)) {
   1450    MOZ_MTLOG(ML_ERROR, "Could not compute peer fingerprint for digest "
   1451                            << digest.algorithm_);
   1452    // Go to end
   1453    PR_SetError(SSL_ERROR_BAD_CERTIFICATE, 0);
   1454    return SECFailure;
   1455  }
   1456 
   1457  if (computed_digest != digest) {
   1458    MOZ_MTLOG(ML_ERROR, "Digest does not match");
   1459    PR_SetError(SSL_ERROR_BAD_CERTIFICATE, 0);
   1460    return SECFailure;
   1461  }
   1462 
   1463  return SECSuccess;
   1464 }
   1465 
   1466 SECStatus TransportLayerDtls::AuthCertificateHook(PRFileDesc* fd,
   1467                                                  PRBool checksig,
   1468                                                  PRBool isServer) {
   1469  CheckThread();
   1470  UniqueCERTCertificate peer_cert(SSL_PeerCertificate(fd));
   1471 
   1472  // We are not set up to take this being called multiple
   1473  // times. Change this if we ever add renegotiation.
   1474  MOZ_ASSERT(!auth_hook_called_);
   1475  if (auth_hook_called_) {
   1476    PR_SetError(PR_UNKNOWN_ERROR, 0);
   1477    return SECFailure;
   1478  }
   1479  auth_hook_called_ = true;
   1480 
   1481  MOZ_ASSERT(verification_mode_ != VERIFY_UNSET);
   1482 
   1483  switch (verification_mode_) {
   1484    case VERIFY_UNSET:
   1485      // Break out to error exit
   1486      PR_SetError(PR_UNKNOWN_ERROR, 0);
   1487      break;
   1488 
   1489    case VERIFY_ALLOW_ALL:
   1490      cert_ok_ = true;
   1491      return SECSuccess;
   1492 
   1493    case VERIFY_DIGEST: {
   1494      MOZ_ASSERT(!digests_.empty());
   1495      // Check all the provided digests
   1496 
   1497      // Checking functions call PR_SetError()
   1498      SECStatus rv = SECFailure;
   1499      for (auto digest : digests_) {
   1500        rv = CheckDigest(digest, peer_cert);
   1501 
   1502        // Matches a digest, we are good to go
   1503        if (rv == SECSuccess) {
   1504          cert_ok_ = true;
   1505          return SECSuccess;
   1506        }
   1507      }
   1508    } break;
   1509    default:
   1510      MOZ_CRASH();  // Can't happen
   1511  }
   1512 
   1513  return SECFailure;
   1514 }
   1515 
   1516 void TransportLayerDtls::TimerCallback(nsITimer* timer, void* arg) {
   1517  TransportLayerDtls* dtls = reinterpret_cast<TransportLayerDtls*>(arg);
   1518 
   1519  MOZ_MTLOG(ML_DEBUG, "DTLS timer expired");
   1520 
   1521  dtls->Handshake();
   1522 }
   1523 
   1524 void TransportLayerDtls::RecordHandshakeCompletionTelemetry(
   1525    const char* aResult) {
   1526  if (role_ == CLIENT) {
   1527    mozilla::glean::webrtcdtls::client_handshake_result.Get(nsCString(aResult))
   1528        .Add(1);
   1529  } else {
   1530    mozilla::glean::webrtcdtls::server_handshake_result.Get(nsCString(aResult))
   1531        .Add(1);
   1532  }
   1533 }
   1534 
   1535 void TransportLayerDtls::RecordStartedHandshakeTelemetry() {
   1536  if (role_ == CLIENT) {
   1537    mozilla::glean::webrtcdtls::client_handshake_started_counter.Add(1);
   1538  } else {
   1539    mozilla::glean::webrtcdtls::server_handshake_started_counter.Add(1);
   1540  }
   1541 }
   1542 
   1543 void TransportLayerDtls::RecordTlsTelemetry() {
   1544  MOZ_ASSERT(state_ == TS_OPEN);
   1545  SSLChannelInfo info;
   1546  SECStatus ss = SSL_GetChannelInfo(ssl_fd_.get(), &info, sizeof(info));
   1547  if (ss != SECSuccess) {
   1548    MOZ_MTLOG(ML_NOTICE,
   1549              LAYER_INFO << "RecordTlsTelemetry failed to get channel info");
   1550    return;
   1551  }
   1552 
   1553  switch (info.protocolVersion) {
   1554    case SSL_LIBRARY_VERSION_TLS_1_1:
   1555      mozilla::glean::webrtcdtls::protocol_version.Get("1.0"_ns).Add(1);
   1556      break;
   1557    case SSL_LIBRARY_VERSION_TLS_1_2:
   1558      mozilla::glean::webrtcdtls::protocol_version.Get("1.2"_ns).Add(1);
   1559      break;
   1560    case SSL_LIBRARY_VERSION_TLS_1_3:
   1561      mozilla::glean::webrtcdtls::protocol_version.Get("1.3"_ns).Add(1);
   1562      break;
   1563    default:
   1564      MOZ_CRASH("Unknown SSL version");
   1565  }
   1566 
   1567  {
   1568    std::ostringstream oss;
   1569    // Record TLS cipher-suite ID as a string (eg;
   1570    // TLS_DHE_RSA_WITH_AES_128_CBC_SHA is 0x0033)
   1571    oss << "0x" << std::setfill('0') << std::setw(4) << std::hex
   1572        << info.cipherSuite;
   1573    mozilla::glean::webrtcdtls::cipher.Get(nsCString(oss.str().c_str())).Add(1);
   1574    MOZ_MTLOG(ML_DEBUG, "cipher: " << oss.str());
   1575  }
   1576 
   1577  // Record Key Exchange Algorithm Type
   1578  // keyExchange null=0, rsa=1, dh=2, ecdh=4, ecdh_hybrid=8
   1579  mozilla::glean::webrtcdtls::key_exchange_algorithm.AccumulateSingleSample(
   1580      info.keaType);
   1581 
   1582  uint16_t cipher;
   1583  nsresult rv = GetSrtpCipher(&cipher);
   1584 
   1585  if (NS_FAILED(rv)) {
   1586    MOZ_MTLOG(ML_DEBUG, "No SRTP cipher suite");
   1587    return;
   1588  }
   1589 
   1590  {
   1591    std::ostringstream oss;
   1592    // Record SRTP cipher-suite ID as a string (eg;
   1593    // SRTP_AES128_CM_HMAC_SHA1_80 is 0x0001)
   1594    oss << "0x" << std::setfill('0') << std::setw(4) << std::hex << cipher;
   1595    mozilla::glean::webrtcdtls::srtp_cipher.Get(nsCString(oss.str().c_str()))
   1596        .Add(1);
   1597    MOZ_MTLOG(ML_DEBUG, "srtp cipher: " << oss.str());
   1598  }
   1599 }
   1600 
   1601 }  // namespace mozilla