tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

rsapkcs.c (50921B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 /*
      6 * RSA PKCS#1 v2.1 (RFC 3447) operations
      7 */
      8 
      9 #ifdef FREEBL_NO_DEPEND
     10 #include "stubs.h"
     11 #endif
     12 
     13 #include "secerr.h"
     14 
     15 #include "blapi.h"
     16 #include "secitem.h"
     17 #include "blapii.h"
     18 
     19 #define RSA_BLOCK_MIN_PAD_LEN 8
     20 #define RSA_BLOCK_FIRST_OCTET 0x00
     21 #define RSA_BLOCK_PRIVATE_PAD_OCTET 0xff
     22 #define RSA_BLOCK_AFTER_PAD_OCTET 0x00
     23 
     24 /*
     25 * RSA block types
     26 *
     27 * The values of RSA_BlockPrivate and RSA_BlockPublic are fixed.
     28 * The value of RSA_BlockRaw isn't fixed by definition, but we are keeping
     29 * the value that NSS has been using in the past.
     30 */
     31 typedef enum {
     32    RSA_BlockPrivate = 1, /* pad for a private-key operation */
     33    RSA_BlockPublic = 2,  /* pad for a public-key operation */
     34    RSA_BlockRaw = 4      /* simply justify the block appropriately */
     35 } RSA_BlockType;
     36 
     37 /* Needed for RSA-PSS functions */
     38 static const unsigned char eightZeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
     39 
     40 /* Constant time comparison of a single byte.
     41 * Returns 1 iff a == b, otherwise returns 0.
     42 * Note: For ranges of bytes, use constantTimeCompare.
     43 */
     44 static unsigned char
     45 constantTimeEQ8(unsigned char a, unsigned char b)
     46 {
     47    unsigned char c = ~((a - b) | (b - a));
     48    c >>= 7;
     49    return c;
     50 }
     51 
     52 /* Constant time comparison of a range of bytes.
     53 * Returns 1 iff len bytes of a are identical to len bytes of b, otherwise
     54 * returns 0.
     55 */
     56 static unsigned char
     57 constantTimeCompare(const unsigned char *a,
     58                    const unsigned char *b,
     59                    unsigned int len)
     60 {
     61    unsigned char tmp = 0;
     62    unsigned int i;
     63    for (i = 0; i < len; ++i, ++a, ++b)
     64        tmp |= *a ^ *b;
     65    return constantTimeEQ8(0x00, tmp);
     66 }
     67 
     68 /* Constant time conditional.
     69 * Returns a if c is 1, or b if c is 0. The result is undefined if c is
     70 * not 0 or 1.
     71 */
     72 static unsigned int
     73 constantTimeCondition(unsigned int c,
     74                      unsigned int a,
     75                      unsigned int b)
     76 {
     77    return (~(c - 1) & a) | ((c - 1) & b);
     78 }
     79 
     80 static unsigned int
     81 rsa_modulusLen(SECItem *modulus)
     82 {
     83    if (modulus->len == 0) {
     84        return 0;
     85    }
     86 
     87    unsigned char byteZero = modulus->data[0];
     88    unsigned int modLen = modulus->len - !byteZero;
     89    return modLen;
     90 }
     91 
     92 static unsigned int
     93 rsa_modulusBits(SECItem *modulus)
     94 {
     95    if (modulus->len == 0) {
     96        return 0;
     97    }
     98 
     99    unsigned char byteZero = modulus->data[0];
    100    unsigned int numBits = (modulus->len - 1) * 8;
    101 
    102    if (byteZero == 0 && modulus->len == 1) {
    103        return 0;
    104    }
    105 
    106    if (byteZero == 0) {
    107        numBits -= 8;
    108        byteZero = modulus->data[1];
    109    }
    110 
    111    while (byteZero > 0) {
    112        numBits++;
    113        byteZero >>= 1;
    114    }
    115 
    116    return numBits;
    117 }
    118 
    119 /*
    120 * Format one block of data for public/private key encryption using
    121 * the rules defined in PKCS #1.
    122 */
    123 static unsigned char *
    124 rsa_FormatOneBlock(unsigned modulusLen,
    125                   RSA_BlockType blockType,
    126                   SECItem *data)
    127 {
    128    unsigned char *block;
    129    unsigned char *bp;
    130    unsigned int padLen;
    131    unsigned int i, j;
    132    SECStatus rv;
    133 
    134    block = (unsigned char *)PORT_Alloc(modulusLen);
    135    if (block == NULL)
    136        return NULL;
    137 
    138    bp = block;
    139 
    140    /*
    141     * All RSA blocks start with two octets:
    142     *  0x00 || BlockType
    143     */
    144    *bp++ = RSA_BLOCK_FIRST_OCTET;
    145    *bp++ = (unsigned char)blockType;
    146 
    147    switch (blockType) {
    148 
    149        /*
    150         * Blocks intended for private-key operation.
    151         */
    152        case RSA_BlockPrivate: /* preferred method */
    153            /*
    154             * 0x00 || BT || Pad || 0x00 || ActualData
    155             *   1      1   padLen    1      data->len
    156             * padLen must be at least RSA_BLOCK_MIN_PAD_LEN (8) bytes.
    157             * Pad is either all 0x00 or all 0xff bytes, depending on blockType.
    158             */
    159            padLen = modulusLen - data->len - 3;
    160            PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
    161            if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
    162                PORT_ZFree(block, modulusLen);
    163                return NULL;
    164            }
    165            PORT_Memset(bp, RSA_BLOCK_PRIVATE_PAD_OCTET, padLen);
    166            bp += padLen;
    167            *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
    168            PORT_Memcpy(bp, data->data, data->len);
    169            break;
    170 
    171        /*
    172         * Blocks intended for public-key operation.
    173         */
    174        case RSA_BlockPublic:
    175            /*
    176             * 0x00 || BT || Pad || 0x00 || ActualData
    177             *   1      1   padLen    1      data->len
    178             * Pad is 8 or more non-zero random bytes.
    179             *
    180             * Build the block left to right.
    181             * Fill the entire block from Pad to the end with random bytes.
    182             * Use the bytes after Pad as a supply of extra random bytes from
    183             * which to find replacements for the zero bytes in Pad.
    184             * If we need more than that, refill the bytes after Pad with
    185             * new random bytes as necessary.
    186             */
    187 
    188            padLen = modulusLen - (data->len + 3);
    189            PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
    190            if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
    191                PORT_ZFree(block, modulusLen);
    192                return NULL;
    193            }
    194            j = modulusLen - 2;
    195            rv = RNG_GenerateGlobalRandomBytes(bp, j);
    196            if (rv == SECSuccess) {
    197                for (i = 0; i < padLen;) {
    198                    unsigned char repl;
    199                    /* Pad with non-zero random data. */
    200                    if (bp[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
    201                        ++i;
    202                        continue;
    203                    }
    204                    if (j <= padLen) {
    205                        rv = RNG_GenerateGlobalRandomBytes(bp + padLen,
    206                                                           modulusLen - (2 + padLen));
    207                        if (rv != SECSuccess)
    208                            break;
    209                        j = modulusLen - 2;
    210                    }
    211                    do {
    212                        repl = bp[--j];
    213                    } while (repl == RSA_BLOCK_AFTER_PAD_OCTET && j > padLen);
    214                    if (repl != RSA_BLOCK_AFTER_PAD_OCTET) {
    215                        bp[i++] = repl;
    216                    }
    217                }
    218            }
    219            if (rv != SECSuccess) {
    220                PORT_ZFree(block, modulusLen);
    221                PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
    222                return NULL;
    223            }
    224            bp += padLen;
    225            *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
    226            PORT_Memcpy(bp, data->data, data->len);
    227            break;
    228 
    229        default:
    230            PORT_Assert(0);
    231            PORT_ZFree(block, modulusLen);
    232            return NULL;
    233    }
    234 
    235    return block;
    236 }
    237 
    238 /* modulusLen has to be larger than RSA_BLOCK_MIN_PAD_LEN + 3, and data has to be smaller than modulus - (RSA_BLOCK_MIN_PAD_LEN + 3) */
    239 static SECStatus
    240 rsa_FormatBlock(SECItem *result,
    241                unsigned modulusLen,
    242                RSA_BlockType blockType,
    243                SECItem *data)
    244 {
    245    switch (blockType) {
    246        case RSA_BlockPrivate:
    247        case RSA_BlockPublic:
    248            /*
    249             * 0x00 || BT || Pad || 0x00 || ActualData
    250             *
    251             * The "3" below is the first octet + the second octet + the 0x00
    252             * octet that always comes just before the ActualData.
    253             */
    254            if (modulusLen < (3 + RSA_BLOCK_MIN_PAD_LEN) || data->len > (modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN))) {
    255                return SECFailure;
    256            }
    257            result->data = rsa_FormatOneBlock(modulusLen, blockType, data);
    258            if (result->data == NULL) {
    259                result->len = 0;
    260                return SECFailure;
    261            }
    262            result->len = modulusLen;
    263 
    264            break;
    265 
    266        case RSA_BlockRaw:
    267            /*
    268             * Pad || ActualData
    269             * Pad is zeros. The application is responsible for recovering
    270             * the actual data.
    271             */
    272            if (data->len > modulusLen) {
    273                return SECFailure;
    274            }
    275            result->data = (unsigned char *)PORT_ZAlloc(modulusLen);
    276            result->len = modulusLen;
    277            PORT_Memcpy(result->data + (modulusLen - data->len),
    278                        data->data, data->len);
    279            break;
    280 
    281        default:
    282            PORT_Assert(0);
    283            result->data = NULL;
    284            result->len = 0;
    285            return SECFailure;
    286    }
    287 
    288    return SECSuccess;
    289 }
    290 
    291 /*
    292 * Mask generation function MGF1 as defined in PKCS #1 v2.1 / RFC 3447.
    293 */
    294 static SECStatus
    295 MGF1(HASH_HashType hashAlg,
    296     unsigned char *mask,
    297     unsigned int maskLen,
    298     const unsigned char *mgfSeed,
    299     unsigned int mgfSeedLen)
    300 {
    301    unsigned int digestLen;
    302    PRUint32 counter;
    303    PRUint32 rounds;
    304    unsigned char *tempHash;
    305    unsigned char *temp;
    306    const SECHashObject *hash;
    307    void *hashContext;
    308    unsigned char C[4];
    309    SECStatus rv = SECSuccess;
    310 
    311    hash = HASH_GetRawHashObject(hashAlg);
    312    if (hash == NULL) {
    313        return SECFailure;
    314    }
    315 
    316    hashContext = (*hash->create)();
    317    rounds = (maskLen + hash->length - 1) / hash->length;
    318    for (counter = 0; counter < rounds; counter++) {
    319        C[0] = (unsigned char)((counter >> 24) & 0xff);
    320        C[1] = (unsigned char)((counter >> 16) & 0xff);
    321        C[2] = (unsigned char)((counter >> 8) & 0xff);
    322        C[3] = (unsigned char)(counter & 0xff);
    323 
    324        /* This could be optimized when the clone functions in
    325         * rawhash.c are implemented. */
    326        (*hash->begin)(hashContext);
    327        (*hash->update)(hashContext, mgfSeed, mgfSeedLen);
    328        (*hash->update)(hashContext, C, sizeof C);
    329 
    330        tempHash = mask + counter * hash->length;
    331        if (counter != (rounds - 1)) {
    332            (*hash->end)(hashContext, tempHash, &digestLen, hash->length);
    333        } else { /* we're in the last round and need to cut the hash */
    334            temp = (unsigned char *)PORT_Alloc(hash->length);
    335            if (!temp) {
    336                rv = SECFailure;
    337                goto done;
    338            }
    339            (*hash->end)(hashContext, temp, &digestLen, hash->length);
    340            PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length);
    341            PORT_Free(temp);
    342        }
    343    }
    344 
    345 done:
    346    (*hash->destroy)(hashContext, PR_TRUE);
    347    return rv;
    348 }
    349 
    350 /* XXX Doesn't set error code */
    351 SECStatus
    352 RSA_SignRaw(RSAPrivateKey *key,
    353            unsigned char *output,
    354            unsigned int *outputLen,
    355            unsigned int maxOutputLen,
    356            const unsigned char *data,
    357            unsigned int dataLen)
    358 {
    359    SECStatus rv = SECSuccess;
    360    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    361    SECItem formatted;
    362    SECItem unformatted;
    363 
    364    if (maxOutputLen < modulusLen)
    365        return SECFailure;
    366 
    367    unformatted.len = dataLen;
    368    unformatted.data = (unsigned char *)data;
    369    formatted.data = NULL;
    370    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
    371    if (rv != SECSuccess)
    372        goto done;
    373 
    374    rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
    375    *outputLen = modulusLen;
    376 
    377 done:
    378    if (formatted.data != NULL)
    379        PORT_ZFree(formatted.data, modulusLen);
    380    return rv;
    381 }
    382 
    383 /* XXX Doesn't set error code */
    384 SECStatus
    385 RSA_CheckSignRaw(RSAPublicKey *key,
    386                 const unsigned char *sig,
    387                 unsigned int sigLen,
    388                 const unsigned char *hash,
    389                 unsigned int hashLen)
    390 {
    391    SECStatus rv;
    392    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    393    unsigned char *buffer;
    394 
    395    if (sigLen != modulusLen)
    396        goto failure;
    397    if (hashLen > modulusLen)
    398        goto failure;
    399 
    400    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
    401    if (!buffer)
    402        goto failure;
    403 
    404    rv = RSA_PublicKeyOp(key, buffer, sig);
    405    if (rv != SECSuccess)
    406        goto loser;
    407 
    408    /*
    409     * make sure we get the same results
    410     */
    411    /* XXX(rsleevi): Constant time */
    412    /* NOTE: should we verify the leading zeros? */
    413    if (PORT_Memcmp(buffer + (modulusLen - hashLen), hash, hashLen) != 0)
    414        goto loser;
    415 
    416    PORT_Free(buffer);
    417    return SECSuccess;
    418 
    419 loser:
    420    PORT_Free(buffer);
    421 failure:
    422    return SECFailure;
    423 }
    424 
    425 /* XXX Doesn't set error code */
    426 SECStatus
    427 RSA_CheckSignRecoverRaw(RSAPublicKey *key,
    428                        unsigned char *data,
    429                        unsigned int *dataLen,
    430                        unsigned int maxDataLen,
    431                        const unsigned char *sig,
    432                        unsigned int sigLen)
    433 {
    434    SECStatus rv;
    435    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    436 
    437    if (sigLen != modulusLen)
    438        goto failure;
    439    if (maxDataLen < modulusLen)
    440        goto failure;
    441 
    442    rv = RSA_PublicKeyOp(key, data, sig);
    443    if (rv != SECSuccess)
    444        goto failure;
    445 
    446    *dataLen = modulusLen;
    447    return SECSuccess;
    448 
    449 failure:
    450    return SECFailure;
    451 }
    452 
    453 /* XXX Doesn't set error code */
    454 SECStatus
    455 RSA_EncryptRaw(RSAPublicKey *key,
    456               unsigned char *output,
    457               unsigned int *outputLen,
    458               unsigned int maxOutputLen,
    459               const unsigned char *input,
    460               unsigned int inputLen)
    461 {
    462    SECStatus rv;
    463    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    464    SECItem formatted;
    465    SECItem unformatted;
    466 
    467    formatted.data = NULL;
    468    if (maxOutputLen < modulusLen)
    469        goto failure;
    470 
    471    unformatted.len = inputLen;
    472    unformatted.data = (unsigned char *)input;
    473    formatted.data = NULL;
    474    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
    475    if (rv != SECSuccess)
    476        goto failure;
    477 
    478    rv = RSA_PublicKeyOp(key, output, formatted.data);
    479    if (rv != SECSuccess)
    480        goto failure;
    481 
    482    PORT_ZFree(formatted.data, modulusLen);
    483    *outputLen = modulusLen;
    484    return SECSuccess;
    485 
    486 failure:
    487    if (formatted.data != NULL)
    488        PORT_ZFree(formatted.data, modulusLen);
    489    return SECFailure;
    490 }
    491 
    492 /* XXX Doesn't set error code */
    493 SECStatus
    494 RSA_DecryptRaw(RSAPrivateKey *key,
    495               unsigned char *output,
    496               unsigned int *outputLen,
    497               unsigned int maxOutputLen,
    498               const unsigned char *input,
    499               unsigned int inputLen)
    500 {
    501    SECStatus rv;
    502    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    503 
    504    if (modulusLen > maxOutputLen)
    505        goto failure;
    506    if (inputLen != modulusLen)
    507        goto failure;
    508 
    509    rv = RSA_PrivateKeyOp(key, output, input);
    510    if (rv != SECSuccess)
    511        goto failure;
    512 
    513    *outputLen = modulusLen;
    514    return SECSuccess;
    515 
    516 failure:
    517    return SECFailure;
    518 }
    519 
    520 /*
    521 * Decodes an EME-OAEP encoded block, validating the encoding in constant
    522 * time.
    523 * Described in RFC 3447, section 7.1.2.
    524 * input contains the encoded block, after decryption.
    525 * label is the optional value L that was associated with the message.
    526 * On success, the original message and message length will be stored in
    527 * output and outputLen.
    528 */
    529 static SECStatus
    530 eme_oaep_decode(unsigned char *output,
    531                unsigned int *outputLen,
    532                unsigned int maxOutputLen,
    533                const unsigned char *input,
    534                unsigned int inputLen,
    535                HASH_HashType hashAlg,
    536                HASH_HashType maskHashAlg,
    537                const unsigned char *label,
    538                unsigned int labelLen)
    539 {
    540    const SECHashObject *hash;
    541    void *hashContext;
    542    SECStatus rv = SECFailure;
    543    unsigned char labelHash[HASH_LENGTH_MAX];
    544    unsigned int i;
    545    unsigned int maskLen;
    546    unsigned int paddingOffset;
    547    unsigned char *mask = NULL;
    548    unsigned char *tmpOutput = NULL;
    549    unsigned char isGood;
    550    unsigned char foundPaddingEnd;
    551 
    552    hash = HASH_GetRawHashObject(hashAlg);
    553 
    554    /* 1.c */
    555    if (inputLen < (hash->length * 2) + 2) {
    556        PORT_SetError(SEC_ERROR_INPUT_LEN);
    557        return SECFailure;
    558    }
    559 
    560    /* Step 3.a - Generate lHash */
    561    hashContext = (*hash->create)();
    562    if (hashContext == NULL) {
    563        PORT_SetError(SEC_ERROR_NO_MEMORY);
    564        return SECFailure;
    565    }
    566    (*hash->begin)(hashContext);
    567    if (labelLen > 0)
    568        (*hash->update)(hashContext, label, labelLen);
    569    (*hash->end)(hashContext, labelHash, &i, sizeof(labelHash));
    570    (*hash->destroy)(hashContext, PR_TRUE);
    571 
    572    tmpOutput = (unsigned char *)PORT_Alloc(inputLen);
    573    if (tmpOutput == NULL) {
    574        PORT_SetError(SEC_ERROR_NO_MEMORY);
    575        goto done;
    576    }
    577 
    578    maskLen = inputLen - hash->length - 1;
    579    mask = (unsigned char *)PORT_Alloc(maskLen);
    580    if (mask == NULL) {
    581        PORT_SetError(SEC_ERROR_NO_MEMORY);
    582        goto done;
    583    }
    584 
    585    PORT_Memcpy(tmpOutput, input, inputLen);
    586 
    587    /* 3.c - Generate seedMask */
    588    MGF1(maskHashAlg, mask, hash->length, &tmpOutput[1 + hash->length],
    589         inputLen - hash->length - 1);
    590    /* 3.d - Unmask seed */
    591    for (i = 0; i < hash->length; ++i)
    592        tmpOutput[1 + i] ^= mask[i];
    593 
    594    /* 3.e - Generate dbMask */
    595    MGF1(maskHashAlg, mask, maskLen, &tmpOutput[1], hash->length);
    596    /* 3.f - Unmask DB */
    597    for (i = 0; i < maskLen; ++i)
    598        tmpOutput[1 + hash->length + i] ^= mask[i];
    599 
    600    /* 3.g - Compare Y, lHash, and PS in constant time
    601     * Warning: This code is timing dependent and must not disclose which of
    602     * these were invalid.
    603     */
    604    paddingOffset = 0;
    605    isGood = 1;
    606    foundPaddingEnd = 0;
    607 
    608    /* Compare Y */
    609    isGood &= constantTimeEQ8(0x00, tmpOutput[0]);
    610 
    611    /* Compare lHash and lHash' */
    612    isGood &= constantTimeCompare(&labelHash[0],
    613                                  &tmpOutput[1 + hash->length],
    614                                  hash->length);
    615 
    616    /* Compare that the padding is zero or more zero octets, followed by a
    617     * 0x01 octet */
    618    for (i = 1 + (hash->length * 2); i < inputLen; ++i) {
    619        unsigned char isZero = constantTimeEQ8(0x00, tmpOutput[i]);
    620        unsigned char isOne = constantTimeEQ8(0x01, tmpOutput[i]);
    621        /* non-constant time equivalent:
    622         * if (tmpOutput[i] == 0x01 && !foundPaddingEnd)
    623         *     paddingOffset = i;
    624         */
    625        paddingOffset = constantTimeCondition(isOne & ~foundPaddingEnd, i,
    626                                              paddingOffset);
    627        /* non-constant time equivalent:
    628         * if (tmpOutput[i] == 0x01)
    629         *    foundPaddingEnd = true;
    630         *
    631         * Note: This may yield false positives, as it will be set whenever
    632         * a 0x01 byte is encountered. If there was bad padding (eg:
    633         * 0x03 0x02 0x01), foundPaddingEnd will still be set to true, and
    634         * paddingOffset will still be set to 2.
    635         */
    636        foundPaddingEnd = constantTimeCondition(isOne, 1, foundPaddingEnd);
    637        /* non-constant time equivalent:
    638         * if (tmpOutput[i] != 0x00 && tmpOutput[i] != 0x01 &&
    639         *     !foundPaddingEnd) {
    640         *    isGood = false;
    641         * }
    642         *
    643         * Note: This may yield false positives, as a message (and padding)
    644         * that is entirely zeros will result in isGood still being true. Thus
    645         * it's necessary to check foundPaddingEnd is positive below.
    646         */
    647        isGood = constantTimeCondition(~foundPaddingEnd & ~isZero, 0, isGood);
    648    }
    649 
    650    /* While both isGood and foundPaddingEnd may have false positives, they
    651     * cannot BOTH have false positives. If both are not true, then an invalid
    652     * message was received. Note, this comparison must still be done in constant
    653     * time so as not to leak either condition.
    654     */
    655    if (!(isGood & foundPaddingEnd)) {
    656        PORT_SetError(SEC_ERROR_BAD_DATA);
    657        goto done;
    658    }
    659 
    660    /* End timing dependent code */
    661 
    662    ++paddingOffset; /* Skip the 0x01 following the end of PS */
    663 
    664    *outputLen = inputLen - paddingOffset;
    665    if (*outputLen > maxOutputLen) {
    666        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
    667        goto done;
    668    }
    669 
    670    if (*outputLen)
    671        PORT_Memcpy(output, &tmpOutput[paddingOffset], *outputLen);
    672    rv = SECSuccess;
    673 
    674 done:
    675    if (mask)
    676        PORT_ZFree(mask, maskLen);
    677    if (tmpOutput)
    678        PORT_ZFree(tmpOutput, inputLen);
    679    return rv;
    680 }
    681 
    682 /*
    683 * Generate an EME-OAEP encoded block for encryption
    684 * Described in RFC 3447, section 7.1.1
    685 * We use input instead of M for the message to be encrypted
    686 * label is the optional value L to be associated with the message.
    687 */
    688 static SECStatus
    689 eme_oaep_encode(unsigned char *em,
    690                unsigned int emLen,
    691                const unsigned char *input,
    692                unsigned int inputLen,
    693                HASH_HashType hashAlg,
    694                HASH_HashType maskHashAlg,
    695                const unsigned char *label,
    696                unsigned int labelLen,
    697                const unsigned char *seed,
    698                unsigned int seedLen)
    699 {
    700    const SECHashObject *hash;
    701    void *hashContext;
    702    SECStatus rv;
    703    unsigned char *mask;
    704    unsigned int reservedLen;
    705    unsigned int dbMaskLen;
    706    unsigned int i;
    707 
    708    hash = HASH_GetRawHashObject(hashAlg);
    709    PORT_Assert(seed == NULL || seedLen == hash->length);
    710 
    711    /* Step 1.b */
    712    reservedLen = (2 * hash->length) + 2;
    713    if (emLen < reservedLen || inputLen > (emLen - reservedLen)) {
    714        PORT_SetError(SEC_ERROR_INPUT_LEN);
    715        return SECFailure;
    716    }
    717 
    718    /*
    719     * From RFC 3447, Section 7.1
    720     *                      +----------+---------+-------+
    721     *                 DB = |  lHash   |    PS   |   M   |
    722     *                      +----------+---------+-------+
    723     *                                     |
    724     *           +----------+              V
    725     *           |   seed   |--> MGF ---> xor
    726     *           +----------+              |
    727     *                 |                   |
    728     *        +--+     V                   |
    729     *        |00|    xor <----- MGF <-----|
    730     *        +--+     |                   |
    731     *          |      |                   |
    732     *          V      V                   V
    733     *        +--+----------+----------------------------+
    734     *  EM =  |00|maskedSeed|          maskedDB          |
    735     *        +--+----------+----------------------------+
    736     *
    737     * We use mask to hold the result of the MGF functions, and all other
    738     * values are generated in their final resting place.
    739     */
    740    *em = 0x00;
    741 
    742    /* Step 2.a - Generate lHash */
    743    hashContext = (*hash->create)();
    744    if (hashContext == NULL) {
    745        PORT_SetError(SEC_ERROR_NO_MEMORY);
    746        return SECFailure;
    747    }
    748    (*hash->begin)(hashContext);
    749    if (labelLen > 0)
    750        (*hash->update)(hashContext, label, labelLen);
    751    (*hash->end)(hashContext, &em[1 + hash->length], &i, hash->length);
    752    (*hash->destroy)(hashContext, PR_TRUE);
    753 
    754    /* Step 2.b - Generate PS */
    755    if (emLen - reservedLen - inputLen > 0) {
    756        PORT_Memset(em + 1 + (hash->length * 2), 0x00,
    757                    emLen - reservedLen - inputLen);
    758    }
    759 
    760    /* Step 2.c. - Generate DB
    761     * DB = lHash || PS || 0x01 || M
    762     * Note that PS and lHash have already been placed into em at their
    763     * appropriate offsets. This just copies M into place
    764     */
    765    em[emLen - inputLen - 1] = 0x01;
    766    if (inputLen)
    767        PORT_Memcpy(em + emLen - inputLen, input, inputLen);
    768 
    769    if (seed == NULL) {
    770        /* Step 2.d - Generate seed */
    771        rv = RNG_GenerateGlobalRandomBytes(em + 1, hash->length);
    772        if (rv != SECSuccess) {
    773            return rv;
    774        }
    775    } else {
    776        /* For Known Answer Tests, copy the supplied seed. */
    777        PORT_Memcpy(em + 1, seed, seedLen);
    778    }
    779 
    780    /* Step 2.e - Generate dbMask*/
    781    dbMaskLen = emLen - hash->length - 1;
    782    mask = (unsigned char *)PORT_Alloc(dbMaskLen);
    783    if (mask == NULL) {
    784        PORT_SetError(SEC_ERROR_NO_MEMORY);
    785        return SECFailure;
    786    }
    787    MGF1(maskHashAlg, mask, dbMaskLen, em + 1, hash->length);
    788    /* Step 2.f - Compute maskedDB*/
    789    for (i = 0; i < dbMaskLen; ++i)
    790        em[1 + hash->length + i] ^= mask[i];
    791 
    792    /* Step 2.g - Generate seedMask */
    793    MGF1(maskHashAlg, mask, hash->length, &em[1 + hash->length], dbMaskLen);
    794    /* Step 2.h - Compute maskedSeed */
    795    for (i = 0; i < hash->length; ++i)
    796        em[1 + i] ^= mask[i];
    797 
    798    PORT_ZFree(mask, dbMaskLen);
    799    return SECSuccess;
    800 }
    801 
    802 SECStatus
    803 RSA_EncryptOAEP(RSAPublicKey *key,
    804                HASH_HashType hashAlg,
    805                HASH_HashType maskHashAlg,
    806                const unsigned char *label,
    807                unsigned int labelLen,
    808                const unsigned char *seed,
    809                unsigned int seedLen,
    810                unsigned char *output,
    811                unsigned int *outputLen,
    812                unsigned int maxOutputLen,
    813                const unsigned char *input,
    814                unsigned int inputLen)
    815 {
    816    SECStatus rv = SECFailure;
    817    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    818    unsigned char *oaepEncoded = NULL;
    819 
    820    if (maxOutputLen < modulusLen) {
    821        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
    822        return SECFailure;
    823    }
    824 
    825    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
    826        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
    827        return SECFailure;
    828    }
    829 
    830    if ((labelLen == 0 && label != NULL) ||
    831        (labelLen > 0 && label == NULL)) {
    832        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
    833        return SECFailure;
    834    }
    835 
    836    oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
    837    if (oaepEncoded == NULL) {
    838        PORT_SetError(SEC_ERROR_NO_MEMORY);
    839        return SECFailure;
    840    }
    841    rv = eme_oaep_encode(oaepEncoded, modulusLen, input, inputLen,
    842                         hashAlg, maskHashAlg, label, labelLen, seed, seedLen);
    843    if (rv != SECSuccess)
    844        goto done;
    845 
    846    rv = RSA_PublicKeyOp(key, output, oaepEncoded);
    847    if (rv != SECSuccess)
    848        goto done;
    849    *outputLen = modulusLen;
    850 
    851 done:
    852    PORT_Free(oaepEncoded);
    853    return rv;
    854 }
    855 
    856 SECStatus
    857 RSA_DecryptOAEP(RSAPrivateKey *key,
    858                HASH_HashType hashAlg,
    859                HASH_HashType maskHashAlg,
    860                const unsigned char *label,
    861                unsigned int labelLen,
    862                unsigned char *output,
    863                unsigned int *outputLen,
    864                unsigned int maxOutputLen,
    865                const unsigned char *input,
    866                unsigned int inputLen)
    867 {
    868    SECStatus rv = SECFailure;
    869    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    870    unsigned char *oaepEncoded = NULL;
    871 
    872    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
    873        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
    874        return SECFailure;
    875    }
    876 
    877    if (inputLen != modulusLen) {
    878        PORT_SetError(SEC_ERROR_INPUT_LEN);
    879        return SECFailure;
    880    }
    881 
    882    if ((labelLen == 0 && label != NULL) ||
    883        (labelLen > 0 && label == NULL)) {
    884        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
    885        return SECFailure;
    886    }
    887 
    888    oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
    889    if (oaepEncoded == NULL) {
    890        PORT_SetError(SEC_ERROR_NO_MEMORY);
    891        return SECFailure;
    892    }
    893 
    894    rv = RSA_PrivateKeyOpDoubleChecked(key, oaepEncoded, input);
    895    if (rv != SECSuccess) {
    896        goto done;
    897    }
    898    rv = eme_oaep_decode(output, outputLen, maxOutputLen, oaepEncoded,
    899                         modulusLen, hashAlg, maskHashAlg, label,
    900                         labelLen);
    901 
    902 done:
    903    if (oaepEncoded)
    904        PORT_ZFree(oaepEncoded, modulusLen);
    905    return rv;
    906 }
    907 
    908 /* XXX Doesn't set error code */
    909 SECStatus
    910 RSA_EncryptBlock(RSAPublicKey *key,
    911                 unsigned char *output,
    912                 unsigned int *outputLen,
    913                 unsigned int maxOutputLen,
    914                 const unsigned char *input,
    915                 unsigned int inputLen)
    916 {
    917    SECStatus rv;
    918    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
    919    SECItem formatted;
    920    SECItem unformatted;
    921 
    922    formatted.data = NULL;
    923    if (maxOutputLen < modulusLen)
    924        goto failure;
    925 
    926    unformatted.len = inputLen;
    927    unformatted.data = (unsigned char *)input;
    928    formatted.data = NULL;
    929    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPublic,
    930                         &unformatted);
    931    if (rv != SECSuccess)
    932        goto failure;
    933 
    934    rv = RSA_PublicKeyOp(key, output, formatted.data);
    935    if (rv != SECSuccess)
    936        goto failure;
    937 
    938    PORT_ZFree(formatted.data, modulusLen);
    939    *outputLen = modulusLen;
    940    return SECSuccess;
    941 
    942 failure:
    943    if (formatted.data != NULL)
    944        PORT_ZFree(formatted.data, modulusLen);
    945    return SECFailure;
    946 }
    947 
    948 static HMACContext *
    949 rsa_GetHMACContext(const SECHashObject *hash, RSAPrivateKey *key,
    950                   const unsigned char *input, unsigned int inputLen)
    951 {
    952    unsigned char keyHash[HASH_LENGTH_MAX];
    953    void *hashContext;
    954    HMACContext *hmac = NULL;
    955    unsigned int privKeyLen = key->privateExponent.len;
    956    unsigned int keyLen;
    957    SECStatus rv;
    958 
    959    /* first get the key hash (should store in the key structure) */
    960    PORT_Memset(keyHash, 0, sizeof(keyHash));
    961    hashContext = (*hash->create)();
    962    if (hashContext == NULL) {
    963        return NULL;
    964    }
    965    (*hash->begin)(hashContext);
    966    if (privKeyLen < inputLen) {
    967        int padLen = inputLen - privKeyLen;
    968        while (padLen > sizeof(keyHash)) {
    969            (*hash->update)(hashContext, keyHash, sizeof(keyHash));
    970            padLen -= sizeof(keyHash);
    971        }
    972        (*hash->update)(hashContext, keyHash, padLen);
    973    }
    974    (*hash->update)(hashContext, key->privateExponent.data, privKeyLen);
    975    (*hash->end)(hashContext, keyHash, &keyLen, sizeof(keyHash));
    976    (*hash->destroy)(hashContext, PR_TRUE);
    977 
    978    /* now create the hmac key */
    979    hmac = HMAC_Create(hash, keyHash, keyLen, PR_TRUE);
    980    if (hmac == NULL) {
    981        PORT_SafeZero(keyHash, sizeof(keyHash));
    982        return NULL;
    983    }
    984    HMAC_Begin(hmac);
    985    HMAC_Update(hmac, input, inputLen);
    986    rv = HMAC_Finish(hmac, keyHash, &keyLen, sizeof(keyHash));
    987    if (rv != SECSuccess) {
    988        PORT_SafeZero(keyHash, sizeof(keyHash));
    989        HMAC_Destroy(hmac, PR_TRUE);
    990        return NULL;
    991    }
    992    /* Finally set the new key into the hash context. We
    993     * reuse the original context allocated above so we don't
    994     * need to allocate and free another one */
    995    rv = HMAC_ReInit(hmac, hash, keyHash, keyLen, PR_TRUE);
    996    PORT_SafeZero(keyHash, sizeof(keyHash));
    997    if (rv != SECSuccess) {
    998        HMAC_Destroy(hmac, PR_TRUE);
    999        return NULL;
   1000    }
   1001 
   1002    return hmac;
   1003 }
   1004 
   1005 static SECStatus
   1006 rsa_HMACPrf(HMACContext *hmac, const char *label, int labelLen,
   1007            int hashLength, unsigned char *output, int length)
   1008 {
   1009    unsigned char iterator[2] = { 0, 0 };
   1010    unsigned char encodedLen[2] = { 0, 0 };
   1011    unsigned char hmacLast[HASH_LENGTH_MAX];
   1012    unsigned int left = length;
   1013    unsigned int hashReturn;
   1014    SECStatus rv = SECSuccess;
   1015 
   1016    /* encodedLen is in bits, length is in bytes, thus the shifts
   1017     * do an implied multiply by 8 */
   1018    encodedLen[0] = (length >> 5) & 0xff;
   1019    encodedLen[1] = (length << 3) & 0xff;
   1020 
   1021    while (left > hashLength) {
   1022        HMAC_Begin(hmac);
   1023        HMAC_Update(hmac, iterator, 2);
   1024        HMAC_Update(hmac, (const unsigned char *)label, labelLen);
   1025        HMAC_Update(hmac, encodedLen, 2);
   1026        rv = HMAC_Finish(hmac, output, &hashReturn, hashLength);
   1027        if (rv != SECSuccess) {
   1028            return rv;
   1029        }
   1030        iterator[1]++;
   1031        if (iterator[1] == 0)
   1032            iterator[0]++;
   1033        left -= hashLength;
   1034        output += hashLength;
   1035    }
   1036    if (left) {
   1037        HMAC_Begin(hmac);
   1038        HMAC_Update(hmac, iterator, 2);
   1039        HMAC_Update(hmac, (const unsigned char *)label, labelLen);
   1040        HMAC_Update(hmac, encodedLen, 2);
   1041        rv = HMAC_Finish(hmac, hmacLast, &hashReturn, sizeof(hmacLast));
   1042        if (rv != SECSuccess) {
   1043            return rv;
   1044        }
   1045        PORT_Memcpy(output, hmacLast, left);
   1046        PORT_SafeZero(hmacLast, sizeof(hmacLast));
   1047    }
   1048    return rv;
   1049 }
   1050 
   1051 /* This function takes a 16-bit input number and
   1052 * creates the smallest mask which covers
   1053 * the whole number. Examples:
   1054 *     0x81 -> 0xff
   1055 *     0x1af -> 0x1ff
   1056 *     0x4d1 -> 0x7ff
   1057 */
   1058 static int
   1059 makeMask16(int len)
   1060 {
   1061    // or the high bit in each bit location
   1062    len |= (len >> 1);
   1063    len |= (len >> 2);
   1064    len |= (len >> 4);
   1065    len |= (len >> 8);
   1066    return len;
   1067 }
   1068 
   1069 #define STRING_AND_LENGTH(s) s, sizeof(s) - 1
   1070 static int
   1071 rsa_GetErrorLength(HMACContext *hmac, int hashLen, int maxLegalLen)
   1072 {
   1073    unsigned char out[128 * 2];
   1074    unsigned char *outp;
   1075    int outLength = 0;
   1076    int lengthMask;
   1077    SECStatus rv;
   1078 
   1079    lengthMask = makeMask16(maxLegalLen);
   1080    rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("length"), hashLen,
   1081                     out, sizeof(out));
   1082    if (rv != SECSuccess) {
   1083        return -1;
   1084    }
   1085    for (outp = out; outp < out + sizeof(out); outp += 2) {
   1086        int candidate = outp[0] << 8 | outp[1];
   1087        candidate = candidate & lengthMask;
   1088        outLength = PORT_CT_SEL(PORT_CT_LT(candidate, maxLegalLen),
   1089                                candidate, outLength);
   1090    }
   1091    PORT_SafeZero(out, sizeof(out));
   1092    return outLength;
   1093 }
   1094 
   1095 /*
   1096 * This function can only fail in environmental cases: Programming errors
   1097 * and out of memory situations. It can't fail if the keys are valid and
   1098 * the inputs are the proper size. If the actual RSA decryption fails, a
   1099 * fake value and a fake length, both of which have already been generated
   1100 * based on the key and input, are returned.
   1101 * Applications are expected to detect decryption failures based on the fact
   1102 * that the decrypted value (usually a key) doesn't validate. The prevents
   1103 * Blecheinbaucher style attacks against the key. */
   1104 SECStatus
   1105 RSA_DecryptBlock(RSAPrivateKey *key,
   1106                 unsigned char *output,
   1107                 unsigned int *outputLen,
   1108                 unsigned int maxOutputLen,
   1109                 const unsigned char *input,
   1110                 unsigned int inputLen)
   1111 {
   1112    SECStatus rv;
   1113    PRUint32 fail;
   1114    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1115    unsigned int i;
   1116    unsigned char *buffer = NULL;
   1117    unsigned char *errorBuffer = NULL;
   1118    unsigned char *bp = NULL;
   1119    unsigned char *ep = NULL;
   1120    unsigned int outLen = modulusLen;
   1121    unsigned int maxLegalLen = modulusLen - 10;
   1122    unsigned int errorLength;
   1123    const SECHashObject *hashObj;
   1124    HMACContext *hmac = NULL;
   1125 
   1126    /* failures in the top section indicate failures in the environment
   1127     * (memory) or the library. OK to return errors in these cases because
   1128     * it doesn't provide any oracle information to attackers. */
   1129    if (inputLen != modulusLen || modulusLen < 10) {
   1130        PORT_SetError(SEC_ERROR_INVALID_ARGS);
   1131        return SECFailure;
   1132    }
   1133 
   1134    /* Allocate enough space to decrypt */
   1135    buffer = PORT_ZAlloc(modulusLen);
   1136    if (!buffer) {
   1137        goto loser;
   1138    }
   1139    errorBuffer = PORT_ZAlloc(modulusLen);
   1140    if (!errorBuffer) {
   1141        goto loser;
   1142    }
   1143    hashObj = HASH_GetRawHashObject(HASH_AlgSHA256);
   1144    if (hashObj == NULL) {
   1145        goto loser;
   1146    }
   1147 
   1148    /* calculate the values to return in the error case rather than
   1149     * the actual returned values. This data is the same for the
   1150     * same input and private key. */
   1151    hmac = rsa_GetHMACContext(hashObj, key, input, inputLen);
   1152    if (hmac == NULL) {
   1153        goto loser;
   1154    }
   1155    errorLength = rsa_GetErrorLength(hmac, hashObj->length, maxLegalLen);
   1156    if (((int)errorLength) < 0) {
   1157        goto loser;
   1158    }
   1159    /* we always have to generate a full moduluslen error string. Otherwise
   1160     * we create a timing dependency on errorLength, which could be used to
   1161     * determine the difference between errorLength and outputLen and tell
   1162     * us that there was a pkcs1 decryption failure */
   1163    rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("message"),
   1164                     hashObj->length, errorBuffer, modulusLen);
   1165    if (rv != SECSuccess) {
   1166        goto loser;
   1167    }
   1168 
   1169    HMAC_Destroy(hmac, PR_TRUE);
   1170    hmac = NULL;
   1171 
   1172    /* From here on out, we will always return success. If there is
   1173     * an error, we will return deterministic output based on the key
   1174     * and the input data. */
   1175    rv = RSA_PrivateKeyOp(key, buffer, input);
   1176 
   1177    fail = PORT_CT_NE(rv, SECSuccess);
   1178    fail |= PORT_CT_NE(buffer[0], RSA_BLOCK_FIRST_OCTET) | PORT_CT_NE(buffer[1], RSA_BlockPublic);
   1179 
   1180    /* There have to be at least 8 bytes of padding. */
   1181    for (i = 2; i < 10; i++) {
   1182        fail |= PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET);
   1183    }
   1184 
   1185    for (i = 10; i < modulusLen; i++) {
   1186        unsigned int newLen = modulusLen - i - 1;
   1187        PRUint32 condition = PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET) & PORT_CT_EQ(outLen, modulusLen);
   1188        outLen = PORT_CT_SEL(condition, newLen, outLen);
   1189    }
   1190    // this can only happen if a zero wasn't found above
   1191    fail |= PORT_CT_GE(outLen, modulusLen);
   1192 
   1193    outLen = PORT_CT_SEL(fail, errorLength, outLen);
   1194 
   1195    /* index into the correct buffer. Do it before we truncate outLen if the
   1196     * application was asking for less data than we can return */
   1197    bp = buffer + modulusLen - outLen;
   1198    ep = errorBuffer + modulusLen - outLen;
   1199 
   1200    /* at this point, outLen returns no information about decryption failures,
   1201     * no need to hide its value. maxOutputLen is how much data the
   1202     * application is expecting, which is also not sensitive. */
   1203    if (outLen > maxOutputLen) {
   1204        outLen = maxOutputLen;
   1205    }
   1206 
   1207    /* we can't use PORT_Memcpy because caching could create a time dependency
   1208     * on the status of fail. */
   1209    for (i = 0; i < outLen; i++) {
   1210        output[i] = PORT_CT_SEL(fail, ep[i], bp[i]);
   1211    }
   1212 
   1213    *outputLen = outLen;
   1214 
   1215    PORT_Free(buffer);
   1216    PORT_Free(errorBuffer);
   1217 
   1218    return SECSuccess;
   1219 
   1220 loser:
   1221    if (hmac) {
   1222        HMAC_Destroy(hmac, PR_TRUE);
   1223    }
   1224    PORT_Free(buffer);
   1225    PORT_Free(errorBuffer);
   1226 
   1227    return SECFailure;
   1228 }
   1229 
   1230 /*
   1231 * Encode a RSA-PSS signature.
   1232 * Described in RFC 3447, section 9.1.1.
   1233 * We use mHash instead of M as input.
   1234 * emBits from the RFC is just modBits - 1, see section 8.1.1.
   1235 * We only support MGF1 as the MGF.
   1236 */
   1237 SECStatus
   1238 RSA_EMSAEncodePSS(unsigned char *em,
   1239                  unsigned int emLen,
   1240                  unsigned int emBits,
   1241                  const unsigned char *mHash,
   1242                  HASH_HashType hashAlg,
   1243                  HASH_HashType maskHashAlg,
   1244                  const unsigned char *salt,
   1245                  unsigned int saltLen)
   1246 {
   1247    const SECHashObject *hash;
   1248    void *hash_context;
   1249    unsigned char *dbMask;
   1250    unsigned int dbMaskLen;
   1251    unsigned int i;
   1252    SECStatus rv;
   1253 
   1254    hash = HASH_GetRawHashObject(hashAlg);
   1255    dbMaskLen = emLen - hash->length - 1;
   1256 
   1257    /* Step 3 */
   1258    if (emLen < hash->length + saltLen + 2) {
   1259        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
   1260        return SECFailure;
   1261    }
   1262 
   1263    /* Step 4 */
   1264    if (salt == NULL) {
   1265        rv = RNG_GenerateGlobalRandomBytes(&em[dbMaskLen - saltLen], saltLen);
   1266        if (rv != SECSuccess) {
   1267            return rv;
   1268        }
   1269    } else {
   1270        PORT_Memcpy(&em[dbMaskLen - saltLen], salt, saltLen);
   1271    }
   1272 
   1273    /* Step 5 + 6 */
   1274    /* Compute H and store it at its final location &em[dbMaskLen]. */
   1275    hash_context = (*hash->create)();
   1276    if (hash_context == NULL) {
   1277        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1278        return SECFailure;
   1279    }
   1280    (*hash->begin)(hash_context);
   1281    (*hash->update)(hash_context, eightZeros, 8);
   1282    (*hash->update)(hash_context, mHash, hash->length);
   1283    (*hash->update)(hash_context, &em[dbMaskLen - saltLen], saltLen);
   1284    (*hash->end)(hash_context, &em[dbMaskLen], &i, hash->length);
   1285    (*hash->destroy)(hash_context, PR_TRUE);
   1286 
   1287    /* Step 7 + 8 */
   1288    PORT_Memset(em, 0, dbMaskLen - saltLen - 1);
   1289    em[dbMaskLen - saltLen - 1] = 0x01;
   1290 
   1291    /* Step 9 */
   1292    dbMask = (unsigned char *)PORT_Alloc(dbMaskLen);
   1293    if (dbMask == NULL) {
   1294        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1295        return SECFailure;
   1296    }
   1297    MGF1(maskHashAlg, dbMask, dbMaskLen, &em[dbMaskLen], hash->length);
   1298 
   1299    /* Step 10 */
   1300    for (i = 0; i < dbMaskLen; i++)
   1301        em[i] ^= dbMask[i];
   1302    PORT_Free(dbMask);
   1303 
   1304    /* Step 11 */
   1305    em[0] &= 0xff >> (8 * emLen - emBits);
   1306 
   1307    /* Step 12 */
   1308    em[emLen - 1] = 0xbc;
   1309 
   1310    return SECSuccess;
   1311 }
   1312 
   1313 /*
   1314 * Verify a RSA-PSS signature.
   1315 * Described in RFC 3447, section 9.1.2.
   1316 * We use mHash instead of M as input.
   1317 * emBits from the RFC is just modBits - 1, see section 8.1.2.
   1318 * We only support MGF1 as the MGF.
   1319 */
   1320 static SECStatus
   1321 emsa_pss_verify(const unsigned char *mHash,
   1322                const unsigned char *em,
   1323                unsigned int emLen,
   1324                unsigned int emBits,
   1325                HASH_HashType hashAlg,
   1326                HASH_HashType maskHashAlg,
   1327                unsigned int saltLen)
   1328 {
   1329    const SECHashObject *hash;
   1330    void *hash_context;
   1331    unsigned char *db;
   1332    unsigned char *H_; /* H' from the RFC */
   1333    unsigned int i;
   1334    unsigned int dbMaskLen;
   1335    unsigned int zeroBits;
   1336    SECStatus rv;
   1337 
   1338    hash = HASH_GetRawHashObject(hashAlg);
   1339    dbMaskLen = emLen - hash->length - 1;
   1340 
   1341    /* Step 3 + 4 */
   1342    if ((emLen < (hash->length + saltLen + 2)) ||
   1343        (em[emLen - 1] != 0xbc)) {
   1344        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1345        return SECFailure;
   1346    }
   1347 
   1348    /* Step 6 */
   1349    zeroBits = 8 * emLen - emBits;
   1350    if (em[0] >> (8 - zeroBits)) {
   1351        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1352        return SECFailure;
   1353    }
   1354 
   1355    /* Step 7 */
   1356    db = (unsigned char *)PORT_Alloc(dbMaskLen);
   1357    if (db == NULL) {
   1358        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1359        return SECFailure;
   1360    }
   1361    /* &em[dbMaskLen] points to H, used as mgfSeed */
   1362    MGF1(maskHashAlg, db, dbMaskLen, &em[dbMaskLen], hash->length);
   1363 
   1364    /* Step 8 */
   1365    for (i = 0; i < dbMaskLen; i++) {
   1366        db[i] ^= em[i];
   1367    }
   1368 
   1369    /* Step 9 */
   1370    db[0] &= 0xff >> zeroBits;
   1371 
   1372    /* Step 10 */
   1373    for (i = 0; i < (dbMaskLen - saltLen - 1); i++) {
   1374        if (db[i] != 0) {
   1375            PORT_Free(db);
   1376            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1377            return SECFailure;
   1378        }
   1379    }
   1380    if (db[dbMaskLen - saltLen - 1] != 0x01) {
   1381        PORT_Free(db);
   1382        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1383        return SECFailure;
   1384    }
   1385 
   1386    /* Step 12 + 13 */
   1387    H_ = (unsigned char *)PORT_Alloc(hash->length);
   1388    if (H_ == NULL) {
   1389        PORT_Free(db);
   1390        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1391        return SECFailure;
   1392    }
   1393    hash_context = (*hash->create)();
   1394    if (hash_context == NULL) {
   1395        PORT_Free(db);
   1396        PORT_Free(H_);
   1397        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1398        return SECFailure;
   1399    }
   1400    (*hash->begin)(hash_context);
   1401    (*hash->update)(hash_context, eightZeros, 8);
   1402    (*hash->update)(hash_context, mHash, hash->length);
   1403    (*hash->update)(hash_context, &db[dbMaskLen - saltLen], saltLen);
   1404    (*hash->end)(hash_context, H_, &i, hash->length);
   1405    (*hash->destroy)(hash_context, PR_TRUE);
   1406 
   1407    PORT_Free(db);
   1408 
   1409    /* Step 14 */
   1410    if (PORT_Memcmp(H_, &em[dbMaskLen], hash->length) != 0) {
   1411        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1412        rv = SECFailure;
   1413    } else {
   1414        rv = SECSuccess;
   1415    }
   1416 
   1417    PORT_Free(H_);
   1418    return rv;
   1419 }
   1420 
   1421 SECStatus
   1422 RSA_SignPSS(RSAPrivateKey *key,
   1423            HASH_HashType hashAlg,
   1424            HASH_HashType maskHashAlg,
   1425            const unsigned char *salt,
   1426            unsigned int saltLength,
   1427            unsigned char *output,
   1428            unsigned int *outputLen,
   1429            unsigned int maxOutputLen,
   1430            const unsigned char *input,
   1431            unsigned int inputLen)
   1432 {
   1433    SECStatus rv = SECSuccess;
   1434    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1435    unsigned int modulusBits = rsa_modulusBits(&key->modulus);
   1436    unsigned int emLen = modulusLen;
   1437    unsigned char *pssEncoded, *em;
   1438 
   1439    if (maxOutputLen < modulusLen) {
   1440        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
   1441        return SECFailure;
   1442    }
   1443 
   1444    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
   1445        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
   1446        return SECFailure;
   1447    }
   1448 
   1449    pssEncoded = em = (unsigned char *)PORT_Alloc(modulusLen);
   1450    if (pssEncoded == NULL) {
   1451        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1452        return SECFailure;
   1453    }
   1454 
   1455    /* len(em) == ceil((modulusBits - 1) / 8). */
   1456    if (modulusBits % 8 == 1) {
   1457        em[0] = 0;
   1458        emLen--;
   1459        em++;
   1460    }
   1461    rv = RSA_EMSAEncodePSS(em, emLen, modulusBits - 1, input, hashAlg,
   1462                           maskHashAlg, salt, saltLength);
   1463    if (rv != SECSuccess)
   1464        goto done;
   1465 
   1466    // This sets error codes upon failure.
   1467    rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded);
   1468    *outputLen = modulusLen;
   1469 
   1470 done:
   1471    PORT_Free(pssEncoded);
   1472    return rv;
   1473 }
   1474 
   1475 SECStatus
   1476 RSA_CheckSignPSS(RSAPublicKey *key,
   1477                 HASH_HashType hashAlg,
   1478                 HASH_HashType maskHashAlg,
   1479                 unsigned int saltLength,
   1480                 const unsigned char *sig,
   1481                 unsigned int sigLen,
   1482                 const unsigned char *hash,
   1483                 unsigned int hashLen)
   1484 {
   1485    SECStatus rv;
   1486    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1487    unsigned int modulusBits = rsa_modulusBits(&key->modulus);
   1488    unsigned int emLen = modulusLen;
   1489    unsigned char *buffer, *em;
   1490 
   1491    if (sigLen != modulusLen) {
   1492        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1493        return SECFailure;
   1494    }
   1495 
   1496    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
   1497        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
   1498        return SECFailure;
   1499    }
   1500 
   1501    buffer = em = (unsigned char *)PORT_Alloc(modulusLen);
   1502    if (!buffer) {
   1503        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1504        return SECFailure;
   1505    }
   1506 
   1507    rv = RSA_PublicKeyOp(key, buffer, sig);
   1508    if (rv != SECSuccess) {
   1509        PORT_Free(buffer);
   1510        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1511        return SECFailure;
   1512    }
   1513 
   1514    /* len(em) == ceil((modulusBits - 1) / 8). */
   1515    if (modulusBits % 8 == 1) {
   1516        emLen--;
   1517        em++;
   1518    }
   1519    rv = emsa_pss_verify(hash, em, emLen, modulusBits - 1, hashAlg,
   1520                         maskHashAlg, saltLength);
   1521 
   1522    PORT_Free(buffer);
   1523    return rv;
   1524 }
   1525 
   1526 SECStatus
   1527 RSA_Sign(RSAPrivateKey *key,
   1528         unsigned char *output,
   1529         unsigned int *outputLen,
   1530         unsigned int maxOutputLen,
   1531         const unsigned char *input,
   1532         unsigned int inputLen)
   1533 {
   1534    SECStatus rv = SECFailure;
   1535    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1536    SECItem formatted = { siBuffer, NULL, 0 };
   1537    SECItem unformatted = { siBuffer, (unsigned char *)input, inputLen };
   1538 
   1539    if (maxOutputLen < modulusLen) {
   1540        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
   1541        goto done;
   1542    }
   1543 
   1544    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate,
   1545                         &unformatted);
   1546    if (rv != SECSuccess) {
   1547        PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
   1548        goto done;
   1549    }
   1550 
   1551    // This sets error codes upon failure.
   1552    rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
   1553    *outputLen = modulusLen;
   1554 
   1555 done:
   1556    if (formatted.data != NULL) {
   1557        PORT_ZFree(formatted.data, modulusLen);
   1558    }
   1559    return rv;
   1560 }
   1561 
   1562 SECStatus
   1563 RSA_CheckSign(RSAPublicKey *key,
   1564              const unsigned char *sig,
   1565              unsigned int sigLen,
   1566              const unsigned char *data,
   1567              unsigned int dataLen)
   1568 {
   1569    SECStatus rv = SECFailure;
   1570    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1571    unsigned int i;
   1572    unsigned char *buffer = NULL;
   1573 
   1574    if (sigLen != modulusLen) {
   1575        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1576        goto done;
   1577    }
   1578 
   1579    /*
   1580     * 0x00 || BT || Pad || 0x00 || ActualData
   1581     *
   1582     * The "3" below is the first octet + the second octet + the 0x00
   1583     * octet that always comes just before the ActualData.
   1584     */
   1585    if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) {
   1586        PORT_SetError(SEC_ERROR_BAD_DATA);
   1587        goto done;
   1588    }
   1589 
   1590    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
   1591    if (!buffer) {
   1592        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1593        goto done;
   1594    }
   1595 
   1596    if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
   1597        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1598        goto done;
   1599    }
   1600 
   1601    /*
   1602     * check the padding that was used
   1603     */
   1604    if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
   1605        buffer[1] != (unsigned char)RSA_BlockPrivate) {
   1606        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1607        goto done;
   1608    }
   1609    for (i = 2; i < modulusLen - dataLen - 1; i++) {
   1610        if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
   1611            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1612            goto done;
   1613        }
   1614    }
   1615    if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
   1616        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1617        goto done;
   1618    }
   1619 
   1620    /*
   1621     * make sure we get the same results
   1622     */
   1623    if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) == 0) {
   1624        rv = SECSuccess;
   1625    }
   1626 
   1627 done:
   1628    if (buffer) {
   1629        PORT_Free(buffer);
   1630    }
   1631    return rv;
   1632 }
   1633 
   1634 SECStatus
   1635 RSA_CheckSignRecover(RSAPublicKey *key,
   1636                     unsigned char *output,
   1637                     unsigned int *outputLen,
   1638                     unsigned int maxOutputLen,
   1639                     const unsigned char *sig,
   1640                     unsigned int sigLen)
   1641 {
   1642    SECStatus rv = SECFailure;
   1643    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
   1644    unsigned int i;
   1645    unsigned char *buffer = NULL;
   1646    unsigned int padLen;
   1647 
   1648    if (sigLen != modulusLen) {
   1649        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1650        goto done;
   1651    }
   1652 
   1653    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
   1654    if (!buffer) {
   1655        PORT_SetError(SEC_ERROR_NO_MEMORY);
   1656        goto done;
   1657    }
   1658 
   1659    if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
   1660        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1661        goto done;
   1662    }
   1663 
   1664    *outputLen = 0;
   1665 
   1666    /*
   1667     * check the padding that was used
   1668     */
   1669    if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
   1670        buffer[1] != (unsigned char)RSA_BlockPrivate) {
   1671        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1672        goto done;
   1673    }
   1674    for (i = 2; i < modulusLen; i++) {
   1675        if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) {
   1676            *outputLen = modulusLen - i - 1;
   1677            break;
   1678        }
   1679        if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
   1680            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1681            goto done;
   1682        }
   1683    }
   1684    padLen = i - 2;
   1685    if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
   1686        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1687        goto done;
   1688    }
   1689    if (*outputLen == 0) {
   1690        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
   1691        goto done;
   1692    }
   1693    if (*outputLen > maxOutputLen) {
   1694        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
   1695        goto done;
   1696    }
   1697 
   1698    PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen);
   1699    rv = SECSuccess;
   1700 
   1701 done:
   1702    if (buffer) {
   1703        PORT_Free(buffer);
   1704    }
   1705    return rv;
   1706 }