tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ssl_fragment_unittest.cc (5289B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "secerr.h"
      8 #include "ssl.h"
      9 #include "sslerr.h"
     10 #include "sslproto.h"
     11 
     12 #include "gtest_utils.h"
     13 #include "nss_scoped_ptrs.h"
     14 #include "tls_connect.h"
     15 #include "tls_filter.h"
     16 #include "tls_parser.h"
     17 
     18 namespace nss_test {
     19 
     20 // This class cuts every unencrypted handshake record into two parts.
     21 class RecordFragmenter : public PacketFilter {
     22 public:
     23  RecordFragmenter(bool is_dtls13)
     24      : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {}
     25 
     26 private:
     27  class HandshakeSplitter {
     28   public:
     29    HandshakeSplitter(bool is_dtls13, const DataBuffer& input,
     30                      DataBuffer* output, uint64_t* sequence_number)
     31        : is_dtls13_(is_dtls13),
     32          input_(input),
     33          output_(output),
     34          cursor_(0),
     35          sequence_number_(sequence_number) {}
     36 
     37   private:
     38    void WriteRecord(TlsRecordHeader& record_header,
     39                     DataBuffer& record_fragment) {
     40      TlsRecordHeader fragment_header(
     41          record_header.variant(), record_header.version(),
     42          record_header.content_type(), *sequence_number_);
     43      ++*sequence_number_;
     44      if (::g_ssl_gtest_verbose) {
     45        std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
     46                  << std::endl;
     47      }
     48      cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
     49    }
     50 
     51    bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
     52      TlsParser parser(record);
     53      while (parser.remaining()) {
     54        TlsHandshakeFilter::HandshakeHeader handshake_header;
     55        DataBuffer handshake_body;
     56        bool complete = false;
     57        if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
     58                                    &handshake_body, &complete)) {
     59          ADD_FAILURE() << "couldn't parse handshake header";
     60          return false;
     61        }
     62        if (!complete) {
     63          ADD_FAILURE() << "don't want to deal with fragmented messages";
     64          return false;
     65        }
     66 
     67        DataBuffer record_fragment;
     68        // We can't fragment handshake records that are too small.
     69        if (handshake_body.len() < 2) {
     70          handshake_header.Write(&record_fragment, 0U, handshake_body);
     71          WriteRecord(record_header, record_fragment);
     72          continue;
     73        }
     74 
     75        size_t cut = handshake_body.len() / 2;
     76        handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
     77                                       cut);
     78        WriteRecord(record_header, record_fragment);
     79 
     80        handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
     81                                       cut, handshake_body.len() - cut);
     82        WriteRecord(record_header, record_fragment);
     83      }
     84      return true;
     85    }
     86 
     87   public:
     88    bool Split() {
     89      TlsParser parser(input_);
     90      while (parser.remaining()) {
     91        TlsRecordHeader header;
     92        DataBuffer record;
     93        if (!header.Parse(is_dtls13_, 0, &parser, &record)) {
     94          ADD_FAILURE() << "bad record header";
     95          return false;
     96        }
     97 
     98        if (::g_ssl_gtest_verbose) {
     99          std::cerr << "Record: " << header << ' ' << record << std::endl;
    100        }
    101 
    102        // Don't touch packets from a non-zero epoch.  Leave these unmodified.
    103        if ((header.sequence_number() >> 48) != 0ULL) {
    104          cursor_ = header.Write(output_, cursor_, record);
    105          continue;
    106        }
    107 
    108        // Just rewrite the sequence number (CCS only).
    109        if (header.content_type() != ssl_ct_handshake) {
    110          EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type());
    111          WriteRecord(header, record);
    112          continue;
    113        }
    114 
    115        if (!SplitRecord(header, record)) {
    116          return false;
    117        }
    118      }
    119      return true;
    120    }
    121 
    122   private:
    123    bool is_dtls13_;
    124    const DataBuffer& input_;
    125    DataBuffer* output_;
    126    size_t cursor_;
    127    uint64_t* sequence_number_;
    128  };
    129 
    130 protected:
    131  virtual PacketFilter::Action Filter(const DataBuffer& input,
    132                                      DataBuffer* output) override {
    133    if (!splitting_) {
    134      return KEEP;
    135    }
    136 
    137    output->Allocate(input.len());
    138    HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_);
    139    if (!splitter.Split()) {
    140      // If splitting fails, we obviously reached encrypted packets.
    141      // Stop splitting from that point onward.
    142      splitting_ = false;
    143      return KEEP;
    144    }
    145 
    146    return CHANGE;
    147  }
    148 
    149 private:
    150  bool is_dtls13_;
    151  uint64_t sequence_number_;
    152  bool splitting_;
    153 };
    154 
    155 TEST_P(TlsConnectDatagram, FragmentClientPackets) {
    156  bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
    157  client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
    158  Connect();
    159  SendReceive();
    160 }
    161 
    162 TEST_P(TlsConnectDatagram, FragmentServerPackets) {
    163  bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
    164  server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
    165  Connect();
    166  SendReceive();
    167 }
    168 
    169 }  // namespace nss_test