tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

pk11_ecdh_unittest.cc (7731B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      3 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 #include <algorithm>
      6 #include <memory>
      7 #include "nss.h"
      8 #include "pk11pub.h"
      9 #include "prerror.h"
     10 
     11 #include "cpputil.h"
     12 #include "json_reader.h"
     13 #include "gtest/gtest.h"
     14 #include "nss_scoped_ptrs.h"
     15 #include "testvectors_base/test-structs.h"
     16 
     17 namespace nss_test {
     18 
     19 class Pkcs11EcdhTest : public ::testing::Test {
     20 protected:
     21  void Derive(const std::string& curve, const EcdhTestVector& vec) {
     22    std::cout << "Run test " << vec.id << std::endl;
     23 
     24    SECItem spki_item = {siBuffer, toUcharPtr(vec.public_key.data()),
     25                         static_cast<unsigned int>(vec.public_key.size())};
     26    ScopedCERTSubjectPublicKeyInfo cert_spki(
     27        SECKEY_DecodeDERSubjectPublicKeyInfo(&spki_item));
     28    if (vec.valid) {
     29      ASSERT_TRUE(!!cert_spki);
     30    } else if (!cert_spki) {
     31      ASSERT_TRUE(vec.invalid_asn);
     32      return;
     33    }
     34 
     35    ScopedSECKEYPublicKey pub_key(SECKEY_ExtractPublicKey(cert_spki.get()));
     36    if (vec.valid) {
     37      ASSERT_TRUE(!!pub_key);
     38    } else if (!pub_key) {
     39      return;
     40    }
     41 
     42    ScopedSECKEYPrivateKey priv_key = ImportPrivateKey(curve, vec);
     43    if (vec.valid) {
     44      ASSERT_TRUE(priv_key);
     45    } else if (!priv_key) {
     46      return;
     47    }
     48 
     49    ScopedPK11SymKey sym_key(
     50        PK11_PubDeriveWithKDF(priv_key.get(), pub_key.get(), false, nullptr,
     51                              nullptr, CKM_ECDH1_DERIVE, CKM_SHA512_HMAC,
     52                              CKA_DERIVE, 0, CKD_NULL, nullptr, nullptr));
     53    if (vec.valid) {
     54      ASSERT_TRUE(!!sym_key);
     55 
     56      SECStatus rv = PK11_ExtractKeyValue(sym_key.get());
     57      EXPECT_EQ(SECSuccess, rv);
     58 
     59      SECItem expect_item = {siBuffer, toUcharPtr(vec.secret.data()),
     60                             static_cast<unsigned int>(vec.secret.size())};
     61 
     62      SECItem* derived_key = PK11_GetKeyData(sym_key.get());
     63      EXPECT_EQ(0, SECITEM_CompareItem(derived_key, &expect_item));
     64    } else if (!vec.invalid_asn) {
     65      // Invalid encodings could produce an output if we get here, so only
     66      // check when the encoding is valid.
     67      ASSERT_FALSE(!!sym_key);
     68    }
     69  };
     70 
     71  static void ReadTestAttr(EcdhTestVector& t, const std::string& n,
     72                           JsonReader& r) {
     73    if (n == "public") {
     74      t.public_key = r.ReadHex();
     75    } else if (n == "private") {
     76      t.private_key = r.ReadHex();
     77    } else if (n == "shared") {
     78      t.secret = r.ReadHex();
     79    } else {
     80      FAIL() << "unsupported test case field: " << n;
     81    }
     82  }
     83 
     84  void RunGroup(JsonReader& r) {
     85    std::vector<EcdhTestVector> tests;
     86    std::string curve;
     87    while (r.NextItem()) {
     88      std::string n = r.ReadLabel();
     89      if (n == "") {
     90        break;
     91      }
     92      if (n == "curve") {
     93        curve = r.ReadString();
     94      } else if (n == "encoding") {
     95        ASSERT_EQ("asn", r.ReadString());
     96      } else if (n == "type") {
     97        ASSERT_EQ("EcdhTest", r.ReadString());
     98      } else if (n == "tests") {
     99        WycheproofReadTests(r, &tests, ReadTestAttr, false,
    100                            [](EcdhTestVector& t, const std::string&,
    101                               const std::vector<std::string>& flags) {
    102                              t.invalid_asn =
    103                                  std::find(flags.begin(), flags.end(),
    104                                            "InvalidAsn") != flags.end();
    105                            });
    106      } else {
    107        FAIL() << "unknown group label: " << n;
    108      }
    109    }
    110 
    111    for (auto& t : tests) {
    112      Derive(curve, t);
    113    }
    114  }
    115 
    116  void Run(const std::string& file) {
    117    WycheproofHeader(file, "ECDH", "ecdh_test_schema.json",
    118                     [this](JsonReader& r) { RunGroup(r); });
    119  }
    120 
    121 private:
    122  void OidForCurve(const std::string& curve, std::vector<uint8_t>* der) {
    123    SECOidTag tag;
    124    if (curve == "secp256r1") {
    125      tag = SEC_OID_SECG_EC_SECP256R1;
    126    } else if (curve == "secp384r1") {
    127      tag = SEC_OID_SECG_EC_SECP384R1;
    128    } else if (curve == "secp521r1") {
    129      tag = SEC_OID_SECG_EC_SECP521R1;
    130    } else {
    131      FAIL() << "unknown curve: " << curve;
    132    }
    133    SECOidData* oid_data = SECOID_FindOIDByTag(tag);
    134    ASSERT_TRUE(oid_data);
    135    der->push_back(SEC_ASN1_OBJECT_ID);
    136    der->push_back(oid_data->oid.len);
    137    der->insert(der->end(), oid_data->oid.data,
    138                oid_data->oid.data + oid_data->oid.len);
    139  }
    140 
    141  // Construct a garbage public value for the given curve.
    142  // NSS needs a value for this, but it doesn't care what it is.
    143  void PublicValue(const std::string& curve, std::vector<uint8_t>* der) {
    144    size_t len;
    145    if (curve == "secp256r1") {
    146      len = 32;
    147    } else if (curve == "secp384r1") {
    148      len = 48;
    149    } else if (curve == "secp521r1") {
    150      len = 64;
    151    } else {
    152      FAIL() << "unknown curve: " << curve;
    153    }
    154    der->push_back(0x04);
    155    for (size_t i = 0; i < len * 2; ++i) {
    156      der->push_back(0x00);
    157    }
    158  }
    159 
    160  void InsertLength(std::vector<uint8_t>* der, size_t offset) {
    161    size_t len = der->size() - offset;
    162    ASSERT_GT(256u, len) << "unsupported length for DER";
    163    if (len > 127) {
    164      der->insert(der->begin() + offset, 0x81);
    165      offset++;
    166    }
    167    der->insert(der->begin() + offset, static_cast<uint8_t>(len));
    168  }
    169 
    170  // A very hacking PKCS#8 encoder that is sufficient to dupe NSS into
    171  // thinking that it is a valid EC private key.
    172  std::vector<uint8_t> BuildDerPrivateKey(const std::string& curve,
    173                                          const EcdhTestVector& vec) {
    174    std::vector<uint8_t> der;
    175    std::vector<size_t> length_inserts;
    176 
    177    der.push_back(0x30);
    178    length_inserts.push_back(der.size());
    179    der.insert(der.end(), {0x02, 0x01, 0x00, 0x30});
    180    size_t oid_length_insert = der.size();
    181    der.insert(der.end(),
    182               {0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01});
    183    OidForCurve(curve, &der);
    184    InsertLength(&der, oid_length_insert);
    185 
    186    der.push_back(0x04);
    187    length_inserts.push_back(der.size());
    188    der.push_back(0x30);
    189    length_inserts.push_back(der.size());
    190 
    191    der.insert(der.end(), {0x02, 0x01, 0x01, 0x04});
    192    der.push_back(vec.private_key.size());
    193    der.insert(der.end(), vec.private_key.begin(), vec.private_key.end());
    194 
    195    der.push_back(0xa1);
    196    length_inserts.push_back(der.size());
    197    der.push_back(0x03);
    198    length_inserts.push_back(der.size());
    199    der.push_back(0x00);
    200    PublicValue(curve, &der);
    201 
    202    for (auto i = length_inserts.rbegin(); i != length_inserts.rend(); ++i) {
    203      InsertLength(&der, *i);
    204    }
    205    return der;
    206  }
    207 
    208  ScopedSECKEYPrivateKey ImportPrivateKey(const std::string& curve,
    209                                          const EcdhTestVector& vec) {
    210    ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    211    EXPECT_TRUE(slot);
    212    if (!slot) {
    213      return nullptr;
    214    }
    215 
    216    std::vector<uint8_t> der = BuildDerPrivateKey(curve, vec);
    217    SECItem der_item = {siBuffer, const_cast<uint8_t*>(der.data()),
    218                        static_cast<unsigned int>(der.size())};
    219    SECKEYPrivateKey* key = nullptr;
    220    SECStatus rv = PK11_ImportDERPrivateKeyInfoAndReturnKey(
    221        slot.get(), &der_item, nullptr, nullptr, false, true, KU_KEY_AGREEMENT,
    222        &key, nullptr);
    223    if (vec.valid) {
    224      EXPECT_EQ(SECSuccess, rv)
    225          << "unable to load private key DER for test " << vec.id << ": "
    226          << PORT_ErrorToString(PORT_GetError());
    227    }
    228 
    229    return ScopedSECKEYPrivateKey(key);
    230  }
    231 };
    232 
    233 TEST_F(Pkcs11EcdhTest, P256) { Run("ecdh_secp256r1"); }
    234 TEST_F(Pkcs11EcdhTest, P384) { Run("ecdh_secp384r1"); }
    235 TEST_F(Pkcs11EcdhTest, P521) { Run("ecdh_secp521r1"); }
    236 
    237 }  // namespace nss_test