pk11_ecdh_unittest.cc (7731B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 3 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 #include <algorithm> 6 #include <memory> 7 #include "nss.h" 8 #include "pk11pub.h" 9 #include "prerror.h" 10 11 #include "cpputil.h" 12 #include "json_reader.h" 13 #include "gtest/gtest.h" 14 #include "nss_scoped_ptrs.h" 15 #include "testvectors_base/test-structs.h" 16 17 namespace nss_test { 18 19 class Pkcs11EcdhTest : public ::testing::Test { 20 protected: 21 void Derive(const std::string& curve, const EcdhTestVector& vec) { 22 std::cout << "Run test " << vec.id << std::endl; 23 24 SECItem spki_item = {siBuffer, toUcharPtr(vec.public_key.data()), 25 static_cast<unsigned int>(vec.public_key.size())}; 26 ScopedCERTSubjectPublicKeyInfo cert_spki( 27 SECKEY_DecodeDERSubjectPublicKeyInfo(&spki_item)); 28 if (vec.valid) { 29 ASSERT_TRUE(!!cert_spki); 30 } else if (!cert_spki) { 31 ASSERT_TRUE(vec.invalid_asn); 32 return; 33 } 34 35 ScopedSECKEYPublicKey pub_key(SECKEY_ExtractPublicKey(cert_spki.get())); 36 if (vec.valid) { 37 ASSERT_TRUE(!!pub_key); 38 } else if (!pub_key) { 39 return; 40 } 41 42 ScopedSECKEYPrivateKey priv_key = ImportPrivateKey(curve, vec); 43 if (vec.valid) { 44 ASSERT_TRUE(priv_key); 45 } else if (!priv_key) { 46 return; 47 } 48 49 ScopedPK11SymKey sym_key( 50 PK11_PubDeriveWithKDF(priv_key.get(), pub_key.get(), false, nullptr, 51 nullptr, CKM_ECDH1_DERIVE, CKM_SHA512_HMAC, 52 CKA_DERIVE, 0, CKD_NULL, nullptr, nullptr)); 53 if (vec.valid) { 54 ASSERT_TRUE(!!sym_key); 55 56 SECStatus rv = PK11_ExtractKeyValue(sym_key.get()); 57 EXPECT_EQ(SECSuccess, rv); 58 59 SECItem expect_item = {siBuffer, toUcharPtr(vec.secret.data()), 60 static_cast<unsigned int>(vec.secret.size())}; 61 62 SECItem* derived_key = PK11_GetKeyData(sym_key.get()); 63 EXPECT_EQ(0, SECITEM_CompareItem(derived_key, &expect_item)); 64 } else if (!vec.invalid_asn) { 65 // Invalid encodings could produce an output if we get here, so only 66 // check when the encoding is valid. 67 ASSERT_FALSE(!!sym_key); 68 } 69 }; 70 71 static void ReadTestAttr(EcdhTestVector& t, const std::string& n, 72 JsonReader& r) { 73 if (n == "public") { 74 t.public_key = r.ReadHex(); 75 } else if (n == "private") { 76 t.private_key = r.ReadHex(); 77 } else if (n == "shared") { 78 t.secret = r.ReadHex(); 79 } else { 80 FAIL() << "unsupported test case field: " << n; 81 } 82 } 83 84 void RunGroup(JsonReader& r) { 85 std::vector<EcdhTestVector> tests; 86 std::string curve; 87 while (r.NextItem()) { 88 std::string n = r.ReadLabel(); 89 if (n == "") { 90 break; 91 } 92 if (n == "curve") { 93 curve = r.ReadString(); 94 } else if (n == "encoding") { 95 ASSERT_EQ("asn", r.ReadString()); 96 } else if (n == "type") { 97 ASSERT_EQ("EcdhTest", r.ReadString()); 98 } else if (n == "tests") { 99 WycheproofReadTests(r, &tests, ReadTestAttr, false, 100 [](EcdhTestVector& t, const std::string&, 101 const std::vector<std::string>& flags) { 102 t.invalid_asn = 103 std::find(flags.begin(), flags.end(), 104 "InvalidAsn") != flags.end(); 105 }); 106 } else { 107 FAIL() << "unknown group label: " << n; 108 } 109 } 110 111 for (auto& t : tests) { 112 Derive(curve, t); 113 } 114 } 115 116 void Run(const std::string& file) { 117 WycheproofHeader(file, "ECDH", "ecdh_test_schema.json", 118 [this](JsonReader& r) { RunGroup(r); }); 119 } 120 121 private: 122 void OidForCurve(const std::string& curve, std::vector<uint8_t>* der) { 123 SECOidTag tag; 124 if (curve == "secp256r1") { 125 tag = SEC_OID_SECG_EC_SECP256R1; 126 } else if (curve == "secp384r1") { 127 tag = SEC_OID_SECG_EC_SECP384R1; 128 } else if (curve == "secp521r1") { 129 tag = SEC_OID_SECG_EC_SECP521R1; 130 } else { 131 FAIL() << "unknown curve: " << curve; 132 } 133 SECOidData* oid_data = SECOID_FindOIDByTag(tag); 134 ASSERT_TRUE(oid_data); 135 der->push_back(SEC_ASN1_OBJECT_ID); 136 der->push_back(oid_data->oid.len); 137 der->insert(der->end(), oid_data->oid.data, 138 oid_data->oid.data + oid_data->oid.len); 139 } 140 141 // Construct a garbage public value for the given curve. 142 // NSS needs a value for this, but it doesn't care what it is. 143 void PublicValue(const std::string& curve, std::vector<uint8_t>* der) { 144 size_t len; 145 if (curve == "secp256r1") { 146 len = 32; 147 } else if (curve == "secp384r1") { 148 len = 48; 149 } else if (curve == "secp521r1") { 150 len = 64; 151 } else { 152 FAIL() << "unknown curve: " << curve; 153 } 154 der->push_back(0x04); 155 for (size_t i = 0; i < len * 2; ++i) { 156 der->push_back(0x00); 157 } 158 } 159 160 void InsertLength(std::vector<uint8_t>* der, size_t offset) { 161 size_t len = der->size() - offset; 162 ASSERT_GT(256u, len) << "unsupported length for DER"; 163 if (len > 127) { 164 der->insert(der->begin() + offset, 0x81); 165 offset++; 166 } 167 der->insert(der->begin() + offset, static_cast<uint8_t>(len)); 168 } 169 170 // A very hacking PKCS#8 encoder that is sufficient to dupe NSS into 171 // thinking that it is a valid EC private key. 172 std::vector<uint8_t> BuildDerPrivateKey(const std::string& curve, 173 const EcdhTestVector& vec) { 174 std::vector<uint8_t> der; 175 std::vector<size_t> length_inserts; 176 177 der.push_back(0x30); 178 length_inserts.push_back(der.size()); 179 der.insert(der.end(), {0x02, 0x01, 0x00, 0x30}); 180 size_t oid_length_insert = der.size(); 181 der.insert(der.end(), 182 {0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01}); 183 OidForCurve(curve, &der); 184 InsertLength(&der, oid_length_insert); 185 186 der.push_back(0x04); 187 length_inserts.push_back(der.size()); 188 der.push_back(0x30); 189 length_inserts.push_back(der.size()); 190 191 der.insert(der.end(), {0x02, 0x01, 0x01, 0x04}); 192 der.push_back(vec.private_key.size()); 193 der.insert(der.end(), vec.private_key.begin(), vec.private_key.end()); 194 195 der.push_back(0xa1); 196 length_inserts.push_back(der.size()); 197 der.push_back(0x03); 198 length_inserts.push_back(der.size()); 199 der.push_back(0x00); 200 PublicValue(curve, &der); 201 202 for (auto i = length_inserts.rbegin(); i != length_inserts.rend(); ++i) { 203 InsertLength(&der, *i); 204 } 205 return der; 206 } 207 208 ScopedSECKEYPrivateKey ImportPrivateKey(const std::string& curve, 209 const EcdhTestVector& vec) { 210 ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); 211 EXPECT_TRUE(slot); 212 if (!slot) { 213 return nullptr; 214 } 215 216 std::vector<uint8_t> der = BuildDerPrivateKey(curve, vec); 217 SECItem der_item = {siBuffer, const_cast<uint8_t*>(der.data()), 218 static_cast<unsigned int>(der.size())}; 219 SECKEYPrivateKey* key = nullptr; 220 SECStatus rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( 221 slot.get(), &der_item, nullptr, nullptr, false, true, KU_KEY_AGREEMENT, 222 &key, nullptr); 223 if (vec.valid) { 224 EXPECT_EQ(SECSuccess, rv) 225 << "unable to load private key DER for test " << vec.id << ": " 226 << PORT_ErrorToString(PORT_GetError()); 227 } 228 229 return ScopedSECKEYPrivateKey(key); 230 } 231 }; 232 233 TEST_F(Pkcs11EcdhTest, P256) { Run("ecdh_secp256r1"); } 234 TEST_F(Pkcs11EcdhTest, P384) { Run("ecdh_secp384r1"); } 235 TEST_F(Pkcs11EcdhTest, P521) { Run("ecdh_secp521r1"); } 236 237 } // namespace nss_test