tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ssl_skip_unittest.cc (9229B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "sslerr.h"
      8 
      9 #include "tls_connect.h"
     10 #include "tls_filter.h"
     11 #include "tls_parser.h"
     12 
     13 /*
     14 * The tests in this file test that the TLS state machine is robust against
     15 * attacks that alter the order of handshake messages.
     16 *
     17 * See <https://www.smacktls.com/smack.pdf> for a description of the problems
     18 * that this sort of attack can enable.
     19 */
     20 namespace nss_test {
     21 
     22 class TlsHandshakeSkipFilter : public TlsRecordFilter {
     23 public:
     24  // A TLS record filter that skips handshake messages of the identified type.
     25  TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& a,
     26                         uint8_t handshake_type)
     27      : TlsRecordFilter(a), handshake_type_(handshake_type), skipped_(false) {}
     28 
     29 protected:
     30  // Takes a record; if it is a handshake record, it removes the first handshake
     31  // message that is of handshake_type_ type.
     32  virtual PacketFilter::Action FilterRecord(
     33      const TlsRecordHeader& record_header, const DataBuffer& input,
     34      DataBuffer* output) {
     35    if (record_header.content_type() != ssl_ct_handshake) {
     36      return KEEP;
     37    }
     38 
     39    size_t output_offset = 0U;
     40    output->Allocate(input.len());
     41 
     42    TlsParser parser(input);
     43    while (parser.remaining()) {
     44      size_t start = parser.consumed();
     45      TlsHandshakeFilter::HandshakeHeader header;
     46      DataBuffer ignored;
     47      bool complete = false;
     48      if (!header.Parse(&parser, record_header, DataBuffer(), &ignored,
     49                        &complete)) {
     50        ADD_FAILURE() << "Error parsing handshake header";
     51        return KEEP;
     52      }
     53      if (!complete) {
     54        ADD_FAILURE() << "Don't want to deal with fragmented input";
     55        return KEEP;
     56      }
     57 
     58      if (skipped_ || header.handshake_type() != handshake_type_) {
     59        size_t entire_length = parser.consumed() - start;
     60        output->Write(output_offset, input.data() + start, entire_length);
     61        // DTLS sequence numbers need to be rewritten
     62        if (skipped_ && header.is_dtls()) {
     63          output->data()[start + 5] -= 1;
     64        }
     65        output_offset += entire_length;
     66      } else {
     67        std::cerr << "Dropping handshake: "
     68                  << static_cast<unsigned>(handshake_type_) << std::endl;
     69        // We only need to report that the output contains changed data if we
     70        // drop a handshake message.  But once we've skipped one message, we
     71        // have to modify all subsequent handshake messages so that they include
     72        // the correct DTLS sequence numbers.
     73        skipped_ = true;
     74      }
     75    }
     76    output->Truncate(output_offset);
     77    return skipped_ ? CHANGE : KEEP;
     78  }
     79 
     80 private:
     81  // The type of handshake message to drop.
     82  uint8_t handshake_type_;
     83  // Whether this filter has ever skipped a handshake message.  Track this so
     84  // that sequence numbers on DTLS handshake messages can be rewritten in
     85  // subsequent calls.
     86  bool skipped_;
     87 };
     88 
     89 class TlsSkipTest : public TlsConnectTestBase,
     90                    public ::testing::WithParamInterface<
     91                        std::tuple<SSLProtocolVariant, uint16_t>> {
     92 protected:
     93  TlsSkipTest()
     94      : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
     95 
     96  void SetUp() override {
     97    TlsConnectTestBase::SetUp();
     98    EnsureTlsSetup();
     99  }
    100 
    101  void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
    102                      uint8_t alert = kTlsAlertUnexpectedMessage) {
    103    server_->SetFilter(filter);
    104    ConnectExpectAlert(client_, alert);
    105  }
    106 };
    107 
    108 class Tls13SkipTest : public TlsConnectTestBase,
    109                      public ::testing::WithParamInterface<SSLProtocolVariant> {
    110 protected:
    111  Tls13SkipTest()
    112      : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
    113 
    114  void SetUp() override {
    115    TlsConnectTestBase::SetUp();
    116    EnsureTlsSetup();
    117    // until we can fix filters to work with MLKEM
    118    client_->ConfigNamedGroups(kNonPQDHEGroups);
    119    server_->ConfigNamedGroups(kNonPQDHEGroups);
    120  }
    121 
    122  void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
    123    filter->EnableDecryption();
    124    server_->SetFilter(filter);
    125    ExpectAlert(client_, kTlsAlertUnexpectedMessage);
    126    ConnectExpectFail();
    127    client_->CheckErrorCode(error);
    128    server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
    129  }
    130 
    131  void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
    132    filter->EnableDecryption();
    133    client_->SetFilter(filter);
    134    server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
    135    ConnectExpectFailOneSide(TlsAgent::SERVER);
    136 
    137    server_->CheckErrorCode(error);
    138    ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    139 
    140    client_->Handshake();  // Make sure to consume the alert the server sends.
    141  }
    142 };
    143 
    144 TEST_P(TlsSkipTest, SkipCertificateRsa) {
    145  EnableOnlyStaticRsaCiphers();
    146  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    147      server_, kTlsHandshakeCertificate));
    148  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
    149 }
    150 
    151 TEST_P(TlsSkipTest, SkipCertificateDhe) {
    152  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    153      server_, kTlsHandshakeCertificate));
    154  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
    155 }
    156 
    157 TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
    158  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    159      server_, kTlsHandshakeCertificate));
    160  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
    161 }
    162 
    163 TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
    164  Reset(TlsAgent::kServerEcdsa256);
    165  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    166      server_, kTlsHandshakeCertificate));
    167  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
    168 }
    169 
    170 TEST_P(TlsSkipTest, SkipServerKeyExchange) {
    171  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    172      server_, kTlsHandshakeServerKeyExchange));
    173  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
    174 }
    175 
    176 TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
    177  Reset(TlsAgent::kServerEcdsa256);
    178  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    179      server_, kTlsHandshakeServerKeyExchange));
    180  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
    181 }
    182 
    183 TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
    184  auto chain = std::make_shared<ChainedPacketFilter>(
    185      ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
    186                                  server_, kTlsHandshakeCertificate),
    187                              std::make_shared<TlsHandshakeSkipFilter>(
    188                                  server_, kTlsHandshakeServerKeyExchange)});
    189  ServerSkipTest(chain);
    190  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
    191 }
    192 
    193 TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
    194  Reset(TlsAgent::kServerEcdsa256);
    195  auto chain = std::make_shared<ChainedPacketFilter>();
    196  chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
    197      server_, kTlsHandshakeCertificate));
    198  chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
    199      server_, kTlsHandshakeServerKeyExchange));
    200  ServerSkipTest(chain);
    201  client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
    202 }
    203 
    204 TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
    205  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    206                     server_, kTlsHandshakeEncryptedExtensions),
    207                 SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
    208 }
    209 
    210 TEST_P(Tls13SkipTest, SkipServerCertificate) {
    211  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    212                     server_, kTlsHandshakeCertificate),
    213                 SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
    214 }
    215 
    216 TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
    217  ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    218                     server_, kTlsHandshakeCertificateVerify),
    219                 SSL_ERROR_RX_UNEXPECTED_FINISHED);
    220 }
    221 
    222 TEST_P(Tls13SkipTest, SkipClientCertificate) {
    223  client_->SetupClientAuth();
    224  server_->RequestClientAuth(true);
    225  client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    226  ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    227                     client_, kTlsHandshakeCertificate),
    228                 SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
    229 }
    230 
    231 TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
    232  client_->SetupClientAuth();
    233  server_->RequestClientAuth(true);
    234  client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
    235  ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
    236                     client_, kTlsHandshakeCertificateVerify),
    237                 SSL_ERROR_RX_UNEXPECTED_FINISHED);
    238 }
    239 
    240 INSTANTIATE_TEST_SUITE_P(
    241    SkipTls10, TlsSkipTest,
    242    ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
    243                       TlsConnectTestBase::kTlsV10));
    244 INSTANTIATE_TEST_SUITE_P(SkipVariants, TlsSkipTest,
    245                         ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
    246                                            TlsConnectTestBase::kTlsV11V12));
    247 INSTANTIATE_TEST_SUITE_P(Skip13Variants, Tls13SkipTest,
    248                         TlsConnectTestBase::kTlsVariantsAll);
    249 }  // namespace nss_test