tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

FaultyServer.cpp (8228B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 #include <stdio.h>
      6 
      7 #include "nspr.h"
      8 #include "ScopedNSSTypes.h"
      9 #include "ssl.h"
     10 #include "ssl3prot.h"
     11 #include "sslexp.h"
     12 #include "sslimpl.h"
     13 #include "TLSServer.h"
     14 
     15 #include "mozilla/Sprintf.h"
     16 
     17 using namespace mozilla;
     18 using namespace mozilla::test;
     19 
     20 enum FaultType {
     21  None = 0,
     22  ZeroRtt,
     23  UnknownSNI,
     24  Mlkem768x25519,
     25 };
     26 
     27 struct FaultyServerHost {
     28  const char* mHostName;
     29  const char* mCertName;
     30  FaultType mFaultType;
     31 };
     32 
     33 const char* kHostOk = "ok.example.com";
     34 const char* kHostUnknown = "unknown.example.com";
     35 const char* kHostZeroRttAlertBadMac = "0rtt-alert-bad-mac.example.com";
     36 const char* kHostZeroRttAlertVersion =
     37    "0rtt-alert-protocol-version.example.com";
     38 const char* kHostZeroRttAlertUnexpected = "0rtt-alert-unexpected.example.com";
     39 const char* kHostZeroRttAlertDowngrade = "0rtt-alert-downgrade.example.com";
     40 
     41 const char* kHostMlkem768x25519NetInterrupt =
     42    "mlkem768x25519-net-interrupt.example.com";
     43 const char* kHostMlkem768x25519AlertAfterServerHello =
     44    "mlkem768x25519-alert-after-server-hello.example.com";
     45 
     46 const char* kCertWildcard = "default-ee";
     47 
     48 /* Each type of failure gets a different SNI.
     49 * the "default-ee" cert has a SAN for *.example.com
     50 * the "no-san-ee" cert is signed by the test-ca, but it doesn't have any SANs.
     51 */
     52 MOZ_RUNINIT const FaultyServerHost sFaultyServerHosts[]{
     53    {kHostOk, kCertWildcard, None},
     54    {kHostUnknown, kCertWildcard, UnknownSNI},
     55    {kHostZeroRttAlertBadMac, kCertWildcard, ZeroRtt},
     56    {kHostZeroRttAlertVersion, kCertWildcard, ZeroRtt},
     57    {kHostZeroRttAlertUnexpected, kCertWildcard, ZeroRtt},
     58    {kHostZeroRttAlertDowngrade, kCertWildcard, ZeroRtt},
     59    {kHostMlkem768x25519NetInterrupt, kCertWildcard, Mlkem768x25519},
     60    {kHostMlkem768x25519AlertAfterServerHello, kCertWildcard, Mlkem768x25519},
     61    {nullptr, nullptr},
     62 };
     63 
     64 nsresult SendAll(PRFileDesc* aSocket, const char* aData, size_t aDataLen) {
     65  if (gDebugLevel >= DEBUG_VERBOSE) {
     66    fprintf(stderr, "sending '%s'\n", aData);
     67  }
     68 
     69  int32_t len = static_cast<int32_t>(aDataLen);
     70  while (len > 0) {
     71    int32_t bytesSent = PR_Send(aSocket, aData, len, 0, PR_INTERVAL_NO_TIMEOUT);
     72    if (bytesSent == -1) {
     73      PrintPRError("PR_Send failed");
     74      return NS_ERROR_FAILURE;
     75    }
     76 
     77    len -= bytesSent;
     78    aData += bytesSent;
     79  }
     80 
     81  return NS_OK;
     82 }
     83 
     84 // returns 0 on success, non-zero on error
     85 int DoCallback(const char* path) {
     86  UniquePRFileDesc socket(PR_NewTCPSocket());
     87  if (!socket) {
     88    PrintPRError("PR_NewTCPSocket failed");
     89    return 1;
     90  }
     91 
     92  uint32_t port = 0;
     93  const char* callbackPort = PR_GetEnv("FAULTY_SERVER_CALLBACK_PORT");
     94  if (callbackPort) {
     95    port = atoi(callbackPort);
     96  }
     97  if (!port) {
     98    return 0;
     99  }
    100 
    101  PRNetAddr addr;
    102  PR_InitializeNetAddr(PR_IpAddrLoopback, port, &addr);
    103  if (PR_Connect(socket.get(), &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) {
    104    PrintPRError("PR_Connect failed");
    105    return 1;
    106  }
    107 
    108  char request[512];
    109  SprintfLiteral(request, "GET %s HTTP/1.0\r\n\r\n", path);
    110  SendAll(socket.get(), request, strlen(request));
    111  char buf[4096];
    112  memset(buf, 0, sizeof(buf));
    113  int32_t bytesRead =
    114      PR_Recv(socket.get(), buf, sizeof(buf) - 1, 0, PR_INTERVAL_NO_TIMEOUT);
    115  if (bytesRead < 0) {
    116    PrintPRError("PR_Recv failed 1");
    117    return 1;
    118  }
    119  if (bytesRead == 0) {
    120    fprintf(stderr, "PR_Recv eof 1\n");
    121    return 1;
    122  }
    123  // fprintf(stderr, "%s\n", buf);
    124  return 0;
    125 }
    126 
    127 /* These are very rough examples. In practice the `arg` parameter to a callback
    128 * might need to be an object that holds some state, like the various traffic
    129 * secrets. */
    130 
    131 /* An SSLSecretCallback is called after every key derivation step in the TLS
    132 * 1.3 key schedule.
    133 *
    134 * Epoch 1 is for the early traffic secret.
    135 * Epoch 2 is for the handshake traffic secrets.
    136 * Epoch 3 is for the application traffic secrets.
    137 */
    138 void SecretCallbackFailZeroRtt(PRFileDesc* fd, PRUint16 epoch,
    139                               SSLSecretDirection dir, PK11SymKey* secret,
    140                               void* arg) {
    141  fprintf(stderr, "0RTT handler epoch=%d dir=%d\n", epoch, (uint32_t)dir);
    142  FaultyServerHost* host = static_cast<FaultyServerHost*>(arg);
    143 
    144  if (epoch == 1 && dir == ssl_secret_read) {
    145    sslSocket* ss = ssl_FindSocket(fd);
    146    if (!ss) {
    147      fprintf(stderr, "0RTT handler, no ss!\n");
    148      return;
    149    }
    150 
    151    char path[256];
    152    SprintfLiteral(path, "/callback/%d", epoch);
    153    DoCallback(path);
    154 
    155    fprintf(stderr, "0RTT handler, configuring alert\n");
    156    if (!strcmp(host->mHostName, kHostZeroRttAlertBadMac)) {
    157      SSL3_SendAlert(ss, alert_fatal, bad_record_mac);
    158    } else if (!strcmp(host->mHostName, kHostZeroRttAlertVersion)) {
    159      SSL3_SendAlert(ss, alert_fatal, protocol_version);
    160    } else if (!strcmp(host->mHostName, kHostZeroRttAlertUnexpected)) {
    161      SSL3_SendAlert(ss, alert_fatal, unexpected_message);
    162    }
    163  }
    164 }
    165 
    166 SECStatus FailingWriteCallback(PRFileDesc* fd, PRUint16 epoch,
    167                               SSLContentType contentType, const PRUint8* data,
    168                               unsigned int len, void* arg) {
    169  return SECFailure;
    170 }
    171 
    172 void SecretCallbackFailMlkem768x25519(PRFileDesc* fd, PRUint16 epoch,
    173                                      SSLSecretDirection dir,
    174                                      PK11SymKey* secret, void* arg) {
    175  fprintf(stderr, "Mlkem768x25519 handler epoch=%d dir=%d\n", epoch,
    176          (uint32_t)dir);
    177  FaultyServerHost* host = static_cast<FaultyServerHost*>(arg);
    178 
    179  if (epoch == 2 && dir == ssl_secret_write) {
    180    sslSocket* ss = ssl_FindSocket(fd);
    181    if (!ss) {
    182      fprintf(stderr, "Mlkem768x25519 handler, no ss!\n");
    183      return;
    184    }
    185 
    186    if (!ss->sec.keaGroup) {
    187      fprintf(stderr, "Mlkem768x25519 handler, no ss->sec.keaGroup!\n");
    188      return;
    189    }
    190 
    191    char path[256];
    192    SprintfLiteral(path, "/callback/%u", ss->sec.keaGroup->name);
    193    DoCallback(path);
    194 
    195    if (ss->sec.keaGroup->name != ssl_grp_kem_mlkem768x25519) {
    196      return;
    197    }
    198 
    199    fprintf(stderr, "Mlkem768x25519 handler, configuring alert\n");
    200    if (strcmp(host->mHostName, kHostMlkem768x25519NetInterrupt) == 0) {
    201      // Install a record write callback that causes the next write to fail.
    202      // The client will see this as a PR_END_OF_FILE / NS_ERROR_NET_INTERRUPT
    203      // error.
    204      ss->recordWriteCallback = FailingWriteCallback;
    205    } else if (!strcmp(host->mHostName,
    206                       kHostMlkem768x25519AlertAfterServerHello)) {
    207      SSL3_SendAlert(ss, alert_fatal, close_notify);
    208    }
    209  }
    210 }
    211 
    212 int32_t DoSNISocketConfig(PRFileDesc* aFd, const SECItem* aSrvNameArr,
    213                          uint32_t aSrvNameArrSize, void* aArg) {
    214  const FaultyServerHost* host =
    215      GetHostForSNI(aSrvNameArr, aSrvNameArrSize, sFaultyServerHosts);
    216  if (!host || host->mFaultType == UnknownSNI) {
    217    PrintPRError("No cert found for hostname");
    218    return SSL_SNI_SEND_ALERT;
    219  }
    220 
    221  if (gDebugLevel >= DEBUG_VERBOSE) {
    222    fprintf(stderr, "found pre-defined host '%s'\n", host->mHostName);
    223  }
    224 
    225  const SSLNamedGroup mlkemTestNamedGroups[] = {ssl_grp_kem_mlkem768x25519,
    226                                                ssl_grp_ec_curve25519};
    227 
    228  switch (host->mFaultType) {
    229    case ZeroRtt:
    230      SSL_SecretCallback(aFd, &SecretCallbackFailZeroRtt, (void*)host);
    231      break;
    232    case Mlkem768x25519:
    233      SSL_SecretCallback(aFd, &SecretCallbackFailMlkem768x25519, (void*)host);
    234      SSL_NamedGroupConfig(aFd, mlkemTestNamedGroups,
    235                           std::size(mlkemTestNamedGroups));
    236      break;
    237    case None:
    238      break;
    239    default:
    240      break;
    241  }
    242 
    243  UniqueCERTCertificate cert;
    244  SSLKEAType certKEA;
    245  if (SECSuccess != ConfigSecureServerWithNamedCert(aFd, host->mCertName, &cert,
    246                                                    &certKEA, nullptr)) {
    247    return SSL_SNI_SEND_ALERT;
    248  }
    249 
    250  return 0;
    251 }
    252 
    253 SECStatus ConfigureServer(PRFileDesc* aFd) { return SECSuccess; }
    254 
    255 int main(int argc, char* argv[]) {
    256  int rv = StartServer(argc, argv, DoSNISocketConfig, nullptr, ConfigureServer);
    257  if (rv < 0) {
    258    return rv;
    259  }
    260 }