ssl_masking_unittest.cc (13346B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include <memory> 8 9 #include "keyhi.h" 10 #include "pk11pub.h" 11 #include "secerr.h" 12 #include "ssl.h" 13 #include "sslerr.h" 14 #include "sslexp.h" 15 #include "sslproto.h" 16 17 #include "gtest_utils.h" 18 #include "nss_scoped_ptrs.h" 19 #include "scoped_ptrs_ssl.h" 20 #include "tls_connect.h" 21 22 namespace nss_test { 23 24 // From tls_hkdf_unittest.cc: 25 extern size_t GetHashLength(SSLHashType ht); 26 27 const std::string kLabel = "sn"; 28 29 class MaskingTest : public ::testing::Test { 30 public: 31 MaskingTest() : slot_(PK11_GetInternalSlot()) {} 32 33 void InitSecret(SSLHashType hash_type) { 34 ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); 35 PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN, 36 nullptr, AES_128_KEY_LENGTH, nullptr); 37 ASSERT_NE(nullptr, s); 38 secret_.reset(s); 39 } 40 41 void SetUp() override { 42 InitSecret(ssl_hash_sha256); 43 PORT_SetError(0); 44 } 45 46 protected: 47 ScopedPK11SymKey secret_; 48 ScopedPK11SlotInfo slot_; 49 // Should have 4B ctr, 12B nonce for ChaCha, or >=16B ciphertext for AES. 50 // Use the same default size for mask output. 51 static const int kSampleSize = 16; 52 static const int kMaskSize = 16; 53 void CreateMask(PRUint16 ciphersuite, SSLProtocolVariant variant, 54 std::string label, const std::vector<uint8_t> &sample, 55 std::vector<uint8_t> *out_mask) { 56 ASSERT_NE(nullptr, out_mask); 57 SSLMaskingContext *ctx_init = nullptr; 58 EXPECT_EQ(SECSuccess, 59 SSL_CreateVariantMaskingContext( 60 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, variant, 61 secret_.get(), label.c_str(), label.size(), &ctx_init)); 62 ASSERT_NE(nullptr, ctx_init); 63 ScopedSSLMaskingContext ctx(ctx_init); 64 65 EXPECT_EQ(SECSuccess, 66 SSL_CreateMask(ctx.get(), sample.data(), sample.size(), 67 out_mask->data(), out_mask->size())); 68 bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(), 69 [](uint8_t v) { return v == 0; }); 70 71 // If out_mask is short, |all_zeros| will be (expectedly) true often enough 72 // to fail tests. 73 // In this case, just retry to make sure we're not outputting zeros 74 // continuously. 75 if (all_zeros && out_mask->size() < 3) { 76 unsigned int tries = 2; 77 std::vector<uint8_t> tmp_sample = sample; 78 std::vector<uint8_t> tmp_mask(out_mask->size()); 79 while (tries--) { 80 tmp_sample.data()[0]++; // Tweak something to get a new mask. 81 EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(), 82 tmp_sample.size(), tmp_mask.data(), 83 tmp_mask.size())); 84 bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(), 85 [](uint8_t v) { return v == 0; }); 86 if (!retry_zero) { 87 all_zeros = false; 88 break; 89 } 90 } 91 } 92 EXPECT_FALSE(all_zeros); 93 } 94 }; 95 96 class SuiteTest : public MaskingTest, 97 public ::testing::WithParamInterface<uint16_t> { 98 public: 99 SuiteTest() : ciphersuite_(GetParam()) {} 100 void CreateMask(std::string label, const std::vector<uint8_t> &sample, 101 std::vector<uint8_t> *out_mask) { 102 MaskingTest::CreateMask(ciphersuite_, ssl_variant_datagram, label, sample, 103 out_mask); 104 } 105 106 protected: 107 const uint16_t ciphersuite_; 108 }; 109 110 class VariantTest : public MaskingTest, 111 public ::testing::WithParamInterface<SSLProtocolVariant> { 112 public: 113 VariantTest() : variant_(GetParam()) {} 114 void CreateMask(uint16_t ciphersuite, std::string label, 115 const std::vector<uint8_t> &sample, 116 std::vector<uint8_t> *out_mask) { 117 MaskingTest::CreateMask(ciphersuite, variant_, label, sample, out_mask); 118 } 119 120 protected: 121 const SSLProtocolVariant variant_; 122 }; 123 124 class VariantSuiteTest : public MaskingTest, 125 public ::testing::WithParamInterface< 126 std::tuple<SSLProtocolVariant, uint16_t>> { 127 public: 128 VariantSuiteTest() 129 : variant_(std::get<0>(GetParam())), 130 ciphersuite_(std::get<1>(GetParam())) {} 131 void CreateMask(std::string label, const std::vector<uint8_t> &sample, 132 std::vector<uint8_t> *out_mask) { 133 MaskingTest::CreateMask(ciphersuite_, variant_, label, sample, out_mask); 134 } 135 136 protected: 137 const SSLProtocolVariant variant_; 138 const uint16_t ciphersuite_; 139 }; 140 141 TEST_P(VariantSuiteTest, MaskContextNoLabel) { 142 std::vector<uint8_t> sample(kSampleSize); 143 std::vector<uint8_t> mask(kMaskSize); 144 CreateMask(std::string(""), sample, &mask); 145 } 146 147 TEST_P(VariantSuiteTest, MaskNoSample) { 148 std::vector<uint8_t> mask(kMaskSize); 149 SSLMaskingContext *ctx_init = nullptr; 150 EXPECT_EQ(SECSuccess, 151 SSL_CreateVariantMaskingContext( 152 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, 153 secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); 154 ASSERT_NE(nullptr, ctx_init); 155 ScopedSSLMaskingContext ctx(ctx_init); 156 157 EXPECT_EQ(SECFailure, 158 SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size())); 159 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 160 161 EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(), 162 mask.data(), mask.size())); 163 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 164 } 165 166 TEST_P(VariantSuiteTest, MaskShortSample) { 167 std::vector<uint8_t> sample(kSampleSize); 168 std::vector<uint8_t> mask(kMaskSize); 169 SSLMaskingContext *ctx_init = nullptr; 170 EXPECT_EQ(SECSuccess, 171 SSL_CreateVariantMaskingContext( 172 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, 173 secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); 174 ASSERT_NE(nullptr, ctx_init); 175 ScopedSSLMaskingContext ctx(ctx_init); 176 177 EXPECT_EQ(SECFailure, 178 SSL_CreateMask(ctx.get(), sample.data(), sample.size() - 1, 179 mask.data(), mask.size())); 180 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 181 } 182 183 TEST_P(VariantSuiteTest, MaskContextUnsupportedMech) { 184 std::vector<uint8_t> sample(kSampleSize); 185 std::vector<uint8_t> mask(kMaskSize); 186 SSLMaskingContext *ctx_init = nullptr; 187 EXPECT_EQ(SECFailure, 188 SSL_CreateVariantMaskingContext( 189 SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA256, 190 variant_, secret_.get(), nullptr, 0, &ctx_init)); 191 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 192 EXPECT_EQ(nullptr, ctx_init); 193 } 194 195 TEST_P(VariantSuiteTest, MaskContextUnsupportedVersion) { 196 std::vector<uint8_t> sample(kSampleSize); 197 std::vector<uint8_t> mask(kMaskSize); 198 SSLMaskingContext *ctx_init = nullptr; 199 EXPECT_EQ(SECFailure, SSL_CreateVariantMaskingContext( 200 SSL_LIBRARY_VERSION_TLS_1_2, ciphersuite_, variant_, 201 secret_.get(), nullptr, 0, &ctx_init)); 202 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 203 EXPECT_EQ(nullptr, ctx_init); 204 } 205 206 TEST_P(VariantSuiteTest, MaskMaxLength) { 207 uint32_t max_mask_len = kMaskSize; 208 if (ciphersuite_ == TLS_CHACHA20_POLY1305_SHA256) { 209 // Internal limitation for ChaCha20 masks. 210 max_mask_len = 128; 211 } 212 213 std::vector<uint8_t> sample(kSampleSize); 214 std::vector<uint8_t> mask(max_mask_len + 1); 215 SSLMaskingContext *ctx_init = nullptr; 216 EXPECT_EQ(SECSuccess, 217 SSL_CreateVariantMaskingContext( 218 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, 219 secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); 220 ASSERT_NE(nullptr, ctx_init); 221 ScopedSSLMaskingContext ctx(ctx_init); 222 223 EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), 224 mask.data(), mask.size() - 1)); 225 EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), 226 mask.data(), mask.size())); 227 EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); 228 } 229 230 TEST_P(VariantSuiteTest, MaskMinLength) { 231 std::vector<uint8_t> sample(kSampleSize); 232 std::vector<uint8_t> mask(1); // Don't pass a null 233 234 SSLMaskingContext *ctx_init = nullptr; 235 EXPECT_EQ(SECSuccess, 236 SSL_CreateVariantMaskingContext( 237 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, 238 secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); 239 ASSERT_NE(nullptr, ctx_init); 240 ScopedSSLMaskingContext ctx(ctx_init); 241 EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), 242 mask.data(), 0)); 243 EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); 244 EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), 245 mask.data(), 1)); 246 } 247 248 TEST_P(VariantSuiteTest, MaskRotateLabel) { 249 std::vector<uint8_t> sample(kSampleSize); 250 std::vector<uint8_t> mask1(kMaskSize); 251 std::vector<uint8_t> mask2(kMaskSize); 252 EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), 253 sample.size())); 254 255 CreateMask(kLabel, sample, &mask1); 256 CreateMask(std::string("sn1"), sample, &mask2); 257 EXPECT_FALSE(mask1 == mask2); 258 } 259 260 TEST_P(VariantSuiteTest, MaskRotateSample) { 261 std::vector<uint8_t> sample(kSampleSize); 262 std::vector<uint8_t> mask1(kMaskSize); 263 std::vector<uint8_t> mask2(kMaskSize); 264 265 EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), 266 sample.size())); 267 CreateMask(kLabel, sample, &mask1); 268 269 EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), 270 sample.size())); 271 CreateMask(kLabel, sample, &mask2); 272 EXPECT_FALSE(mask1 == mask2); 273 } 274 275 TEST_P(VariantSuiteTest, MaskRederive) { 276 std::vector<uint8_t> sample(kSampleSize); 277 std::vector<uint8_t> mask1(kMaskSize); 278 std::vector<uint8_t> mask2(kMaskSize); 279 280 SECStatus rv = 281 PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size()); 282 EXPECT_EQ(SECSuccess, rv); 283 284 // Check that re-using inputs with a new context produces the same mask. 285 CreateMask(kLabel, sample, &mask1); 286 CreateMask(kLabel, sample, &mask2); 287 EXPECT_TRUE(mask1 == mask2); 288 } 289 290 TEST_P(SuiteTest, MaskTlsVariantKeySeparation) { 291 std::vector<uint8_t> sample(kSampleSize); 292 std::vector<uint8_t> tls_mask(kMaskSize); 293 std::vector<uint8_t> dtls_mask(kMaskSize); 294 SSLMaskingContext *stream_ctx_init = nullptr; 295 SSLMaskingContext *datagram_ctx_init = nullptr; 296 297 // Init 298 EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( 299 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, 300 ssl_variant_stream, secret_.get(), kLabel.c_str(), 301 kLabel.size(), &stream_ctx_init)); 302 ASSERT_NE(nullptr, stream_ctx_init); 303 EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( 304 SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, 305 ssl_variant_datagram, secret_.get(), kLabel.c_str(), 306 kLabel.size(), &datagram_ctx_init)); 307 ASSERT_NE(nullptr, datagram_ctx_init); 308 ScopedSSLMaskingContext tls_ctx(stream_ctx_init); 309 ScopedSSLMaskingContext dtls_ctx(datagram_ctx_init); 310 311 // Derive 312 EXPECT_EQ(SECSuccess, 313 SSL_CreateMask(tls_ctx.get(), sample.data(), sample.size(), 314 tls_mask.data(), tls_mask.size())); 315 316 EXPECT_EQ(SECSuccess, 317 SSL_CreateMask(dtls_ctx.get(), sample.data(), sample.size(), 318 dtls_mask.data(), dtls_mask.size())); 319 EXPECT_NE(tls_mask, dtls_mask); 320 } 321 322 TEST_P(VariantTest, MaskChaChaRederiveOddSizes) { 323 // Non-block-aligned. 324 std::vector<uint8_t> sample(27); 325 std::vector<uint8_t> mask1(26); 326 std::vector<uint8_t> mask2(25); 327 EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), 328 sample.size())); 329 CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1); 330 CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2); 331 mask1.pop_back(); 332 EXPECT_TRUE(mask1 == mask2); 333 } 334 335 static const uint16_t kMaskingCiphersuites[] = {TLS_CHACHA20_POLY1305_SHA256, 336 TLS_AES_128_GCM_SHA256, 337 TLS_AES_256_GCM_SHA384}; 338 ::testing::internal::ParamGenerator<uint16_t> kMaskingCiphersuiteParams = 339 ::testing::ValuesIn(kMaskingCiphersuites); 340 341 INSTANTIATE_TEST_SUITE_P(GenericMasking, SuiteTest, kMaskingCiphersuiteParams); 342 343 INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantTest, 344 TlsConnectTestBase::kTlsVariantsAll); 345 346 INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantSuiteTest, 347 ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, 348 kMaskingCiphersuiteParams)); 349 350 } // namespace nss_test