tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ssl_record_unittest.cc (27519B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "nss.h"
      8 #include "ssl.h"
      9 #include "sslimpl.h"
     10 
     11 #include "databuffer.h"
     12 #include "gtest_utils.h"
     13 #include "tls_connect.h"
     14 #include "tls_filter.h"
     15 
     16 namespace nss_test {
     17 
     18 const static size_t kMacSize = 20;
     19 
     20 class TlsPaddingTest
     21    : public ::testing::Test,
     22      public ::testing::WithParamInterface<std::tuple<size_t, bool>> {
     23 public:
     24  TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) {
     25    size_t extra =
     26        (plaintext_len_ + 1) % 16;  // Bytes past a block (1 == pad len)
     27    // Minimal padding.
     28    pad_len_ = extra ? 16 - extra : 0;
     29    if (std::get<1>(GetParam())) {
     30      // Maximal padding.
     31      pad_len_ += 240;
     32    }
     33    MakePaddedPlaintext();
     34  }
     35 
     36  // Makes a plaintext record with correct padding.
     37  void MakePaddedPlaintext() {
     38    EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16);
     39    size_t i = 0;
     40    plaintext_.Allocate(plaintext_len_ + pad_len_ + 1);
     41    for (; i < plaintext_len_; ++i) {
     42      plaintext_.Write(i, 'A', 1);
     43    }
     44 
     45    for (; i < plaintext_len_ + pad_len_ + 1; ++i) {
     46      plaintext_.Write(i, pad_len_, 1);
     47    }
     48  }
     49 
     50  void Unpad(bool expect_success) {
     51    std::cerr << "Content length=" << plaintext_len_
     52              << " padding length=" << pad_len_
     53              << " total length=" << plaintext_.len() << std::endl;
     54    std::cerr << "Plaintext: " << plaintext_ << std::endl;
     55    sslBuffer s;
     56    s.buf = const_cast<unsigned char*>(
     57        static_cast<const unsigned char*>(plaintext_.data()));
     58    s.len = plaintext_.len();
     59    SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
     60    if (expect_success) {
     61      EXPECT_EQ(SECSuccess, rv);
     62      EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len));
     63    } else {
     64      EXPECT_EQ(SECFailure, rv);
     65    }
     66  }
     67 
     68 protected:
     69  size_t plaintext_len_;
     70  size_t pad_len_;
     71  DataBuffer plaintext_;
     72 };
     73 
     74 TEST_P(TlsPaddingTest, Correct) {
     75  if (plaintext_len_ >= kMacSize) {
     76    Unpad(true);
     77  } else {
     78    Unpad(false);
     79  }
     80 }
     81 
     82 TEST_P(TlsPaddingTest, PadTooLong) {
     83  if (plaintext_.len() < 255) {
     84    plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1);
     85    Unpad(false);
     86  }
     87 }
     88 
     89 TEST_P(TlsPaddingTest, FirstByteOfPadWrong) {
     90  if (pad_len_) {
     91    plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1);
     92    Unpad(false);
     93  }
     94 }
     95 
     96 TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
     97  if (pad_len_) {
     98    plaintext_.Write(plaintext_.len() - 2,
     99                     plaintext_.data()[plaintext_.len() - 1] + 1, 1);
    100    Unpad(false);
    101  }
    102 }
    103 
    104 class RecordReplacer : public TlsRecordFilter {
    105 public:
    106  RecordReplacer(const std::shared_ptr<TlsAgent>& a, size_t size)
    107      : TlsRecordFilter(a), size_(size) {
    108    Disable();
    109  }
    110 
    111  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
    112                                    const DataBuffer& data,
    113                                    DataBuffer* changed) override {
    114    EXPECT_EQ(ssl_ct_application_data, header.content_type());
    115    changed->Allocate(size_);
    116 
    117    for (size_t i = 0; i < size_; ++i) {
    118      changed->data()[i] = i & 0xff;
    119    }
    120 
    121    Disable();
    122    return CHANGE;
    123  }
    124 
    125 private:
    126  size_t size_;
    127 };
    128 
    129 TEST_P(TlsConnectStream, BadRecordMac) {
    130  EnsureTlsSetup();
    131  Connect();
    132  client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
    133  ExpectAlert(server_, kTlsAlertBadRecordMac);
    134  client_->SendData(10);
    135 
    136  // Read from the client, get error.
    137  uint8_t buf[10];
    138  PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
    139  EXPECT_GT(0, rv);
    140  EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError());
    141 
    142  // Read the server alert.
    143  rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
    144  EXPECT_GT(0, rv);
    145  EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError());
    146 }
    147 
    148 TEST_F(TlsConnectStreamTls13, LargeRecord) {
    149  EnsureTlsSetup();
    150 
    151  const size_t record_limit = 16384;
    152  auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit);
    153  replacer->EnableDecryption();
    154  Connect();
    155 
    156  replacer->Enable();
    157  client_->SendData(10);
    158  WAIT_(server_->received_bytes() == record_limit, 2000);
    159  ASSERT_EQ(record_limit, server_->received_bytes());
    160 }
    161 
    162 TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
    163  EnsureTlsSetup();
    164 
    165  const size_t record_limit = 16384;
    166  auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit + 1);
    167  replacer->EnableDecryption();
    168  Connect();
    169 
    170  replacer->Enable();
    171  ExpectAlert(server_, kTlsAlertRecordOverflow);
    172  client_->SendData(10);  // This is expanded.
    173 
    174  uint8_t buf[record_limit + 2];
    175  PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
    176  EXPECT_GT(0, rv);
    177  EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError());
    178 
    179  // Read the server alert.
    180  rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
    181  EXPECT_GT(0, rv);
    182  EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
    183 }
    184 
    185 class ShortHeaderChecker : public PacketFilter {
    186 public:
    187  PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) {
    188    // The first octet should be 0b001000xx.
    189    EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3));
    190    return KEEP;
    191  }
    192 };
    193 
    194 TEST_F(TlsConnectDatagram13, AeadLimit) {
    195  Connect();
    196  EXPECT_EQ(SECSuccess, SSLInt_AdvanceDtls13DecryptFailures(server_->ssl_fd(),
    197                                                            (1ULL << 36) - 2));
    198  SendReceive(50);
    199 
    200  // Expect this to increment the counter. We should still be able to talk.
    201  client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
    202  client_->SendData(10);
    203  server_->ReadBytes(10);
    204  client_->ClearFilter();
    205  client_->ResetSentBytes(50);
    206  SendReceive(60);
    207 
    208  // Expect alert when the limit is hit.
    209  client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
    210  client_->SendData(10);
    211  ExpectAlert(server_, kTlsAlertBadRecordMac);
    212 
    213  // Check the error on both endpoints.
    214  uint8_t buf[10];
    215  PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
    216  EXPECT_EQ(-1, rv);
    217  EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError());
    218 
    219  rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
    220  EXPECT_EQ(-1, rv);
    221  EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError());
    222 }
    223 
    224 TEST_F(TlsConnectDatagram13, ShortHeadersClient) {
    225  Connect();
    226  client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
    227  client_->SetFilter(std::make_shared<ShortHeaderChecker>());
    228  SendReceive();
    229 }
    230 
    231 TEST_F(TlsConnectDatagram13, ShortHeadersServer) {
    232  Connect();
    233  server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
    234  server_->SetFilter(std::make_shared<ShortHeaderChecker>());
    235  SendReceive();
    236 }
    237 
    238 // Send a DTLSCiphertext header with a 2B sequence number, and no length.
    239 TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) {
    240  StartConnect();
    241  TlsSendCipherSpecCapturer capturer(client_);
    242  Connect();
    243  SendReceive(50);
    244 
    245  uint8_t buf[] = {0x32, 0x33, 0x34};
    246  auto spec = capturer.spec(1);
    247  ASSERT_NE(nullptr, spec.get());
    248  ASSERT_EQ(3, spec->epoch());
    249 
    250  uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno;
    251  TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct,
    252                         0x0003000000000001);
    253  TlsRecordHeader out_header(header);
    254  DataBuffer msg(buf, sizeof(buf));
    255  msg.Write(msg.len(), ssl_ct_application_data, 1);
    256  DataBuffer ciphertext;
    257  EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header));
    258 
    259  DataBuffer record;
    260  auto rv = out_header.Write(&record, 0, ciphertext);
    261  EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
    262  client_->SendDirect(record);
    263 
    264  server_->ReadBytes(3);
    265 }
    266 
    267 TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) {
    268  StartConnect();
    269  client_->Handshake();  // Send ClientHello
    270  server_->Handshake();  // Send first server flight
    271 
    272  // Record and drop the first record, which is the Finished.
    273  auto recorder = std::make_shared<TlsRecordRecorder>(client_);
    274  recorder->EnableDecryption();
    275  auto dropper = std::make_shared<SelectiveDropFilter>(1);
    276  client_->SetFilter(std::make_shared<ChainedPacketFilter>(
    277      ChainedPacketFilterInit({recorder, dropper})));
    278  client_->Handshake();  // Save and drop CFIN.
    279  EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    280 
    281  ASSERT_EQ(1U, recorder->count());
    282  auto& finished = recorder->record(0);
    283 
    284  DataBuffer d;
    285  size_t offset = d.Write(0, ssl_ct_handshake, 1);
    286  offset = d.Write(offset, SSL_LIBRARY_VERSION_TLS_1_2, 2);
    287  offset = d.Write(offset, finished.buffer.len(), 2);
    288  d.Append(finished.buffer);
    289  client_->SendDirect(d);
    290 
    291  // Now process the message.
    292  ExpectAlert(server_, kTlsAlertUnexpectedMessage);
    293  // The server should generate an alert.
    294  server_->Handshake();
    295  EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
    296  server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE);
    297  // Have the client consume the alert.
    298  client_->Handshake();
    299  EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
    300  client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
    301 }
    302 
    303 const static size_t kContentSizesArr[] = {
    304    1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
    305 
    306 auto kContentSizes = ::testing::ValuesIn(kContentSizesArr);
    307 const static bool kTrueFalseArr[] = {true, false};
    308 auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
    309 
    310 INSTANTIATE_TEST_SUITE_P(TlsPadding, TlsPaddingTest,
    311                         ::testing::Combine(kContentSizes, kTrueFalse));
    312 
    313 /* Filter to modify record header and content */
    314 class Tls13RecordModifier : public TlsRecordFilter {
    315 public:
    316  Tls13RecordModifier(const std::shared_ptr<TlsAgent>& a,
    317                      uint8_t contentType = ssl_ct_handshake, size_t size = 0,
    318                      size_t padding = 0)
    319      : TlsRecordFilter(a),
    320        contentType_(contentType),
    321        size_(size),
    322        padding_(padding) {}
    323 
    324 protected:
    325  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
    326                                    const DataBuffer& record, size_t* offset,
    327                                    DataBuffer* output) override {
    328    if (!header.is_protected()) {
    329      return KEEP;
    330    }
    331 
    332    uint16_t protection_epoch;
    333    uint8_t inner_content_type;
    334    DataBuffer plaintext;
    335    TlsRecordHeader out_header;
    336    if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
    337                   &plaintext, &out_header)) {
    338      return KEEP;
    339    }
    340 
    341    if (decrypting() && inner_content_type != ssl_ct_application_data) {
    342      return KEEP;
    343    }
    344 
    345    DataBuffer ciphertext;
    346    bool ok = Protect(spec(protection_epoch), out_header, contentType_,
    347                      DataBuffer(size_), &ciphertext, &out_header, padding_);
    348    EXPECT_TRUE(ok);
    349    if (!ok) {
    350      return KEEP;
    351    }
    352 
    353    *offset = out_header.Write(output, *offset, ciphertext);
    354    return CHANGE;
    355  }
    356 
    357 private:
    358  uint8_t contentType_;
    359  size_t size_;
    360  size_t padding_;
    361 };
    362 
    363 /* Zero-length InnerPlaintext test class
    364 *
    365 * Parameter = Tuple of:
    366 * - TLS variant (datagram/stream)
    367 * - Content type to be set in zero-length inner plaintext record
    368 * - Padding of record plaintext
    369 */
    370 class ZeroLengthInnerPlaintextSetupTls13
    371    : public TlsConnectTestBase,
    372      public testing::WithParamInterface<
    373          std::tuple<SSLProtocolVariant, SSLContentType, size_t>> {
    374 public:
    375  ZeroLengthInnerPlaintextSetupTls13()
    376      : TlsConnectTestBase(std::get<0>(GetParam()),
    377                           SSL_LIBRARY_VERSION_TLS_1_3),
    378        contentType_(std::get<1>(GetParam())),
    379        padding_(std::get<2>(GetParam())){};
    380 
    381 protected:
    382  SSLContentType contentType_;
    383  size_t padding_;
    384 };
    385 
    386 /* Test correct rejection of TLS 1.3 encrypted handshake/alert records with
    387 * zero-length inner plaintext content length with and without padding.
    388 *
    389 * Implementations MUST NOT send Handshake and Alert records that have a
    390 * zero-length TLSInnerPlaintext.content; if such a message is received,
    391 * the receiving implementation MUST terminate the connection with an
    392 * "unexpected_message" alert [RFC8446, Section 5.4]. */
    393 TEST_P(ZeroLengthInnerPlaintextSetupTls13, ZeroLengthInnerPlaintextRun) {
    394  EnsureTlsSetup();
    395 
    396  // Filter modifies record to be zero-length
    397  auto filter =
    398      MakeTlsFilter<Tls13RecordModifier>(client_, contentType_, 0, padding_);
    399  filter->EnableDecryption();
    400  filter->Disable();
    401 
    402  Connect();
    403 
    404  filter->Enable();
    405 
    406  // Record will be overwritten
    407  client_->SendData(0xf);
    408 
    409  // Receive corrupt record
    410  if (variant_ == ssl_variant_stream) {
    411    server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    412    // 22B = 16B MAC + 1B innerContentType + 5B Header
    413    server_->ReadBytes(22);
    414    // Process alert at peer
    415    client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    416    client_->Handshake();
    417  } else { /* DTLS */
    418    size_t received = server_->received_bytes();
    419    // 22B = 16B MAC + 1B innerContentType + 5B Header
    420    server_->ReadBytes(22);
    421    // Check that no bytes were received => packet was dropped
    422    ASSERT_EQ(received, server_->received_bytes());
    423    // Check that we are still connected / not in error state
    424    EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    425    EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
    426  }
    427 }
    428 
    429 // Test for TLS and DTLS
    430 const SSLProtocolVariant kZeroLengthInnerPlaintextVariants[] = {
    431    ssl_variant_stream, ssl_variant_datagram};
    432 // Test for handshake and alert fragments
    433 const SSLContentType kZeroLengthInnerPlaintextContentTypes[] = {
    434    ssl_ct_handshake, ssl_ct_alert};
    435 // Test with 0,1 and 100 octets of padding
    436 const size_t kZeroLengthInnerPlaintextPadding[] = {0, 1, 100};
    437 
    438 INSTANTIATE_TEST_SUITE_P(
    439    ZeroLengthInnerPlaintextTest, ZeroLengthInnerPlaintextSetupTls13,
    440    testing::Combine(testing::ValuesIn(kZeroLengthInnerPlaintextVariants),
    441                     testing::ValuesIn(kZeroLengthInnerPlaintextContentTypes),
    442                     testing::ValuesIn(kZeroLengthInnerPlaintextPadding)),
    443    [](const testing::TestParamInfo<
    444        ZeroLengthInnerPlaintextSetupTls13::ParamType>& inf) {
    445      return std::string(std::get<0>(inf.param) == ssl_variant_stream
    446                             ? "Tls"
    447                             : "Dtls") +
    448             "ZeroLengthInnerPlaintext" +
    449             (std::get<1>(inf.param) == ssl_ct_handshake ? "Handshake"
    450                                                         : "Alert") +
    451             (std::get<2>(inf.param)
    452                  ? "Padding" + std::to_string(std::get<2>(inf.param)) + "B"
    453                  : "") +
    454             "Test";
    455    });
    456 
    457 /* Zero-length record test class
    458 *
    459 * Parameter = Tuple of:
    460 * - TLS variant (datagram/stream)
    461 * - TLS version
    462 * - Content type to be set in zero-length record
    463 */
    464 class ZeroLengthRecordSetup
    465    : public TlsConnectTestBase,
    466      public testing::WithParamInterface<
    467          std::tuple<SSLProtocolVariant, uint16_t, SSLContentType>> {
    468 public:
    469  ZeroLengthRecordSetup()
    470      : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
    471        variant_(std::get<0>(GetParam())),
    472        contentType_(std::get<2>(GetParam())){};
    473 
    474  void createZeroLengthRecord(DataBuffer& buffer, unsigned epoch = 0,
    475                              unsigned seqn = 0) {
    476    size_t idx = 0;
    477    // Set header content type
    478    idx = buffer.Write(idx, contentType_, 1);
    479    // The record version is not checked during record layer handling
    480    idx = buffer.Write(idx, 0xDEAD, 2);
    481    // DTLS (version always < TLS 1.3)
    482    if (variant_ == ssl_variant_datagram) {
    483      // Set epoch (Should be 0 before handshake)
    484      idx = buffer.Write(idx, 0U, 2);
    485      // Set 6B sequence number (0 if send as first message)
    486      idx = buffer.Write(idx, 0U, 2);
    487      idx = buffer.Write(idx, 0U, 4);
    488    }
    489    // Set fragment to be of zero-length
    490    (void)buffer.Write(idx, 0U, 2);
    491  }
    492 
    493 protected:
    494  SSLProtocolVariant variant_;
    495  SSLContentType contentType_;
    496 };
    497 
    498 /* Test handling of zero-length (ciphertext/fragment) records before handshake.
    499 *
    500 * This is only tested before the first handshake, since after it all of these
    501 * messages are expected to be encrypted which is impossible for a content
    502 * length of zero, always leading to a bad record mac. For TLS 1.3 only
    503 * records of application data content type is legal after the handshake.
    504 *
    505 * Handshake records of length zero will be ignored in the record layer since
    506 * the RFC does only specify that such records MUST NOT be sent but it does not
    507 * state that an alert should be sent or the connection be terminated
    508 * [RFC8446, Section 5.1].
    509 *
    510 * Even though only handshake messages are handled (ignored) in the record
    511 * layer handling, this test covers zero-length records of all content types
    512 * for complete coverage of cases.
    513 *
    514 * !!! Expected TLS (Stream) behavior !!!
    515 * - Handshake records of zero length are ignored.
    516 * - Alert and ChangeCipherSpec records of zero-length lead to illegal
    517 * parameter alerts due to the malformed record content.
    518 * - ApplicationData before the handshake leads to an unexpected message alert.
    519 *
    520 * !!! Expected DTLS (Datagram) behavior !!!
    521 * - Handshake message of zero length are ignored.
    522 * - Alert messages lead to an illegal parameter alert due to malformed record
    523 * content.
    524 * - ChangeCipherSpec records before the first handshake are not expected and
    525 * ignored (see ssl3con.c, line 3276).
    526 * - ApplicationData before the handshake is ignored since it could be a packet
    527 * received in incorrect order (see ssl3con.c, line 13353).
    528 */
    529 TEST_P(ZeroLengthRecordSetup, ZeroLengthRecordRun) {
    530  EnsureTlsSetup();
    531 
    532  // Send zero-length record
    533  DataBuffer buffer;
    534  createZeroLengthRecord(buffer);
    535  client_->SendDirect(buffer);
    536  // This must be set, otherwise handshake completness assertions might fail
    537  server_->StartConnect();
    538 
    539  SSLAlertDescription alert = close_notify;
    540 
    541  switch (variant_) {
    542    case ssl_variant_datagram:
    543      switch (contentType_) {
    544        case ssl_ct_alert:
    545          // Should actually be ignored, see bug 1829391.
    546          alert = illegal_parameter;
    547          break;
    548        case ssl_ct_ack:
    549          if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
    550            // Skipped due to bug 1829391.
    551            GTEST_SKIP();
    552          }
    553          // DTLS versions < 1.3 correctly ignore the invalid record
    554          // so we fall through.
    555        case ssl_ct_change_cipher_spec:
    556        case ssl_ct_application_data:
    557        case ssl_ct_handshake:
    558          server_->Handshake();
    559          Connect();
    560          return;
    561      }
    562      break;
    563    case ssl_variant_stream:
    564      switch (contentType_) {
    565        case ssl_ct_alert:
    566        case ssl_ct_change_cipher_spec:
    567          alert = illegal_parameter;
    568          break;
    569        case ssl_ct_application_data:
    570        case ssl_ct_ack:
    571          alert = unexpected_message;
    572          break;
    573        case ssl_ct_handshake:
    574          // TLS ignores unprotected zero-length handshake records
    575          server_->Handshake();
    576          Connect();
    577          return;
    578      }
    579      break;
    580  }
    581 
    582  // Assert alert is send for TLS and DTLS alert records
    583  server_->ExpectSendAlert(alert);
    584  server_->Handshake();
    585 
    586  // Consume alert at peer, expect alert for TLS and DTLS alert records
    587  client_->StartConnect();
    588  client_->ExpectReceiveAlert(alert);
    589  client_->Handshake();
    590 }
    591 
    592 // Test for handshake, alert, change_cipher_spec and application data fragments
    593 const SSLContentType kZeroLengthRecordContentTypes[] = {
    594    ssl_ct_handshake, ssl_ct_alert, ssl_ct_change_cipher_spec,
    595    ssl_ct_application_data, ssl_ct_ack};
    596 
    597 INSTANTIATE_TEST_SUITE_P(
    598    ZeroLengthRecordTest, ZeroLengthRecordSetup,
    599    testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
    600                     TlsConnectTestBase::kTlsV11Plus,
    601                     testing::ValuesIn(kZeroLengthRecordContentTypes)),
    602    [](const testing::TestParamInfo<ZeroLengthRecordSetup::ParamType>& inf) {
    603      std::string variant =
    604          (std::get<0>(inf.param) == ssl_variant_stream) ? "Tls" : "Dtls";
    605      std::string version = VersionString(std::get<1>(inf.param));
    606      std::replace(version.begin(), version.end(), '.', '_');
    607      std::string contentType;
    608      switch (std::get<2>(inf.param)) {
    609        case ssl_ct_handshake:
    610          contentType = "Handshake";
    611          break;
    612        case ssl_ct_alert:
    613          contentType = "Alert";
    614          break;
    615        case ssl_ct_application_data:
    616          contentType = "ApplicationData";
    617          break;
    618        case ssl_ct_change_cipher_spec:
    619          contentType = "ChangeCipherSpec";
    620          break;
    621        case ssl_ct_ack:
    622          contentType = "Ack";
    623          break;
    624      }
    625      return variant + version + "ZeroLength" + contentType + "Test";
    626    });
    627 
    628 /* Test correct handling of records with invalid content types.
    629 *
    630 * TLS:
    631 * If a TLS implementation receives an unexpected record type, it MUST
    632 * terminate the connection with an "unexpected_message" alert
    633 * [RFC8446, Section 5].
    634 *
    635 * DTLS:
    636 * In general, invalid records SHOULD be silently discarded...
    637 * [RFC6347, Section 4.1.2.7]. */
    638 class UndefinedContentTypeSetup : public TlsConnectGeneric {
    639 public:
    640  UndefinedContentTypeSetup() : TlsConnectGeneric() { StartConnect(); };
    641 
    642  void createUndefinedContentTypeRecord(DataBuffer& buffer, unsigned epoch = 0,
    643                                        unsigned seqn = 0) {
    644    // dummy data
    645    uint8_t data[] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE};
    646 
    647    size_t idx = 0;
    648    // Set undefined content type
    649    idx = buffer.Write(idx, 0xFF, 1);
    650    // The record version is not checked during record layer handling
    651    idx = buffer.Write(idx, 0xDEAD, 2);
    652    // DTLS (version always < TLS 1.3)
    653    if (variant_ == ssl_variant_datagram) {
    654      // Set epoch (Should be 0 before/during handshake)
    655      idx = buffer.Write(idx, epoch, 2);
    656      // Set 6B sequence number (0 if send as first message)
    657      idx = buffer.Write(idx, 0U, 2);
    658      idx = buffer.Write(idx, seqn, 4);
    659    }
    660    // Set fragment length
    661    idx = buffer.Write(idx, 5U, 2);
    662    // Add data to record
    663    (void)buffer.Write(idx, data, 5);
    664  }
    665 
    666  void checkUndefinedContentTypeHandling(std::shared_ptr<TlsAgent> sender,
    667                                         std::shared_ptr<TlsAgent> receiver) {
    668    if (variant_ == ssl_variant_stream) {
    669      // Handle record and expect alert to be sent
    670      receiver->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    671      receiver->ReadBytes();
    672      /* Digest and assert that the correct alert was received at peer
    673       *
    674       * The 1.3 server expects all messages other than the ClientHello to be
    675       * encrypted and responds with an unexpected message alert to alerts. */
    676      if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && sender == server_) {
    677        sender->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    678      } else {
    679        sender->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    680      }
    681      sender->ReadBytes();
    682    } else {  // DTLS drops invalid records silently
    683      size_t received = receiver->received_bytes();
    684      receiver->ReadBytes();
    685      // Ensure no bytes were received/record was dropped
    686      ASSERT_EQ(received, receiver->received_bytes());
    687    }
    688  }
    689 
    690 protected:
    691  DataBuffer buffer_;
    692 };
    693 
    694 INSTANTIATE_TEST_SUITE_P(
    695    UndefinedContentTypePreHandshakeStream, UndefinedContentTypeSetup,
    696    ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
    697                       TlsConnectTestBase::kTlsVAll));
    698 
    699 INSTANTIATE_TEST_SUITE_P(
    700    UndefinedContentTypePreHandshakeDatagram, UndefinedContentTypeSetup,
    701    ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
    702                       TlsConnectTestBase::kTlsV11Plus));
    703 
    704 TEST_P(UndefinedContentTypeSetup,
    705       ServerReceiveUndefinedContentTypePreClientHello) {
    706  createUndefinedContentTypeRecord(buffer_);
    707 
    708  // Send undefined content type record
    709  client_->SendDirect(buffer_);
    710 
    711  checkUndefinedContentTypeHandling(client_, server_);
    712 }
    713 
    714 TEST_P(UndefinedContentTypeSetup,
    715       ServerReceiveUndefinedContentTypePostClientHello) {
    716  // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent
    717  createUndefinedContentTypeRecord(buffer_, 0, 1);
    718 
    719  // Send ClientHello
    720  client_->Handshake();
    721  // Send undefined content type record
    722  client_->SendDirect(buffer_);
    723 
    724  checkUndefinedContentTypeHandling(client_, server_);
    725 }
    726 
    727 TEST_P(UndefinedContentTypeSetup,
    728       ClientReceiveUndefinedContentTypePreClientHello) {
    729  createUndefinedContentTypeRecord(buffer_);
    730 
    731  // Send undefined content type record
    732  server_->SendDirect(buffer_);
    733 
    734  checkUndefinedContentTypeHandling(server_, client_);
    735 }
    736 
    737 TEST_P(UndefinedContentTypeSetup,
    738       ClientReceiveUndefinedContentTypePostClientHello) {
    739  // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent
    740  createUndefinedContentTypeRecord(buffer_, 0, 1);
    741 
    742  // Send ClientHello
    743  client_->Handshake();
    744  // Send undefined content type record
    745  server_->SendDirect(buffer_);
    746 
    747  checkUndefinedContentTypeHandling(server_, client_);
    748 }
    749 
    750 class RecordOuterContentTypeSetter : public TlsRecordFilter {
    751 public:
    752  RecordOuterContentTypeSetter(const std::shared_ptr<TlsAgent>& a,
    753                               uint8_t contentType = ssl_ct_handshake)
    754      : TlsRecordFilter(a), contentType_(contentType) {}
    755 
    756 protected:
    757  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
    758                                    const DataBuffer& record, size_t* offset,
    759                                    DataBuffer* output) override {
    760    TlsRecordHeader hdr(header.variant(), header.version(), contentType_,
    761                        header.sequence_number());
    762 
    763    *offset = hdr.Write(output, *offset, record);
    764    return CHANGE;
    765  }
    766 
    767 private:
    768  uint8_t contentType_;
    769 };
    770 
    771 /* Test correct handling of invalid inner and outer record content type.
    772 * This is only possible for TLS 1.3, since only for this version decryption
    773 * and encryption of manipulated records is supported by the test suite. */
    774 TEST_P(TlsConnectTls13, UndefinedOuterContentType13) {
    775  EnsureTlsSetup();
    776  Connect();
    777 
    778  // Manipulate record: set invalid content type 0xff
    779  MakeTlsFilter<RecordOuterContentTypeSetter>(client_, 0xff);
    780  client_->SendData(50);
    781 
    782  if (variant_ == ssl_variant_stream) {
    783    // Handle invalid record
    784    server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    785    server_->ReadBytes();
    786    // Handle alert at peer
    787    client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    788    client_->ReadBytes();
    789  } else {
    790    // Make sure DTLS drops invalid record silently
    791    size_t received = server_->received_bytes();
    792    server_->ReadBytes();
    793    ASSERT_EQ(received, server_->received_bytes());
    794  }
    795 }
    796 
    797 TEST_P(TlsConnectTls13, UndefinedInnerContentType13) {
    798  EnsureTlsSetup();
    799 
    800  // Manipulate record: set invalid content type 0xff and length to 50.
    801  auto filter = MakeTlsFilter<Tls13RecordModifier>(client_, 0xff, 50, 0);
    802  filter->EnableDecryption();
    803  filter->Disable();
    804 
    805  Connect();
    806 
    807  filter->Enable();
    808  // Send manipulate record with invalid content type
    809  client_->SendData(50);
    810 
    811  if (variant_ == ssl_variant_stream) {
    812    // Handle invalid record
    813    server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    814    server_->ReadBytes();
    815    // Handle alert at peer
    816    client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    817    client_->ReadBytes();
    818  } else {
    819    // Make sure DTLS drops invalid record silently
    820    size_t received = server_->received_bytes();
    821    server_->ReadBytes();
    822    ASSERT_EQ(received, server_->received_bytes());
    823  }
    824 }
    825 
    826 }  // namespace nss_test