tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ssl_masking_unittest.cc (13346B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include <memory>
      8 
      9 #include "keyhi.h"
     10 #include "pk11pub.h"
     11 #include "secerr.h"
     12 #include "ssl.h"
     13 #include "sslerr.h"
     14 #include "sslexp.h"
     15 #include "sslproto.h"
     16 
     17 #include "gtest_utils.h"
     18 #include "nss_scoped_ptrs.h"
     19 #include "scoped_ptrs_ssl.h"
     20 #include "tls_connect.h"
     21 
     22 namespace nss_test {
     23 
     24 // From tls_hkdf_unittest.cc:
     25 extern size_t GetHashLength(SSLHashType ht);
     26 
     27 const std::string kLabel = "sn";
     28 
     29 class MaskingTest : public ::testing::Test {
     30 public:
     31  MaskingTest() : slot_(PK11_GetInternalSlot()) {}
     32 
     33  void InitSecret(SSLHashType hash_type) {
     34    ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
     35    PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN,
     36                                nullptr, AES_128_KEY_LENGTH, nullptr);
     37    ASSERT_NE(nullptr, s);
     38    secret_.reset(s);
     39  }
     40 
     41  void SetUp() override {
     42    InitSecret(ssl_hash_sha256);
     43    PORT_SetError(0);
     44  }
     45 
     46 protected:
     47  ScopedPK11SymKey secret_;
     48  ScopedPK11SlotInfo slot_;
     49  // Should have 4B ctr, 12B nonce for ChaCha, or >=16B ciphertext for AES.
     50  // Use the same default size for mask output.
     51  static const int kSampleSize = 16;
     52  static const int kMaskSize = 16;
     53  void CreateMask(PRUint16 ciphersuite, SSLProtocolVariant variant,
     54                  std::string label, const std::vector<uint8_t> &sample,
     55                  std::vector<uint8_t> *out_mask) {
     56    ASSERT_NE(nullptr, out_mask);
     57    SSLMaskingContext *ctx_init = nullptr;
     58    EXPECT_EQ(SECSuccess,
     59              SSL_CreateVariantMaskingContext(
     60                  SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, variant,
     61                  secret_.get(), label.c_str(), label.size(), &ctx_init));
     62    ASSERT_NE(nullptr, ctx_init);
     63    ScopedSSLMaskingContext ctx(ctx_init);
     64 
     65    EXPECT_EQ(SECSuccess,
     66              SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
     67                             out_mask->data(), out_mask->size()));
     68    bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(),
     69                                 [](uint8_t v) { return v == 0; });
     70 
     71    // If out_mask is short, |all_zeros| will be (expectedly) true often enough
     72    // to fail tests.
     73    // In this case, just retry to make sure we're not outputting zeros
     74    // continuously.
     75    if (all_zeros && out_mask->size() < 3) {
     76      unsigned int tries = 2;
     77      std::vector<uint8_t> tmp_sample = sample;
     78      std::vector<uint8_t> tmp_mask(out_mask->size());
     79      while (tries--) {
     80        tmp_sample.data()[0]++;  // Tweak something to get a new mask.
     81        EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(),
     82                                             tmp_sample.size(), tmp_mask.data(),
     83                                             tmp_mask.size()));
     84        bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(),
     85                                      [](uint8_t v) { return v == 0; });
     86        if (!retry_zero) {
     87          all_zeros = false;
     88          break;
     89        }
     90      }
     91    }
     92    EXPECT_FALSE(all_zeros);
     93  }
     94 };
     95 
     96 class SuiteTest : public MaskingTest,
     97                  public ::testing::WithParamInterface<uint16_t> {
     98 public:
     99  SuiteTest() : ciphersuite_(GetParam()) {}
    100  void CreateMask(std::string label, const std::vector<uint8_t> &sample,
    101                  std::vector<uint8_t> *out_mask) {
    102    MaskingTest::CreateMask(ciphersuite_, ssl_variant_datagram, label, sample,
    103                            out_mask);
    104  }
    105 
    106 protected:
    107  const uint16_t ciphersuite_;
    108 };
    109 
    110 class VariantTest : public MaskingTest,
    111                    public ::testing::WithParamInterface<SSLProtocolVariant> {
    112 public:
    113  VariantTest() : variant_(GetParam()) {}
    114  void CreateMask(uint16_t ciphersuite, std::string label,
    115                  const std::vector<uint8_t> &sample,
    116                  std::vector<uint8_t> *out_mask) {
    117    MaskingTest::CreateMask(ciphersuite, variant_, label, sample, out_mask);
    118  }
    119 
    120 protected:
    121  const SSLProtocolVariant variant_;
    122 };
    123 
    124 class VariantSuiteTest : public MaskingTest,
    125                         public ::testing::WithParamInterface<
    126                             std::tuple<SSLProtocolVariant, uint16_t>> {
    127 public:
    128  VariantSuiteTest()
    129      : variant_(std::get<0>(GetParam())),
    130        ciphersuite_(std::get<1>(GetParam())) {}
    131  void CreateMask(std::string label, const std::vector<uint8_t> &sample,
    132                  std::vector<uint8_t> *out_mask) {
    133    MaskingTest::CreateMask(ciphersuite_, variant_, label, sample, out_mask);
    134  }
    135 
    136 protected:
    137  const SSLProtocolVariant variant_;
    138  const uint16_t ciphersuite_;
    139 };
    140 
    141 TEST_P(VariantSuiteTest, MaskContextNoLabel) {
    142  std::vector<uint8_t> sample(kSampleSize);
    143  std::vector<uint8_t> mask(kMaskSize);
    144  CreateMask(std::string(""), sample, &mask);
    145 }
    146 
    147 TEST_P(VariantSuiteTest, MaskNoSample) {
    148  std::vector<uint8_t> mask(kMaskSize);
    149  SSLMaskingContext *ctx_init = nullptr;
    150  EXPECT_EQ(SECSuccess,
    151            SSL_CreateVariantMaskingContext(
    152                SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
    153                secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
    154  ASSERT_NE(nullptr, ctx_init);
    155  ScopedSSLMaskingContext ctx(ctx_init);
    156 
    157  EXPECT_EQ(SECFailure,
    158            SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size()));
    159  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    160 
    161  EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(),
    162                                       mask.data(), mask.size()));
    163  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    164 }
    165 
    166 TEST_P(VariantSuiteTest, MaskShortSample) {
    167  std::vector<uint8_t> sample(kSampleSize);
    168  std::vector<uint8_t> mask(kMaskSize);
    169  SSLMaskingContext *ctx_init = nullptr;
    170  EXPECT_EQ(SECSuccess,
    171            SSL_CreateVariantMaskingContext(
    172                SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
    173                secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
    174  ASSERT_NE(nullptr, ctx_init);
    175  ScopedSSLMaskingContext ctx(ctx_init);
    176 
    177  EXPECT_EQ(SECFailure,
    178            SSL_CreateMask(ctx.get(), sample.data(), sample.size() - 1,
    179                           mask.data(), mask.size()));
    180  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    181 }
    182 
    183 TEST_P(VariantSuiteTest, MaskContextUnsupportedMech) {
    184  std::vector<uint8_t> sample(kSampleSize);
    185  std::vector<uint8_t> mask(kMaskSize);
    186  SSLMaskingContext *ctx_init = nullptr;
    187  EXPECT_EQ(SECFailure,
    188            SSL_CreateVariantMaskingContext(
    189                SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA256,
    190                variant_, secret_.get(), nullptr, 0, &ctx_init));
    191  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    192  EXPECT_EQ(nullptr, ctx_init);
    193 }
    194 
    195 TEST_P(VariantSuiteTest, MaskContextUnsupportedVersion) {
    196  std::vector<uint8_t> sample(kSampleSize);
    197  std::vector<uint8_t> mask(kMaskSize);
    198  SSLMaskingContext *ctx_init = nullptr;
    199  EXPECT_EQ(SECFailure, SSL_CreateVariantMaskingContext(
    200                            SSL_LIBRARY_VERSION_TLS_1_2, ciphersuite_, variant_,
    201                            secret_.get(), nullptr, 0, &ctx_init));
    202  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    203  EXPECT_EQ(nullptr, ctx_init);
    204 }
    205 
    206 TEST_P(VariantSuiteTest, MaskMaxLength) {
    207  uint32_t max_mask_len = kMaskSize;
    208  if (ciphersuite_ == TLS_CHACHA20_POLY1305_SHA256) {
    209    // Internal limitation for ChaCha20 masks.
    210    max_mask_len = 128;
    211  }
    212 
    213  std::vector<uint8_t> sample(kSampleSize);
    214  std::vector<uint8_t> mask(max_mask_len + 1);
    215  SSLMaskingContext *ctx_init = nullptr;
    216  EXPECT_EQ(SECSuccess,
    217            SSL_CreateVariantMaskingContext(
    218                SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
    219                secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
    220  ASSERT_NE(nullptr, ctx_init);
    221  ScopedSSLMaskingContext ctx(ctx_init);
    222 
    223  EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
    224                                       mask.data(), mask.size() - 1));
    225  EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
    226                                       mask.data(), mask.size()));
    227  EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError());
    228 }
    229 
    230 TEST_P(VariantSuiteTest, MaskMinLength) {
    231  std::vector<uint8_t> sample(kSampleSize);
    232  std::vector<uint8_t> mask(1);  // Don't pass a null
    233 
    234  SSLMaskingContext *ctx_init = nullptr;
    235  EXPECT_EQ(SECSuccess,
    236            SSL_CreateVariantMaskingContext(
    237                SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
    238                secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
    239  ASSERT_NE(nullptr, ctx_init);
    240  ScopedSSLMaskingContext ctx(ctx_init);
    241  EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
    242                                       mask.data(), 0));
    243  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    244  EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
    245                                       mask.data(), 1));
    246 }
    247 
    248 TEST_P(VariantSuiteTest, MaskRotateLabel) {
    249  std::vector<uint8_t> sample(kSampleSize);
    250  std::vector<uint8_t> mask1(kMaskSize);
    251  std::vector<uint8_t> mask2(kMaskSize);
    252  EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
    253                                                  sample.size()));
    254 
    255  CreateMask(kLabel, sample, &mask1);
    256  CreateMask(std::string("sn1"), sample, &mask2);
    257  EXPECT_FALSE(mask1 == mask2);
    258 }
    259 
    260 TEST_P(VariantSuiteTest, MaskRotateSample) {
    261  std::vector<uint8_t> sample(kSampleSize);
    262  std::vector<uint8_t> mask1(kMaskSize);
    263  std::vector<uint8_t> mask2(kMaskSize);
    264 
    265  EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
    266                                                  sample.size()));
    267  CreateMask(kLabel, sample, &mask1);
    268 
    269  EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
    270                                                  sample.size()));
    271  CreateMask(kLabel, sample, &mask2);
    272  EXPECT_FALSE(mask1 == mask2);
    273 }
    274 
    275 TEST_P(VariantSuiteTest, MaskRederive) {
    276  std::vector<uint8_t> sample(kSampleSize);
    277  std::vector<uint8_t> mask1(kMaskSize);
    278  std::vector<uint8_t> mask2(kMaskSize);
    279 
    280  SECStatus rv =
    281      PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size());
    282  EXPECT_EQ(SECSuccess, rv);
    283 
    284  // Check that re-using inputs with a new context produces the same mask.
    285  CreateMask(kLabel, sample, &mask1);
    286  CreateMask(kLabel, sample, &mask2);
    287  EXPECT_TRUE(mask1 == mask2);
    288 }
    289 
    290 TEST_P(SuiteTest, MaskTlsVariantKeySeparation) {
    291  std::vector<uint8_t> sample(kSampleSize);
    292  std::vector<uint8_t> tls_mask(kMaskSize);
    293  std::vector<uint8_t> dtls_mask(kMaskSize);
    294  SSLMaskingContext *stream_ctx_init = nullptr;
    295  SSLMaskingContext *datagram_ctx_init = nullptr;
    296 
    297  // Init
    298  EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext(
    299                            SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_,
    300                            ssl_variant_stream, secret_.get(), kLabel.c_str(),
    301                            kLabel.size(), &stream_ctx_init));
    302  ASSERT_NE(nullptr, stream_ctx_init);
    303  EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext(
    304                            SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_,
    305                            ssl_variant_datagram, secret_.get(), kLabel.c_str(),
    306                            kLabel.size(), &datagram_ctx_init));
    307  ASSERT_NE(nullptr, datagram_ctx_init);
    308  ScopedSSLMaskingContext tls_ctx(stream_ctx_init);
    309  ScopedSSLMaskingContext dtls_ctx(datagram_ctx_init);
    310 
    311  // Derive
    312  EXPECT_EQ(SECSuccess,
    313            SSL_CreateMask(tls_ctx.get(), sample.data(), sample.size(),
    314                           tls_mask.data(), tls_mask.size()));
    315 
    316  EXPECT_EQ(SECSuccess,
    317            SSL_CreateMask(dtls_ctx.get(), sample.data(), sample.size(),
    318                           dtls_mask.data(), dtls_mask.size()));
    319  EXPECT_NE(tls_mask, dtls_mask);
    320 }
    321 
    322 TEST_P(VariantTest, MaskChaChaRederiveOddSizes) {
    323  // Non-block-aligned.
    324  std::vector<uint8_t> sample(27);
    325  std::vector<uint8_t> mask1(26);
    326  std::vector<uint8_t> mask2(25);
    327  EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
    328                                                  sample.size()));
    329  CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1);
    330  CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2);
    331  mask1.pop_back();
    332  EXPECT_TRUE(mask1 == mask2);
    333 }
    334 
    335 static const uint16_t kMaskingCiphersuites[] = {TLS_CHACHA20_POLY1305_SHA256,
    336                                                TLS_AES_128_GCM_SHA256,
    337                                                TLS_AES_256_GCM_SHA384};
    338 ::testing::internal::ParamGenerator<uint16_t> kMaskingCiphersuiteParams =
    339    ::testing::ValuesIn(kMaskingCiphersuites);
    340 
    341 INSTANTIATE_TEST_SUITE_P(GenericMasking, SuiteTest, kMaskingCiphersuiteParams);
    342 
    343 INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantTest,
    344                         TlsConnectTestBase::kTlsVariantsAll);
    345 
    346 INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantSuiteTest,
    347                         ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
    348                                            kMaskingCiphersuiteParams));
    349 
    350 }  // namespace nss_test