ssl_skip_unittest.cc (9229B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "sslerr.h" 8 9 #include "tls_connect.h" 10 #include "tls_filter.h" 11 #include "tls_parser.h" 12 13 /* 14 * The tests in this file test that the TLS state machine is robust against 15 * attacks that alter the order of handshake messages. 16 * 17 * See <https://www.smacktls.com/smack.pdf> for a description of the problems 18 * that this sort of attack can enable. 19 */ 20 namespace nss_test { 21 22 class TlsHandshakeSkipFilter : public TlsRecordFilter { 23 public: 24 // A TLS record filter that skips handshake messages of the identified type. 25 TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& a, 26 uint8_t handshake_type) 27 : TlsRecordFilter(a), handshake_type_(handshake_type), skipped_(false) {} 28 29 protected: 30 // Takes a record; if it is a handshake record, it removes the first handshake 31 // message that is of handshake_type_ type. 32 virtual PacketFilter::Action FilterRecord( 33 const TlsRecordHeader& record_header, const DataBuffer& input, 34 DataBuffer* output) { 35 if (record_header.content_type() != ssl_ct_handshake) { 36 return KEEP; 37 } 38 39 size_t output_offset = 0U; 40 output->Allocate(input.len()); 41 42 TlsParser parser(input); 43 while (parser.remaining()) { 44 size_t start = parser.consumed(); 45 TlsHandshakeFilter::HandshakeHeader header; 46 DataBuffer ignored; 47 bool complete = false; 48 if (!header.Parse(&parser, record_header, DataBuffer(), &ignored, 49 &complete)) { 50 ADD_FAILURE() << "Error parsing handshake header"; 51 return KEEP; 52 } 53 if (!complete) { 54 ADD_FAILURE() << "Don't want to deal with fragmented input"; 55 return KEEP; 56 } 57 58 if (skipped_ || header.handshake_type() != handshake_type_) { 59 size_t entire_length = parser.consumed() - start; 60 output->Write(output_offset, input.data() + start, entire_length); 61 // DTLS sequence numbers need to be rewritten 62 if (skipped_ && header.is_dtls()) { 63 output->data()[start + 5] -= 1; 64 } 65 output_offset += entire_length; 66 } else { 67 std::cerr << "Dropping handshake: " 68 << static_cast<unsigned>(handshake_type_) << std::endl; 69 // We only need to report that the output contains changed data if we 70 // drop a handshake message. But once we've skipped one message, we 71 // have to modify all subsequent handshake messages so that they include 72 // the correct DTLS sequence numbers. 73 skipped_ = true; 74 } 75 } 76 output->Truncate(output_offset); 77 return skipped_ ? CHANGE : KEEP; 78 } 79 80 private: 81 // The type of handshake message to drop. 82 uint8_t handshake_type_; 83 // Whether this filter has ever skipped a handshake message. Track this so 84 // that sequence numbers on DTLS handshake messages can be rewritten in 85 // subsequent calls. 86 bool skipped_; 87 }; 88 89 class TlsSkipTest : public TlsConnectTestBase, 90 public ::testing::WithParamInterface< 91 std::tuple<SSLProtocolVariant, uint16_t>> { 92 protected: 93 TlsSkipTest() 94 : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} 95 96 void SetUp() override { 97 TlsConnectTestBase::SetUp(); 98 EnsureTlsSetup(); 99 } 100 101 void ServerSkipTest(std::shared_ptr<PacketFilter> filter, 102 uint8_t alert = kTlsAlertUnexpectedMessage) { 103 server_->SetFilter(filter); 104 ConnectExpectAlert(client_, alert); 105 } 106 }; 107 108 class Tls13SkipTest : public TlsConnectTestBase, 109 public ::testing::WithParamInterface<SSLProtocolVariant> { 110 protected: 111 Tls13SkipTest() 112 : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} 113 114 void SetUp() override { 115 TlsConnectTestBase::SetUp(); 116 EnsureTlsSetup(); 117 // until we can fix filters to work with MLKEM 118 client_->ConfigNamedGroups(kNonPQDHEGroups); 119 server_->ConfigNamedGroups(kNonPQDHEGroups); 120 } 121 122 void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { 123 filter->EnableDecryption(); 124 server_->SetFilter(filter); 125 ExpectAlert(client_, kTlsAlertUnexpectedMessage); 126 ConnectExpectFail(); 127 client_->CheckErrorCode(error); 128 server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); 129 } 130 131 void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { 132 filter->EnableDecryption(); 133 client_->SetFilter(filter); 134 server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); 135 ConnectExpectFailOneSide(TlsAgent::SERVER); 136 137 server_->CheckErrorCode(error); 138 ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); 139 140 client_->Handshake(); // Make sure to consume the alert the server sends. 141 } 142 }; 143 144 TEST_P(TlsSkipTest, SkipCertificateRsa) { 145 EnableOnlyStaticRsaCiphers(); 146 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 147 server_, kTlsHandshakeCertificate)); 148 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); 149 } 150 151 TEST_P(TlsSkipTest, SkipCertificateDhe) { 152 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 153 server_, kTlsHandshakeCertificate)); 154 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); 155 } 156 157 TEST_P(TlsSkipTest, SkipCertificateEcdhe) { 158 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 159 server_, kTlsHandshakeCertificate)); 160 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); 161 } 162 163 TEST_P(TlsSkipTest, SkipCertificateEcdsa) { 164 Reset(TlsAgent::kServerEcdsa256); 165 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 166 server_, kTlsHandshakeCertificate)); 167 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); 168 } 169 170 TEST_P(TlsSkipTest, SkipServerKeyExchange) { 171 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 172 server_, kTlsHandshakeServerKeyExchange)); 173 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); 174 } 175 176 TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { 177 Reset(TlsAgent::kServerEcdsa256); 178 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 179 server_, kTlsHandshakeServerKeyExchange)); 180 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); 181 } 182 183 TEST_P(TlsSkipTest, SkipCertAndKeyExch) { 184 auto chain = std::make_shared<ChainedPacketFilter>( 185 ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>( 186 server_, kTlsHandshakeCertificate), 187 std::make_shared<TlsHandshakeSkipFilter>( 188 server_, kTlsHandshakeServerKeyExchange)}); 189 ServerSkipTest(chain); 190 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); 191 } 192 193 TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { 194 Reset(TlsAgent::kServerEcdsa256); 195 auto chain = std::make_shared<ChainedPacketFilter>(); 196 chain->Add(std::make_shared<TlsHandshakeSkipFilter>( 197 server_, kTlsHandshakeCertificate)); 198 chain->Add(std::make_shared<TlsHandshakeSkipFilter>( 199 server_, kTlsHandshakeServerKeyExchange)); 200 ServerSkipTest(chain); 201 client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); 202 } 203 204 TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { 205 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 206 server_, kTlsHandshakeEncryptedExtensions), 207 SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); 208 } 209 210 TEST_P(Tls13SkipTest, SkipServerCertificate) { 211 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 212 server_, kTlsHandshakeCertificate), 213 SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); 214 } 215 216 TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { 217 ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 218 server_, kTlsHandshakeCertificateVerify), 219 SSL_ERROR_RX_UNEXPECTED_FINISHED); 220 } 221 222 TEST_P(Tls13SkipTest, SkipClientCertificate) { 223 client_->SetupClientAuth(); 224 server_->RequestClientAuth(true); 225 client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); 226 ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 227 client_, kTlsHandshakeCertificate), 228 SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); 229 } 230 231 TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { 232 client_->SetupClientAuth(); 233 server_->RequestClientAuth(true); 234 client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); 235 ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( 236 client_, kTlsHandshakeCertificateVerify), 237 SSL_ERROR_RX_UNEXPECTED_FINISHED); 238 } 239 240 INSTANTIATE_TEST_SUITE_P( 241 SkipTls10, TlsSkipTest, 242 ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, 243 TlsConnectTestBase::kTlsV10)); 244 INSTANTIATE_TEST_SUITE_P(SkipVariants, TlsSkipTest, 245 ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, 246 TlsConnectTestBase::kTlsV11V12)); 247 INSTANTIATE_TEST_SUITE_P(Skip13Variants, Tls13SkipTest, 248 TlsConnectTestBase::kTlsVariantsAll); 249 } // namespace nss_test