tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

kyber-pqcrystals-ref.c (90148B)


      1 /*
      2 * SPDX-License-Identifier: Apache-2.0
      3 *
      4 * This file was generated from
      5 *   https://github.com/pq-crystals/kyber/commit/e0d1c6ff
      6 *
      7 * Files from that repository are listed here surrounded by
      8 * "* begin: [file] *" and "* end: [file] *" comments.
      9 *
     10 * The following changes have been made:
     11 *  - include guards have been removed,
     12 *  - include directives have been removed,
     13 *  - "#ifdef KYBER90S" blocks have been evaluated with "KYBER90S" undefined,
     14 *  - functions outside of kem.c have been made static.
     15 */
     16 
     17 /** begin: ref/LICENSE **
     18 Public Domain (https://creativecommons.org/share-your-work/public-domain/cc0/);
     19 or Apache 2.0 License (https://www.apache.org/licenses/LICENSE-2.0.html).
     20 
     21 For Keccak and AES we are using public-domain
     22 code from sources and by authors listed in
     23 comments on top of the respective files.
     24 ** end: ref/LICENSE **/
     25 
     26 /** begin: ref/AUTHORS **
     27 Joppe Bos,
     28 Léo Ducas,
     29 Eike Kiltz,
     30 Tancrède Lepoint,
     31 Vadim Lyubashevsky,
     32 John Schanck,
     33 Peter Schwabe,
     34 Gregor Seiler,
     35 Damien Stehlé
     36 ** end: ref/AUTHORS **/
     37 
     38 #include <assert.h>
     39 #include <stddef.h>
     40 #include <stdint.h>
     41 #include <string.h>
     42 
     43 #ifdef FREEBL_NO_DEPEND
     44 #include "stubs.h"
     45 #endif
     46 
     47 #include "secport.h"
     48 
     49 // We need to provide an implementation of randombytes to avoid an unused
     50 // function warning. We don't use the randomized API in freebl, so we'll make
     51 // calling randombytes an error.
     52 static void
     53 randombytes(uint8_t *out, size_t outlen)
     54 {
     55    // this memset is to avoid "maybe-uninitialized" warnings that gcc-11 issues
     56    // for the (unused) crypto_kem_keypair and crypto_kem_enc functions.
     57    memset(out, 0, outlen);
     58    assert(0);
     59 }
     60 
     61 /*************************************************
     62 * Name:        verify
     63 *
     64 * Description: Compare two arrays for equality in constant time.
     65 *
     66 * Arguments:   const uint8_t *a: pointer to first byte array
     67 *              const uint8_t *b: pointer to second byte array
     68 *              size_t len:       length of the byte arrays
     69 *
     70 * Returns 0 if the byte arrays are equal, 1 otherwise
     71 **************************************************/
     72 static int
     73 verify(const uint8_t *a, const uint8_t *b, size_t len)
     74 {
     75    return NSS_SecureMemcmp(a, b, len);
     76 }
     77 
     78 /*************************************************
     79 * Name:        cmov
     80 *
     81 * Description: Copy len bytes from x to r if b is 1;
     82 *              don't modify x if b is 0. Requires b to be in {0,1};
     83 *              assumes two's complement representation of negative integers.
     84 *              Runs in constant time.
     85 *
     86 * Arguments:   uint8_t *r:       pointer to output byte array
     87 *              const uint8_t *x: pointer to input byte array
     88 *              size_t len:       Amount of bytes to be copied
     89 *              uint8_t b:        Condition bit; has to be in {0,1}
     90 **************************************************/
     91 static void
     92 cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b)
     93 {
     94    NSS_SecureSelect(r, r, x, len, b);
     95 }
     96 
     97 /** begin: ref/params.h **/
     98 #ifndef KYBER_K
     99 #define KYBER_K 3 /* Change this for different security strengths */
    100 #endif
    101 
    102 //#define KYBER_90S	/* Uncomment this if you want the 90S variant */
    103 
    104 /* Don't change parameters below this line */
    105 #if (KYBER_K == 2)
    106 #define KYBER_NAMESPACE(s) pqcrystals_kyber512_ref_##s
    107 #elif (KYBER_K == 3)
    108 #define KYBER_NAMESPACE(s) pqcrystals_kyber768_ref_##s
    109 #elif (KYBER_K == 4)
    110 #define KYBER_NAMESPACE(s) pqcrystals_kyber1024_ref_##s
    111 #else
    112 #error "KYBER_K must be in {2,3,4}"
    113 #endif
    114 
    115 #define KYBER_N 256
    116 #define KYBER_Q 3329
    117 
    118 #define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */
    119 #define KYBER_SSBYTES 32  /* size in bytes of shared key */
    120 
    121 #define KYBER_POLYBYTES 384
    122 #define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES)
    123 
    124 #if KYBER_K == 2
    125 #define KYBER_ETA1 3
    126 #define KYBER_POLYCOMPRESSEDBYTES 128
    127 #define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320)
    128 #elif KYBER_K == 3
    129 #define KYBER_ETA1 2
    130 #define KYBER_POLYCOMPRESSEDBYTES 128
    131 #define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320)
    132 #elif KYBER_K == 4
    133 #define KYBER_ETA1 2
    134 #define KYBER_POLYCOMPRESSEDBYTES 160
    135 #define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352)
    136 #endif
    137 
    138 #define KYBER_ETA2 2
    139 
    140 #define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES)
    141 #define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES)
    142 #define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES)
    143 #define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES)
    144 
    145 #define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES)
    146 /* 32 bytes of additional space to save H(pk) */
    147 #define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2 * KYBER_SYMBYTES)
    148 #define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES)
    149 /** end: ref/params.h **/
    150 
    151 /** begin: ref/reduce.h **/
    152 #define MONT -1044 // 2^16 mod q
    153 #define QINV -3327 // q^-1 mod 2^16
    154 
    155 #define montgomery_reduce KYBER_NAMESPACE(montgomery_reduce)
    156 static int16_t montgomery_reduce(int32_t a);
    157 
    158 #define barrett_reduce KYBER_NAMESPACE(barrett_reduce)
    159 static int16_t barrett_reduce(int16_t a);
    160 /** end: ref/reduce.h **/
    161 
    162 /** begin: ref/ntt.h **/
    163 #define zetas KYBER_NAMESPACE(zetas)
    164 extern const int16_t zetas[128];
    165 
    166 #define ntt KYBER_NAMESPACE(ntt)
    167 static void ntt(int16_t poly[256]);
    168 
    169 #define invntt KYBER_NAMESPACE(invntt)
    170 static void invntt(int16_t poly[256]);
    171 
    172 #define basemul KYBER_NAMESPACE(basemul)
    173 static void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta);
    174 /** end: ref/ntt.h **/
    175 
    176 /** begin: ref/poly.h **/
    177 /*
    178 * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial
    179 * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1]
    180 */
    181 typedef struct {
    182    int16_t coeffs[KYBER_N];
    183 } poly;
    184 
    185 #define poly_compress KYBER_NAMESPACE(poly_compress)
    186 static void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a);
    187 #define poly_decompress KYBER_NAMESPACE(poly_decompress)
    188 static void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]);
    189 
    190 #define poly_tobytes KYBER_NAMESPACE(poly_tobytes)
    191 static void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a);
    192 #define poly_frombytes KYBER_NAMESPACE(poly_frombytes)
    193 static void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]);
    194 
    195 #define poly_frommsg KYBER_NAMESPACE(poly_frommsg)
    196 static void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]);
    197 #define poly_tomsg KYBER_NAMESPACE(poly_tomsg)
    198 static void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r);
    199 
    200 #define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1)
    201 static void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce);
    202 
    203 #define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2)
    204 static void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce);
    205 
    206 #define poly_ntt KYBER_NAMESPACE(poly_ntt)
    207 static void poly_ntt(poly *r);
    208 #define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont)
    209 static void poly_invntt_tomont(poly *r);
    210 #define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery)
    211 static void poly_basemul_montgomery(poly *r, const poly *a, const poly *b);
    212 #define poly_tomont KYBER_NAMESPACE(poly_tomont)
    213 static void poly_tomont(poly *r);
    214 
    215 #define poly_reduce KYBER_NAMESPACE(poly_reduce)
    216 static void poly_reduce(poly *r);
    217 
    218 #define poly_add KYBER_NAMESPACE(poly_add)
    219 static void poly_add(poly *r, const poly *a, const poly *b);
    220 #define poly_sub KYBER_NAMESPACE(poly_sub)
    221 static void poly_sub(poly *r, const poly *a, const poly *b);
    222 /** end: ref/poly.h **/
    223 
    224 /** begin: ref/cbd.h **/
    225 #define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1)
    226 static void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1 * KYBER_N / 4]);
    227 
    228 #define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2)
    229 static void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2 * KYBER_N / 4]);
    230 /** end: ref/cbd.h **/
    231 
    232 /** begin: ref/polyvec.h **/
    233 typedef struct {
    234    poly vec[KYBER_K];
    235 } polyvec;
    236 
    237 #define polyvec_compress KYBER_NAMESPACE(polyvec_compress)
    238 static void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a);
    239 #define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress)
    240 static void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]);
    241 
    242 #define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes)
    243 static void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a);
    244 #define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes)
    245 static void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]);
    246 
    247 #define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt)
    248 static void polyvec_ntt(polyvec *r);
    249 #define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont)
    250 static void polyvec_invntt_tomont(polyvec *r);
    251 
    252 #define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery)
    253 static void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b);
    254 
    255 #define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce)
    256 static void polyvec_reduce(polyvec *r);
    257 
    258 #define polyvec_add KYBER_NAMESPACE(polyvec_add)
    259 static void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b);
    260 /** end: ref/polyvec.h **/
    261 
    262 /** begin: ref/indcpa.h **/
    263 #define gen_matrix KYBER_NAMESPACE(gen_matrix)
    264 static void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed);
    265 
    266 #define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand)
    267 static void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
    268                                  uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES],
    269                                  const uint8_t coins[KYBER_SYMBYTES]);
    270 
    271 #define indcpa_enc KYBER_NAMESPACE(indcpa_enc)
    272 static void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
    273                       const uint8_t m[KYBER_INDCPA_MSGBYTES],
    274                       const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
    275                       const uint8_t coins[KYBER_SYMBYTES]);
    276 
    277 #define indcpa_dec KYBER_NAMESPACE(indcpa_dec)
    278 static void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
    279                       const uint8_t c[KYBER_INDCPA_BYTES],
    280                       const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]);
    281 /** end: ref/indcpa.h **/
    282 
    283 /** begin: ref/fips202.h **/
    284 #define SHAKE128_RATE 168
    285 #define SHAKE256_RATE 136
    286 #define SHA3_256_RATE 136
    287 #define SHA3_512_RATE 72
    288 
    289 #define FIPS202_NAMESPACE(s) pqcrystals_kyber_fips202_ref_##s
    290 
    291 typedef struct {
    292    uint64_t s[25];
    293    unsigned int pos;
    294 } keccak_state;
    295 
    296 #define shake128_init FIPS202_NAMESPACE(shake128_init)
    297 void shake128_init(keccak_state *state);
    298 #define shake128_absorb FIPS202_NAMESPACE(shake128_absorb)
    299 void shake128_absorb(keccak_state *state, const uint8_t *in, size_t inlen);
    300 #define shake128_finalize FIPS202_NAMESPACE(shake128_finalize)
    301 void shake128_finalize(keccak_state *state);
    302 #define shake128_squeeze FIPS202_NAMESPACE(shake128_squeeze)
    303 void shake128_squeeze(uint8_t *out, size_t outlen, keccak_state *state);
    304 #define shake128_absorb_once FIPS202_NAMESPACE(shake128_absorb_once)
    305 void shake128_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen);
    306 #define shake128_squeezeblocks FIPS202_NAMESPACE(shake128_squeezeblocks)
    307 void shake128_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state);
    308 
    309 #define shake256_init FIPS202_NAMESPACE(shake256_init)
    310 void shake256_init(keccak_state *state);
    311 #define shake256_absorb FIPS202_NAMESPACE(shake256_absorb)
    312 void shake256_absorb(keccak_state *state, const uint8_t *in, size_t inlen);
    313 #define shake256_finalize FIPS202_NAMESPACE(shake256_finalize)
    314 void shake256_finalize(keccak_state *state);
    315 #define shake256_squeeze FIPS202_NAMESPACE(shake256_squeeze)
    316 void shake256_squeeze(uint8_t *out, size_t outlen, keccak_state *state);
    317 #define shake256_absorb_once FIPS202_NAMESPACE(shake256_absorb_once)
    318 void shake256_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen);
    319 #define shake256_squeezeblocks FIPS202_NAMESPACE(shake256_squeezeblocks)
    320 void shake256_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state);
    321 
    322 #define shake128 FIPS202_NAMESPACE(shake128)
    323 void shake128(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen);
    324 #define shake256 FIPS202_NAMESPACE(shake256)
    325 void shake256(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen);
    326 #define sha3_256 FIPS202_NAMESPACE(sha3_256)
    327 void sha3_256(uint8_t h[32], const uint8_t *in, size_t inlen);
    328 #define sha3_512 FIPS202_NAMESPACE(sha3_512)
    329 void sha3_512(uint8_t h[64], const uint8_t *in, size_t inlen);
    330 /** end: ref/fips202.h **/
    331 
    332 /** begin: ref/symmetric.h **/
    333 typedef keccak_state xof_state;
    334 
    335 #define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb)
    336 static void kyber_shake128_absorb(keccak_state *s,
    337                                  const uint8_t seed[KYBER_SYMBYTES],
    338                                  uint8_t x,
    339                                  uint8_t y);
    340 
    341 #define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf)
    342 static void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce);
    343 
    344 #define XOF_BLOCKBYTES SHAKE128_RATE
    345 
    346 #define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES)
    347 #define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES)
    348 #define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y)
    349 #define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE)
    350 #define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE)
    351 #define kdf(OUT, IN, INBYTES) shake256(OUT, KYBER_SSBYTES, IN, INBYTES)
    352 /** end: ref/symmetric.h **/
    353 
    354 /** begin: ref/kem.h **/
    355 #define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES
    356 #define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES
    357 #define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES
    358 #define CRYPTO_BYTES KYBER_SSBYTES
    359 
    360 #if (KYBER_K == 2)
    361 #define CRYPTO_ALGNAME "Kyber512"
    362 #elif (KYBER_K == 3)
    363 #define CRYPTO_ALGNAME "Kyber768"
    364 #elif (KYBER_K == 4)
    365 #define CRYPTO_ALGNAME "Kyber1024"
    366 #endif
    367 
    368 #define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand)
    369 int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
    370 
    371 #define crypto_kem_keypair KYBER_NAMESPACE(keypair)
    372 int crypto_kem_keypair(uint8_t *pk, uint8_t *sk);
    373 
    374 #define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand)
    375 int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
    376 
    377 #define crypto_kem_enc KYBER_NAMESPACE(enc)
    378 int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
    379 
    380 #define crypto_kem_dec KYBER_NAMESPACE(dec)
    381 int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);
    382 /** end: ref/kem.h **/
    383 
    384 /** begin: ref/reduce.c **/
    385 /*************************************************
    386 * Name:        montgomery_reduce
    387 *
    388 * Description: Montgomery reduction; given a 32-bit integer a, computes
    389 *              16-bit integer congruent to a * R^-1 mod q, where R=2^16
    390 *
    391 * Arguments:   - int32_t a: input integer to be reduced;
    392 *                           has to be in {-q2^15,...,q2^15-1}
    393 *
    394 * Returns:     integer in {-q+1,...,q-1} congruent to a * R^-1 modulo q.
    395 **************************************************/
    396 static int16_t
    397 montgomery_reduce(int32_t a)
    398 {
    399    int16_t t;
    400 
    401    t = (int16_t)a * QINV;
    402    t = (a - (int32_t)t * KYBER_Q) >> 16;
    403    return t;
    404 }
    405 
    406 /*************************************************
    407 * Name:        barrett_reduce
    408 *
    409 * Description: Barrett reduction; given a 16-bit integer a, computes
    410 *              centered representative congruent to a mod q in {-(q-1)/2,...,(q-1)/2}
    411 *
    412 * Arguments:   - int16_t a: input integer to be reduced
    413 *
    414 * Returns:     integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q.
    415 **************************************************/
    416 static int16_t
    417 barrett_reduce(int16_t a)
    418 {
    419    int16_t t;
    420    const int16_t v = ((1 << 26) + KYBER_Q / 2) / KYBER_Q;
    421 
    422    t = ((int32_t)v * a + (1 << 25)) >> 26;
    423    t *= KYBER_Q;
    424    return a - t;
    425 }
    426 /** end: ref/reduce.c **/
    427 
    428 /** begin: ref/cbd.c **/
    429 /*************************************************
    430 * Name:        load32_littleendian
    431 *
    432 * Description: load 4 bytes into a 32-bit integer
    433 *              in little-endian order
    434 *
    435 * Arguments:   - const uint8_t *x: pointer to input byte array
    436 *
    437 * Returns 32-bit unsigned integer loaded from x
    438 **************************************************/
    439 static uint32_t
    440 load32_littleendian(const uint8_t x[4])
    441 {
    442    uint32_t r;
    443    r = (uint32_t)x[0];
    444    r |= (uint32_t)x[1] << 8;
    445    r |= (uint32_t)x[2] << 16;
    446    r |= (uint32_t)x[3] << 24;
    447    return r;
    448 }
    449 
    450 /*************************************************
    451 * Name:        load24_littleendian
    452 *
    453 * Description: load 3 bytes into a 32-bit integer
    454 *              in little-endian order.
    455 *              This function is only needed for Kyber-512
    456 *
    457 * Arguments:   - const uint8_t *x: pointer to input byte array
    458 *
    459 * Returns 32-bit unsigned integer loaded from x (most significant byte is zero)
    460 **************************************************/
    461 #if KYBER_ETA1 == 3
    462 static uint32_t
    463 load24_littleendian(const uint8_t x[3])
    464 {
    465    uint32_t r;
    466    r = (uint32_t)x[0];
    467    r |= (uint32_t)x[1] << 8;
    468    r |= (uint32_t)x[2] << 16;
    469    return r;
    470 }
    471 #endif
    472 
    473 /*************************************************
    474 * Name:        cbd2
    475 *
    476 * Description: Given an array of uniformly random bytes, compute
    477 *              polynomial with coefficients distributed according to
    478 *              a centered binomial distribution with parameter eta=2
    479 *
    480 * Arguments:   - poly *r: pointer to output polynomial
    481 *              - const uint8_t *buf: pointer to input byte array
    482 **************************************************/
    483 static void
    484 cbd2(poly *r, const uint8_t buf[2 * KYBER_N / 4])
    485 {
    486    unsigned int i, j;
    487    uint32_t t, d;
    488    int16_t a, b;
    489 
    490    for (i = 0; i < KYBER_N / 8; i++) {
    491        t = load32_littleendian(buf + 4 * i);
    492        d = t & 0x55555555;
    493        d += (t >> 1) & 0x55555555;
    494 
    495        for (j = 0; j < 8; j++) {
    496            a = (d >> (4 * j + 0)) & 0x3;
    497            b = (d >> (4 * j + 2)) & 0x3;
    498            r->coeffs[8 * i + j] = a - b;
    499        }
    500    }
    501 }
    502 
    503 /*************************************************
    504 * Name:        cbd3
    505 *
    506 * Description: Given an array of uniformly random bytes, compute
    507 *              polynomial with coefficients distributed according to
    508 *              a centered binomial distribution with parameter eta=3.
    509 *              This function is only needed for Kyber-512
    510 *
    511 * Arguments:   - poly *r: pointer to output polynomial
    512 *              - const uint8_t *buf: pointer to input byte array
    513 **************************************************/
    514 #if KYBER_ETA1 == 3
    515 static void
    516 cbd3(poly *r, const uint8_t buf[3 * KYBER_N / 4])
    517 {
    518    unsigned int i, j;
    519    uint32_t t, d;
    520    int16_t a, b;
    521 
    522    for (i = 0; i < KYBER_N / 4; i++) {
    523        t = load24_littleendian(buf + 3 * i);
    524        d = t & 0x00249249;
    525        d += (t >> 1) & 0x00249249;
    526        d += (t >> 2) & 0x00249249;
    527 
    528        for (j = 0; j < 4; j++) {
    529            a = (d >> (6 * j + 0)) & 0x7;
    530            b = (d >> (6 * j + 3)) & 0x7;
    531            r->coeffs[4 * i + j] = a - b;
    532        }
    533    }
    534 }
    535 #endif
    536 
    537 static void
    538 poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1 * KYBER_N / 4])
    539 {
    540 #if KYBER_ETA1 == 2
    541    cbd2(r, buf);
    542 #elif KYBER_ETA1 == 3
    543    cbd3(r, buf);
    544 #else
    545 #error "This implementation requires eta1 in {2,3}"
    546 #endif
    547 }
    548 
    549 static void
    550 poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2 * KYBER_N / 4])
    551 {
    552 #if KYBER_ETA2 == 2
    553    cbd2(r, buf);
    554 #else
    555 #error "This implementation requires eta2 = 2"
    556 #endif
    557 }
    558 /** end: ref/cbd.c **/
    559 
    560 /** begin: ref/ntt.c **/
    561 /* Code to generate zetas and zetas_inv used in the number-theoretic transform:
    562 
    563 #define KYBER_ROOT_OF_UNITY 17
    564 
    565 static const uint8_t tree[128] = {
    566  0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120,
    567  4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124,
    568  2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122,
    569  6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126,
    570  1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121,
    571  5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125,
    572  3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123,
    573  7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127
    574 };
    575 
    576 static void init_ntt() {
    577  unsigned int i;
    578  int16_t tmp[128];
    579 
    580  tmp[0] = MONT;
    581  for(i=1;i<128;i++)
    582    tmp[i] = fqmul(tmp[i-1],MONT*KYBER_ROOT_OF_UNITY % KYBER_Q);
    583 
    584  for(i=0;i<128;i++) {
    585    zetas[i] = tmp[tree[i]];
    586    if(zetas[i] > KYBER_Q/2)
    587      zetas[i] -= KYBER_Q;
    588    if(zetas[i] < -KYBER_Q/2)
    589      zetas[i] += KYBER_Q;
    590  }
    591 }
    592 */
    593 
    594 const int16_t zetas[128] = {
    595    -1044, -758, -359, -1517, 1493, 1422, 287, 202,
    596    -171, 622, 1577, 182, 962, -1202, -1474, 1468,
    597    573, -1325, 264, 383, -829, 1458, -1602, -130,
    598    -681, 1017, 732, 608, -1542, 411, -205, -1571,
    599    1223, 652, -552, 1015, -1293, 1491, -282, -1544,
    600    516, -8, -320, -666, -1618, -1162, 126, 1469,
    601    -853, -90, -271, 830, 107, -1421, -247, -951,
    602    -398, 961, -1508, -725, 448, -1065, 677, -1275,
    603    -1103, 430, 555, 843, -1251, 871, 1550, 105,
    604    422, 587, 177, -235, -291, -460, 1574, 1653,
    605    -246, 778, 1159, -147, -777, 1483, -602, 1119,
    606    -1590, 644, -872, 349, 418, 329, -156, -75,
    607    817, 1097, 603, 610, 1322, -1285, -1465, 384,
    608    -1215, -136, 1218, -1335, -874, 220, -1187, -1659,
    609    -1185, -1530, -1278, 794, -1510, -854, -870, 478,
    610    -108, -308, 996, 991, 958, -1460, 1522, 1628
    611 };
    612 
    613 /*************************************************
    614 * Name:        fqmul
    615 *
    616 * Description: Multiplication followed by Montgomery reduction
    617 *
    618 * Arguments:   - int16_t a: first factor
    619 *              - int16_t b: second factor
    620 *
    621 * Returns 16-bit integer congruent to a*b*R^{-1} mod q
    622 **************************************************/
    623 static int16_t
    624 fqmul(int16_t a, int16_t b)
    625 {
    626    return montgomery_reduce((int32_t)a * b);
    627 }
    628 
    629 /*************************************************
    630 * Name:        ntt
    631 *
    632 * Description: Inplace number-theoretic transform (NTT) in Rq.
    633 *              input is in standard order, output is in bitreversed order
    634 *
    635 * Arguments:   - int16_t r[256]: pointer to input/output vector of elements of Zq
    636 **************************************************/
    637 static void
    638 ntt(int16_t r[256])
    639 {
    640    unsigned int len, start, j, k;
    641    int16_t t, zeta;
    642 
    643    k = 1;
    644    for (len = 128; len >= 2; len >>= 1) {
    645        for (start = 0; start < 256; start = j + len) {
    646            zeta = zetas[k++];
    647            for (j = start; j < start + len; j++) {
    648                t = fqmul(zeta, r[j + len]);
    649                r[j + len] = r[j] - t;
    650                r[j] = r[j] + t;
    651            }
    652        }
    653    }
    654 }
    655 
    656 /*************************************************
    657 * Name:        invntt_tomont
    658 *
    659 * Description: Inplace inverse number-theoretic transform in Rq and
    660 *              multiplication by Montgomery factor 2^16.
    661 *              Input is in bitreversed order, output is in standard order
    662 *
    663 * Arguments:   - int16_t r[256]: pointer to input/output vector of elements of Zq
    664 **************************************************/
    665 static void
    666 invntt(int16_t r[256])
    667 {
    668    unsigned int start, len, j, k;
    669    int16_t t, zeta;
    670    const int16_t f = 1441; // mont^2/128
    671 
    672    k = 127;
    673    for (len = 2; len <= 128; len <<= 1) {
    674        for (start = 0; start < 256; start = j + len) {
    675            zeta = zetas[k--];
    676            for (j = start; j < start + len; j++) {
    677                t = r[j];
    678                r[j] = barrett_reduce(t + r[j + len]);
    679                r[j + len] = r[j + len] - t;
    680                r[j + len] = fqmul(zeta, r[j + len]);
    681            }
    682        }
    683    }
    684 
    685    for (j = 0; j < 256; j++)
    686        r[j] = fqmul(r[j], f);
    687 }
    688 
    689 /*************************************************
    690 * Name:        basemul
    691 *
    692 * Description: Multiplication of polynomials in Zq[X]/(X^2-zeta)
    693 *              used for multiplication of elements in Rq in NTT domain
    694 *
    695 * Arguments:   - int16_t r[2]: pointer to the output polynomial
    696 *              - const int16_t a[2]: pointer to the first factor
    697 *              - const int16_t b[2]: pointer to the second factor
    698 *              - int16_t zeta: integer defining the reduction polynomial
    699 **************************************************/
    700 static void
    701 basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta)
    702 {
    703    r[0] = fqmul(a[1], b[1]);
    704    r[0] = fqmul(r[0], zeta);
    705    r[0] += fqmul(a[0], b[0]);
    706    r[1] = fqmul(a[0], b[1]);
    707    r[1] += fqmul(a[1], b[0]);
    708 }
    709 /** end: ref/ntt.c **/
    710 
    711 /** begin: ref/poly.c **/
    712 /*************************************************
    713 * Name:        poly_compress
    714 *
    715 * Description: Compression and subsequent serialization of a polynomial
    716 *
    717 * Arguments:   - uint8_t *r: pointer to output byte array
    718 *                            (of length KYBER_POLYCOMPRESSEDBYTES)
    719 *              - const poly *a: pointer to input polynomial
    720 **************************************************/
    721 static void
    722 poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a)
    723 {
    724    unsigned int i, j;
    725    int16_t u;
    726    uint32_t d0;
    727    uint8_t t[8];
    728 
    729 #if (KYBER_POLYCOMPRESSEDBYTES == 128)
    730    for (i = 0; i < KYBER_N / 8; i++) {
    731        for (j = 0; j < 8; j++) {
    732            // map to positive standard representatives
    733            u = a->coeffs[8 * i + j];
    734            u += (u >> 15) & KYBER_Q;
    735            /*    t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 15; */
    736            d0 = u << 4;
    737            d0 += 1665;
    738            d0 *= 80635;
    739            d0 >>= 28;
    740            t[j] = d0 & 0xf;
    741        }
    742 
    743        r[0] = t[0] | (t[1] << 4);
    744        r[1] = t[2] | (t[3] << 4);
    745        r[2] = t[4] | (t[5] << 4);
    746        r[3] = t[6] | (t[7] << 4);
    747        r += 4;
    748    }
    749 #elif (KYBER_POLYCOMPRESSEDBYTES == 160)
    750    for (i = 0; i < KYBER_N / 8; i++) {
    751        for (j = 0; j < 8; j++) {
    752            // map to positive standard representatives
    753            u = a->coeffs[8 * i + j];
    754            u += (u >> 15) & KYBER_Q;
    755            /*      t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 31; */
    756            d0 = u << 5;
    757            d0 += 1664;
    758            d0 *= 40318;
    759            d0 >>= 27;
    760            t[j] = d0 & 0x1f;
    761        }
    762 
    763        r[0] = (t[0] >> 0) | (t[1] << 5);
    764        r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
    765        r[2] = (t[3] >> 1) | (t[4] << 4);
    766        r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
    767        r[4] = (t[6] >> 2) | (t[7] << 3);
    768        r += 5;
    769    }
    770 #else
    771 #error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}"
    772 #endif
    773 }
    774 
    775 /*************************************************
    776 * Name:        poly_decompress
    777 *
    778 * Description: De-serialization and subsequent decompression of a polynomial;
    779 *              approximate inverse of poly_compress
    780 *
    781 * Arguments:   - poly *r: pointer to output polynomial
    782 *              - const uint8_t *a: pointer to input byte array
    783 *                                  (of length KYBER_POLYCOMPRESSEDBYTES bytes)
    784 **************************************************/
    785 static void
    786 poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES])
    787 {
    788    unsigned int i;
    789 
    790 #if (KYBER_POLYCOMPRESSEDBYTES == 128)
    791    for (i = 0; i < KYBER_N / 2; i++) {
    792        r->coeffs[2 * i + 0] = (((uint16_t)(a[0] & 15) * KYBER_Q) + 8) >> 4;
    793        r->coeffs[2 * i + 1] = (((uint16_t)(a[0] >> 4) * KYBER_Q) + 8) >> 4;
    794        a += 1;
    795    }
    796 #elif (KYBER_POLYCOMPRESSEDBYTES == 160)
    797    unsigned int j;
    798    uint8_t t[8];
    799    for (i = 0; i < KYBER_N / 8; i++) {
    800        t[0] = (a[0] >> 0);
    801        t[1] = (a[0] >> 5) | (a[1] << 3);
    802        t[2] = (a[1] >> 2);
    803        t[3] = (a[1] >> 7) | (a[2] << 1);
    804        t[4] = (a[2] >> 4) | (a[3] << 4);
    805        t[5] = (a[3] >> 1);
    806        t[6] = (a[3] >> 6) | (a[4] << 2);
    807        t[7] = (a[4] >> 3);
    808        a += 5;
    809 
    810        for (j = 0; j < 8; j++)
    811            r->coeffs[8 * i + j] = ((uint32_t)(t[j] & 31) * KYBER_Q + 16) >> 5;
    812    }
    813 #else
    814 #error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}"
    815 #endif
    816 }
    817 
    818 /*************************************************
    819 * Name:        poly_tobytes
    820 *
    821 * Description: Serialization of a polynomial
    822 *
    823 * Arguments:   - uint8_t *r: pointer to output byte array
    824 *                            (needs space for KYBER_POLYBYTES bytes)
    825 *              - const poly *a: pointer to input polynomial
    826 **************************************************/
    827 static void
    828 poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a)
    829 {
    830    unsigned int i;
    831    uint16_t t0, t1;
    832 
    833    for (i = 0; i < KYBER_N / 2; i++) {
    834        // map to positive standard representatives
    835        t0 = a->coeffs[2 * i];
    836        t0 += ((int16_t)t0 >> 15) & KYBER_Q;
    837        t1 = a->coeffs[2 * i + 1];
    838        t1 += ((int16_t)t1 >> 15) & KYBER_Q;
    839        r[3 * i + 0] = (t0 >> 0);
    840        r[3 * i + 1] = (t0 >> 8) | (t1 << 4);
    841        r[3 * i + 2] = (t1 >> 4);
    842    }
    843 }
    844 
    845 /*************************************************
    846 * Name:        poly_frombytes
    847 *
    848 * Description: De-serialization of a polynomial;
    849 *              inverse of poly_tobytes
    850 *
    851 * Arguments:   - poly *r: pointer to output polynomial
    852 *              - const uint8_t *a: pointer to input byte array
    853 *                                  (of KYBER_POLYBYTES bytes)
    854 **************************************************/
    855 static void
    856 poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES])
    857 {
    858    unsigned int i;
    859    for (i = 0; i < KYBER_N / 2; i++) {
    860        r->coeffs[2 * i] = ((a[3 * i + 0] >> 0) | ((uint16_t)a[3 * i + 1] << 8)) & 0xFFF;
    861        r->coeffs[2 * i + 1] = ((a[3 * i + 1] >> 4) | ((uint16_t)a[3 * i + 2] << 4)) & 0xFFF;
    862    }
    863 }
    864 
    865 /*************************************************
    866 * Name:        poly_frommsg
    867 *
    868 * Description: Convert 32-byte message to polynomial
    869 *
    870 * Arguments:   - poly *r: pointer to output polynomial
    871 *              - const uint8_t *msg: pointer to input message
    872 **************************************************/
    873 static void
    874 poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES])
    875 {
    876    unsigned int i, j;
    877    int16_t mask;
    878 
    879 #if (KYBER_INDCPA_MSGBYTES != KYBER_N / 8)
    880 #error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!"
    881 #endif
    882 
    883    for (i = 0; i < KYBER_N / 8; i++) {
    884        for (j = 0; j < 8; j++) {
    885            mask = -(int16_t)((msg[i] >> j) & 1);
    886            r->coeffs[8 * i + j] = mask & ((KYBER_Q + 1) / 2);
    887        }
    888    }
    889 }
    890 
    891 /*************************************************
    892 * Name:        poly_tomsg
    893 *
    894 * Description: Convert polynomial to 32-byte message
    895 *
    896 * Arguments:   - uint8_t *msg: pointer to output message
    897 *              - const poly *a: pointer to input polynomial
    898 **************************************************/
    899 static void
    900 poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a)
    901 {
    902    unsigned int i, j;
    903    uint32_t t;
    904 
    905    for (i = 0; i < KYBER_N / 8; i++) {
    906        msg[i] = 0;
    907        for (j = 0; j < 8; j++) {
    908            t = a->coeffs[8 * i + j];
    909            // t += ((int16_t)t >> 15) & KYBER_Q;
    910            // t  = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1;
    911            t <<= 1;
    912            t += 1665;
    913            t *= 80635;
    914            t >>= 28;
    915            t &= 1;
    916            msg[i] |= t << j;
    917        }
    918    }
    919 }
    920 
    921 /*************************************************
    922 * Name:        poly_getnoise_eta1
    923 *
    924 * Description: Sample a polynomial deterministically from a seed and a nonce,
    925 *              with output polynomial close to centered binomial distribution
    926 *              with parameter KYBER_ETA1
    927 *
    928 * Arguments:   - poly *r: pointer to output polynomial
    929 *              - const uint8_t *seed: pointer to input seed
    930 *                                     (of length KYBER_SYMBYTES bytes)
    931 *              - uint8_t nonce: one-byte input nonce
    932 **************************************************/
    933 static void
    934 poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce)
    935 {
    936    uint8_t buf[KYBER_ETA1 * KYBER_N / 4];
    937    prf(buf, sizeof(buf), seed, nonce);
    938    poly_cbd_eta1(r, buf);
    939 }
    940 
    941 /*************************************************
    942 * Name:        poly_getnoise_eta2
    943 *
    944 * Description: Sample a polynomial deterministically from a seed and a nonce,
    945 *              with output polynomial close to centered binomial distribution
    946 *              with parameter KYBER_ETA2
    947 *
    948 * Arguments:   - poly *r: pointer to output polynomial
    949 *              - const uint8_t *seed: pointer to input seed
    950 *                                     (of length KYBER_SYMBYTES bytes)
    951 *              - uint8_t nonce: one-byte input nonce
    952 **************************************************/
    953 static void
    954 poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce)
    955 {
    956    uint8_t buf[KYBER_ETA2 * KYBER_N / 4];
    957    prf(buf, sizeof(buf), seed, nonce);
    958    poly_cbd_eta2(r, buf);
    959 }
    960 
    961 /*************************************************
    962 * Name:        poly_ntt
    963 *
    964 * Description: Computes negacyclic number-theoretic transform (NTT) of
    965 *              a polynomial in place;
    966 *              inputs assumed to be in normal order, output in bitreversed order
    967 *
    968 * Arguments:   - uint16_t *r: pointer to in/output polynomial
    969 **************************************************/
    970 static void
    971 poly_ntt(poly *r)
    972 {
    973    ntt(r->coeffs);
    974    poly_reduce(r);
    975 }
    976 
    977 /*************************************************
    978 * Name:        poly_invntt_tomont
    979 *
    980 * Description: Computes inverse of negacyclic number-theoretic transform (NTT)
    981 *              of a polynomial in place;
    982 *              inputs assumed to be in bitreversed order, output in normal order
    983 *
    984 * Arguments:   - uint16_t *a: pointer to in/output polynomial
    985 **************************************************/
    986 static void
    987 poly_invntt_tomont(poly *r)
    988 {
    989    invntt(r->coeffs);
    990 }
    991 
    992 /*************************************************
    993 * Name:        poly_basemul_montgomery
    994 *
    995 * Description: Multiplication of two polynomials in NTT domain
    996 *
    997 * Arguments:   - poly *r: pointer to output polynomial
    998 *              - const poly *a: pointer to first input polynomial
    999 *              - const poly *b: pointer to second input polynomial
   1000 **************************************************/
   1001 static void
   1002 poly_basemul_montgomery(poly *r, const poly *a, const poly *b)
   1003 {
   1004    unsigned int i;
   1005    for (i = 0; i < KYBER_N / 4; i++) {
   1006        basemul(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], zetas[64 + i]);
   1007        basemul(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], &b->coeffs[4 * i + 2], -zetas[64 + i]);
   1008    }
   1009 }
   1010 
   1011 /*************************************************
   1012 * Name:        poly_tomont
   1013 *
   1014 * Description: Inplace conversion of all coefficients of a polynomial
   1015 *              from normal domain to Montgomery domain
   1016 *
   1017 * Arguments:   - poly *r: pointer to input/output polynomial
   1018 **************************************************/
   1019 static void
   1020 poly_tomont(poly *r)
   1021 {
   1022    unsigned int i;
   1023    const int16_t f = (1ULL << 32) % KYBER_Q;
   1024    for (i = 0; i < KYBER_N; i++)
   1025        r->coeffs[i] = montgomery_reduce((int32_t)r->coeffs[i] * f);
   1026 }
   1027 
   1028 /*************************************************
   1029 * Name:        poly_reduce
   1030 *
   1031 * Description: Applies Barrett reduction to all coefficients of a polynomial
   1032 *              for details of the Barrett reduction see comments in reduce.c
   1033 *
   1034 * Arguments:   - poly *r: pointer to input/output polynomial
   1035 **************************************************/
   1036 static void
   1037 poly_reduce(poly *r)
   1038 {
   1039    unsigned int i;
   1040    for (i = 0; i < KYBER_N; i++)
   1041        r->coeffs[i] = barrett_reduce(r->coeffs[i]);
   1042 }
   1043 
   1044 /*************************************************
   1045 * Name:        poly_add
   1046 *
   1047 * Description: Add two polynomials; no modular reduction is performed
   1048 *
   1049 * Arguments: - poly *r: pointer to output polynomial
   1050 *            - const poly *a: pointer to first input polynomial
   1051 *            - const poly *b: pointer to second input polynomial
   1052 **************************************************/
   1053 static void
   1054 poly_add(poly *r, const poly *a, const poly *b)
   1055 {
   1056    unsigned int i;
   1057    for (i = 0; i < KYBER_N; i++)
   1058        r->coeffs[i] = a->coeffs[i] + b->coeffs[i];
   1059 }
   1060 
   1061 /*************************************************
   1062 * Name:        poly_sub
   1063 *
   1064 * Description: Subtract two polynomials; no modular reduction is performed
   1065 *
   1066 * Arguments: - poly *r:       pointer to output polynomial
   1067 *            - const poly *a: pointer to first input polynomial
   1068 *            - const poly *b: pointer to second input polynomial
   1069 **************************************************/
   1070 static void
   1071 poly_sub(poly *r, const poly *a, const poly *b)
   1072 {
   1073    unsigned int i;
   1074    for (i = 0; i < KYBER_N; i++)
   1075        r->coeffs[i] = a->coeffs[i] - b->coeffs[i];
   1076 }
   1077 /** end: ref/poly.c **/
   1078 
   1079 /** begin: ref/polyvec.c **/
   1080 /*************************************************
   1081 * Name:        polyvec_compress
   1082 *
   1083 * Description: Compress and serialize vector of polynomials
   1084 *
   1085 * Arguments:   - uint8_t *r: pointer to output byte array
   1086 *                            (needs space for KYBER_POLYVECCOMPRESSEDBYTES)
   1087 *              - const polyvec *a: pointer to input vector of polynomials
   1088 **************************************************/
   1089 static void
   1090 polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a)
   1091 {
   1092    unsigned int i, j, k;
   1093    uint64_t d0;
   1094 
   1095 #if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352))
   1096    uint16_t t[8];
   1097    for (i = 0; i < KYBER_K; i++) {
   1098        for (j = 0; j < KYBER_N / 8; j++) {
   1099            for (k = 0; k < 8; k++) {
   1100                t[k] = a->vec[i].coeffs[8 * j + k];
   1101                t[k] += ((int16_t)t[k] >> 15) & KYBER_Q;
   1102                /*      t[k]  = ((((uint32_t)t[k] << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; */
   1103                d0 = t[k];
   1104                d0 <<= 11;
   1105                d0 += 1664;
   1106                d0 *= 645084;
   1107                d0 >>= 31;
   1108                t[k] = d0 & 0x7ff;
   1109            }
   1110 
   1111            r[0] = (t[0] >> 0);
   1112            r[1] = (t[0] >> 8) | (t[1] << 3);
   1113            r[2] = (t[1] >> 5) | (t[2] << 6);
   1114            r[3] = (t[2] >> 2);
   1115            r[4] = (t[2] >> 10) | (t[3] << 1);
   1116            r[5] = (t[3] >> 7) | (t[4] << 4);
   1117            r[6] = (t[4] >> 4) | (t[5] << 7);
   1118            r[7] = (t[5] >> 1);
   1119            r[8] = (t[5] >> 9) | (t[6] << 2);
   1120            r[9] = (t[6] >> 6) | (t[7] << 5);
   1121            r[10] = (t[7] >> 3);
   1122            r += 11;
   1123        }
   1124    }
   1125 #elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320))
   1126    uint16_t t[4];
   1127    for (i = 0; i < KYBER_K; i++) {
   1128        for (j = 0; j < KYBER_N / 4; j++) {
   1129            for (k = 0; k < 4; k++) {
   1130                t[k] = a->vec[i].coeffs[4 * j + k];
   1131                t[k] += ((int16_t)t[k] >> 15) & KYBER_Q;
   1132                /*      t[k]  = ((((uint32_t)t[k] << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; */
   1133                d0 = t[k];
   1134                d0 <<= 10;
   1135                d0 += 1665;
   1136                d0 *= 1290167;
   1137                d0 >>= 32;
   1138                t[k] = d0 & 0x3ff;
   1139            }
   1140 
   1141            r[0] = (t[0] >> 0);
   1142            r[1] = (t[0] >> 8) | (t[1] << 2);
   1143            r[2] = (t[1] >> 6) | (t[2] << 4);
   1144            r[3] = (t[2] >> 4) | (t[3] << 6);
   1145            r[4] = (t[3] >> 2);
   1146            r += 5;
   1147        }
   1148    }
   1149 #else
   1150 #error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}"
   1151 #endif
   1152 }
   1153 
   1154 /*************************************************
   1155 * Name:        polyvec_decompress
   1156 *
   1157 * Description: De-serialize and decompress vector of polynomials;
   1158 *              approximate inverse of polyvec_compress
   1159 *
   1160 * Arguments:   - polyvec *r:       pointer to output vector of polynomials
   1161 *              - const uint8_t *a: pointer to input byte array
   1162 *                                  (of length KYBER_POLYVECCOMPRESSEDBYTES)
   1163 **************************************************/
   1164 static void
   1165 polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES])
   1166 {
   1167    unsigned int i, j, k;
   1168 
   1169 #if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352))
   1170    uint16_t t[8];
   1171    for (i = 0; i < KYBER_K; i++) {
   1172        for (j = 0; j < KYBER_N / 8; j++) {
   1173            t[0] = (a[0] >> 0) | ((uint16_t)a[1] << 8);
   1174            t[1] = (a[1] >> 3) | ((uint16_t)a[2] << 5);
   1175            t[2] = (a[2] >> 6) | ((uint16_t)a[3] << 2) | ((uint16_t)a[4] << 10);
   1176            t[3] = (a[4] >> 1) | ((uint16_t)a[5] << 7);
   1177            t[4] = (a[5] >> 4) | ((uint16_t)a[6] << 4);
   1178            t[5] = (a[6] >> 7) | ((uint16_t)a[7] << 1) | ((uint16_t)a[8] << 9);
   1179            t[6] = (a[8] >> 2) | ((uint16_t)a[9] << 6);
   1180            t[7] = (a[9] >> 5) | ((uint16_t)a[10] << 3);
   1181            a += 11;
   1182 
   1183            for (k = 0; k < 8; k++)
   1184                r->vec[i].coeffs[8 * j + k] = ((uint32_t)(t[k] & 0x7FF) * KYBER_Q + 1024) >> 11;
   1185        }
   1186    }
   1187 #elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320))
   1188    uint16_t t[4];
   1189    for (i = 0; i < KYBER_K; i++) {
   1190        for (j = 0; j < KYBER_N / 4; j++) {
   1191            t[0] = (a[0] >> 0) | ((uint16_t)a[1] << 8);
   1192            t[1] = (a[1] >> 2) | ((uint16_t)a[2] << 6);
   1193            t[2] = (a[2] >> 4) | ((uint16_t)a[3] << 4);
   1194            t[3] = (a[3] >> 6) | ((uint16_t)a[4] << 2);
   1195            a += 5;
   1196 
   1197            for (k = 0; k < 4; k++)
   1198                r->vec[i].coeffs[4 * j + k] = ((uint32_t)(t[k] & 0x3FF) * KYBER_Q + 512) >> 10;
   1199        }
   1200    }
   1201 #else
   1202 #error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}"
   1203 #endif
   1204 }
   1205 
   1206 /*************************************************
   1207 * Name:        polyvec_tobytes
   1208 *
   1209 * Description: Serialize vector of polynomials
   1210 *
   1211 * Arguments:   - uint8_t *r: pointer to output byte array
   1212 *                            (needs space for KYBER_POLYVECBYTES)
   1213 *              - const polyvec *a: pointer to input vector of polynomials
   1214 **************************************************/
   1215 static void
   1216 polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a)
   1217 {
   1218    unsigned int i;
   1219    for (i = 0; i < KYBER_K; i++)
   1220        poly_tobytes(r + i * KYBER_POLYBYTES, &a->vec[i]);
   1221 }
   1222 
   1223 /*************************************************
   1224 * Name:        polyvec_frombytes
   1225 *
   1226 * Description: De-serialize vector of polynomials;
   1227 *              inverse of polyvec_tobytes
   1228 *
   1229 * Arguments:   - uint8_t *r:       pointer to output byte array
   1230 *              - const polyvec *a: pointer to input vector of polynomials
   1231 *                                  (of length KYBER_POLYVECBYTES)
   1232 **************************************************/
   1233 static void
   1234 polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES])
   1235 {
   1236    unsigned int i;
   1237    for (i = 0; i < KYBER_K; i++)
   1238        poly_frombytes(&r->vec[i], a + i * KYBER_POLYBYTES);
   1239 }
   1240 
   1241 /*************************************************
   1242 * Name:        polyvec_ntt
   1243 *
   1244 * Description: Apply forward NTT to all elements of a vector of polynomials
   1245 *
   1246 * Arguments:   - polyvec *r: pointer to in/output vector of polynomials
   1247 **************************************************/
   1248 static void
   1249 polyvec_ntt(polyvec *r)
   1250 {
   1251    unsigned int i;
   1252    for (i = 0; i < KYBER_K; i++)
   1253        poly_ntt(&r->vec[i]);
   1254 }
   1255 
   1256 /*************************************************
   1257 * Name:        polyvec_invntt_tomont
   1258 *
   1259 * Description: Apply inverse NTT to all elements of a vector of polynomials
   1260 *              and multiply by Montgomery factor 2^16
   1261 *
   1262 * Arguments:   - polyvec *r: pointer to in/output vector of polynomials
   1263 **************************************************/
   1264 static void
   1265 polyvec_invntt_tomont(polyvec *r)
   1266 {
   1267    unsigned int i;
   1268    for (i = 0; i < KYBER_K; i++)
   1269        poly_invntt_tomont(&r->vec[i]);
   1270 }
   1271 
   1272 /*************************************************
   1273 * Name:        polyvec_basemul_acc_montgomery
   1274 *
   1275 * Description: Multiply elements of a and b in NTT domain, accumulate into r,
   1276 *              and multiply by 2^-16.
   1277 *
   1278 * Arguments: - poly *r: pointer to output polynomial
   1279 *            - const polyvec *a: pointer to first input vector of polynomials
   1280 *            - const polyvec *b: pointer to second input vector of polynomials
   1281 **************************************************/
   1282 static void
   1283 polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b)
   1284 {
   1285    unsigned int i;
   1286    poly t;
   1287 
   1288    poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]);
   1289    for (i = 1; i < KYBER_K; i++) {
   1290        poly_basemul_montgomery(&t, &a->vec[i], &b->vec[i]);
   1291        poly_add(r, r, &t);
   1292    }
   1293 
   1294    poly_reduce(r);
   1295 }
   1296 
   1297 /*************************************************
   1298 * Name:        polyvec_reduce
   1299 *
   1300 * Description: Applies Barrett reduction to each coefficient
   1301 *              of each element of a vector of polynomials;
   1302 *              for details of the Barrett reduction see comments in reduce.c
   1303 *
   1304 * Arguments:   - polyvec *r: pointer to input/output polynomial
   1305 **************************************************/
   1306 static void
   1307 polyvec_reduce(polyvec *r)
   1308 {
   1309    unsigned int i;
   1310    for (i = 0; i < KYBER_K; i++)
   1311        poly_reduce(&r->vec[i]);
   1312 }
   1313 
   1314 /*************************************************
   1315 * Name:        polyvec_add
   1316 *
   1317 * Description: Add vectors of polynomials
   1318 *
   1319 * Arguments: - polyvec *r: pointer to output vector of polynomials
   1320 *            - const polyvec *a: pointer to first input vector of polynomials
   1321 *            - const polyvec *b: pointer to second input vector of polynomials
   1322 **************************************************/
   1323 static void
   1324 polyvec_add(polyvec *r, const polyvec *a, const polyvec *b)
   1325 {
   1326    unsigned int i;
   1327    for (i = 0; i < KYBER_K; i++)
   1328        poly_add(&r->vec[i], &a->vec[i], &b->vec[i]);
   1329 }
   1330 /** end: ref/polyvec.c **/
   1331 
   1332 /** begin: ref/indcpa.c **/
   1333 /*************************************************
   1334 * Name:        pack_pk
   1335 *
   1336 * Description: Serialize the public key as concatenation of the
   1337 *              serialized vector of polynomials pk
   1338 *              and the public seed used to generate the matrix A.
   1339 *
   1340 * Arguments:   uint8_t *r: pointer to the output serialized public key
   1341 *              polyvec *pk: pointer to the input public-key polyvec
   1342 *              const uint8_t *seed: pointer to the input public seed
   1343 **************************************************/
   1344 static void
   1345 pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES],
   1346        polyvec *pk,
   1347        const uint8_t seed[KYBER_SYMBYTES])
   1348 {
   1349    size_t i;
   1350    polyvec_tobytes(r, pk);
   1351    for (i = 0; i < KYBER_SYMBYTES; i++)
   1352        r[i + KYBER_POLYVECBYTES] = seed[i];
   1353 }
   1354 
   1355 /*************************************************
   1356 * Name:        unpack_pk
   1357 *
   1358 * Description: De-serialize public key from a byte array;
   1359 *              approximate inverse of pack_pk
   1360 *
   1361 * Arguments:   - polyvec *pk: pointer to output public-key polynomial vector
   1362 *              - uint8_t *seed: pointer to output seed to generate matrix A
   1363 *              - const uint8_t *packedpk: pointer to input serialized public key
   1364 **************************************************/
   1365 static void
   1366 unpack_pk(polyvec *pk,
   1367          uint8_t seed[KYBER_SYMBYTES],
   1368          const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES])
   1369 {
   1370    size_t i;
   1371    polyvec_frombytes(pk, packedpk);
   1372    for (i = 0; i < KYBER_SYMBYTES; i++)
   1373        seed[i] = packedpk[i + KYBER_POLYVECBYTES];
   1374 }
   1375 
   1376 /*************************************************
   1377 * Name:        pack_sk
   1378 *
   1379 * Description: Serialize the secret key
   1380 *
   1381 * Arguments:   - uint8_t *r: pointer to output serialized secret key
   1382 *              - polyvec *sk: pointer to input vector of polynomials (secret key)
   1383 **************************************************/
   1384 static void
   1385 pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk)
   1386 {
   1387    polyvec_tobytes(r, sk);
   1388 }
   1389 
   1390 /*************************************************
   1391 * Name:        unpack_sk
   1392 *
   1393 * Description: De-serialize the secret key; inverse of pack_sk
   1394 *
   1395 * Arguments:   - polyvec *sk: pointer to output vector of polynomials (secret key)
   1396 *              - const uint8_t *packedsk: pointer to input serialized secret key
   1397 **************************************************/
   1398 static void
   1399 unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES])
   1400 {
   1401    polyvec_frombytes(sk, packedsk);
   1402 }
   1403 
   1404 /*************************************************
   1405 * Name:        pack_ciphertext
   1406 *
   1407 * Description: Serialize the ciphertext as concatenation of the
   1408 *              compressed and serialized vector of polynomials b
   1409 *              and the compressed and serialized polynomial v
   1410 *
   1411 * Arguments:   uint8_t *r: pointer to the output serialized ciphertext
   1412 *              poly *pk: pointer to the input vector of polynomials b
   1413 *              poly *v: pointer to the input polynomial v
   1414 **************************************************/
   1415 static void
   1416 pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v)
   1417 {
   1418    polyvec_compress(r, b);
   1419    poly_compress(r + KYBER_POLYVECCOMPRESSEDBYTES, v);
   1420 }
   1421 
   1422 /*************************************************
   1423 * Name:        unpack_ciphertext
   1424 *
   1425 * Description: De-serialize and decompress ciphertext from a byte array;
   1426 *              approximate inverse of pack_ciphertext
   1427 *
   1428 * Arguments:   - polyvec *b: pointer to the output vector of polynomials b
   1429 *              - poly *v: pointer to the output polynomial v
   1430 *              - const uint8_t *c: pointer to the input serialized ciphertext
   1431 **************************************************/
   1432 static void
   1433 unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES])
   1434 {
   1435    polyvec_decompress(b, c);
   1436    poly_decompress(v, c + KYBER_POLYVECCOMPRESSEDBYTES);
   1437 }
   1438 
   1439 /*************************************************
   1440 * Name:        rej_uniform
   1441 *
   1442 * Description: Run rejection sampling on uniform random bytes to generate
   1443 *              uniform random integers mod q
   1444 *
   1445 * Arguments:   - int16_t *r: pointer to output buffer
   1446 *              - unsigned int len: requested number of 16-bit integers (uniform mod q)
   1447 *              - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes)
   1448 *              - unsigned int buflen: length of input buffer in bytes
   1449 *
   1450 * Returns number of sampled 16-bit integers (at most len)
   1451 **************************************************/
   1452 static unsigned int
   1453 rej_uniform(int16_t *r,
   1454            unsigned int len,
   1455            const uint8_t *buf,
   1456            unsigned int buflen)
   1457 {
   1458    unsigned int ctr, pos;
   1459    uint16_t val0, val1;
   1460 
   1461    ctr = pos = 0;
   1462    while (ctr < len && pos + 3 <= buflen) {
   1463        val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF;
   1464        val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF;
   1465        pos += 3;
   1466 
   1467        if (val0 < KYBER_Q)
   1468            r[ctr++] = val0;
   1469        if (ctr < len && val1 < KYBER_Q)
   1470            r[ctr++] = val1;
   1471    }
   1472 
   1473    return ctr;
   1474 }
   1475 
   1476 #define gen_a(A, B) gen_matrix(A, B, 0)
   1477 #define gen_at(A, B) gen_matrix(A, B, 1)
   1478 
   1479 /*************************************************
   1480 * Name:        gen_matrix
   1481 *
   1482 * Description: Deterministically generate matrix A (or the transpose of A)
   1483 *              from a seed. Entries of the matrix are polynomials that look
   1484 *              uniformly random. Performs rejection sampling on output of
   1485 *              a XOF
   1486 *
   1487 * Arguments:   - polyvec *a: pointer to ouptput matrix A
   1488 *              - const uint8_t *seed: pointer to input seed
   1489 *              - int transposed: boolean deciding whether A or A^T is generated
   1490 **************************************************/
   1491 #define GEN_MATRIX_NBLOCKS ((12 * KYBER_N / 8 * (1 << 12) / KYBER_Q + XOF_BLOCKBYTES) / XOF_BLOCKBYTES)
   1492 // Not static for benchmarking
   1493 static void
   1494 gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed)
   1495 {
   1496    unsigned int ctr, i, j, k;
   1497    unsigned int buflen, off;
   1498    uint8_t buf[GEN_MATRIX_NBLOCKS * XOF_BLOCKBYTES + 2];
   1499    xof_state state;
   1500 
   1501    for (i = 0; i < KYBER_K; i++) {
   1502        for (j = 0; j < KYBER_K; j++) {
   1503            if (transposed)
   1504                xof_absorb(&state, seed, i, j);
   1505            else
   1506                xof_absorb(&state, seed, j, i);
   1507 
   1508            xof_squeezeblocks(buf, GEN_MATRIX_NBLOCKS, &state);
   1509            buflen = GEN_MATRIX_NBLOCKS * XOF_BLOCKBYTES;
   1510            ctr = rej_uniform(a[i].vec[j].coeffs, KYBER_N, buf, buflen);
   1511 
   1512            while (ctr < KYBER_N) {
   1513                off = buflen % 3;
   1514                for (k = 0; k < off; k++)
   1515                    buf[k] = buf[buflen - off + k];
   1516                xof_squeezeblocks(buf + off, 1, &state);
   1517                buflen = off + XOF_BLOCKBYTES;
   1518                ctr += rej_uniform(a[i].vec[j].coeffs + ctr, KYBER_N - ctr, buf, buflen);
   1519            }
   1520        }
   1521    }
   1522 }
   1523 
   1524 /*************************************************
   1525 * Name:        indcpa_keypair_derand
   1526 *
   1527 * Description: Generates public and private key for the CPA-secure
   1528 *              public-key encryption scheme underlying Kyber
   1529 *
   1530 * Arguments:   - uint8_t *pk: pointer to output public key
   1531 *                             (of length KYBER_INDCPA_PUBLICKEYBYTES bytes)
   1532 *              - uint8_t *sk: pointer to output private key
   1533 *                             (of length KYBER_INDCPA_SECRETKEYBYTES bytes)
   1534 *              - const uint8_t *coins: pointer to input randomness
   1535 *                             (of length KYBER_SYMBYTES bytes)
   1536 **************************************************/
   1537 static void
   1538 indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
   1539                      uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES],
   1540                      const uint8_t coins[KYBER_SYMBYTES])
   1541 {
   1542    unsigned int i;
   1543    uint8_t buf[2 * KYBER_SYMBYTES];
   1544    const uint8_t *publicseed = buf;
   1545    const uint8_t *noiseseed = buf + KYBER_SYMBYTES;
   1546    uint8_t nonce = 0;
   1547    polyvec a[KYBER_K], e, pkpv, skpv;
   1548 
   1549    hash_g(buf, coins, KYBER_SYMBYTES);
   1550 
   1551    gen_a(a, publicseed);
   1552 
   1553    for (i = 0; i < KYBER_K; i++)
   1554        poly_getnoise_eta1(&skpv.vec[i], noiseseed, nonce++);
   1555    for (i = 0; i < KYBER_K; i++)
   1556        poly_getnoise_eta1(&e.vec[i], noiseseed, nonce++);
   1557 
   1558    polyvec_ntt(&skpv);
   1559    polyvec_ntt(&e);
   1560 
   1561    // matrix-vector multiplication
   1562    for (i = 0; i < KYBER_K; i++) {
   1563        polyvec_basemul_acc_montgomery(&pkpv.vec[i], &a[i], &skpv);
   1564        poly_tomont(&pkpv.vec[i]);
   1565    }
   1566 
   1567    polyvec_add(&pkpv, &pkpv, &e);
   1568    polyvec_reduce(&pkpv);
   1569 
   1570    pack_sk(sk, &skpv);
   1571    pack_pk(pk, &pkpv, publicseed);
   1572 }
   1573 
   1574 /*************************************************
   1575 * Name:        indcpa_enc
   1576 *
   1577 * Description: Encryption function of the CPA-secure
   1578 *              public-key encryption scheme underlying Kyber.
   1579 *
   1580 * Arguments:   - uint8_t *c: pointer to output ciphertext
   1581 *                            (of length KYBER_INDCPA_BYTES bytes)
   1582 *              - const uint8_t *m: pointer to input message
   1583 *                                  (of length KYBER_INDCPA_MSGBYTES bytes)
   1584 *              - const uint8_t *pk: pointer to input public key
   1585 *                                   (of length KYBER_INDCPA_PUBLICKEYBYTES)
   1586 *              - const uint8_t *coins: pointer to input random coins used as seed
   1587 *                                      (of length KYBER_SYMBYTES) to deterministically
   1588 *                                      generate all randomness
   1589 **************************************************/
   1590 static void
   1591 indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
   1592           const uint8_t m[KYBER_INDCPA_MSGBYTES],
   1593           const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
   1594           const uint8_t coins[KYBER_SYMBYTES])
   1595 {
   1596    unsigned int i;
   1597    uint8_t seed[KYBER_SYMBYTES];
   1598    uint8_t nonce = 0;
   1599    polyvec sp, pkpv, ep, at[KYBER_K], b;
   1600    poly v, k, epp;
   1601 
   1602    unpack_pk(&pkpv, seed, pk);
   1603    poly_frommsg(&k, m);
   1604    gen_at(at, seed);
   1605 
   1606    for (i = 0; i < KYBER_K; i++)
   1607        poly_getnoise_eta1(sp.vec + i, coins, nonce++);
   1608    for (i = 0; i < KYBER_K; i++)
   1609        poly_getnoise_eta2(ep.vec + i, coins, nonce++);
   1610    poly_getnoise_eta2(&epp, coins, nonce++);
   1611 
   1612    polyvec_ntt(&sp);
   1613 
   1614    // matrix-vector multiplication
   1615    for (i = 0; i < KYBER_K; i++)
   1616        polyvec_basemul_acc_montgomery(&b.vec[i], &at[i], &sp);
   1617 
   1618    polyvec_basemul_acc_montgomery(&v, &pkpv, &sp);
   1619 
   1620    polyvec_invntt_tomont(&b);
   1621    poly_invntt_tomont(&v);
   1622 
   1623    polyvec_add(&b, &b, &ep);
   1624    poly_add(&v, &v, &epp);
   1625    poly_add(&v, &v, &k);
   1626    polyvec_reduce(&b);
   1627    poly_reduce(&v);
   1628 
   1629    pack_ciphertext(c, &b, &v);
   1630 }
   1631 
   1632 /*************************************************
   1633 * Name:        indcpa_dec
   1634 *
   1635 * Description: Decryption function of the CPA-secure
   1636 *              public-key encryption scheme underlying Kyber.
   1637 *
   1638 * Arguments:   - uint8_t *m: pointer to output decrypted message
   1639 *                            (of length KYBER_INDCPA_MSGBYTES)
   1640 *              - const uint8_t *c: pointer to input ciphertext
   1641 *                                  (of length KYBER_INDCPA_BYTES)
   1642 *              - const uint8_t *sk: pointer to input secret key
   1643 *                                   (of length KYBER_INDCPA_SECRETKEYBYTES)
   1644 **************************************************/
   1645 static void
   1646 indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
   1647           const uint8_t c[KYBER_INDCPA_BYTES],
   1648           const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES])
   1649 {
   1650    polyvec b, skpv;
   1651    poly v, mp;
   1652 
   1653    unpack_ciphertext(&b, &v, c);
   1654    unpack_sk(&skpv, sk);
   1655 
   1656    polyvec_ntt(&b);
   1657    polyvec_basemul_acc_montgomery(&mp, &skpv, &b);
   1658    poly_invntt_tomont(&mp);
   1659 
   1660    poly_sub(&mp, &v, &mp);
   1661    poly_reduce(&mp);
   1662 
   1663    poly_tomsg(m, &mp);
   1664 }
   1665 /** end: ref/indcpa.c **/
   1666 
   1667 /** begin: ref/fips202.c **/
   1668 /* Based on the public domain implementation in crypto_hash/keccakc512/simple/ from
   1669 * http://bench.cr.yp.to/supercop.html by Ronny Van Keer and the public domain "TweetFips202"
   1670 * implementation from https://twitter.com/tweetfips202 by Gilles Van Assche, Daniel J. Bernstein,
   1671 * and Peter Schwabe */
   1672 
   1673 #define NROUNDS 24
   1674 #define ROL(a, offset) ((a << offset) ^ (a >> (64 - offset)))
   1675 
   1676 /*************************************************
   1677 * Name:        load64
   1678 *
   1679 * Description: Load 8 bytes into uint64_t in little-endian order
   1680 *
   1681 * Arguments:   - const uint8_t *x: pointer to input byte array
   1682 *
   1683 * Returns the loaded 64-bit unsigned integer
   1684 **************************************************/
   1685 static uint64_t
   1686 load64(const uint8_t x[8])
   1687 {
   1688    unsigned int i;
   1689    uint64_t r = 0;
   1690 
   1691    for (i = 0; i < 8; i++)
   1692        r |= (uint64_t)x[i] << 8 * i;
   1693 
   1694    return r;
   1695 }
   1696 
   1697 /*************************************************
   1698 * Name:        store64
   1699 *
   1700 * Description: Store a 64-bit integer to array of 8 bytes in little-endian order
   1701 *
   1702 * Arguments:   - uint8_t *x: pointer to the output byte array (allocated)
   1703 *              - uint64_t u: input 64-bit unsigned integer
   1704 **************************************************/
   1705 static void
   1706 store64(uint8_t x[8], uint64_t u)
   1707 {
   1708    unsigned int i;
   1709 
   1710    for (i = 0; i < 8; i++)
   1711        x[i] = u >> 8 * i;
   1712 }
   1713 
   1714 /* Keccak round constants */
   1715 static const uint64_t KeccakF_RoundConstants[NROUNDS] = {
   1716    (uint64_t)0x0000000000000001ULL,
   1717    (uint64_t)0x0000000000008082ULL,
   1718    (uint64_t)0x800000000000808aULL,
   1719    (uint64_t)0x8000000080008000ULL,
   1720    (uint64_t)0x000000000000808bULL,
   1721    (uint64_t)0x0000000080000001ULL,
   1722    (uint64_t)0x8000000080008081ULL,
   1723    (uint64_t)0x8000000000008009ULL,
   1724    (uint64_t)0x000000000000008aULL,
   1725    (uint64_t)0x0000000000000088ULL,
   1726    (uint64_t)0x0000000080008009ULL,
   1727    (uint64_t)0x000000008000000aULL,
   1728    (uint64_t)0x000000008000808bULL,
   1729    (uint64_t)0x800000000000008bULL,
   1730    (uint64_t)0x8000000000008089ULL,
   1731    (uint64_t)0x8000000000008003ULL,
   1732    (uint64_t)0x8000000000008002ULL,
   1733    (uint64_t)0x8000000000000080ULL,
   1734    (uint64_t)0x000000000000800aULL,
   1735    (uint64_t)0x800000008000000aULL,
   1736    (uint64_t)0x8000000080008081ULL,
   1737    (uint64_t)0x8000000000008080ULL,
   1738    (uint64_t)0x0000000080000001ULL,
   1739    (uint64_t)0x8000000080008008ULL
   1740 };
   1741 
   1742 /*************************************************
   1743 * Name:        KeccakF1600_StatePermute
   1744 *
   1745 * Description: The Keccak F1600 Permutation
   1746 *
   1747 * Arguments:   - uint64_t *state: pointer to input/output Keccak state
   1748 **************************************************/
   1749 static void
   1750 KeccakF1600_StatePermute(uint64_t state[25])
   1751 {
   1752    int round;
   1753 
   1754    uint64_t Aba, Abe, Abi, Abo, Abu;
   1755    uint64_t Aga, Age, Agi, Ago, Agu;
   1756    uint64_t Aka, Ake, Aki, Ako, Aku;
   1757    uint64_t Ama, Ame, Ami, Amo, Amu;
   1758    uint64_t Asa, Ase, Asi, Aso, Asu;
   1759    uint64_t BCa, BCe, BCi, BCo, BCu;
   1760    uint64_t Da, De, Di, Do, Du;
   1761    uint64_t Eba, Ebe, Ebi, Ebo, Ebu;
   1762    uint64_t Ega, Ege, Egi, Ego, Egu;
   1763    uint64_t Eka, Eke, Eki, Eko, Eku;
   1764    uint64_t Ema, Eme, Emi, Emo, Emu;
   1765    uint64_t Esa, Ese, Esi, Eso, Esu;
   1766 
   1767    // copyFromState(A, state)
   1768    Aba = state[0];
   1769    Abe = state[1];
   1770    Abi = state[2];
   1771    Abo = state[3];
   1772    Abu = state[4];
   1773    Aga = state[5];
   1774    Age = state[6];
   1775    Agi = state[7];
   1776    Ago = state[8];
   1777    Agu = state[9];
   1778    Aka = state[10];
   1779    Ake = state[11];
   1780    Aki = state[12];
   1781    Ako = state[13];
   1782    Aku = state[14];
   1783    Ama = state[15];
   1784    Ame = state[16];
   1785    Ami = state[17];
   1786    Amo = state[18];
   1787    Amu = state[19];
   1788    Asa = state[20];
   1789    Ase = state[21];
   1790    Asi = state[22];
   1791    Aso = state[23];
   1792    Asu = state[24];
   1793 
   1794    for (round = 0; round < NROUNDS; round += 2) {
   1795        //    prepareTheta
   1796        BCa = Aba ^ Aga ^ Aka ^ Ama ^ Asa;
   1797        BCe = Abe ^ Age ^ Ake ^ Ame ^ Ase;
   1798        BCi = Abi ^ Agi ^ Aki ^ Ami ^ Asi;
   1799        BCo = Abo ^ Ago ^ Ako ^ Amo ^ Aso;
   1800        BCu = Abu ^ Agu ^ Aku ^ Amu ^ Asu;
   1801 
   1802        // thetaRhoPiChiIotaPrepareTheta(round, A, E)
   1803        Da = BCu ^ ROL(BCe, 1);
   1804        De = BCa ^ ROL(BCi, 1);
   1805        Di = BCe ^ ROL(BCo, 1);
   1806        Do = BCi ^ ROL(BCu, 1);
   1807        Du = BCo ^ ROL(BCa, 1);
   1808 
   1809        Aba ^= Da;
   1810        BCa = Aba;
   1811        Age ^= De;
   1812        BCe = ROL(Age, 44);
   1813        Aki ^= Di;
   1814        BCi = ROL(Aki, 43);
   1815        Amo ^= Do;
   1816        BCo = ROL(Amo, 21);
   1817        Asu ^= Du;
   1818        BCu = ROL(Asu, 14);
   1819        Eba = BCa ^ ((~BCe) & BCi);
   1820        Eba ^= (uint64_t)KeccakF_RoundConstants[round];
   1821        Ebe = BCe ^ ((~BCi) & BCo);
   1822        Ebi = BCi ^ ((~BCo) & BCu);
   1823        Ebo = BCo ^ ((~BCu) & BCa);
   1824        Ebu = BCu ^ ((~BCa) & BCe);
   1825 
   1826        Abo ^= Do;
   1827        BCa = ROL(Abo, 28);
   1828        Agu ^= Du;
   1829        BCe = ROL(Agu, 20);
   1830        Aka ^= Da;
   1831        BCi = ROL(Aka, 3);
   1832        Ame ^= De;
   1833        BCo = ROL(Ame, 45);
   1834        Asi ^= Di;
   1835        BCu = ROL(Asi, 61);
   1836        Ega = BCa ^ ((~BCe) & BCi);
   1837        Ege = BCe ^ ((~BCi) & BCo);
   1838        Egi = BCi ^ ((~BCo) & BCu);
   1839        Ego = BCo ^ ((~BCu) & BCa);
   1840        Egu = BCu ^ ((~BCa) & BCe);
   1841 
   1842        Abe ^= De;
   1843        BCa = ROL(Abe, 1);
   1844        Agi ^= Di;
   1845        BCe = ROL(Agi, 6);
   1846        Ako ^= Do;
   1847        BCi = ROL(Ako, 25);
   1848        Amu ^= Du;
   1849        BCo = ROL(Amu, 8);
   1850        Asa ^= Da;
   1851        BCu = ROL(Asa, 18);
   1852        Eka = BCa ^ ((~BCe) & BCi);
   1853        Eke = BCe ^ ((~BCi) & BCo);
   1854        Eki = BCi ^ ((~BCo) & BCu);
   1855        Eko = BCo ^ ((~BCu) & BCa);
   1856        Eku = BCu ^ ((~BCa) & BCe);
   1857 
   1858        Abu ^= Du;
   1859        BCa = ROL(Abu, 27);
   1860        Aga ^= Da;
   1861        BCe = ROL(Aga, 36);
   1862        Ake ^= De;
   1863        BCi = ROL(Ake, 10);
   1864        Ami ^= Di;
   1865        BCo = ROL(Ami, 15);
   1866        Aso ^= Do;
   1867        BCu = ROL(Aso, 56);
   1868        Ema = BCa ^ ((~BCe) & BCi);
   1869        Eme = BCe ^ ((~BCi) & BCo);
   1870        Emi = BCi ^ ((~BCo) & BCu);
   1871        Emo = BCo ^ ((~BCu) & BCa);
   1872        Emu = BCu ^ ((~BCa) & BCe);
   1873 
   1874        Abi ^= Di;
   1875        BCa = ROL(Abi, 62);
   1876        Ago ^= Do;
   1877        BCe = ROL(Ago, 55);
   1878        Aku ^= Du;
   1879        BCi = ROL(Aku, 39);
   1880        Ama ^= Da;
   1881        BCo = ROL(Ama, 41);
   1882        Ase ^= De;
   1883        BCu = ROL(Ase, 2);
   1884        Esa = BCa ^ ((~BCe) & BCi);
   1885        Ese = BCe ^ ((~BCi) & BCo);
   1886        Esi = BCi ^ ((~BCo) & BCu);
   1887        Eso = BCo ^ ((~BCu) & BCa);
   1888        Esu = BCu ^ ((~BCa) & BCe);
   1889 
   1890        //    prepareTheta
   1891        BCa = Eba ^ Ega ^ Eka ^ Ema ^ Esa;
   1892        BCe = Ebe ^ Ege ^ Eke ^ Eme ^ Ese;
   1893        BCi = Ebi ^ Egi ^ Eki ^ Emi ^ Esi;
   1894        BCo = Ebo ^ Ego ^ Eko ^ Emo ^ Eso;
   1895        BCu = Ebu ^ Egu ^ Eku ^ Emu ^ Esu;
   1896 
   1897        // thetaRhoPiChiIotaPrepareTheta(round+1, E, A)
   1898        Da = BCu ^ ROL(BCe, 1);
   1899        De = BCa ^ ROL(BCi, 1);
   1900        Di = BCe ^ ROL(BCo, 1);
   1901        Do = BCi ^ ROL(BCu, 1);
   1902        Du = BCo ^ ROL(BCa, 1);
   1903 
   1904        Eba ^= Da;
   1905        BCa = Eba;
   1906        Ege ^= De;
   1907        BCe = ROL(Ege, 44);
   1908        Eki ^= Di;
   1909        BCi = ROL(Eki, 43);
   1910        Emo ^= Do;
   1911        BCo = ROL(Emo, 21);
   1912        Esu ^= Du;
   1913        BCu = ROL(Esu, 14);
   1914        Aba = BCa ^ ((~BCe) & BCi);
   1915        Aba ^= (uint64_t)KeccakF_RoundConstants[round + 1];
   1916        Abe = BCe ^ ((~BCi) & BCo);
   1917        Abi = BCi ^ ((~BCo) & BCu);
   1918        Abo = BCo ^ ((~BCu) & BCa);
   1919        Abu = BCu ^ ((~BCa) & BCe);
   1920 
   1921        Ebo ^= Do;
   1922        BCa = ROL(Ebo, 28);
   1923        Egu ^= Du;
   1924        BCe = ROL(Egu, 20);
   1925        Eka ^= Da;
   1926        BCi = ROL(Eka, 3);
   1927        Eme ^= De;
   1928        BCo = ROL(Eme, 45);
   1929        Esi ^= Di;
   1930        BCu = ROL(Esi, 61);
   1931        Aga = BCa ^ ((~BCe) & BCi);
   1932        Age = BCe ^ ((~BCi) & BCo);
   1933        Agi = BCi ^ ((~BCo) & BCu);
   1934        Ago = BCo ^ ((~BCu) & BCa);
   1935        Agu = BCu ^ ((~BCa) & BCe);
   1936 
   1937        Ebe ^= De;
   1938        BCa = ROL(Ebe, 1);
   1939        Egi ^= Di;
   1940        BCe = ROL(Egi, 6);
   1941        Eko ^= Do;
   1942        BCi = ROL(Eko, 25);
   1943        Emu ^= Du;
   1944        BCo = ROL(Emu, 8);
   1945        Esa ^= Da;
   1946        BCu = ROL(Esa, 18);
   1947        Aka = BCa ^ ((~BCe) & BCi);
   1948        Ake = BCe ^ ((~BCi) & BCo);
   1949        Aki = BCi ^ ((~BCo) & BCu);
   1950        Ako = BCo ^ ((~BCu) & BCa);
   1951        Aku = BCu ^ ((~BCa) & BCe);
   1952 
   1953        Ebu ^= Du;
   1954        BCa = ROL(Ebu, 27);
   1955        Ega ^= Da;
   1956        BCe = ROL(Ega, 36);
   1957        Eke ^= De;
   1958        BCi = ROL(Eke, 10);
   1959        Emi ^= Di;
   1960        BCo = ROL(Emi, 15);
   1961        Eso ^= Do;
   1962        BCu = ROL(Eso, 56);
   1963        Ama = BCa ^ ((~BCe) & BCi);
   1964        Ame = BCe ^ ((~BCi) & BCo);
   1965        Ami = BCi ^ ((~BCo) & BCu);
   1966        Amo = BCo ^ ((~BCu) & BCa);
   1967        Amu = BCu ^ ((~BCa) & BCe);
   1968 
   1969        Ebi ^= Di;
   1970        BCa = ROL(Ebi, 62);
   1971        Ego ^= Do;
   1972        BCe = ROL(Ego, 55);
   1973        Eku ^= Du;
   1974        BCi = ROL(Eku, 39);
   1975        Ema ^= Da;
   1976        BCo = ROL(Ema, 41);
   1977        Ese ^= De;
   1978        BCu = ROL(Ese, 2);
   1979        Asa = BCa ^ ((~BCe) & BCi);
   1980        Ase = BCe ^ ((~BCi) & BCo);
   1981        Asi = BCi ^ ((~BCo) & BCu);
   1982        Aso = BCo ^ ((~BCu) & BCa);
   1983        Asu = BCu ^ ((~BCa) & BCe);
   1984    }
   1985 
   1986    // copyToState(state, A)
   1987    state[0] = Aba;
   1988    state[1] = Abe;
   1989    state[2] = Abi;
   1990    state[3] = Abo;
   1991    state[4] = Abu;
   1992    state[5] = Aga;
   1993    state[6] = Age;
   1994    state[7] = Agi;
   1995    state[8] = Ago;
   1996    state[9] = Agu;
   1997    state[10] = Aka;
   1998    state[11] = Ake;
   1999    state[12] = Aki;
   2000    state[13] = Ako;
   2001    state[14] = Aku;
   2002    state[15] = Ama;
   2003    state[16] = Ame;
   2004    state[17] = Ami;
   2005    state[18] = Amo;
   2006    state[19] = Amu;
   2007    state[20] = Asa;
   2008    state[21] = Ase;
   2009    state[22] = Asi;
   2010    state[23] = Aso;
   2011    state[24] = Asu;
   2012 }
   2013 
   2014 /*************************************************
   2015 * Name:        keccak_init
   2016 *
   2017 * Description: Initializes the Keccak state.
   2018 *
   2019 * Arguments:   - uint64_t *s: pointer to Keccak state
   2020 **************************************************/
   2021 static void
   2022 keccak_init(uint64_t s[25])
   2023 {
   2024    unsigned int i;
   2025    for (i = 0; i < 25; i++)
   2026        s[i] = 0;
   2027 }
   2028 
   2029 /*************************************************
   2030 * Name:        keccak_absorb
   2031 *
   2032 * Description: Absorb step of Keccak; incremental.
   2033 *
   2034 * Arguments:   - uint64_t *s: pointer to Keccak state
   2035 *              - unsigned int pos: position in current block to be absorbed
   2036 *              - unsigned int r: rate in bytes (e.g., 168 for SHAKE128)
   2037 *              - const uint8_t *in: pointer to input to be absorbed into s
   2038 *              - size_t inlen: length of input in bytes
   2039 *
   2040 * Returns new position pos in current block
   2041 **************************************************/
   2042 static unsigned int
   2043 keccak_absorb(uint64_t s[25],
   2044              unsigned int pos,
   2045              unsigned int r,
   2046              const uint8_t *in,
   2047              size_t inlen)
   2048 {
   2049    unsigned int i;
   2050 
   2051    while (pos + inlen >= r) {
   2052        for (i = pos; i < r; i++)
   2053            s[i / 8] ^= (uint64_t)*in++ << 8 * (i % 8);
   2054        inlen -= r - pos;
   2055        KeccakF1600_StatePermute(s);
   2056        pos = 0;
   2057    }
   2058 
   2059    for (i = pos; i < pos + inlen; i++)
   2060        s[i / 8] ^= (uint64_t)*in++ << 8 * (i % 8);
   2061 
   2062    return i;
   2063 }
   2064 
   2065 /*************************************************
   2066 * Name:        keccak_finalize
   2067 *
   2068 * Description: Finalize absorb step.
   2069 *
   2070 * Arguments:   - uint64_t *s: pointer to Keccak state
   2071 *              - unsigned int pos: position in current block to be absorbed
   2072 *              - unsigned int r: rate in bytes (e.g., 168 for SHAKE128)
   2073 *              - uint8_t p: domain separation byte
   2074 **************************************************/
   2075 static void
   2076 keccak_finalize(uint64_t s[25], unsigned int pos, unsigned int r, uint8_t p)
   2077 {
   2078    s[pos / 8] ^= (uint64_t)p << 8 * (pos % 8);
   2079    s[r / 8 - 1] ^= 1ULL << 63;
   2080 }
   2081 
   2082 /*************************************************
   2083 * Name:        keccak_squeeze
   2084 *
   2085 * Description: Squeeze step of Keccak. Squeezes arbitratrily many bytes.
   2086 *              Modifies the state. Can be called multiple times to keep
   2087 *              squeezing, i.e., is incremental.
   2088 *
   2089 * Arguments:   - uint8_t *out: pointer to output
   2090 *              - size_t outlen: number of bytes to be squeezed (written to out)
   2091 *              - uint64_t *s: pointer to input/output Keccak state
   2092 *              - unsigned int pos: number of bytes in current block already squeezed
   2093 *              - unsigned int r: rate in bytes (e.g., 168 for SHAKE128)
   2094 *
   2095 * Returns new position pos in current block
   2096 **************************************************/
   2097 static unsigned int
   2098 keccak_squeeze(uint8_t *out,
   2099               size_t outlen,
   2100               uint64_t s[25],
   2101               unsigned int pos,
   2102               unsigned int r)
   2103 {
   2104    unsigned int i;
   2105 
   2106    while (outlen) {
   2107        if (pos == r) {
   2108            KeccakF1600_StatePermute(s);
   2109            pos = 0;
   2110        }
   2111        for (i = pos; i < r && i < pos + outlen; i++)
   2112            *out++ = s[i / 8] >> 8 * (i % 8);
   2113        outlen -= i - pos;
   2114        pos = i;
   2115    }
   2116 
   2117    return pos;
   2118 }
   2119 
   2120 /*************************************************
   2121 * Name:        keccak_absorb_once
   2122 *
   2123 * Description: Absorb step of Keccak;
   2124 *              non-incremental, starts by zeroeing the state.
   2125 *
   2126 * Arguments:   - uint64_t *s: pointer to (uninitialized) output Keccak state
   2127 *              - unsigned int r: rate in bytes (e.g., 168 for SHAKE128)
   2128 *              - const uint8_t *in: pointer to input to be absorbed into s
   2129 *              - size_t inlen: length of input in bytes
   2130 *              - uint8_t p: domain-separation byte for different Keccak-derived functions
   2131 **************************************************/
   2132 static void
   2133 keccak_absorb_once(uint64_t s[25],
   2134                   unsigned int r,
   2135                   const uint8_t *in,
   2136                   size_t inlen,
   2137                   uint8_t p)
   2138 {
   2139    unsigned int i;
   2140 
   2141    for (i = 0; i < 25; i++)
   2142        s[i] = 0;
   2143 
   2144    while (inlen >= r) {
   2145        for (i = 0; i < r / 8; i++)
   2146            s[i] ^= load64(in + 8 * i);
   2147        in += r;
   2148        inlen -= r;
   2149        KeccakF1600_StatePermute(s);
   2150    }
   2151 
   2152    for (i = 0; i < inlen; i++)
   2153        s[i / 8] ^= (uint64_t)in[i] << 8 * (i % 8);
   2154 
   2155    s[i / 8] ^= (uint64_t)p << 8 * (i % 8);
   2156    s[(r - 1) / 8] ^= 1ULL << 63;
   2157 }
   2158 
   2159 /*************************************************
   2160 * Name:        keccak_squeezeblocks
   2161 *
   2162 * Description: Squeeze step of Keccak. Squeezes full blocks of r bytes each.
   2163 *              Modifies the state. Can be called multiple times to keep
   2164 *              squeezing, i.e., is incremental. Assumes zero bytes of current
   2165 *              block have already been squeezed.
   2166 *
   2167 * Arguments:   - uint8_t *out: pointer to output blocks
   2168 *              - size_t nblocks: number of blocks to be squeezed (written to out)
   2169 *              - uint64_t *s: pointer to input/output Keccak state
   2170 *              - unsigned int r: rate in bytes (e.g., 168 for SHAKE128)
   2171 **************************************************/
   2172 static void
   2173 keccak_squeezeblocks(uint8_t *out,
   2174                     size_t nblocks,
   2175                     uint64_t s[25],
   2176                     unsigned int r)
   2177 {
   2178    unsigned int i;
   2179 
   2180    while (nblocks) {
   2181        KeccakF1600_StatePermute(s);
   2182        for (i = 0; i < r / 8; i++)
   2183            store64(out + 8 * i, s[i]);
   2184        out += r;
   2185        nblocks -= 1;
   2186    }
   2187 }
   2188 
   2189 /*************************************************
   2190 * Name:        shake128_init
   2191 *
   2192 * Description: Initilizes Keccak state for use as SHAKE128 XOF
   2193 *
   2194 * Arguments:   - keccak_state *state: pointer to (uninitialized) Keccak state
   2195 **************************************************/
   2196 void
   2197 shake128_init(keccak_state *state)
   2198 {
   2199    keccak_init(state->s);
   2200    state->pos = 0;
   2201 }
   2202 
   2203 /*************************************************
   2204 * Name:        shake128_absorb
   2205 *
   2206 * Description: Absorb step of the SHAKE128 XOF; incremental.
   2207 *
   2208 * Arguments:   - keccak_state *state: pointer to (initialized) output Keccak state
   2209 *              - const uint8_t *in: pointer to input to be absorbed into s
   2210 *              - size_t inlen: length of input in bytes
   2211 **************************************************/
   2212 void
   2213 shake128_absorb(keccak_state *state, const uint8_t *in, size_t inlen)
   2214 {
   2215    state->pos = keccak_absorb(state->s, state->pos, SHAKE128_RATE, in, inlen);
   2216 }
   2217 
   2218 /*************************************************
   2219 * Name:        shake128_finalize
   2220 *
   2221 * Description: Finalize absorb step of the SHAKE128 XOF.
   2222 *
   2223 * Arguments:   - keccak_state *state: pointer to Keccak state
   2224 **************************************************/
   2225 void
   2226 shake128_finalize(keccak_state *state)
   2227 {
   2228    keccak_finalize(state->s, state->pos, SHAKE128_RATE, 0x1F);
   2229    state->pos = SHAKE128_RATE;
   2230 }
   2231 
   2232 /*************************************************
   2233 * Name:        shake128_squeeze
   2234 *
   2235 * Description: Squeeze step of SHAKE128 XOF. Squeezes arbitraily many
   2236 *              bytes. Can be called multiple times to keep squeezing.
   2237 *
   2238 * Arguments:   - uint8_t *out: pointer to output blocks
   2239 *              - size_t outlen : number of bytes to be squeezed (written to output)
   2240 *              - keccak_state *s: pointer to input/output Keccak state
   2241 **************************************************/
   2242 void
   2243 shake128_squeeze(uint8_t *out, size_t outlen, keccak_state *state)
   2244 {
   2245    state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE128_RATE);
   2246 }
   2247 
   2248 /*************************************************
   2249 * Name:        shake128_absorb_once
   2250 *
   2251 * Description: Initialize, absorb into and finalize SHAKE128 XOF; non-incremental.
   2252 *
   2253 * Arguments:   - keccak_state *state: pointer to (uninitialized) output Keccak state
   2254 *              - const uint8_t *in: pointer to input to be absorbed into s
   2255 *              - size_t inlen: length of input in bytes
   2256 **************************************************/
   2257 void
   2258 shake128_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen)
   2259 {
   2260    keccak_absorb_once(state->s, SHAKE128_RATE, in, inlen, 0x1F);
   2261    state->pos = SHAKE128_RATE;
   2262 }
   2263 
   2264 /*************************************************
   2265 * Name:        shake128_squeezeblocks
   2266 *
   2267 * Description: Squeeze step of SHAKE128 XOF. Squeezes full blocks of
   2268 *              SHAKE128_RATE bytes each. Can be called multiple times
   2269 *              to keep squeezing. Assumes new block has not yet been
   2270 *              started (state->pos = SHAKE128_RATE).
   2271 *
   2272 * Arguments:   - uint8_t *out: pointer to output blocks
   2273 *              - size_t nblocks: number of blocks to be squeezed (written to output)
   2274 *              - keccak_state *s: pointer to input/output Keccak state
   2275 **************************************************/
   2276 void
   2277 shake128_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state)
   2278 {
   2279    keccak_squeezeblocks(out, nblocks, state->s, SHAKE128_RATE);
   2280 }
   2281 
   2282 /*************************************************
   2283 * Name:        shake256_init
   2284 *
   2285 * Description: Initilizes Keccak state for use as SHAKE256 XOF
   2286 *
   2287 * Arguments:   - keccak_state *state: pointer to (uninitialized) Keccak state
   2288 **************************************************/
   2289 void
   2290 shake256_init(keccak_state *state)
   2291 {
   2292    keccak_init(state->s);
   2293    state->pos = 0;
   2294 }
   2295 
   2296 /*************************************************
   2297 * Name:        shake256_absorb
   2298 *
   2299 * Description: Absorb step of the SHAKE256 XOF; incremental.
   2300 *
   2301 * Arguments:   - keccak_state *state: pointer to (initialized) output Keccak state
   2302 *              - const uint8_t *in: pointer to input to be absorbed into s
   2303 *              - size_t inlen: length of input in bytes
   2304 **************************************************/
   2305 void
   2306 shake256_absorb(keccak_state *state, const uint8_t *in, size_t inlen)
   2307 {
   2308    state->pos = keccak_absorb(state->s, state->pos, SHAKE256_RATE, in, inlen);
   2309 }
   2310 
   2311 /*************************************************
   2312 * Name:        shake256_finalize
   2313 *
   2314 * Description: Finalize absorb step of the SHAKE256 XOF.
   2315 *
   2316 * Arguments:   - keccak_state *state: pointer to Keccak state
   2317 **************************************************/
   2318 void
   2319 shake256_finalize(keccak_state *state)
   2320 {
   2321    keccak_finalize(state->s, state->pos, SHAKE256_RATE, 0x1F);
   2322    state->pos = SHAKE256_RATE;
   2323 }
   2324 
   2325 /*************************************************
   2326 * Name:        shake256_squeeze
   2327 *
   2328 * Description: Squeeze step of SHAKE256 XOF. Squeezes arbitraily many
   2329 *              bytes. Can be called multiple times to keep squeezing.
   2330 *
   2331 * Arguments:   - uint8_t *out: pointer to output blocks
   2332 *              - size_t outlen : number of bytes to be squeezed (written to output)
   2333 *              - keccak_state *s: pointer to input/output Keccak state
   2334 **************************************************/
   2335 void
   2336 shake256_squeeze(uint8_t *out, size_t outlen, keccak_state *state)
   2337 {
   2338    state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE256_RATE);
   2339 }
   2340 
   2341 /*************************************************
   2342 * Name:        shake256_absorb_once
   2343 *
   2344 * Description: Initialize, absorb into and finalize SHAKE256 XOF; non-incremental.
   2345 *
   2346 * Arguments:   - keccak_state *state: pointer to (uninitialized) output Keccak state
   2347 *              - const uint8_t *in: pointer to input to be absorbed into s
   2348 *              - size_t inlen: length of input in bytes
   2349 **************************************************/
   2350 void
   2351 shake256_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen)
   2352 {
   2353    keccak_absorb_once(state->s, SHAKE256_RATE, in, inlen, 0x1F);
   2354    state->pos = SHAKE256_RATE;
   2355 }
   2356 
   2357 /*************************************************
   2358 * Name:        shake256_squeezeblocks
   2359 *
   2360 * Description: Squeeze step of SHAKE256 XOF. Squeezes full blocks of
   2361 *              SHAKE256_RATE bytes each. Can be called multiple times
   2362 *              to keep squeezing. Assumes next block has not yet been
   2363 *              started (state->pos = SHAKE256_RATE).
   2364 *
   2365 * Arguments:   - uint8_t *out: pointer to output blocks
   2366 *              - size_t nblocks: number of blocks to be squeezed (written to output)
   2367 *              - keccak_state *s: pointer to input/output Keccak state
   2368 **************************************************/
   2369 void
   2370 shake256_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state)
   2371 {
   2372    keccak_squeezeblocks(out, nblocks, state->s, SHAKE256_RATE);
   2373 }
   2374 
   2375 /*************************************************
   2376 * Name:        shake128
   2377 *
   2378 * Description: SHAKE128 XOF with non-incremental API
   2379 *
   2380 * Arguments:   - uint8_t *out: pointer to output
   2381 *              - size_t outlen: requested output length in bytes
   2382 *              - const uint8_t *in: pointer to input
   2383 *              - size_t inlen: length of input in bytes
   2384 **************************************************/
   2385 void
   2386 shake128(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen)
   2387 {
   2388    size_t nblocks;
   2389    keccak_state state;
   2390 
   2391    shake128_absorb_once(&state, in, inlen);
   2392    nblocks = outlen / SHAKE128_RATE;
   2393    shake128_squeezeblocks(out, nblocks, &state);
   2394    outlen -= nblocks * SHAKE128_RATE;
   2395    out += nblocks * SHAKE128_RATE;
   2396    shake128_squeeze(out, outlen, &state);
   2397 }
   2398 
   2399 /*************************************************
   2400 * Name:        shake256
   2401 *
   2402 * Description: SHAKE256 XOF with non-incremental API
   2403 *
   2404 * Arguments:   - uint8_t *out: pointer to output
   2405 *              - size_t outlen: requested output length in bytes
   2406 *              - const uint8_t *in: pointer to input
   2407 *              - size_t inlen: length of input in bytes
   2408 **************************************************/
   2409 void
   2410 shake256(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen)
   2411 {
   2412    size_t nblocks;
   2413    keccak_state state;
   2414 
   2415    shake256_absorb_once(&state, in, inlen);
   2416    nblocks = outlen / SHAKE256_RATE;
   2417    shake256_squeezeblocks(out, nblocks, &state);
   2418    outlen -= nblocks * SHAKE256_RATE;
   2419    out += nblocks * SHAKE256_RATE;
   2420    shake256_squeeze(out, outlen, &state);
   2421 }
   2422 
   2423 /*************************************************
   2424 * Name:        sha3_256
   2425 *
   2426 * Description: SHA3-256 with non-incremental API
   2427 *
   2428 * Arguments:   - uint8_t *h: pointer to output (32 bytes)
   2429 *              - const uint8_t *in: pointer to input
   2430 *              - size_t inlen: length of input in bytes
   2431 **************************************************/
   2432 void
   2433 sha3_256(uint8_t h[32], const uint8_t *in, size_t inlen)
   2434 {
   2435    unsigned int i;
   2436    uint64_t s[25];
   2437 
   2438    keccak_absorb_once(s, SHA3_256_RATE, in, inlen, 0x06);
   2439    KeccakF1600_StatePermute(s);
   2440    for (i = 0; i < 4; i++)
   2441        store64(h + 8 * i, s[i]);
   2442 }
   2443 
   2444 /*************************************************
   2445 * Name:        sha3_512
   2446 *
   2447 * Description: SHA3-512 with non-incremental API
   2448 *
   2449 * Arguments:   - uint8_t *h: pointer to output (64 bytes)
   2450 *              - const uint8_t *in: pointer to input
   2451 *              - size_t inlen: length of input in bytes
   2452 **************************************************/
   2453 void
   2454 sha3_512(uint8_t h[64], const uint8_t *in, size_t inlen)
   2455 {
   2456    unsigned int i;
   2457    uint64_t s[25];
   2458 
   2459    keccak_absorb_once(s, SHA3_512_RATE, in, inlen, 0x06);
   2460    KeccakF1600_StatePermute(s);
   2461    for (i = 0; i < 8; i++)
   2462        store64(h + 8 * i, s[i]);
   2463 }
   2464 /** end: ref/fips202.c **/
   2465 
   2466 /** begin: ref/symmetric-shake.c **/
   2467 /*************************************************
   2468 * Name:        kyber_shake128_absorb
   2469 *
   2470 * Description: Absorb step of the SHAKE128 specialized for the Kyber context.
   2471 *
   2472 * Arguments:   - keccak_state *state: pointer to (uninitialized) output Keccak state
   2473 *              - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state
   2474 *              - uint8_t i: additional byte of input
   2475 *              - uint8_t j: additional byte of input
   2476 **************************************************/
   2477 static void
   2478 kyber_shake128_absorb(keccak_state *state,
   2479                      const uint8_t seed[KYBER_SYMBYTES],
   2480                      uint8_t x,
   2481                      uint8_t y)
   2482 {
   2483    uint8_t extseed[KYBER_SYMBYTES + 2];
   2484 
   2485    memcpy(extseed, seed, KYBER_SYMBYTES);
   2486    extseed[KYBER_SYMBYTES + 0] = x;
   2487    extseed[KYBER_SYMBYTES + 1] = y;
   2488 
   2489    shake128_absorb_once(state, extseed, sizeof(extseed));
   2490 }
   2491 
   2492 /*************************************************
   2493 * Name:        kyber_shake256_prf
   2494 *
   2495 * Description: Usage of SHAKE256 as a PRF, concatenates secret and public input
   2496 *              and then generates outlen bytes of SHAKE256 output
   2497 *
   2498 * Arguments:   - uint8_t *out: pointer to output
   2499 *              - size_t outlen: number of requested output bytes
   2500 *              - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES)
   2501 *              - uint8_t nonce: single-byte nonce (public PRF input)
   2502 **************************************************/
   2503 static void
   2504 kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce)
   2505 {
   2506    uint8_t extkey[KYBER_SYMBYTES + 1];
   2507 
   2508    memcpy(extkey, key, KYBER_SYMBYTES);
   2509    extkey[KYBER_SYMBYTES] = nonce;
   2510 
   2511    shake256(out, outlen, extkey, sizeof(extkey));
   2512 }
   2513 /** end: ref/symmetric-shake.c **/
   2514 
   2515 /** begin: ref/kem.c **/
   2516 /*************************************************
   2517 * Name:        crypto_kem_keypair_derand
   2518 *
   2519 * Description: Generates public and private key
   2520 *              for CCA-secure Kyber key encapsulation mechanism
   2521 *
   2522 * Arguments:   - uint8_t *pk: pointer to output public key
   2523 *                (an already allocated array of KYBER_PUBLICKEYBYTES bytes)
   2524 *              - uint8_t *sk: pointer to output private key
   2525 *                (an already allocated array of KYBER_SECRETKEYBYTES bytes)
   2526 *              - uint8_t *coins: pointer to input randomness
   2527 *                (an already allocated array filled with 2*KYBER_SYMBYTES random bytes)
   2528 **
   2529 * Returns 0 (success)
   2530 **************************************************/
   2531 int
   2532 crypto_kem_keypair_derand(uint8_t *pk,
   2533                          uint8_t *sk,
   2534                          const uint8_t *coins)
   2535 {
   2536    size_t i;
   2537    indcpa_keypair_derand(pk, sk, coins);
   2538    for (i = 0; i < KYBER_INDCPA_PUBLICKEYBYTES; i++)
   2539        sk[i + KYBER_INDCPA_SECRETKEYBYTES] = pk[i];
   2540    hash_h(sk + KYBER_SECRETKEYBYTES - 2 * KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES);
   2541    /* Value z for pseudo-random output on reject */
   2542    for (i = 0; i < KYBER_SYMBYTES; i++)
   2543        sk[KYBER_SECRETKEYBYTES - KYBER_SYMBYTES + i] = coins[KYBER_SYMBYTES + i];
   2544    return 0;
   2545 }
   2546 
   2547 /*************************************************
   2548 * Name:        crypto_kem_keypair
   2549 *
   2550 * Description: Generates public and private key
   2551 *              for CCA-secure Kyber key encapsulation mechanism
   2552 *
   2553 * Arguments:   - uint8_t *pk: pointer to output public key
   2554 *                (an already allocated array of KYBER_PUBLICKEYBYTES bytes)
   2555 *              - uint8_t *sk: pointer to output private key
   2556 *                (an already allocated array of KYBER_SECRETKEYBYTES bytes)
   2557 *
   2558 * Returns 0 (success)
   2559 **************************************************/
   2560 int
   2561 crypto_kem_keypair(uint8_t *pk,
   2562                   uint8_t *sk)
   2563 {
   2564    uint8_t coins[2 * KYBER_SYMBYTES];
   2565    randombytes(coins, KYBER_SYMBYTES);
   2566    randombytes(coins + KYBER_SYMBYTES, KYBER_SYMBYTES);
   2567    crypto_kem_keypair_derand(pk, sk, coins);
   2568    return 0;
   2569 }
   2570 
   2571 /*************************************************
   2572 * Name:        crypto_kem_enc_derand
   2573 *
   2574 * Description: Generates cipher text and shared
   2575 *              secret for given public key
   2576 *
   2577 * Arguments:   - uint8_t *ct: pointer to output cipher text
   2578 *                (an already allocated array of KYBER_CIPHERTEXTBYTES bytes)
   2579 *              - uint8_t *ss: pointer to output shared secret
   2580 *                (an already allocated array of KYBER_SSBYTES bytes)
   2581 *              - const uint8_t *pk: pointer to input public key
   2582 *                (an already allocated array of KYBER_PUBLICKEYBYTES bytes)
   2583 *              - const uint8_t *coins: pointer to input randomness
   2584 *                (an already allocated array filled with KYBER_SYMBYTES random bytes)
   2585 **
   2586 * Returns 0 (success)
   2587 **************************************************/
   2588 int
   2589 crypto_kem_enc_derand(uint8_t *ct,
   2590                      uint8_t *ss,
   2591                      const uint8_t *pk,
   2592                      const uint8_t *coins)
   2593 {
   2594    uint8_t buf[2 * KYBER_SYMBYTES];
   2595    /* Will contain key, coins */
   2596    uint8_t kr[2 * KYBER_SYMBYTES];
   2597 
   2598    /* Don't release system RNG output */
   2599    hash_h(buf, coins, KYBER_SYMBYTES);
   2600 
   2601    /* Multitarget countermeasure for coins + contributory KEM */
   2602    hash_h(buf + KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES);
   2603    hash_g(kr, buf, 2 * KYBER_SYMBYTES);
   2604 
   2605    /* coins are in kr+KYBER_SYMBYTES */
   2606    indcpa_enc(ct, buf, pk, kr + KYBER_SYMBYTES);
   2607 
   2608    /* overwrite coins in kr with H(c) */
   2609    hash_h(kr + KYBER_SYMBYTES, ct, KYBER_CIPHERTEXTBYTES);
   2610    /* hash concatenation of pre-k and H(c) to k */
   2611    kdf(ss, kr, 2 * KYBER_SYMBYTES);
   2612    return 0;
   2613 }
   2614 
   2615 /*************************************************
   2616 * Name:        crypto_kem_enc
   2617 *
   2618 * Description: Generates cipher text and shared
   2619 *              secret for given public key
   2620 *
   2621 * Arguments:   - uint8_t *ct: pointer to output cipher text
   2622 *                (an already allocated array of KYBER_CIPHERTEXTBYTES bytes)
   2623 *              - uint8_t *ss: pointer to output shared secret
   2624 *                (an already allocated array of KYBER_SSBYTES bytes)
   2625 *              - const uint8_t *pk: pointer to input public key
   2626 *                (an already allocated array of KYBER_PUBLICKEYBYTES bytes)
   2627 *
   2628 * Returns 0 (success)
   2629 **************************************************/
   2630 int
   2631 crypto_kem_enc(uint8_t *ct,
   2632               uint8_t *ss,
   2633               const uint8_t *pk)
   2634 {
   2635    uint8_t coins[KYBER_SYMBYTES];
   2636    randombytes(coins, KYBER_SYMBYTES);
   2637    crypto_kem_enc_derand(ct, ss, pk, coins);
   2638    return 0;
   2639 }
   2640 
   2641 /*************************************************
   2642 * Name:        crypto_kem_dec
   2643 *
   2644 * Description: Generates shared secret for given
   2645 *              cipher text and private key
   2646 *
   2647 * Arguments:   - uint8_t *ss: pointer to output shared secret
   2648 *                (an already allocated array of KYBER_SSBYTES bytes)
   2649 *              - const uint8_t *ct: pointer to input cipher text
   2650 *                (an already allocated array of KYBER_CIPHERTEXTBYTES bytes)
   2651 *              - const uint8_t *sk: pointer to input private key
   2652 *                (an already allocated array of KYBER_SECRETKEYBYTES bytes)
   2653 *
   2654 * Returns 0.
   2655 *
   2656 * On failure, ss will contain a pseudo-random value.
   2657 **************************************************/
   2658 int
   2659 crypto_kem_dec(uint8_t *ss,
   2660               const uint8_t *ct,
   2661               const uint8_t *sk)
   2662 {
   2663    size_t i;
   2664    int fail;
   2665    uint8_t buf[2 * KYBER_SYMBYTES];
   2666    /* Will contain key, coins */
   2667    uint8_t kr[2 * KYBER_SYMBYTES];
   2668    uint8_t cmp[KYBER_CIPHERTEXTBYTES];
   2669    const uint8_t *pk = sk + KYBER_INDCPA_SECRETKEYBYTES;
   2670 
   2671    indcpa_dec(buf, ct, sk);
   2672 
   2673    /* Multitarget countermeasure for coins + contributory KEM */
   2674    for (i = 0; i < KYBER_SYMBYTES; i++)
   2675        buf[KYBER_SYMBYTES + i] = sk[KYBER_SECRETKEYBYTES - 2 * KYBER_SYMBYTES + i];
   2676    hash_g(kr, buf, 2 * KYBER_SYMBYTES);
   2677 
   2678    /* coins are in kr+KYBER_SYMBYTES */
   2679    indcpa_enc(cmp, buf, pk, kr + KYBER_SYMBYTES);
   2680 
   2681    fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES);
   2682 
   2683    /* overwrite coins in kr with H(c) */
   2684    hash_h(kr + KYBER_SYMBYTES, ct, KYBER_CIPHERTEXTBYTES);
   2685 
   2686    /* Overwrite pre-k with z on re-encryption failure */
   2687    cmov(kr, sk + KYBER_SECRETKEYBYTES - KYBER_SYMBYTES, KYBER_SYMBYTES, fail);
   2688 
   2689    /* hash concatenation of pre-k and H(c) to k */
   2690    kdf(ss, kr, 2 * KYBER_SYMBYTES);
   2691    return 0;
   2692 }
   2693 /** end: ref/kem.c **/