tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

pk11_hpke_unittest.cc (31593B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include <memory>
      8 #include "blapi.h"
      9 #include "gtest/gtest.h"
     10 #include "json_reader.h"
     11 #include "nss.h"
     12 #include "nss_scoped_ptrs.h"
     13 #include "pk11hpke.h"
     14 #include "pk11pub.h"
     15 #include "secerr.h"
     16 #include "sechash.h"
     17 #include "util.h"
     18 
     19 extern std::string g_source_dir;
     20 
     21 namespace nss_test {
     22 
     23 /* See note in pk11pub.h. */
     24 #include "cpputil.h"
     25 
     26 class HpkeTest {
     27 protected:
     28  void CheckEquality(const std::vector<uint8_t> &expected, SECItem *actual) {
     29    if (!actual) {
     30      EXPECT_TRUE(expected.empty());
     31      return;
     32    }
     33    std::vector<uint8_t> vact(actual->data, actual->data + actual->len);
     34    EXPECT_EQ(expected, vact);
     35  }
     36 
     37  void CheckEquality(SECItem *expected, SECItem *actual) {
     38    EXPECT_EQ(!!expected, !!actual);
     39    if (expected && actual) {
     40      EXPECT_EQ(expected->len, actual->len);
     41      if (expected->len == actual->len) {
     42        EXPECT_EQ(0, memcmp(expected->data, actual->data, actual->len));
     43      }
     44    }
     45  }
     46 
     47  void CheckEquality(const std::vector<uint8_t> &expected, PK11SymKey *actual) {
     48    if (!actual) {
     49      EXPECT_TRUE(expected.empty());
     50      return;
     51    }
     52    SECStatus rv = PK11_ExtractKeyValue(actual);
     53    EXPECT_EQ(SECSuccess, rv);
     54    if (rv != SECSuccess) {
     55      return;
     56    }
     57    SECItem *rawkey = PK11_GetKeyData(actual);
     58    CheckEquality(expected, rawkey);
     59  }
     60 
     61  void CheckEquality(PK11SymKey *expected, PK11SymKey *actual) {
     62    if (!actual || !expected) {
     63      EXPECT_EQ(!!expected, !!actual);
     64      return;
     65    }
     66    SECStatus rv = PK11_ExtractKeyValue(expected);
     67    EXPECT_EQ(SECSuccess, rv);
     68    if (rv != SECSuccess) {
     69      return;
     70    }
     71    SECItem *raw = PK11_GetKeyData(expected);
     72    ASSERT_NE(nullptr, raw);
     73    ASSERT_NE(nullptr, raw->data);
     74    std::vector<uint8_t> expected_vec(raw->data, raw->data + raw->len);
     75    CheckEquality(expected_vec, actual);
     76  }
     77 
     78  void Seal(const ScopedHpkeContext &cx, const std::vector<uint8_t> &aad_vec,
     79            const std::vector<uint8_t> &pt_vec,
     80            std::vector<uint8_t> *out_sealed) {
     81    SECItem aad_item = {siBuffer, toUcharPtr(aad_vec.data()),
     82                        static_cast<unsigned int>(aad_vec.size())};
     83    SECItem pt_item = {siBuffer, toUcharPtr(pt_vec.data()),
     84                       static_cast<unsigned int>(pt_vec.size())};
     85 
     86    SECItem *sealed_item = nullptr;
     87    EXPECT_EQ(SECSuccess,
     88              PK11_HPKE_Seal(cx.get(), &aad_item, &pt_item, &sealed_item));
     89    ASSERT_NE(nullptr, sealed_item);
     90    ScopedSECItem sealed(sealed_item);
     91    out_sealed->assign(sealed->data, sealed->data + sealed->len);
     92  }
     93 
     94  void Open(const ScopedHpkeContext &cx, const std::vector<uint8_t> &aad_vec,
     95            const std::vector<uint8_t> &ct_vec,
     96            std::vector<uint8_t> *out_opened) {
     97    SECItem aad_item = {siBuffer, toUcharPtr(aad_vec.data()),
     98                        static_cast<unsigned int>(aad_vec.size())};
     99    SECItem ct_item = {siBuffer, toUcharPtr(ct_vec.data()),
    100                       static_cast<unsigned int>(ct_vec.size())};
    101    SECItem *opened_item = nullptr;
    102    EXPECT_EQ(SECSuccess,
    103              PK11_HPKE_Open(cx.get(), &aad_item, &ct_item, &opened_item));
    104    ASSERT_NE(nullptr, opened_item);
    105    ScopedSECItem opened(opened_item);
    106    out_opened->assign(opened->data, opened->data + opened->len);
    107  }
    108 
    109  void SealOpen(const ScopedHpkeContext &sender,
    110                const ScopedHpkeContext &receiver,
    111                const std::vector<uint8_t> &msg,
    112                const std::vector<uint8_t> &aad,
    113                const std::vector<uint8_t> *expect) {
    114    std::vector<uint8_t> sealed;
    115    std::vector<uint8_t> opened;
    116    Seal(sender, aad, msg, &sealed);
    117    if (expect) {
    118      EXPECT_EQ(*expect, sealed);
    119    }
    120    Open(receiver, aad, sealed, &opened);
    121    EXPECT_EQ(msg, opened);
    122  }
    123 
    124  void ExportSecret(const ScopedHpkeContext &receiver,
    125                    ScopedPK11SymKey &exported) {
    126    std::vector<uint8_t> context = {'c', 't', 'x', 't'};
    127    SECItem context_item = {siBuffer, context.data(),
    128                            static_cast<unsigned int>(context.size())};
    129    PK11SymKey *tmp_exported = nullptr;
    130    ASSERT_EQ(SECSuccess, PK11_HPKE_ExportSecret(receiver.get(), &context_item,
    131                                                 64, &tmp_exported));
    132    exported.reset(tmp_exported);
    133  }
    134 
    135  void ExportImportRecvContext(ScopedHpkeContext &scoped_cx,
    136                               PK11SymKey *wrapping_key) {
    137    SECItem *tmp_exported = nullptr;
    138    EXPECT_EQ(SECSuccess, PK11_HPKE_ExportContext(scoped_cx.get(), wrapping_key,
    139                                                  &tmp_exported));
    140    EXPECT_NE(nullptr, tmp_exported);
    141    ScopedSECItem context(tmp_exported);
    142    scoped_cx.reset();
    143 
    144    HpkeContext *tmp_imported =
    145        PK11_HPKE_ImportContext(context.get(), wrapping_key);
    146    EXPECT_NE(nullptr, tmp_imported);
    147    scoped_cx.reset(tmp_imported);
    148  }
    149 
    150  bool GenerateKeyPair(ScopedSECKEYPublicKey &pub_key,
    151                       ScopedSECKEYPrivateKey &priv_key) {
    152    ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    153    if (!slot) {
    154      ADD_FAILURE() << "Couldn't get slot";
    155      return false;
    156    }
    157 
    158    unsigned char param_buf[65];
    159    SECItem ecdsa_params = {siBuffer, param_buf, sizeof(param_buf)};
    160    SECOidData *oid_data = SECOID_FindOIDByTag(SEC_OID_CURVE25519);
    161    if (!oid_data) {
    162      ADD_FAILURE() << "Couldn't get oid_data";
    163      return false;
    164    }
    165    ecdsa_params.data[0] = SEC_ASN1_OBJECT_ID;
    166    ecdsa_params.data[1] = oid_data->oid.len;
    167    memcpy(ecdsa_params.data + 2, oid_data->oid.data, oid_data->oid.len);
    168    ecdsa_params.len = oid_data->oid.len + 2;
    169 
    170    SECKEYPublicKey *pub_tmp;
    171    SECKEYPrivateKey *priv_tmp;
    172    priv_tmp =
    173        PK11_GenerateKeyPair(slot.get(), CKM_EC_KEY_PAIR_GEN, &ecdsa_params,
    174                             &pub_tmp, PR_FALSE, PR_TRUE, nullptr);
    175    if (!pub_tmp || !priv_tmp) {
    176      ADD_FAILURE() << "PK11_GenerateKeyPair failed";
    177      return false;
    178    }
    179 
    180    pub_key.reset(pub_tmp);
    181    priv_key.reset(priv_tmp);
    182    return true;
    183  }
    184 
    185  void SetUpEphemeralContexts(ScopedHpkeContext &sender,
    186                              ScopedHpkeContext &receiver,
    187                              HpkeModeId mode = HpkeModeBase,
    188                              HpkeKemId kem = HpkeDhKemX25519Sha256,
    189                              HpkeKdfId kdf = HpkeKdfHkdfSha256,
    190                              HpkeAeadId aead = HpkeAeadAes128Gcm) {
    191    // Generate a PSK, if the mode calls for it.
    192    PRUint8 psk_id_buf[] = {'p', 's', 'k', '-', 'i', 'd'};
    193    SECItem psk_id = {siBuffer, psk_id_buf, sizeof(psk_id_buf)};
    194    SECItem *psk_id_item = (mode == HpkeModePsk) ? &psk_id : nullptr;
    195    ScopedPK11SymKey psk;
    196    if (mode == HpkeModePsk) {
    197      ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    198      ASSERT_TRUE(slot);
    199      PK11SymKey *tmp_psk =
    200          PK11_KeyGen(slot.get(), CKM_HKDF_DERIVE, nullptr, 16, nullptr);
    201      ASSERT_NE(nullptr, tmp_psk);
    202      psk.reset(tmp_psk);
    203    }
    204 
    205    std::vector<uint8_t> info = {'t', 'e', 's', 't', '-', 'i', 'n', 'f', 'o'};
    206    SECItem info_item = {siBuffer, info.data(),
    207                         static_cast<unsigned int>(info.size())};
    208    sender.reset(PK11_HPKE_NewContext(kem, kdf, aead, psk.get(), psk_id_item));
    209    receiver.reset(
    210        PK11_HPKE_NewContext(kem, kdf, aead, psk.get(), psk_id_item));
    211    ASSERT_TRUE(sender);
    212    ASSERT_TRUE(receiver);
    213 
    214    ScopedSECKEYPublicKey pub_key_r;
    215    ScopedSECKEYPrivateKey priv_key_r;
    216    ASSERT_TRUE(GenerateKeyPair(pub_key_r, priv_key_r));
    217    EXPECT_EQ(SECSuccess, PK11_HPKE_SetupS(sender.get(), nullptr, nullptr,
    218                                           pub_key_r.get(), &info_item));
    219 
    220    const SECItem *enc = PK11_HPKE_GetEncapPubKey(sender.get());
    221    EXPECT_NE(nullptr, enc);
    222    EXPECT_EQ(SECSuccess, PK11_HPKE_SetupR(
    223                              receiver.get(), pub_key_r.get(), priv_key_r.get(),
    224                              const_cast<SECItem *>(enc), &info_item));
    225  }
    226 };
    227 
    228 struct HpkeEncryptVector {
    229  std::vector<uint8_t> pt;
    230  std::vector<uint8_t> aad;
    231  std::vector<uint8_t> ct;
    232 
    233  static std::vector<HpkeEncryptVector> ReadVec(JsonReader &r) {
    234    std::vector<HpkeEncryptVector> all;
    235 
    236    while (r.NextItemArray()) {
    237      HpkeEncryptVector enc;
    238      while (r.NextItem()) {
    239        std::string n = r.ReadLabel();
    240        if (n == "") {
    241          break;
    242        }
    243        if (n == "plaintext") {
    244          enc.pt = r.ReadHex();
    245        } else if (n == "aad") {
    246          enc.aad = r.ReadHex();
    247        } else if (n == "ciphertext") {
    248          enc.ct = r.ReadHex();
    249        } else {
    250          r.SkipValue();
    251        }
    252      }
    253      all.push_back(enc);
    254    }
    255 
    256    return all;
    257  }
    258 };
    259 
    260 struct HpkeExportVector {
    261  std::vector<uint8_t> ctxt;
    262  size_t len;
    263  std::vector<uint8_t> exported;
    264 
    265  static std::vector<HpkeExportVector> ReadVec(JsonReader &r) {
    266    std::vector<HpkeExportVector> all;
    267 
    268    while (r.NextItemArray()) {
    269      HpkeExportVector exp;
    270      while (r.NextItem()) {
    271        std::string n = r.ReadLabel();
    272        if (n == "") {
    273          break;
    274        }
    275        if (n == "exporter_context") {
    276          exp.ctxt = r.ReadHex();
    277        } else if (n == "L") {
    278          exp.len = r.ReadInt();
    279        } else if (n == "exported_value") {
    280          exp.exported = r.ReadHex();
    281        } else {
    282          r.SkipValue();
    283        }
    284      }
    285      all.push_back(exp);
    286    }
    287 
    288    return all;
    289  }
    290 };
    291 
    292 struct HpkeVector {
    293  uint32_t test_id;
    294  HpkeModeId mode;
    295  HpkeKemId kem_id;
    296  HpkeKdfId kdf_id;
    297  HpkeAeadId aead_id;
    298  std::vector<uint8_t> info;
    299  std::vector<uint8_t> pkcs8_e;
    300  std::vector<uint8_t> pkcs8_r;
    301  std::vector<uint8_t> psk;
    302  std::vector<uint8_t> psk_id;
    303  std::vector<uint8_t> enc;
    304  std::vector<uint8_t> key;
    305  std::vector<uint8_t> nonce;
    306  std::vector<HpkeEncryptVector> encryptions;
    307  std::vector<HpkeExportVector> exports;
    308 
    309  static std::vector<uint8_t> Pkcs8(const std::vector<uint8_t> &sk,
    310                                    const std::vector<uint8_t> &pk) {
    311    // Only X25519 format.
    312    std::vector<uint8_t> v(105);
    313    v.assign({0x30, 0x67, 0x02, 0x01, 0x00, 0x30, 0x14, 0x06, 0x07,
    314              0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x09,
    315              0x2b, 0x06, 0x01, 0x04, 0x01, 0xda, 0x47, 0x0f, 0x01,
    316              0x04, 0x4c, 0x30, 0x4a, 0x02, 0x01, 0x01, 0x04, 0x20});
    317    v.insert(v.end(), sk.begin(), sk.end());
    318    v.insert(v.end(), {0xa1, 0x23, 0x03, 0x21, 0x00});
    319    v.insert(v.end(), pk.begin(), pk.end());
    320    return v;
    321  }
    322 
    323  static std::vector<HpkeVector> Read(JsonReader &r) {
    324    std::vector<HpkeVector> all_tests;
    325    uint32_t test_id = 0;
    326 
    327    while (r.NextItemArray()) {
    328      HpkeVector vec = {0};
    329      uint32_t fields = 0;
    330      enum class RequiredFields {
    331        mode,
    332        kem,
    333        kdf,
    334        aead,
    335        skEm,
    336        skRm,
    337        pkEm,
    338        pkRm,
    339        all
    340      };
    341      std::vector<uint8_t> sk_e, pk_e, sk_r, pk_r;
    342      test_id++;
    343 
    344      while (r.NextItem()) {
    345        std::string n = r.ReadLabel();
    346        if (n == "") {
    347          break;
    348        }
    349        if (n == "mode") {
    350          vec.mode = static_cast<HpkeModeId>(r.ReadInt());
    351          fields |= 1 << static_cast<uint32_t>(RequiredFields::mode);
    352        } else if (n == "kem_id") {
    353          vec.kem_id = static_cast<HpkeKemId>(r.ReadInt());
    354          fields |= 1 << static_cast<uint32_t>(RequiredFields::kem);
    355        } else if (n == "kdf_id") {
    356          vec.kdf_id = static_cast<HpkeKdfId>(r.ReadInt());
    357          fields |= 1 << static_cast<uint32_t>(RequiredFields::kdf);
    358        } else if (n == "aead_id") {
    359          vec.aead_id = static_cast<HpkeAeadId>(r.ReadInt());
    360          fields |= 1 << static_cast<uint32_t>(RequiredFields::aead);
    361        } else if (n == "info") {
    362          vec.info = r.ReadHex();
    363        } else if (n == "skEm") {
    364          sk_e = r.ReadHex();
    365          fields |= 1 << static_cast<uint32_t>(RequiredFields::skEm);
    366        } else if (n == "pkEm") {
    367          pk_e = r.ReadHex();
    368          fields |= 1 << static_cast<uint32_t>(RequiredFields::pkEm);
    369        } else if (n == "skRm") {
    370          sk_r = r.ReadHex();
    371          fields |= 1 << static_cast<uint32_t>(RequiredFields::skRm);
    372        } else if (n == "pkRm") {
    373          pk_r = r.ReadHex();
    374          fields |= 1 << static_cast<uint32_t>(RequiredFields::pkRm);
    375        } else if (n == "psk") {
    376          vec.psk = r.ReadHex();
    377        } else if (n == "psk_id") {
    378          vec.psk_id = r.ReadHex();
    379        } else if (n == "enc") {
    380          vec.enc = r.ReadHex();
    381        } else if (n == "key") {
    382          vec.key = r.ReadHex();
    383        } else if (n == "base_nonce") {
    384          vec.nonce = r.ReadHex();
    385        } else if (n == "encryptions") {
    386          vec.encryptions = HpkeEncryptVector::ReadVec(r);
    387        } else if (n == "exports") {
    388          vec.exports = HpkeExportVector::ReadVec(r);
    389        } else {
    390          r.SkipValue();
    391        }
    392      }
    393 
    394      if (fields != (1 << static_cast<uint32_t>(RequiredFields::all)) - 1) {
    395        std::cerr << "Skipping entry " << test_id << " for missing fields"
    396                  << std::endl;
    397        continue;
    398      }
    399      // Skip modes and configurations we don't support.
    400      if (vec.mode != HpkeModeBase && vec.mode != HpkeModePsk) {
    401        continue;
    402      }
    403      SECStatus rv =
    404          PK11_HPKE_ValidateParameters(vec.kem_id, vec.kdf_id, vec.aead_id);
    405      if (rv != SECSuccess) {
    406        continue;
    407      }
    408 
    409      vec.test_id = test_id;
    410      vec.pkcs8_e = HpkeVector::Pkcs8(sk_e, pk_e);
    411      vec.pkcs8_r = HpkeVector::Pkcs8(sk_r, pk_r);
    412      all_tests.push_back(vec);
    413    }
    414 
    415    return all_tests;
    416  }
    417 };
    418 
    419 class TestVectors : public HpkeTest, public ::testing::Test {
    420  struct Endpoint {
    421    bool init(const HpkeVector &vec, const std::vector<uint8_t> &sk_data) {
    422      ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    423      if (!slot) {
    424        ADD_FAILURE() << "No slot";
    425        return false;
    426      }
    427 
    428      cx_ = Endpoint::MakeContext(slot, vec);
    429 
    430      SECItem item = {siBuffer, toUcharPtr(sk_data.data()),
    431                      static_cast<unsigned int>(sk_data.size())};
    432      SECKEYPrivateKey *sk = nullptr;
    433      SECStatus rv = PK11_ImportDERPrivateKeyInfoAndReturnKey(
    434          slot.get(), &item, nullptr, nullptr, false, false, KU_ALL, &sk,
    435          nullptr);
    436      if (rv != SECSuccess) {
    437        ADD_FAILURE() << "Failed to import secret";
    438        return false;
    439      }
    440      sk_.reset(sk);
    441      SECKEYPublicKey *pk = SECKEY_ConvertToPublicKey(sk_.get());
    442      pk_.reset(pk);
    443      return cx_ && sk_ && pk_;
    444    }
    445 
    446    static ScopedHpkeContext MakeContext(const ScopedPK11SlotInfo &slot,
    447                                         const HpkeVector &vec) {
    448      ScopedPK11SymKey psk = Endpoint::ReadPsk(slot, vec);
    449      SECItem psk_id_item = {siBuffer, toUcharPtr(vec.psk_id.data()),
    450                             static_cast<unsigned int>(vec.psk_id.size())};
    451      SECItem *psk_id = psk ? &psk_id_item : nullptr;
    452      return ScopedHpkeContext(PK11_HPKE_NewContext(
    453          vec.kem_id, vec.kdf_id, vec.aead_id, psk.get(), psk_id));
    454    }
    455 
    456    static ScopedPK11SymKey ReadPsk(const ScopedPK11SlotInfo &slot,
    457                                    const HpkeVector &vec) {
    458      ScopedPK11SymKey psk;
    459      if (!vec.psk.empty()) {
    460        SECItem psk_item = {siBuffer, toUcharPtr(vec.psk.data()),
    461                            static_cast<unsigned int>(vec.psk.size())};
    462        PK11SymKey *psk_key =
    463            PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
    464                              CKA_WRAP, &psk_item, nullptr);
    465        EXPECT_NE(nullptr, psk_key);
    466        psk.reset(psk_key);
    467      }
    468      return psk;
    469    }
    470 
    471    ScopedHpkeContext cx_;
    472    ScopedSECKEYPublicKey pk_;
    473    ScopedSECKEYPrivateKey sk_;
    474  };
    475 
    476 protected:
    477  void TestExports(const HpkeVector &vec, const Endpoint &sender,
    478                   const Endpoint &receiver) {
    479    for (auto &exp : vec.exports) {
    480      SECItem context_item = {siBuffer, toUcharPtr(exp.ctxt.data()),
    481                              static_cast<unsigned int>(exp.ctxt.size())};
    482      PK11SymKey *actual_r = nullptr;
    483      PK11SymKey *actual_s = nullptr;
    484      ASSERT_EQ(SECSuccess,
    485                PK11_HPKE_ExportSecret(sender.cx_.get(), &context_item, exp.len,
    486                                       &actual_s));
    487      ASSERT_EQ(SECSuccess,
    488                PK11_HPKE_ExportSecret(receiver.cx_.get(), &context_item,
    489                                       exp.len, &actual_r));
    490      ScopedPK11SymKey scoped_act_s(actual_s);
    491      ScopedPK11SymKey scoped_act_r(actual_r);
    492      CheckEquality(exp.exported, scoped_act_s.get());
    493      CheckEquality(exp.exported, scoped_act_r.get());
    494    }
    495  }
    496 
    497  void TestEncryptions(const HpkeVector &vec, const Endpoint &sender,
    498                       const Endpoint &receiver) {
    499    for (auto &enc : vec.encryptions) {
    500      SealOpen(sender.cx_, receiver.cx_, enc.pt, enc.aad, &enc.ct);
    501    }
    502  }
    503 
    504  void SetupS(const ScopedHpkeContext &cx, const ScopedSECKEYPublicKey &pkE,
    505              const ScopedSECKEYPrivateKey &skE,
    506              const ScopedSECKEYPublicKey &pkR,
    507              const std::vector<uint8_t> &info) {
    508    SECItem info_item = {siBuffer, toUcharPtr(info.data()),
    509                         static_cast<unsigned int>(info.size())};
    510    EXPECT_EQ(SECSuccess, PK11_HPKE_SetupS(cx.get(), pkE.get(), skE.get(),
    511                                           pkR.get(), &info_item));
    512  }
    513 
    514  void SetupR(const ScopedHpkeContext &cx, const ScopedSECKEYPublicKey &pkR,
    515              const ScopedSECKEYPrivateKey &skR,
    516              const std::vector<uint8_t> &enc,
    517              const std::vector<uint8_t> &info) {
    518    SECItem enc_item = {siBuffer, toUcharPtr(enc.data()),
    519                        static_cast<unsigned int>(enc.size())};
    520    SECItem info_item = {siBuffer, toUcharPtr(info.data()),
    521                         static_cast<unsigned int>(info.size())};
    522    EXPECT_EQ(SECSuccess, PK11_HPKE_SetupR(cx.get(), pkR.get(), skR.get(),
    523                                           &enc_item, &info_item));
    524  }
    525 
    526  void SetupSenderReceiver(const HpkeVector &vec, const Endpoint &sender,
    527                           const Endpoint &receiver) {
    528    SetupS(sender.cx_, sender.pk_, sender.sk_, receiver.pk_, vec.info);
    529    uint8_t buf[32];  // Curve25519 only, fixed size.
    530    SECItem encap_item = {siBuffer, const_cast<uint8_t *>(buf), sizeof(buf)};
    531    ASSERT_EQ(SECSuccess, PK11_HPKE_Serialize(sender.pk_.get(), encap_item.data,
    532                                              &encap_item.len, encap_item.len));
    533    CheckEquality(vec.enc, &encap_item);
    534    SetupR(receiver.cx_, receiver.pk_, receiver.sk_, vec.enc, vec.info);
    535  }
    536 
    537  void RunTestVector(const HpkeVector &vec) {
    538    Endpoint sender;
    539    ASSERT_TRUE(sender.init(vec, vec.pkcs8_e));
    540    Endpoint receiver;
    541    ASSERT_TRUE(receiver.init(vec, vec.pkcs8_r));
    542 
    543    SetupSenderReceiver(vec, sender, receiver);
    544    TestEncryptions(vec, sender, receiver);
    545    TestExports(vec, sender, receiver);
    546  }
    547 };
    548 
    549 TEST_F(TestVectors, HpkeVectors) {
    550  JsonReader r(::g_source_dir + "/hpke-vectors.json");
    551  auto all_tests = HpkeVector::Read(r);
    552  for (auto &vec : all_tests) {
    553    std::cout << "HPKE vector " << vec.test_id << std::endl;
    554    RunTestVector(vec);
    555  }
    556 }
    557 
    558 class ModeParameterizedTest
    559    : public HpkeTest,
    560      public ::testing::TestWithParam<
    561          std::tuple<HpkeModeId, HpkeKemId, HpkeKdfId, HpkeAeadId>> {};
    562 
    563 static const HpkeModeId kHpkeModesAll[] = {HpkeModeBase, HpkeModePsk};
    564 static const HpkeKemId kHpkeKemIdsAll[] = {HpkeDhKemX25519Sha256};
    565 static const HpkeKdfId kHpkeKdfIdsAll[] = {HpkeKdfHkdfSha256, HpkeKdfHkdfSha384,
    566                                           HpkeKdfHkdfSha512};
    567 static const HpkeAeadId kHpkeAeadIdsAll[] = {HpkeAeadAes128Gcm,
    568                                             HpkeAeadChaCha20Poly1305};
    569 
    570 INSTANTIATE_TEST_SUITE_P(
    571    Pk11Hpke, ModeParameterizedTest,
    572    ::testing::Combine(::testing::ValuesIn(kHpkeModesAll),
    573                       ::testing::ValuesIn(kHpkeKemIdsAll),
    574                       ::testing::ValuesIn(kHpkeKdfIdsAll),
    575                       ::testing::ValuesIn(kHpkeAeadIdsAll)));
    576 
    577 TEST_F(ModeParameterizedTest, BadEncapsulatedPubKey) {
    578  ScopedHpkeContext sender(
    579      PK11_HPKE_NewContext(HpkeDhKemX25519Sha256, HpkeKdfHkdfSha256,
    580                           HpkeAeadAes128Gcm, nullptr, nullptr));
    581  ScopedHpkeContext receiver(
    582      PK11_HPKE_NewContext(HpkeDhKemX25519Sha256, HpkeKdfHkdfSha256,
    583                           HpkeAeadAes128Gcm, nullptr, nullptr));
    584 
    585  SECItem empty = {siBuffer, nullptr, 0};
    586  uint8_t buf[100];
    587  SECItem short_encap = {siBuffer, buf, 1};
    588  SECItem long_encap = {siBuffer, buf, sizeof(buf)};
    589 
    590  SECKEYPublicKey *tmp_pub_key;
    591  ScopedSECKEYPublicKey pub_key;
    592  ScopedSECKEYPrivateKey priv_key;
    593  ASSERT_TRUE(GenerateKeyPair(pub_key, priv_key));
    594 
    595  // Decapsulating an empty buffer should fail.
    596  EXPECT_EQ(SECFailure, PK11_HPKE_Deserialize(sender.get(), empty.data,
    597                                              empty.len, &tmp_pub_key));
    598  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    599 
    600  // Decapsulating anything short will succeed, but the setup will fail.
    601  EXPECT_EQ(SECSuccess, PK11_HPKE_Deserialize(sender.get(), short_encap.data,
    602                                              short_encap.len, &tmp_pub_key));
    603  ScopedSECKEYPublicKey bad_pub_key(tmp_pub_key);
    604 
    605  EXPECT_EQ(SECFailure,
    606            PK11_HPKE_SetupS(receiver.get(), pub_key.get(), priv_key.get(),
    607                             bad_pub_key.get(), &empty));
    608  EXPECT_EQ(SEC_ERROR_INVALID_KEY, PORT_GetError());
    609 
    610  // Test the same for a receiver.
    611  EXPECT_EQ(SECFailure, PK11_HPKE_SetupR(sender.get(), pub_key.get(),
    612                                         priv_key.get(), &empty, &empty));
    613  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    614  EXPECT_EQ(SECFailure, PK11_HPKE_SetupR(sender.get(), pub_key.get(),
    615                                         priv_key.get(), &short_encap, &empty));
    616  EXPECT_EQ(SEC_ERROR_INVALID_KEY, PORT_GetError());
    617 
    618  // Encapsulated key too long
    619  EXPECT_EQ(SECSuccess, PK11_HPKE_Deserialize(sender.get(), long_encap.data,
    620                                              long_encap.len, &tmp_pub_key));
    621  bad_pub_key.reset(tmp_pub_key);
    622  EXPECT_EQ(SECFailure,
    623            PK11_HPKE_SetupS(receiver.get(), pub_key.get(), priv_key.get(),
    624                             bad_pub_key.get(), &empty));
    625  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    626 
    627  EXPECT_EQ(SECFailure, PK11_HPKE_SetupR(sender.get(), pub_key.get(),
    628                                         priv_key.get(), &long_encap, &empty));
    629  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    630 }
    631 
    632 TEST_P(ModeParameterizedTest, ContextExportImportEncrypt) {
    633  std::vector<uint8_t> msg = {'s', 'e', 'c', 'r', 'e', 't'};
    634  std::vector<uint8_t> aad = {'a', 'a', 'd'};
    635 
    636  ScopedHpkeContext sender;
    637  ScopedHpkeContext receiver;
    638  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    639                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    640                         std::get<3>(GetParam()));
    641  SealOpen(sender, receiver, msg, aad, nullptr);
    642  ExportImportRecvContext(receiver, nullptr);
    643  SealOpen(sender, receiver, msg, aad, nullptr);
    644 }
    645 
    646 TEST_P(ModeParameterizedTest, ContextExportImportExport) {
    647  ScopedHpkeContext sender;
    648  ScopedHpkeContext receiver;
    649  ScopedPK11SymKey sender_export;
    650  ScopedPK11SymKey receiver_export;
    651  ScopedPK11SymKey receiver_reexport;
    652  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    653                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    654                         std::get<3>(GetParam()));
    655  ExportSecret(sender, sender_export);
    656  ExportSecret(receiver, receiver_export);
    657  CheckEquality(sender_export.get(), receiver_export.get());
    658  ExportImportRecvContext(receiver, nullptr);
    659  ExportSecret(receiver, receiver_reexport);
    660  CheckEquality(receiver_export.get(), receiver_reexport.get());
    661 }
    662 
    663 TEST_P(ModeParameterizedTest, ContextExportImportWithWrap) {
    664  std::vector<uint8_t> msg = {'s', 'e', 'c', 'r', 'e', 't'};
    665  std::vector<uint8_t> aad = {'a', 'a', 'd'};
    666 
    667  // Generate a wrapping key, then use it for export.
    668  ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    669  ASSERT_TRUE(slot);
    670  ScopedPK11SymKey kek(
    671      PK11_KeyGen(slot.get(), CKM_AES_CBC, nullptr, 16, nullptr));
    672  ASSERT_NE(nullptr, kek);
    673 
    674  ScopedHpkeContext sender;
    675  ScopedHpkeContext receiver;
    676  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    677                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    678                         std::get<3>(GetParam()));
    679  SealOpen(sender, receiver, msg, aad, nullptr);
    680  ExportImportRecvContext(receiver, kek.get());
    681  SealOpen(sender, receiver, msg, aad, nullptr);
    682 }
    683 
    684 TEST_P(ModeParameterizedTest, ExportSenderContext) {
    685  std::vector<uint8_t> msg = {'s', 'e', 'c', 'r', 'e', 't'};
    686  std::vector<uint8_t> aad = {'a', 'a', 'd'};
    687 
    688  ScopedHpkeContext sender;
    689  ScopedHpkeContext receiver;
    690  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    691                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    692                         std::get<3>(GetParam()));
    693 
    694  SECItem *tmp_exported = nullptr;
    695  EXPECT_EQ(SECFailure,
    696            PK11_HPKE_ExportContext(sender.get(), nullptr, &tmp_exported));
    697  EXPECT_EQ(nullptr, tmp_exported);
    698  EXPECT_EQ(SEC_ERROR_NOT_A_RECIPIENT, PORT_GetError());
    699 }
    700 
    701 TEST_P(ModeParameterizedTest, ContextUnwrapBadKey) {
    702  std::vector<uint8_t> msg = {'s', 'e', 'c', 'r', 'e', 't'};
    703  std::vector<uint8_t> aad = {'a', 'a', 'd'};
    704 
    705  // Generate a wrapping key, then use it for export.
    706  ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    707  ASSERT_TRUE(slot);
    708  ScopedPK11SymKey kek(
    709      PK11_KeyGen(slot.get(), CKM_AES_CBC, nullptr, 16, nullptr));
    710  ASSERT_NE(nullptr, kek);
    711  ScopedPK11SymKey not_kek(
    712      PK11_KeyGen(slot.get(), CKM_AES_CBC, nullptr, 16, nullptr));
    713  ASSERT_NE(nullptr, not_kek);
    714  ScopedHpkeContext sender;
    715  ScopedHpkeContext receiver;
    716 
    717  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    718                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    719                         std::get<3>(GetParam()));
    720 
    721  SECItem *tmp_exported = nullptr;
    722  EXPECT_EQ(SECSuccess,
    723            PK11_HPKE_ExportContext(receiver.get(), kek.get(), &tmp_exported));
    724  EXPECT_NE(nullptr, tmp_exported);
    725  ScopedSECItem context(tmp_exported);
    726 
    727  EXPECT_EQ(nullptr, PK11_HPKE_ImportContext(context.get(), not_kek.get()));
    728  EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
    729 }
    730 
    731 TEST_P(ModeParameterizedTest, EphemeralKeys) {
    732  std::vector<uint8_t> msg = {'s', 'e', 'c', 'r', 'e', 't'};
    733  std::vector<uint8_t> aad = {'a', 'a', 'd'};
    734  SECItem msg_item = {siBuffer, msg.data(),
    735                      static_cast<unsigned int>(msg.size())};
    736  SECItem aad_item = {siBuffer, aad.data(),
    737                      static_cast<unsigned int>(aad.size())};
    738  ScopedHpkeContext sender;
    739  ScopedHpkeContext receiver;
    740  SetUpEphemeralContexts(sender, receiver, std::get<0>(GetParam()),
    741                         std::get<1>(GetParam()), std::get<2>(GetParam()),
    742                         std::get<3>(GetParam()));
    743 
    744  SealOpen(sender, receiver, msg, aad, nullptr);
    745 
    746  // Seal for negative tests
    747  SECItem *tmp_sealed = nullptr;
    748  SECItem *tmp_unsealed = nullptr;
    749  EXPECT_EQ(SECSuccess,
    750            PK11_HPKE_Seal(sender.get(), &aad_item, &msg_item, &tmp_sealed));
    751  ASSERT_NE(nullptr, tmp_sealed);
    752  ScopedSECItem sealed(tmp_sealed);
    753 
    754  // Drop AAD
    755  EXPECT_EQ(SECFailure, PK11_HPKE_Open(receiver.get(), nullptr, sealed.get(),
    756                                       &tmp_unsealed));
    757  EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
    758  EXPECT_EQ(nullptr, tmp_unsealed);
    759 
    760  // Modify AAD
    761  aad_item.data[0] ^= 0xff;
    762  EXPECT_EQ(SECFailure, PK11_HPKE_Open(receiver.get(), &aad_item, sealed.get(),
    763                                       &tmp_unsealed));
    764  EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
    765  EXPECT_EQ(nullptr, tmp_unsealed);
    766  aad_item.data[0] ^= 0xff;
    767 
    768  // Modify ciphertext
    769  sealed->data[0] ^= 0xff;
    770  EXPECT_EQ(SECFailure, PK11_HPKE_Open(receiver.get(), &aad_item, sealed.get(),
    771                                       &tmp_unsealed));
    772  EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
    773  EXPECT_EQ(nullptr, tmp_unsealed);
    774  sealed->data[0] ^= 0xff;
    775 
    776  EXPECT_EQ(SECSuccess, PK11_HPKE_Open(receiver.get(), &aad_item, sealed.get(),
    777                                       &tmp_unsealed));
    778  EXPECT_NE(nullptr, tmp_unsealed);
    779  ScopedSECItem unsealed(tmp_unsealed);
    780  CheckEquality(&msg_item, unsealed.get());
    781 }
    782 
    783 TEST_F(ModeParameterizedTest, InvalidContextParams) {
    784  HpkeContext *cx =
    785      PK11_HPKE_NewContext(static_cast<HpkeKemId>(0xff), HpkeKdfHkdfSha256,
    786                           HpkeAeadChaCha20Poly1305, nullptr, nullptr);
    787  EXPECT_EQ(nullptr, cx);
    788  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    789 
    790  cx = PK11_HPKE_NewContext(HpkeDhKemX25519Sha256, static_cast<HpkeKdfId>(0xff),
    791                            HpkeAeadChaCha20Poly1305, nullptr, nullptr);
    792  EXPECT_EQ(nullptr, cx);
    793  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    794  cx = PK11_HPKE_NewContext(HpkeDhKemX25519Sha256, HpkeKdfHkdfSha256,
    795                            static_cast<HpkeAeadId>(0xff), nullptr, nullptr);
    796  EXPECT_EQ(nullptr, cx);
    797  EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
    798 }
    799 
    800 TEST_F(ModeParameterizedTest, InvalidReceiverKeyType) {
    801  ScopedHpkeContext sender(
    802      PK11_HPKE_NewContext(HpkeDhKemX25519Sha256, HpkeKdfHkdfSha256,
    803                           HpkeAeadChaCha20Poly1305, nullptr, nullptr));
    804  ASSERT_TRUE(!!sender);
    805 
    806  ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
    807  if (!slot) {
    808    ADD_FAILURE() << "No slot";
    809    return;
    810  }
    811 
    812  // Give the client an RSA key
    813  PK11RSAGenParams rsa_param;
    814  rsa_param.keySizeInBits = 1024;
    815  rsa_param.pe = 65537L;
    816  SECKEYPublicKey *pub_tmp;
    817  ScopedSECKEYPublicKey pub_key;
    818  ScopedSECKEYPrivateKey priv_key(
    819      PK11_GenerateKeyPair(slot.get(), CKM_RSA_PKCS_KEY_PAIR_GEN, &rsa_param,
    820                           &pub_tmp, PR_FALSE, PR_FALSE, nullptr));
    821  ASSERT_NE(nullptr, priv_key);
    822  ASSERT_NE(nullptr, pub_tmp);
    823  pub_key.reset(pub_tmp);
    824 
    825  SECItem info_item = {siBuffer, nullptr, 0};
    826  EXPECT_EQ(SECFailure, PK11_HPKE_SetupS(sender.get(), nullptr, nullptr,
    827                                         pub_key.get(), &info_item));
    828  EXPECT_EQ(SEC_ERROR_BAD_KEY, PORT_GetError());
    829 
    830  // Try with an unexpected curve
    831  StackSECItem ecParams;
    832  SECOidData *oidData = SECOID_FindOIDByTag(SEC_OID_ANSIX962_EC_PRIME256V1);
    833  ASSERT_NE(oidData, nullptr);
    834  if (!SECITEM_AllocItem(nullptr, &ecParams, (2 + oidData->oid.len))) {
    835    FAIL() << "Couldn't allocate memory for OID.";
    836  }
    837  ecParams.data[0] = SEC_ASN1_OBJECT_ID;
    838  ecParams.data[1] = oidData->oid.len;
    839  memcpy(ecParams.data + 2, oidData->oid.data, oidData->oid.len);
    840 
    841  priv_key.reset(PK11_GenerateKeyPair(slot.get(), CKM_EC_KEY_PAIR_GEN,
    842                                      &ecParams, &pub_tmp, PR_FALSE, PR_FALSE,
    843                                      nullptr));
    844  ASSERT_NE(nullptr, priv_key);
    845  ASSERT_NE(nullptr, pub_tmp);
    846  pub_key.reset(pub_tmp);
    847  EXPECT_EQ(SECFailure, PK11_HPKE_SetupS(sender.get(), nullptr, nullptr,
    848                                         pub_key.get(), &info_item));
    849  EXPECT_EQ(SEC_ERROR_BAD_KEY, PORT_GetError());
    850 }
    851 
    852 }  // namespace nss_test