tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ssl_recordsep_unittest.cc (23319B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "secerr.h"
      8 #include "ssl.h"
      9 #include "sslerr.h"
     10 #include "sslproto.h"
     11 
     12 extern "C" {
     13 // This is not something that should make you happy.
     14 #include "libssl_internals.h"
     15 }
     16 
     17 #include <queue>
     18 #include "gtest_utils.h"
     19 #include "nss_scoped_ptrs.h"
     20 #include "tls_connect.h"
     21 #include "tls_filter.h"
     22 #include "tls_parser.h"
     23 
     24 namespace nss_test {
     25 
     26 class HandshakeSecretTracker {
     27 public:
     28  HandshakeSecretTracker(const std::shared_ptr<TlsAgent>& agent,
     29                         uint16_t first_read_epoch, uint16_t first_write_epoch)
     30      : agent_(agent),
     31        next_read_epoch_(first_read_epoch),
     32        next_write_epoch_(first_write_epoch) {
     33    EXPECT_EQ(SECSuccess,
     34              SSL_SecretCallback(agent_->ssl_fd(),
     35                                 HandshakeSecretTracker::SecretCb, this));
     36  }
     37 
     38  void CheckComplete() const {
     39    EXPECT_EQ(0, next_read_epoch_);
     40    EXPECT_EQ(0, next_write_epoch_);
     41  }
     42 
     43 private:
     44  static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
     45                       PK11SymKey* secret, void* arg) {
     46    HandshakeSecretTracker* t = reinterpret_cast<HandshakeSecretTracker*>(arg);
     47    t->SecretUpdated(epoch, dir, secret);
     48  }
     49 
     50  void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
     51                     PK11SymKey* secret) {
     52    if (g_ssl_gtest_verbose) {
     53      std::cerr << agent_->role_str() << ": secret callback for " << dir
     54                << " epoch " << epoch << std::endl;
     55    }
     56 
     57    EXPECT_TRUE(secret);
     58    uint16_t* p;
     59    if (dir == ssl_secret_read) {
     60      p = &next_read_epoch_;
     61    } else {
     62      ASSERT_EQ(ssl_secret_write, dir);
     63      p = &next_write_epoch_;
     64    }
     65    EXPECT_EQ(*p, epoch);
     66    switch (*p) {
     67      case 1:  // 1 == 0-RTT, next should be handshake.
     68      case 2:  // 2 == handshake, next should be application data.
     69        (*p)++;
     70        break;
     71 
     72      case 3:  // 3 == application data, there should be no more.
     73        // Use 0 as a sentinel value.
     74        *p = 0;
     75        break;
     76 
     77      default:
     78        ADD_FAILURE() << "Unexpected next epoch: " << *p;
     79    }
     80  }
     81 
     82  std::shared_ptr<TlsAgent> agent_;
     83  uint16_t next_read_epoch_;
     84  uint16_t next_write_epoch_;
     85 };
     86 
     87 TEST_F(TlsConnectTest, HandshakeSecrets) {
     88  ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
     89  EnsureTlsSetup();
     90 
     91  HandshakeSecretTracker c(client_, 2, 2);
     92  HandshakeSecretTracker s(server_, 2, 2);
     93 
     94  Connect();
     95  SendReceive();
     96 
     97  c.CheckComplete();
     98  s.CheckComplete();
     99 }
    100 
    101 TEST_F(TlsConnectTest, ZeroRttSecrets) {
    102  SetupForZeroRtt();
    103 
    104  HandshakeSecretTracker c(client_, 2, 1);
    105  HandshakeSecretTracker s(server_, 1, 2);
    106 
    107  client_->Set0RttEnabled(true);
    108  server_->Set0RttEnabled(true);
    109  ExpectResumption(RESUME_TICKET);
    110  ZeroRttSendReceive(true, true);
    111  Handshake();
    112  ExpectEarlyDataAccepted(true);
    113  CheckConnected();
    114  SendReceive();
    115 
    116  c.CheckComplete();
    117  s.CheckComplete();
    118 }
    119 
    120 class KeyUpdateTracker {
    121 public:
    122  KeyUpdateTracker(const std::shared_ptr<TlsAgent>& agent,
    123                   bool expect_read_secret)
    124      : agent_(agent), expect_read_secret_(expect_read_secret), called_(false) {
    125    EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(),
    126                                             KeyUpdateTracker::SecretCb, this));
    127  }
    128 
    129  void CheckCalled() const { EXPECT_TRUE(called_); }
    130 
    131 private:
    132  static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
    133                       PK11SymKey* secret, void* arg) {
    134    KeyUpdateTracker* t = reinterpret_cast<KeyUpdateTracker*>(arg);
    135    t->SecretUpdated(epoch, dir, secret);
    136  }
    137 
    138  void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
    139                     PK11SymKey* secret) {
    140    EXPECT_EQ(4U, epoch);
    141    EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read);
    142    EXPECT_TRUE(secret);
    143    called_ = true;
    144  }
    145 
    146  std::shared_ptr<TlsAgent> agent_;
    147  bool expect_read_secret_;
    148  bool called_;
    149 };
    150 
    151 TEST_F(TlsConnectTest, KeyUpdateSecrets) {
    152  ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
    153  Connect();
    154  // The update is to the client write secret; the server read secret.
    155  KeyUpdateTracker c(client_, false);
    156  KeyUpdateTracker s(server_, true);
    157  EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
    158  SendReceive(50);
    159  SendReceive(60);
    160  CheckEpochs(4, 3);
    161  c.CheckCalled();
    162  s.CheckCalled();
    163 }
    164 
    165 // BadPrSocket is an instance of a PR IO layer that crashes the test if it is
    166 // ever used for reading or writing.  It does that by failing to overwrite any
    167 // of the DummyIOLayerMethods, which all crash when invoked.
    168 class BadPrSocket : public DummyIOLayerMethods {
    169 public:
    170  BadPrSocket(std::shared_ptr<TlsAgent>& agent) : DummyIOLayerMethods() {
    171    static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id");
    172    fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this);
    173 
    174    // This is terrible, but NSPR doesn't provide an easy way to replace the
    175    // bottom layer of an IO stack.  Take the DummyPrSocket and replace its
    176    // NSPR method vtable with the ones from this object.
    177    dummy_layer_ =
    178        PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId());
    179    EXPECT_TRUE(dummy_layer_);
    180    original_methods_ = dummy_layer_->methods;
    181    original_secret_ = dummy_layer_->secret;
    182    dummy_layer_->methods = fd_->methods;
    183    dummy_layer_->secret = reinterpret_cast<PRFilePrivate*>(this);
    184  }
    185 
    186  // This will be destroyed before the agent, so we need to restore the state
    187  // before we tampered with it.
    188  virtual ~BadPrSocket() {
    189    dummy_layer_->methods = original_methods_;
    190    dummy_layer_->secret = original_secret_;
    191  }
    192 
    193 private:
    194  ScopedPRFileDesc fd_;
    195  PRFileDesc* dummy_layer_;
    196  const PRIOMethods* original_methods_;
    197  PRFilePrivate* original_secret_;
    198 };
    199 
    200 class StagedRecords {
    201 public:
    202  StagedRecords(std::shared_ptr<TlsAgent>& agent) : agent_(agent), records_() {
    203    EXPECT_EQ(SECSuccess,
    204              SSL_RecordLayerWriteCallback(
    205                  agent_->ssl_fd(), StagedRecords::StageRecordData, this));
    206  }
    207 
    208  virtual ~StagedRecords() {
    209    // Uninstall so that the callback doesn't fire during cleanup.
    210    EXPECT_EQ(SECSuccess,
    211              SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr));
    212  }
    213 
    214  bool empty() const { return records_.empty(); }
    215 
    216  void ForwardAll(std::shared_ptr<TlsAgent>& peer) {
    217    EXPECT_NE(agent_, peer) << "can't forward to self";
    218    for (auto r : records_) {
    219      r.Forward(peer);
    220    }
    221    records_.clear();
    222  }
    223 
    224  // This forwards all saved data and checks the resulting state.
    225  void ForwardAll(std::shared_ptr<TlsAgent>& peer,
    226                  TlsAgent::State expected_state) {
    227    ForwardAll(peer);
    228    switch (expected_state) {
    229      case TlsAgent::STATE_CONNECTED:
    230        // The handshake callback should have been called, so check that before
    231        // checking that SSL_ForceHandshake succeeds.
    232        EXPECT_EQ(expected_state, peer->state());
    233        EXPECT_EQ(SECSuccess, SSL_ForceHandshake(peer->ssl_fd()));
    234        break;
    235 
    236      case TlsAgent::STATE_CONNECTING:
    237        // Check that SSL_ForceHandshake() blocks.
    238        EXPECT_EQ(SECFailure, SSL_ForceHandshake(peer->ssl_fd()));
    239        EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
    240        // Update and check the state.
    241        peer->Handshake();
    242        EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state());
    243        break;
    244 
    245      default:
    246        ADD_FAILURE() << "No idea how to handle this state";
    247    }
    248  }
    249 
    250  void ForwardPartial(std::shared_ptr<TlsAgent>& peer) {
    251    if (records_.empty()) {
    252      ADD_FAILURE() << "No records to slice";
    253      return;
    254    }
    255    auto& last = records_.back();
    256    auto tail = last.SliceTail();
    257    ForwardAll(peer, TlsAgent::STATE_CONNECTING);
    258    records_.push_back(tail);
    259    EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state());
    260  }
    261 
    262 private:
    263  // A single record.
    264  class StagedRecord {
    265   public:
    266    StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct,
    267                 const uint8_t* data, size_t len)
    268        : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) {
    269      if (g_ssl_gtest_verbose) {
    270        std::cerr << role_ << ": staged epoch " << epoch_ << " "
    271                  << content_type_ << ": " << data_ << std::endl;
    272      }
    273    }
    274 
    275    // This forwards staged data to the identified agent.
    276    void Forward(std::shared_ptr<TlsAgent>& peer) {
    277      // Now there should be staged data.
    278      EXPECT_FALSE(data_.empty());
    279      if (g_ssl_gtest_verbose) {
    280        std::cerr << role_ << ": forward epoch " << epoch_ << " " << data_
    281                  << std::endl;
    282      }
    283      EXPECT_EQ(SECSuccess,
    284                SSL_RecordLayerData(peer->ssl_fd(), epoch_, content_type_,
    285                                    data_.data(),
    286                                    static_cast<unsigned int>(data_.len())));
    287    }
    288 
    289    // Slices the tail off this record and returns it.
    290    StagedRecord SliceTail() {
    291      size_t slice = 1;
    292      if (data_.len() <= slice) {
    293        ADD_FAILURE() << "record too small to slice in two";
    294        slice = 0;
    295      }
    296      size_t keep = data_.len() - slice;
    297      StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep,
    298                        slice);
    299      data_.Truncate(keep);
    300      return tail;
    301    }
    302 
    303   private:
    304    std::string role_;
    305    uint16_t epoch_;
    306    SSLContentType content_type_;
    307    DataBuffer data_;
    308  };
    309 
    310  // This is an SSLRecordWriteCallback that stages data.
    311  static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch,
    312                                   SSLContentType content_type,
    313                                   const PRUint8* data, unsigned int len,
    314                                   void* arg) {
    315    auto stage = reinterpret_cast<StagedRecords*>(arg);
    316    stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch,
    317                                           content_type, data,
    318                                           static_cast<size_t>(len)));
    319    return SECSuccess;
    320  }
    321 
    322  std::shared_ptr<TlsAgent>& agent_;
    323  std::deque<StagedRecord> records_;
    324 };
    325 
    326 // Attempting to feed application data in before the handshake is complete
    327 // should be caught.
    328 static void RefuseApplicationData(std::shared_ptr<TlsAgent>& peer,
    329                                  uint16_t epoch) {
    330  static const uint8_t d[] = {1, 2, 3};
    331  EXPECT_EQ(SECFailure,
    332            SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data,
    333                                d, static_cast<unsigned int>(sizeof(d))));
    334  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    335 }
    336 
    337 static void SendForwardReceive(std::shared_ptr<TlsAgent>& sender,
    338                               StagedRecords& sender_stage,
    339                               std::shared_ptr<TlsAgent>& receiver) {
    340  const size_t count = 10;
    341  sender->SendData(count, count);
    342  sender_stage.ForwardAll(receiver);
    343  receiver->ReadBytes(count);
    344 }
    345 
    346 TEST_P(TlsConnectStream, ReplaceRecordLayer) {
    347  StartConnect();
    348  client_->SetServerKeyBits(server_->server_key_bits());
    349 
    350  // BadPrSocket installs an IO layer that crashes when the SSL layer attempts
    351  // to read or write.
    352  BadPrSocket bad_layer_client(client_);
    353  BadPrSocket bad_layer_server(server_);
    354 
    355  // StagedRecords installs a handler for unprotected data from the socket, and
    356  // captures that data.
    357  StagedRecords client_stage(client_);
    358  StagedRecords server_stage(server_);
    359 
    360  // Both peers should refuse application data from epoch 0.
    361  RefuseApplicationData(client_, 0);
    362  RefuseApplicationData(server_, 0);
    363 
    364  // This first call forwards nothing, but it causes the client to handshake,
    365  // which starts things off.  This stages the ClientHello as a result.
    366  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    367  // This processes the ClientHello and stages the first server flight.
    368  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    369 
    370  // In TLS 1.3, this is 0-RTT; in <TLS 1.3, this is application data.
    371  // Neither is acceptable.
    372  RefuseApplicationData(client_, 1);
    373  RefuseApplicationData(server_, 1);
    374 
    375  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    376    // Application data in handshake is never acceptable.
    377    RefuseApplicationData(client_, 2);
    378    RefuseApplicationData(server_, 2);
    379    // Don't accept real data until the handshake is done.
    380    RefuseApplicationData(client_, 3);
    381    RefuseApplicationData(server_, 3);
    382    // Process the server flight and the client is done.
    383    server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    384    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    385  } else {
    386    server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    387    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    388    server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    389  }
    390  CheckKeys();
    391 
    392  // Reading and writing application data should work.
    393  SendForwardReceive(client_, client_stage, server_);
    394  SendForwardReceive(server_, server_stage, client_);
    395 }
    396 
    397 TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerZeroRtt) {
    398  SetupForZeroRtt();
    399 
    400  client_->Set0RttEnabled(true);
    401  server_->Set0RttEnabled(true);
    402  StartConnect();
    403  client_->SetServerKeyBits(server_->server_key_bits());
    404 
    405  BadPrSocket bad_layer_client(client_);
    406  BadPrSocket bad_layer_server(server_);
    407 
    408  StagedRecords client_stage(client_);
    409  StagedRecords server_stage(server_);
    410 
    411  ExpectResumption(RESUME_TICKET);
    412 
    413  // Send ClientHello
    414  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    415 
    416  // The client can never accept 0-RTT.
    417  RefuseApplicationData(client_, 1);
    418 
    419  // Send some 0-RTT data, which get staged in `client_stage`.
    420  const char* kMsg = "EarlyData";
    421  const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg));
    422  PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen);
    423  EXPECT_EQ(kMsgLen, rv);
    424 
    425  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    426 
    427  // The server should now have 0-RTT to read.
    428  std::vector<uint8_t> buf(kMsgLen);
    429  rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen);
    430  EXPECT_EQ(kMsgLen, rv);
    431 
    432  // The handshake should happily finish.
    433  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    434  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    435  ExpectEarlyDataAccepted(true);
    436  CheckConnected();
    437 
    438  // Reading and writing application data should work.
    439  SendForwardReceive(client_, client_stage, server_);
    440  SendForwardReceive(server_, server_stage, client_);
    441 }
    442 
    443 static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
    444  return SECWouldBlock;
    445 }
    446 
    447 TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) {
    448  StartConnect();
    449  client_->SetServerKeyBits(server_->server_key_bits());
    450 
    451  BadPrSocket bad_layer_client(client_);
    452  BadPrSocket bad_layer_server(server_);
    453  StagedRecords client_stage(client_);
    454  StagedRecords server_stage(server_);
    455 
    456  client_->SetAuthCertificateCallback(AuthCompleteBlock);
    457 
    458  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    459  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    460  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    461 
    462  // Prior to TLS 1.3, the client sends its second flight immediately.  But in
    463  // TLS 1.3, a client won't send a Finished until it is happy with the server
    464  // certificate.  So blocking certificate validation causes the client to send
    465  // nothing.
    466  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    467    ASSERT_TRUE(client_stage.empty());
    468 
    469    // Client should have stopped reading when it saw the Certificate message,
    470    // so it will be reading handshake epoch, and writing cleartext.
    471    client_->CheckEpochs(2, 0);
    472    // Server should be reading handshake, and writing application data.
    473    server_->CheckEpochs(2, 3);
    474 
    475    // Handshake again and the client will read the remainder of the server's
    476    // flight, but it will remain blocked.
    477    client_->Handshake();
    478    ASSERT_TRUE(client_stage.empty());
    479    EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
    480  } else {
    481    // In prior versions, the client's second flight is always sent.
    482    ASSERT_FALSE(client_stage.empty());
    483  }
    484 
    485  // Now declare the certificate good.
    486  EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
    487  client_->Handshake();
    488  ASSERT_FALSE(client_stage.empty());
    489 
    490  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    491    EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    492    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    493  } else {
    494    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    495    server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    496  }
    497  CheckKeys();
    498 
    499  // Reading and writing application data should work.
    500  SendForwardReceive(client_, client_stage, server_);
    501 }
    502 
    503 TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncPostHandshake) {
    504  StartConnect();
    505  client_->SetServerKeyBits(server_->server_key_bits());
    506 
    507  BadPrSocket bad_layer_client(client_);
    508  BadPrSocket bad_layer_server(server_);
    509  StagedRecords client_stage(client_);
    510  StagedRecords server_stage(server_);
    511 
    512  client_->SetAuthCertificateCallback(AuthCompleteBlock);
    513 
    514  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    515  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    516  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    517 
    518  ASSERT_TRUE(client_stage.empty());
    519  client_->Handshake();
    520  ASSERT_TRUE(client_stage.empty());
    521  EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
    522 
    523  // Now declare the certificate good.
    524  EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
    525  client_->Handshake();
    526  ASSERT_FALSE(client_stage.empty());
    527 
    528  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    529    EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
    530    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    531  } else {
    532    client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    533    server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    534  }
    535  CheckKeys();
    536 
    537  // Reading and writing application data should work.
    538  SendForwardReceive(client_, client_stage, server_);
    539 
    540  // Post-handshake messages should work here.
    541  EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
    542  SendForwardReceive(server_, server_stage, client_);
    543 }
    544 
    545 // This test ensures that data is correctly forwarded when the handshake is
    546 // resumed after asynchronous server certificate authentication, when
    547 // SSL_AuthCertificateComplete() is called.  The logic for resuming the
    548 // handshake involves a different code path than the usual one, so this test
    549 // exercises that code fully.
    550 TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) {
    551  StartConnect();
    552  client_->SetServerKeyBits(server_->server_key_bits());
    553 
    554  BadPrSocket bad_layer_client(client_);
    555  BadPrSocket bad_layer_server(server_);
    556  StagedRecords client_stage(client_);
    557  StagedRecords server_stage(server_);
    558 
    559  client_->SetAuthCertificateCallback(AuthCompleteBlock);
    560 
    561  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    562  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    563 
    564  // Send a partial flight on to the client.
    565  // This includes enough to trigger the certificate callback.
    566  server_stage.ForwardPartial(client_);
    567  EXPECT_TRUE(client_stage.empty());
    568 
    569  // Declare the certificate good.
    570  EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
    571  client_->Handshake();
    572  EXPECT_TRUE(client_stage.empty());
    573 
    574  // Send the remainder of the server flight.
    575  PRBool pending = PR_FALSE;
    576  EXPECT_EQ(SECSuccess,
    577            SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending));
    578  EXPECT_EQ(PR_TRUE, pending);
    579  EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
    580  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    581  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    582  CheckKeys();
    583 
    584  SendForwardReceive(server_, server_stage, client_);
    585 }
    586 
    587 TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) {
    588  const uint8_t data[] = {1};
    589  Connect();
    590  uint16_t next_epoch;
    591  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    592    EXPECT_EQ(SECFailure,
    593              SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data,
    594                                  data, sizeof(data)));
    595    EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError())
    596        << "Passing data from an old epoch is rejected";
    597    next_epoch = 4;
    598  } else {
    599    // Prior to TLS 1.3, the epoch is only updated once during the handshake.
    600    next_epoch = 2;
    601  }
    602  EXPECT_EQ(SECFailure,
    603            SSL_RecordLayerData(client_->ssl_fd(), next_epoch,
    604                                ssl_ct_application_data, data, sizeof(data)));
    605  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
    606      << "Passing data from a future epoch blocks";
    607 }
    608 
    609 TEST_F(TlsConnectStreamTls13, ForwardInvalidData) {
    610  const uint8_t data[1] = {0};
    611 
    612  EnsureTlsSetup();
    613  // Zero-length data.
    614  EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0,
    615                                            ssl_ct_application_data, data, 0));
    616  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    617 
    618  // NULL data.
    619  EXPECT_EQ(SECFailure,
    620            SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data,
    621                                nullptr, 1));
    622  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    623 }
    624 
    625 TEST_F(TlsConnectDatagram13, ForwardDataDtls) {
    626  EnsureTlsSetup();
    627  const uint8_t data[1] = {0};
    628  EXPECT_EQ(SECFailure,
    629            SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data,
    630                                data, sizeof(data)));
    631  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    632 }
    633 
    634 TEST_F(TlsConnectStreamTls13, SuppressEndOfEarlyData) {
    635  SetupForZeroRtt();
    636 
    637  client_->Set0RttEnabled(true);
    638  server_->Set0RttEnabled(true);
    639  client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
    640  server_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
    641  StartConnect();
    642  client_->SetServerKeyBits(server_->server_key_bits());
    643 
    644  BadPrSocket bad_layer_client(client_);
    645  BadPrSocket bad_layer_server(server_);
    646 
    647  StagedRecords client_stage(client_);
    648  StagedRecords server_stage(server_);
    649 
    650  ExpectResumption(RESUME_TICKET);
    651 
    652  // Send ClientHello
    653  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
    654 
    655  // Send some 0-RTT data, which get staged in `client_stage`.
    656  const char* kMsg = "ABCDEF";
    657  const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg));
    658  PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen);
    659  EXPECT_EQ(kMsgLen, rv);
    660 
    661  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
    662 
    663  // The server should now have 0-RTT to read.
    664  std::vector<uint8_t> buf(kMsgLen);
    665  rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen);
    666  EXPECT_EQ(kMsgLen, rv);
    667 
    668  // The handshake should happily finish, without the end of the early data.
    669  server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
    670  client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
    671  ExpectEarlyDataAccepted(true);
    672  CheckConnected();
    673 
    674  // Reading and writing application data should work.
    675  SendForwardReceive(client_, client_stage, server_);
    676  SendForwardReceive(server_, server_stage, client_);
    677 }
    678 
    679 }  // namespace nss_test