rsapkcs.c (50921B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 /* 6 * RSA PKCS#1 v2.1 (RFC 3447) operations 7 */ 8 9 #ifdef FREEBL_NO_DEPEND 10 #include "stubs.h" 11 #endif 12 13 #include "secerr.h" 14 15 #include "blapi.h" 16 #include "secitem.h" 17 #include "blapii.h" 18 19 #define RSA_BLOCK_MIN_PAD_LEN 8 20 #define RSA_BLOCK_FIRST_OCTET 0x00 21 #define RSA_BLOCK_PRIVATE_PAD_OCTET 0xff 22 #define RSA_BLOCK_AFTER_PAD_OCTET 0x00 23 24 /* 25 * RSA block types 26 * 27 * The values of RSA_BlockPrivate and RSA_BlockPublic are fixed. 28 * The value of RSA_BlockRaw isn't fixed by definition, but we are keeping 29 * the value that NSS has been using in the past. 30 */ 31 typedef enum { 32 RSA_BlockPrivate = 1, /* pad for a private-key operation */ 33 RSA_BlockPublic = 2, /* pad for a public-key operation */ 34 RSA_BlockRaw = 4 /* simply justify the block appropriately */ 35 } RSA_BlockType; 36 37 /* Needed for RSA-PSS functions */ 38 static const unsigned char eightZeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 }; 39 40 /* Constant time comparison of a single byte. 41 * Returns 1 iff a == b, otherwise returns 0. 42 * Note: For ranges of bytes, use constantTimeCompare. 43 */ 44 static unsigned char 45 constantTimeEQ8(unsigned char a, unsigned char b) 46 { 47 unsigned char c = ~((a - b) | (b - a)); 48 c >>= 7; 49 return c; 50 } 51 52 /* Constant time comparison of a range of bytes. 53 * Returns 1 iff len bytes of a are identical to len bytes of b, otherwise 54 * returns 0. 55 */ 56 static unsigned char 57 constantTimeCompare(const unsigned char *a, 58 const unsigned char *b, 59 unsigned int len) 60 { 61 unsigned char tmp = 0; 62 unsigned int i; 63 for (i = 0; i < len; ++i, ++a, ++b) 64 tmp |= *a ^ *b; 65 return constantTimeEQ8(0x00, tmp); 66 } 67 68 /* Constant time conditional. 69 * Returns a if c is 1, or b if c is 0. The result is undefined if c is 70 * not 0 or 1. 71 */ 72 static unsigned int 73 constantTimeCondition(unsigned int c, 74 unsigned int a, 75 unsigned int b) 76 { 77 return (~(c - 1) & a) | ((c - 1) & b); 78 } 79 80 static unsigned int 81 rsa_modulusLen(SECItem *modulus) 82 { 83 if (modulus->len == 0) { 84 return 0; 85 } 86 87 unsigned char byteZero = modulus->data[0]; 88 unsigned int modLen = modulus->len - !byteZero; 89 return modLen; 90 } 91 92 static unsigned int 93 rsa_modulusBits(SECItem *modulus) 94 { 95 if (modulus->len == 0) { 96 return 0; 97 } 98 99 unsigned char byteZero = modulus->data[0]; 100 unsigned int numBits = (modulus->len - 1) * 8; 101 102 if (byteZero == 0 && modulus->len == 1) { 103 return 0; 104 } 105 106 if (byteZero == 0) { 107 numBits -= 8; 108 byteZero = modulus->data[1]; 109 } 110 111 while (byteZero > 0) { 112 numBits++; 113 byteZero >>= 1; 114 } 115 116 return numBits; 117 } 118 119 /* 120 * Format one block of data for public/private key encryption using 121 * the rules defined in PKCS #1. 122 */ 123 static unsigned char * 124 rsa_FormatOneBlock(unsigned modulusLen, 125 RSA_BlockType blockType, 126 SECItem *data) 127 { 128 unsigned char *block; 129 unsigned char *bp; 130 unsigned int padLen; 131 unsigned int i, j; 132 SECStatus rv; 133 134 block = (unsigned char *)PORT_Alloc(modulusLen); 135 if (block == NULL) 136 return NULL; 137 138 bp = block; 139 140 /* 141 * All RSA blocks start with two octets: 142 * 0x00 || BlockType 143 */ 144 *bp++ = RSA_BLOCK_FIRST_OCTET; 145 *bp++ = (unsigned char)blockType; 146 147 switch (blockType) { 148 149 /* 150 * Blocks intended for private-key operation. 151 */ 152 case RSA_BlockPrivate: /* preferred method */ 153 /* 154 * 0x00 || BT || Pad || 0x00 || ActualData 155 * 1 1 padLen 1 data->len 156 * padLen must be at least RSA_BLOCK_MIN_PAD_LEN (8) bytes. 157 * Pad is either all 0x00 or all 0xff bytes, depending on blockType. 158 */ 159 padLen = modulusLen - data->len - 3; 160 PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN); 161 if (padLen < RSA_BLOCK_MIN_PAD_LEN) { 162 PORT_ZFree(block, modulusLen); 163 return NULL; 164 } 165 PORT_Memset(bp, RSA_BLOCK_PRIVATE_PAD_OCTET, padLen); 166 bp += padLen; 167 *bp++ = RSA_BLOCK_AFTER_PAD_OCTET; 168 PORT_Memcpy(bp, data->data, data->len); 169 break; 170 171 /* 172 * Blocks intended for public-key operation. 173 */ 174 case RSA_BlockPublic: 175 /* 176 * 0x00 || BT || Pad || 0x00 || ActualData 177 * 1 1 padLen 1 data->len 178 * Pad is 8 or more non-zero random bytes. 179 * 180 * Build the block left to right. 181 * Fill the entire block from Pad to the end with random bytes. 182 * Use the bytes after Pad as a supply of extra random bytes from 183 * which to find replacements for the zero bytes in Pad. 184 * If we need more than that, refill the bytes after Pad with 185 * new random bytes as necessary. 186 */ 187 188 padLen = modulusLen - (data->len + 3); 189 PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN); 190 if (padLen < RSA_BLOCK_MIN_PAD_LEN) { 191 PORT_ZFree(block, modulusLen); 192 return NULL; 193 } 194 j = modulusLen - 2; 195 rv = RNG_GenerateGlobalRandomBytes(bp, j); 196 if (rv == SECSuccess) { 197 for (i = 0; i < padLen;) { 198 unsigned char repl; 199 /* Pad with non-zero random data. */ 200 if (bp[i] != RSA_BLOCK_AFTER_PAD_OCTET) { 201 ++i; 202 continue; 203 } 204 if (j <= padLen) { 205 rv = RNG_GenerateGlobalRandomBytes(bp + padLen, 206 modulusLen - (2 + padLen)); 207 if (rv != SECSuccess) 208 break; 209 j = modulusLen - 2; 210 } 211 do { 212 repl = bp[--j]; 213 } while (repl == RSA_BLOCK_AFTER_PAD_OCTET && j > padLen); 214 if (repl != RSA_BLOCK_AFTER_PAD_OCTET) { 215 bp[i++] = repl; 216 } 217 } 218 } 219 if (rv != SECSuccess) { 220 PORT_ZFree(block, modulusLen); 221 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); 222 return NULL; 223 } 224 bp += padLen; 225 *bp++ = RSA_BLOCK_AFTER_PAD_OCTET; 226 PORT_Memcpy(bp, data->data, data->len); 227 break; 228 229 default: 230 PORT_Assert(0); 231 PORT_ZFree(block, modulusLen); 232 return NULL; 233 } 234 235 return block; 236 } 237 238 /* modulusLen has to be larger than RSA_BLOCK_MIN_PAD_LEN + 3, and data has to be smaller than modulus - (RSA_BLOCK_MIN_PAD_LEN + 3) */ 239 static SECStatus 240 rsa_FormatBlock(SECItem *result, 241 unsigned modulusLen, 242 RSA_BlockType blockType, 243 SECItem *data) 244 { 245 switch (blockType) { 246 case RSA_BlockPrivate: 247 case RSA_BlockPublic: 248 /* 249 * 0x00 || BT || Pad || 0x00 || ActualData 250 * 251 * The "3" below is the first octet + the second octet + the 0x00 252 * octet that always comes just before the ActualData. 253 */ 254 if (modulusLen < (3 + RSA_BLOCK_MIN_PAD_LEN) || data->len > (modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN))) { 255 return SECFailure; 256 } 257 result->data = rsa_FormatOneBlock(modulusLen, blockType, data); 258 if (result->data == NULL) { 259 result->len = 0; 260 return SECFailure; 261 } 262 result->len = modulusLen; 263 264 break; 265 266 case RSA_BlockRaw: 267 /* 268 * Pad || ActualData 269 * Pad is zeros. The application is responsible for recovering 270 * the actual data. 271 */ 272 if (data->len > modulusLen) { 273 return SECFailure; 274 } 275 result->data = (unsigned char *)PORT_ZAlloc(modulusLen); 276 result->len = modulusLen; 277 PORT_Memcpy(result->data + (modulusLen - data->len), 278 data->data, data->len); 279 break; 280 281 default: 282 PORT_Assert(0); 283 result->data = NULL; 284 result->len = 0; 285 return SECFailure; 286 } 287 288 return SECSuccess; 289 } 290 291 /* 292 * Mask generation function MGF1 as defined in PKCS #1 v2.1 / RFC 3447. 293 */ 294 static SECStatus 295 MGF1(HASH_HashType hashAlg, 296 unsigned char *mask, 297 unsigned int maskLen, 298 const unsigned char *mgfSeed, 299 unsigned int mgfSeedLen) 300 { 301 unsigned int digestLen; 302 PRUint32 counter; 303 PRUint32 rounds; 304 unsigned char *tempHash; 305 unsigned char *temp; 306 const SECHashObject *hash; 307 void *hashContext; 308 unsigned char C[4]; 309 SECStatus rv = SECSuccess; 310 311 hash = HASH_GetRawHashObject(hashAlg); 312 if (hash == NULL) { 313 return SECFailure; 314 } 315 316 hashContext = (*hash->create)(); 317 rounds = (maskLen + hash->length - 1) / hash->length; 318 for (counter = 0; counter < rounds; counter++) { 319 C[0] = (unsigned char)((counter >> 24) & 0xff); 320 C[1] = (unsigned char)((counter >> 16) & 0xff); 321 C[2] = (unsigned char)((counter >> 8) & 0xff); 322 C[3] = (unsigned char)(counter & 0xff); 323 324 /* This could be optimized when the clone functions in 325 * rawhash.c are implemented. */ 326 (*hash->begin)(hashContext); 327 (*hash->update)(hashContext, mgfSeed, mgfSeedLen); 328 (*hash->update)(hashContext, C, sizeof C); 329 330 tempHash = mask + counter * hash->length; 331 if (counter != (rounds - 1)) { 332 (*hash->end)(hashContext, tempHash, &digestLen, hash->length); 333 } else { /* we're in the last round and need to cut the hash */ 334 temp = (unsigned char *)PORT_Alloc(hash->length); 335 if (!temp) { 336 rv = SECFailure; 337 goto done; 338 } 339 (*hash->end)(hashContext, temp, &digestLen, hash->length); 340 PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length); 341 PORT_Free(temp); 342 } 343 } 344 345 done: 346 (*hash->destroy)(hashContext, PR_TRUE); 347 return rv; 348 } 349 350 /* XXX Doesn't set error code */ 351 SECStatus 352 RSA_SignRaw(RSAPrivateKey *key, 353 unsigned char *output, 354 unsigned int *outputLen, 355 unsigned int maxOutputLen, 356 const unsigned char *data, 357 unsigned int dataLen) 358 { 359 SECStatus rv = SECSuccess; 360 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 361 SECItem formatted; 362 SECItem unformatted; 363 364 if (maxOutputLen < modulusLen) 365 return SECFailure; 366 367 unformatted.len = dataLen; 368 unformatted.data = (unsigned char *)data; 369 formatted.data = NULL; 370 rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted); 371 if (rv != SECSuccess) 372 goto done; 373 374 rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data); 375 *outputLen = modulusLen; 376 377 done: 378 if (formatted.data != NULL) 379 PORT_ZFree(formatted.data, modulusLen); 380 return rv; 381 } 382 383 /* XXX Doesn't set error code */ 384 SECStatus 385 RSA_CheckSignRaw(RSAPublicKey *key, 386 const unsigned char *sig, 387 unsigned int sigLen, 388 const unsigned char *hash, 389 unsigned int hashLen) 390 { 391 SECStatus rv; 392 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 393 unsigned char *buffer; 394 395 if (sigLen != modulusLen) 396 goto failure; 397 if (hashLen > modulusLen) 398 goto failure; 399 400 buffer = (unsigned char *)PORT_Alloc(modulusLen + 1); 401 if (!buffer) 402 goto failure; 403 404 rv = RSA_PublicKeyOp(key, buffer, sig); 405 if (rv != SECSuccess) 406 goto loser; 407 408 /* 409 * make sure we get the same results 410 */ 411 /* XXX(rsleevi): Constant time */ 412 /* NOTE: should we verify the leading zeros? */ 413 if (PORT_Memcmp(buffer + (modulusLen - hashLen), hash, hashLen) != 0) 414 goto loser; 415 416 PORT_Free(buffer); 417 return SECSuccess; 418 419 loser: 420 PORT_Free(buffer); 421 failure: 422 return SECFailure; 423 } 424 425 /* XXX Doesn't set error code */ 426 SECStatus 427 RSA_CheckSignRecoverRaw(RSAPublicKey *key, 428 unsigned char *data, 429 unsigned int *dataLen, 430 unsigned int maxDataLen, 431 const unsigned char *sig, 432 unsigned int sigLen) 433 { 434 SECStatus rv; 435 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 436 437 if (sigLen != modulusLen) 438 goto failure; 439 if (maxDataLen < modulusLen) 440 goto failure; 441 442 rv = RSA_PublicKeyOp(key, data, sig); 443 if (rv != SECSuccess) 444 goto failure; 445 446 *dataLen = modulusLen; 447 return SECSuccess; 448 449 failure: 450 return SECFailure; 451 } 452 453 /* XXX Doesn't set error code */ 454 SECStatus 455 RSA_EncryptRaw(RSAPublicKey *key, 456 unsigned char *output, 457 unsigned int *outputLen, 458 unsigned int maxOutputLen, 459 const unsigned char *input, 460 unsigned int inputLen) 461 { 462 SECStatus rv; 463 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 464 SECItem formatted; 465 SECItem unformatted; 466 467 formatted.data = NULL; 468 if (maxOutputLen < modulusLen) 469 goto failure; 470 471 unformatted.len = inputLen; 472 unformatted.data = (unsigned char *)input; 473 formatted.data = NULL; 474 rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted); 475 if (rv != SECSuccess) 476 goto failure; 477 478 rv = RSA_PublicKeyOp(key, output, formatted.data); 479 if (rv != SECSuccess) 480 goto failure; 481 482 PORT_ZFree(formatted.data, modulusLen); 483 *outputLen = modulusLen; 484 return SECSuccess; 485 486 failure: 487 if (formatted.data != NULL) 488 PORT_ZFree(formatted.data, modulusLen); 489 return SECFailure; 490 } 491 492 /* XXX Doesn't set error code */ 493 SECStatus 494 RSA_DecryptRaw(RSAPrivateKey *key, 495 unsigned char *output, 496 unsigned int *outputLen, 497 unsigned int maxOutputLen, 498 const unsigned char *input, 499 unsigned int inputLen) 500 { 501 SECStatus rv; 502 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 503 504 if (modulusLen > maxOutputLen) 505 goto failure; 506 if (inputLen != modulusLen) 507 goto failure; 508 509 rv = RSA_PrivateKeyOp(key, output, input); 510 if (rv != SECSuccess) 511 goto failure; 512 513 *outputLen = modulusLen; 514 return SECSuccess; 515 516 failure: 517 return SECFailure; 518 } 519 520 /* 521 * Decodes an EME-OAEP encoded block, validating the encoding in constant 522 * time. 523 * Described in RFC 3447, section 7.1.2. 524 * input contains the encoded block, after decryption. 525 * label is the optional value L that was associated with the message. 526 * On success, the original message and message length will be stored in 527 * output and outputLen. 528 */ 529 static SECStatus 530 eme_oaep_decode(unsigned char *output, 531 unsigned int *outputLen, 532 unsigned int maxOutputLen, 533 const unsigned char *input, 534 unsigned int inputLen, 535 HASH_HashType hashAlg, 536 HASH_HashType maskHashAlg, 537 const unsigned char *label, 538 unsigned int labelLen) 539 { 540 const SECHashObject *hash; 541 void *hashContext; 542 SECStatus rv = SECFailure; 543 unsigned char labelHash[HASH_LENGTH_MAX]; 544 unsigned int i; 545 unsigned int maskLen; 546 unsigned int paddingOffset; 547 unsigned char *mask = NULL; 548 unsigned char *tmpOutput = NULL; 549 unsigned char isGood; 550 unsigned char foundPaddingEnd; 551 552 hash = HASH_GetRawHashObject(hashAlg); 553 554 /* 1.c */ 555 if (inputLen < (hash->length * 2) + 2) { 556 PORT_SetError(SEC_ERROR_INPUT_LEN); 557 return SECFailure; 558 } 559 560 /* Step 3.a - Generate lHash */ 561 hashContext = (*hash->create)(); 562 if (hashContext == NULL) { 563 PORT_SetError(SEC_ERROR_NO_MEMORY); 564 return SECFailure; 565 } 566 (*hash->begin)(hashContext); 567 if (labelLen > 0) 568 (*hash->update)(hashContext, label, labelLen); 569 (*hash->end)(hashContext, labelHash, &i, sizeof(labelHash)); 570 (*hash->destroy)(hashContext, PR_TRUE); 571 572 tmpOutput = (unsigned char *)PORT_Alloc(inputLen); 573 if (tmpOutput == NULL) { 574 PORT_SetError(SEC_ERROR_NO_MEMORY); 575 goto done; 576 } 577 578 maskLen = inputLen - hash->length - 1; 579 mask = (unsigned char *)PORT_Alloc(maskLen); 580 if (mask == NULL) { 581 PORT_SetError(SEC_ERROR_NO_MEMORY); 582 goto done; 583 } 584 585 PORT_Memcpy(tmpOutput, input, inputLen); 586 587 /* 3.c - Generate seedMask */ 588 MGF1(maskHashAlg, mask, hash->length, &tmpOutput[1 + hash->length], 589 inputLen - hash->length - 1); 590 /* 3.d - Unmask seed */ 591 for (i = 0; i < hash->length; ++i) 592 tmpOutput[1 + i] ^= mask[i]; 593 594 /* 3.e - Generate dbMask */ 595 MGF1(maskHashAlg, mask, maskLen, &tmpOutput[1], hash->length); 596 /* 3.f - Unmask DB */ 597 for (i = 0; i < maskLen; ++i) 598 tmpOutput[1 + hash->length + i] ^= mask[i]; 599 600 /* 3.g - Compare Y, lHash, and PS in constant time 601 * Warning: This code is timing dependent and must not disclose which of 602 * these were invalid. 603 */ 604 paddingOffset = 0; 605 isGood = 1; 606 foundPaddingEnd = 0; 607 608 /* Compare Y */ 609 isGood &= constantTimeEQ8(0x00, tmpOutput[0]); 610 611 /* Compare lHash and lHash' */ 612 isGood &= constantTimeCompare(&labelHash[0], 613 &tmpOutput[1 + hash->length], 614 hash->length); 615 616 /* Compare that the padding is zero or more zero octets, followed by a 617 * 0x01 octet */ 618 for (i = 1 + (hash->length * 2); i < inputLen; ++i) { 619 unsigned char isZero = constantTimeEQ8(0x00, tmpOutput[i]); 620 unsigned char isOne = constantTimeEQ8(0x01, tmpOutput[i]); 621 /* non-constant time equivalent: 622 * if (tmpOutput[i] == 0x01 && !foundPaddingEnd) 623 * paddingOffset = i; 624 */ 625 paddingOffset = constantTimeCondition(isOne & ~foundPaddingEnd, i, 626 paddingOffset); 627 /* non-constant time equivalent: 628 * if (tmpOutput[i] == 0x01) 629 * foundPaddingEnd = true; 630 * 631 * Note: This may yield false positives, as it will be set whenever 632 * a 0x01 byte is encountered. If there was bad padding (eg: 633 * 0x03 0x02 0x01), foundPaddingEnd will still be set to true, and 634 * paddingOffset will still be set to 2. 635 */ 636 foundPaddingEnd = constantTimeCondition(isOne, 1, foundPaddingEnd); 637 /* non-constant time equivalent: 638 * if (tmpOutput[i] != 0x00 && tmpOutput[i] != 0x01 && 639 * !foundPaddingEnd) { 640 * isGood = false; 641 * } 642 * 643 * Note: This may yield false positives, as a message (and padding) 644 * that is entirely zeros will result in isGood still being true. Thus 645 * it's necessary to check foundPaddingEnd is positive below. 646 */ 647 isGood = constantTimeCondition(~foundPaddingEnd & ~isZero, 0, isGood); 648 } 649 650 /* While both isGood and foundPaddingEnd may have false positives, they 651 * cannot BOTH have false positives. If both are not true, then an invalid 652 * message was received. Note, this comparison must still be done in constant 653 * time so as not to leak either condition. 654 */ 655 if (!(isGood & foundPaddingEnd)) { 656 PORT_SetError(SEC_ERROR_BAD_DATA); 657 goto done; 658 } 659 660 /* End timing dependent code */ 661 662 ++paddingOffset; /* Skip the 0x01 following the end of PS */ 663 664 *outputLen = inputLen - paddingOffset; 665 if (*outputLen > maxOutputLen) { 666 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 667 goto done; 668 } 669 670 if (*outputLen) 671 PORT_Memcpy(output, &tmpOutput[paddingOffset], *outputLen); 672 rv = SECSuccess; 673 674 done: 675 if (mask) 676 PORT_ZFree(mask, maskLen); 677 if (tmpOutput) 678 PORT_ZFree(tmpOutput, inputLen); 679 return rv; 680 } 681 682 /* 683 * Generate an EME-OAEP encoded block for encryption 684 * Described in RFC 3447, section 7.1.1 685 * We use input instead of M for the message to be encrypted 686 * label is the optional value L to be associated with the message. 687 */ 688 static SECStatus 689 eme_oaep_encode(unsigned char *em, 690 unsigned int emLen, 691 const unsigned char *input, 692 unsigned int inputLen, 693 HASH_HashType hashAlg, 694 HASH_HashType maskHashAlg, 695 const unsigned char *label, 696 unsigned int labelLen, 697 const unsigned char *seed, 698 unsigned int seedLen) 699 { 700 const SECHashObject *hash; 701 void *hashContext; 702 SECStatus rv; 703 unsigned char *mask; 704 unsigned int reservedLen; 705 unsigned int dbMaskLen; 706 unsigned int i; 707 708 hash = HASH_GetRawHashObject(hashAlg); 709 PORT_Assert(seed == NULL || seedLen == hash->length); 710 711 /* Step 1.b */ 712 reservedLen = (2 * hash->length) + 2; 713 if (emLen < reservedLen || inputLen > (emLen - reservedLen)) { 714 PORT_SetError(SEC_ERROR_INPUT_LEN); 715 return SECFailure; 716 } 717 718 /* 719 * From RFC 3447, Section 7.1 720 * +----------+---------+-------+ 721 * DB = | lHash | PS | M | 722 * +----------+---------+-------+ 723 * | 724 * +----------+ V 725 * | seed |--> MGF ---> xor 726 * +----------+ | 727 * | | 728 * +--+ V | 729 * |00| xor <----- MGF <-----| 730 * +--+ | | 731 * | | | 732 * V V V 733 * +--+----------+----------------------------+ 734 * EM = |00|maskedSeed| maskedDB | 735 * +--+----------+----------------------------+ 736 * 737 * We use mask to hold the result of the MGF functions, and all other 738 * values are generated in their final resting place. 739 */ 740 *em = 0x00; 741 742 /* Step 2.a - Generate lHash */ 743 hashContext = (*hash->create)(); 744 if (hashContext == NULL) { 745 PORT_SetError(SEC_ERROR_NO_MEMORY); 746 return SECFailure; 747 } 748 (*hash->begin)(hashContext); 749 if (labelLen > 0) 750 (*hash->update)(hashContext, label, labelLen); 751 (*hash->end)(hashContext, &em[1 + hash->length], &i, hash->length); 752 (*hash->destroy)(hashContext, PR_TRUE); 753 754 /* Step 2.b - Generate PS */ 755 if (emLen - reservedLen - inputLen > 0) { 756 PORT_Memset(em + 1 + (hash->length * 2), 0x00, 757 emLen - reservedLen - inputLen); 758 } 759 760 /* Step 2.c. - Generate DB 761 * DB = lHash || PS || 0x01 || M 762 * Note that PS and lHash have already been placed into em at their 763 * appropriate offsets. This just copies M into place 764 */ 765 em[emLen - inputLen - 1] = 0x01; 766 if (inputLen) 767 PORT_Memcpy(em + emLen - inputLen, input, inputLen); 768 769 if (seed == NULL) { 770 /* Step 2.d - Generate seed */ 771 rv = RNG_GenerateGlobalRandomBytes(em + 1, hash->length); 772 if (rv != SECSuccess) { 773 return rv; 774 } 775 } else { 776 /* For Known Answer Tests, copy the supplied seed. */ 777 PORT_Memcpy(em + 1, seed, seedLen); 778 } 779 780 /* Step 2.e - Generate dbMask*/ 781 dbMaskLen = emLen - hash->length - 1; 782 mask = (unsigned char *)PORT_Alloc(dbMaskLen); 783 if (mask == NULL) { 784 PORT_SetError(SEC_ERROR_NO_MEMORY); 785 return SECFailure; 786 } 787 MGF1(maskHashAlg, mask, dbMaskLen, em + 1, hash->length); 788 /* Step 2.f - Compute maskedDB*/ 789 for (i = 0; i < dbMaskLen; ++i) 790 em[1 + hash->length + i] ^= mask[i]; 791 792 /* Step 2.g - Generate seedMask */ 793 MGF1(maskHashAlg, mask, hash->length, &em[1 + hash->length], dbMaskLen); 794 /* Step 2.h - Compute maskedSeed */ 795 for (i = 0; i < hash->length; ++i) 796 em[1 + i] ^= mask[i]; 797 798 PORT_ZFree(mask, dbMaskLen); 799 return SECSuccess; 800 } 801 802 SECStatus 803 RSA_EncryptOAEP(RSAPublicKey *key, 804 HASH_HashType hashAlg, 805 HASH_HashType maskHashAlg, 806 const unsigned char *label, 807 unsigned int labelLen, 808 const unsigned char *seed, 809 unsigned int seedLen, 810 unsigned char *output, 811 unsigned int *outputLen, 812 unsigned int maxOutputLen, 813 const unsigned char *input, 814 unsigned int inputLen) 815 { 816 SECStatus rv = SECFailure; 817 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 818 unsigned char *oaepEncoded = NULL; 819 820 if (maxOutputLen < modulusLen) { 821 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 822 return SECFailure; 823 } 824 825 if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) { 826 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 827 return SECFailure; 828 } 829 830 if ((labelLen == 0 && label != NULL) || 831 (labelLen > 0 && label == NULL)) { 832 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 833 return SECFailure; 834 } 835 836 oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen); 837 if (oaepEncoded == NULL) { 838 PORT_SetError(SEC_ERROR_NO_MEMORY); 839 return SECFailure; 840 } 841 rv = eme_oaep_encode(oaepEncoded, modulusLen, input, inputLen, 842 hashAlg, maskHashAlg, label, labelLen, seed, seedLen); 843 if (rv != SECSuccess) 844 goto done; 845 846 rv = RSA_PublicKeyOp(key, output, oaepEncoded); 847 if (rv != SECSuccess) 848 goto done; 849 *outputLen = modulusLen; 850 851 done: 852 PORT_Free(oaepEncoded); 853 return rv; 854 } 855 856 SECStatus 857 RSA_DecryptOAEP(RSAPrivateKey *key, 858 HASH_HashType hashAlg, 859 HASH_HashType maskHashAlg, 860 const unsigned char *label, 861 unsigned int labelLen, 862 unsigned char *output, 863 unsigned int *outputLen, 864 unsigned int maxOutputLen, 865 const unsigned char *input, 866 unsigned int inputLen) 867 { 868 SECStatus rv = SECFailure; 869 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 870 unsigned char *oaepEncoded = NULL; 871 872 if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) { 873 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 874 return SECFailure; 875 } 876 877 if (inputLen != modulusLen) { 878 PORT_SetError(SEC_ERROR_INPUT_LEN); 879 return SECFailure; 880 } 881 882 if ((labelLen == 0 && label != NULL) || 883 (labelLen > 0 && label == NULL)) { 884 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 885 return SECFailure; 886 } 887 888 oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen); 889 if (oaepEncoded == NULL) { 890 PORT_SetError(SEC_ERROR_NO_MEMORY); 891 return SECFailure; 892 } 893 894 rv = RSA_PrivateKeyOpDoubleChecked(key, oaepEncoded, input); 895 if (rv != SECSuccess) { 896 goto done; 897 } 898 rv = eme_oaep_decode(output, outputLen, maxOutputLen, oaepEncoded, 899 modulusLen, hashAlg, maskHashAlg, label, 900 labelLen); 901 902 done: 903 if (oaepEncoded) 904 PORT_ZFree(oaepEncoded, modulusLen); 905 return rv; 906 } 907 908 /* XXX Doesn't set error code */ 909 SECStatus 910 RSA_EncryptBlock(RSAPublicKey *key, 911 unsigned char *output, 912 unsigned int *outputLen, 913 unsigned int maxOutputLen, 914 const unsigned char *input, 915 unsigned int inputLen) 916 { 917 SECStatus rv; 918 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 919 SECItem formatted; 920 SECItem unformatted; 921 922 formatted.data = NULL; 923 if (maxOutputLen < modulusLen) 924 goto failure; 925 926 unformatted.len = inputLen; 927 unformatted.data = (unsigned char *)input; 928 formatted.data = NULL; 929 rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPublic, 930 &unformatted); 931 if (rv != SECSuccess) 932 goto failure; 933 934 rv = RSA_PublicKeyOp(key, output, formatted.data); 935 if (rv != SECSuccess) 936 goto failure; 937 938 PORT_ZFree(formatted.data, modulusLen); 939 *outputLen = modulusLen; 940 return SECSuccess; 941 942 failure: 943 if (formatted.data != NULL) 944 PORT_ZFree(formatted.data, modulusLen); 945 return SECFailure; 946 } 947 948 static HMACContext * 949 rsa_GetHMACContext(const SECHashObject *hash, RSAPrivateKey *key, 950 const unsigned char *input, unsigned int inputLen) 951 { 952 unsigned char keyHash[HASH_LENGTH_MAX]; 953 void *hashContext; 954 HMACContext *hmac = NULL; 955 unsigned int privKeyLen = key->privateExponent.len; 956 unsigned int keyLen; 957 SECStatus rv; 958 959 /* first get the key hash (should store in the key structure) */ 960 PORT_Memset(keyHash, 0, sizeof(keyHash)); 961 hashContext = (*hash->create)(); 962 if (hashContext == NULL) { 963 return NULL; 964 } 965 (*hash->begin)(hashContext); 966 if (privKeyLen < inputLen) { 967 int padLen = inputLen - privKeyLen; 968 while (padLen > sizeof(keyHash)) { 969 (*hash->update)(hashContext, keyHash, sizeof(keyHash)); 970 padLen -= sizeof(keyHash); 971 } 972 (*hash->update)(hashContext, keyHash, padLen); 973 } 974 (*hash->update)(hashContext, key->privateExponent.data, privKeyLen); 975 (*hash->end)(hashContext, keyHash, &keyLen, sizeof(keyHash)); 976 (*hash->destroy)(hashContext, PR_TRUE); 977 978 /* now create the hmac key */ 979 hmac = HMAC_Create(hash, keyHash, keyLen, PR_TRUE); 980 if (hmac == NULL) { 981 PORT_SafeZero(keyHash, sizeof(keyHash)); 982 return NULL; 983 } 984 HMAC_Begin(hmac); 985 HMAC_Update(hmac, input, inputLen); 986 rv = HMAC_Finish(hmac, keyHash, &keyLen, sizeof(keyHash)); 987 if (rv != SECSuccess) { 988 PORT_SafeZero(keyHash, sizeof(keyHash)); 989 HMAC_Destroy(hmac, PR_TRUE); 990 return NULL; 991 } 992 /* Finally set the new key into the hash context. We 993 * reuse the original context allocated above so we don't 994 * need to allocate and free another one */ 995 rv = HMAC_ReInit(hmac, hash, keyHash, keyLen, PR_TRUE); 996 PORT_SafeZero(keyHash, sizeof(keyHash)); 997 if (rv != SECSuccess) { 998 HMAC_Destroy(hmac, PR_TRUE); 999 return NULL; 1000 } 1001 1002 return hmac; 1003 } 1004 1005 static SECStatus 1006 rsa_HMACPrf(HMACContext *hmac, const char *label, int labelLen, 1007 int hashLength, unsigned char *output, int length) 1008 { 1009 unsigned char iterator[2] = { 0, 0 }; 1010 unsigned char encodedLen[2] = { 0, 0 }; 1011 unsigned char hmacLast[HASH_LENGTH_MAX]; 1012 unsigned int left = length; 1013 unsigned int hashReturn; 1014 SECStatus rv = SECSuccess; 1015 1016 /* encodedLen is in bits, length is in bytes, thus the shifts 1017 * do an implied multiply by 8 */ 1018 encodedLen[0] = (length >> 5) & 0xff; 1019 encodedLen[1] = (length << 3) & 0xff; 1020 1021 while (left > hashLength) { 1022 HMAC_Begin(hmac); 1023 HMAC_Update(hmac, iterator, 2); 1024 HMAC_Update(hmac, (const unsigned char *)label, labelLen); 1025 HMAC_Update(hmac, encodedLen, 2); 1026 rv = HMAC_Finish(hmac, output, &hashReturn, hashLength); 1027 if (rv != SECSuccess) { 1028 return rv; 1029 } 1030 iterator[1]++; 1031 if (iterator[1] == 0) 1032 iterator[0]++; 1033 left -= hashLength; 1034 output += hashLength; 1035 } 1036 if (left) { 1037 HMAC_Begin(hmac); 1038 HMAC_Update(hmac, iterator, 2); 1039 HMAC_Update(hmac, (const unsigned char *)label, labelLen); 1040 HMAC_Update(hmac, encodedLen, 2); 1041 rv = HMAC_Finish(hmac, hmacLast, &hashReturn, sizeof(hmacLast)); 1042 if (rv != SECSuccess) { 1043 return rv; 1044 } 1045 PORT_Memcpy(output, hmacLast, left); 1046 PORT_SafeZero(hmacLast, sizeof(hmacLast)); 1047 } 1048 return rv; 1049 } 1050 1051 /* This function takes a 16-bit input number and 1052 * creates the smallest mask which covers 1053 * the whole number. Examples: 1054 * 0x81 -> 0xff 1055 * 0x1af -> 0x1ff 1056 * 0x4d1 -> 0x7ff 1057 */ 1058 static int 1059 makeMask16(int len) 1060 { 1061 // or the high bit in each bit location 1062 len |= (len >> 1); 1063 len |= (len >> 2); 1064 len |= (len >> 4); 1065 len |= (len >> 8); 1066 return len; 1067 } 1068 1069 #define STRING_AND_LENGTH(s) s, sizeof(s) - 1 1070 static int 1071 rsa_GetErrorLength(HMACContext *hmac, int hashLen, int maxLegalLen) 1072 { 1073 unsigned char out[128 * 2]; 1074 unsigned char *outp; 1075 int outLength = 0; 1076 int lengthMask; 1077 SECStatus rv; 1078 1079 lengthMask = makeMask16(maxLegalLen); 1080 rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("length"), hashLen, 1081 out, sizeof(out)); 1082 if (rv != SECSuccess) { 1083 return -1; 1084 } 1085 for (outp = out; outp < out + sizeof(out); outp += 2) { 1086 int candidate = outp[0] << 8 | outp[1]; 1087 candidate = candidate & lengthMask; 1088 outLength = PORT_CT_SEL(PORT_CT_LT(candidate, maxLegalLen), 1089 candidate, outLength); 1090 } 1091 PORT_SafeZero(out, sizeof(out)); 1092 return outLength; 1093 } 1094 1095 /* 1096 * This function can only fail in environmental cases: Programming errors 1097 * and out of memory situations. It can't fail if the keys are valid and 1098 * the inputs are the proper size. If the actual RSA decryption fails, a 1099 * fake value and a fake length, both of which have already been generated 1100 * based on the key and input, are returned. 1101 * Applications are expected to detect decryption failures based on the fact 1102 * that the decrypted value (usually a key) doesn't validate. The prevents 1103 * Blecheinbaucher style attacks against the key. */ 1104 SECStatus 1105 RSA_DecryptBlock(RSAPrivateKey *key, 1106 unsigned char *output, 1107 unsigned int *outputLen, 1108 unsigned int maxOutputLen, 1109 const unsigned char *input, 1110 unsigned int inputLen) 1111 { 1112 SECStatus rv; 1113 PRUint32 fail; 1114 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1115 unsigned int i; 1116 unsigned char *buffer = NULL; 1117 unsigned char *errorBuffer = NULL; 1118 unsigned char *bp = NULL; 1119 unsigned char *ep = NULL; 1120 unsigned int outLen = modulusLen; 1121 unsigned int maxLegalLen = modulusLen - 10; 1122 unsigned int errorLength; 1123 const SECHashObject *hashObj; 1124 HMACContext *hmac = NULL; 1125 1126 /* failures in the top section indicate failures in the environment 1127 * (memory) or the library. OK to return errors in these cases because 1128 * it doesn't provide any oracle information to attackers. */ 1129 if (inputLen != modulusLen || modulusLen < 10) { 1130 PORT_SetError(SEC_ERROR_INVALID_ARGS); 1131 return SECFailure; 1132 } 1133 1134 /* Allocate enough space to decrypt */ 1135 buffer = PORT_ZAlloc(modulusLen); 1136 if (!buffer) { 1137 goto loser; 1138 } 1139 errorBuffer = PORT_ZAlloc(modulusLen); 1140 if (!errorBuffer) { 1141 goto loser; 1142 } 1143 hashObj = HASH_GetRawHashObject(HASH_AlgSHA256); 1144 if (hashObj == NULL) { 1145 goto loser; 1146 } 1147 1148 /* calculate the values to return in the error case rather than 1149 * the actual returned values. This data is the same for the 1150 * same input and private key. */ 1151 hmac = rsa_GetHMACContext(hashObj, key, input, inputLen); 1152 if (hmac == NULL) { 1153 goto loser; 1154 } 1155 errorLength = rsa_GetErrorLength(hmac, hashObj->length, maxLegalLen); 1156 if (((int)errorLength) < 0) { 1157 goto loser; 1158 } 1159 /* we always have to generate a full moduluslen error string. Otherwise 1160 * we create a timing dependency on errorLength, which could be used to 1161 * determine the difference between errorLength and outputLen and tell 1162 * us that there was a pkcs1 decryption failure */ 1163 rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("message"), 1164 hashObj->length, errorBuffer, modulusLen); 1165 if (rv != SECSuccess) { 1166 goto loser; 1167 } 1168 1169 HMAC_Destroy(hmac, PR_TRUE); 1170 hmac = NULL; 1171 1172 /* From here on out, we will always return success. If there is 1173 * an error, we will return deterministic output based on the key 1174 * and the input data. */ 1175 rv = RSA_PrivateKeyOp(key, buffer, input); 1176 1177 fail = PORT_CT_NE(rv, SECSuccess); 1178 fail |= PORT_CT_NE(buffer[0], RSA_BLOCK_FIRST_OCTET) | PORT_CT_NE(buffer[1], RSA_BlockPublic); 1179 1180 /* There have to be at least 8 bytes of padding. */ 1181 for (i = 2; i < 10; i++) { 1182 fail |= PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET); 1183 } 1184 1185 for (i = 10; i < modulusLen; i++) { 1186 unsigned int newLen = modulusLen - i - 1; 1187 PRUint32 condition = PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET) & PORT_CT_EQ(outLen, modulusLen); 1188 outLen = PORT_CT_SEL(condition, newLen, outLen); 1189 } 1190 // this can only happen if a zero wasn't found above 1191 fail |= PORT_CT_GE(outLen, modulusLen); 1192 1193 outLen = PORT_CT_SEL(fail, errorLength, outLen); 1194 1195 /* index into the correct buffer. Do it before we truncate outLen if the 1196 * application was asking for less data than we can return */ 1197 bp = buffer + modulusLen - outLen; 1198 ep = errorBuffer + modulusLen - outLen; 1199 1200 /* at this point, outLen returns no information about decryption failures, 1201 * no need to hide its value. maxOutputLen is how much data the 1202 * application is expecting, which is also not sensitive. */ 1203 if (outLen > maxOutputLen) { 1204 outLen = maxOutputLen; 1205 } 1206 1207 /* we can't use PORT_Memcpy because caching could create a time dependency 1208 * on the status of fail. */ 1209 for (i = 0; i < outLen; i++) { 1210 output[i] = PORT_CT_SEL(fail, ep[i], bp[i]); 1211 } 1212 1213 *outputLen = outLen; 1214 1215 PORT_Free(buffer); 1216 PORT_Free(errorBuffer); 1217 1218 return SECSuccess; 1219 1220 loser: 1221 if (hmac) { 1222 HMAC_Destroy(hmac, PR_TRUE); 1223 } 1224 PORT_Free(buffer); 1225 PORT_Free(errorBuffer); 1226 1227 return SECFailure; 1228 } 1229 1230 /* 1231 * Encode a RSA-PSS signature. 1232 * Described in RFC 3447, section 9.1.1. 1233 * We use mHash instead of M as input. 1234 * emBits from the RFC is just modBits - 1, see section 8.1.1. 1235 * We only support MGF1 as the MGF. 1236 */ 1237 SECStatus 1238 RSA_EMSAEncodePSS(unsigned char *em, 1239 unsigned int emLen, 1240 unsigned int emBits, 1241 const unsigned char *mHash, 1242 HASH_HashType hashAlg, 1243 HASH_HashType maskHashAlg, 1244 const unsigned char *salt, 1245 unsigned int saltLen) 1246 { 1247 const SECHashObject *hash; 1248 void *hash_context; 1249 unsigned char *dbMask; 1250 unsigned int dbMaskLen; 1251 unsigned int i; 1252 SECStatus rv; 1253 1254 hash = HASH_GetRawHashObject(hashAlg); 1255 dbMaskLen = emLen - hash->length - 1; 1256 1257 /* Step 3 */ 1258 if (emLen < hash->length + saltLen + 2) { 1259 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 1260 return SECFailure; 1261 } 1262 1263 /* Step 4 */ 1264 if (salt == NULL) { 1265 rv = RNG_GenerateGlobalRandomBytes(&em[dbMaskLen - saltLen], saltLen); 1266 if (rv != SECSuccess) { 1267 return rv; 1268 } 1269 } else { 1270 PORT_Memcpy(&em[dbMaskLen - saltLen], salt, saltLen); 1271 } 1272 1273 /* Step 5 + 6 */ 1274 /* Compute H and store it at its final location &em[dbMaskLen]. */ 1275 hash_context = (*hash->create)(); 1276 if (hash_context == NULL) { 1277 PORT_SetError(SEC_ERROR_NO_MEMORY); 1278 return SECFailure; 1279 } 1280 (*hash->begin)(hash_context); 1281 (*hash->update)(hash_context, eightZeros, 8); 1282 (*hash->update)(hash_context, mHash, hash->length); 1283 (*hash->update)(hash_context, &em[dbMaskLen - saltLen], saltLen); 1284 (*hash->end)(hash_context, &em[dbMaskLen], &i, hash->length); 1285 (*hash->destroy)(hash_context, PR_TRUE); 1286 1287 /* Step 7 + 8 */ 1288 PORT_Memset(em, 0, dbMaskLen - saltLen - 1); 1289 em[dbMaskLen - saltLen - 1] = 0x01; 1290 1291 /* Step 9 */ 1292 dbMask = (unsigned char *)PORT_Alloc(dbMaskLen); 1293 if (dbMask == NULL) { 1294 PORT_SetError(SEC_ERROR_NO_MEMORY); 1295 return SECFailure; 1296 } 1297 MGF1(maskHashAlg, dbMask, dbMaskLen, &em[dbMaskLen], hash->length); 1298 1299 /* Step 10 */ 1300 for (i = 0; i < dbMaskLen; i++) 1301 em[i] ^= dbMask[i]; 1302 PORT_Free(dbMask); 1303 1304 /* Step 11 */ 1305 em[0] &= 0xff >> (8 * emLen - emBits); 1306 1307 /* Step 12 */ 1308 em[emLen - 1] = 0xbc; 1309 1310 return SECSuccess; 1311 } 1312 1313 /* 1314 * Verify a RSA-PSS signature. 1315 * Described in RFC 3447, section 9.1.2. 1316 * We use mHash instead of M as input. 1317 * emBits from the RFC is just modBits - 1, see section 8.1.2. 1318 * We only support MGF1 as the MGF. 1319 */ 1320 static SECStatus 1321 emsa_pss_verify(const unsigned char *mHash, 1322 const unsigned char *em, 1323 unsigned int emLen, 1324 unsigned int emBits, 1325 HASH_HashType hashAlg, 1326 HASH_HashType maskHashAlg, 1327 unsigned int saltLen) 1328 { 1329 const SECHashObject *hash; 1330 void *hash_context; 1331 unsigned char *db; 1332 unsigned char *H_; /* H' from the RFC */ 1333 unsigned int i; 1334 unsigned int dbMaskLen; 1335 unsigned int zeroBits; 1336 SECStatus rv; 1337 1338 hash = HASH_GetRawHashObject(hashAlg); 1339 dbMaskLen = emLen - hash->length - 1; 1340 1341 /* Step 3 + 4 */ 1342 if ((emLen < (hash->length + saltLen + 2)) || 1343 (em[emLen - 1] != 0xbc)) { 1344 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1345 return SECFailure; 1346 } 1347 1348 /* Step 6 */ 1349 zeroBits = 8 * emLen - emBits; 1350 if (em[0] >> (8 - zeroBits)) { 1351 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1352 return SECFailure; 1353 } 1354 1355 /* Step 7 */ 1356 db = (unsigned char *)PORT_Alloc(dbMaskLen); 1357 if (db == NULL) { 1358 PORT_SetError(SEC_ERROR_NO_MEMORY); 1359 return SECFailure; 1360 } 1361 /* &em[dbMaskLen] points to H, used as mgfSeed */ 1362 MGF1(maskHashAlg, db, dbMaskLen, &em[dbMaskLen], hash->length); 1363 1364 /* Step 8 */ 1365 for (i = 0; i < dbMaskLen; i++) { 1366 db[i] ^= em[i]; 1367 } 1368 1369 /* Step 9 */ 1370 db[0] &= 0xff >> zeroBits; 1371 1372 /* Step 10 */ 1373 for (i = 0; i < (dbMaskLen - saltLen - 1); i++) { 1374 if (db[i] != 0) { 1375 PORT_Free(db); 1376 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1377 return SECFailure; 1378 } 1379 } 1380 if (db[dbMaskLen - saltLen - 1] != 0x01) { 1381 PORT_Free(db); 1382 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1383 return SECFailure; 1384 } 1385 1386 /* Step 12 + 13 */ 1387 H_ = (unsigned char *)PORT_Alloc(hash->length); 1388 if (H_ == NULL) { 1389 PORT_Free(db); 1390 PORT_SetError(SEC_ERROR_NO_MEMORY); 1391 return SECFailure; 1392 } 1393 hash_context = (*hash->create)(); 1394 if (hash_context == NULL) { 1395 PORT_Free(db); 1396 PORT_Free(H_); 1397 PORT_SetError(SEC_ERROR_NO_MEMORY); 1398 return SECFailure; 1399 } 1400 (*hash->begin)(hash_context); 1401 (*hash->update)(hash_context, eightZeros, 8); 1402 (*hash->update)(hash_context, mHash, hash->length); 1403 (*hash->update)(hash_context, &db[dbMaskLen - saltLen], saltLen); 1404 (*hash->end)(hash_context, H_, &i, hash->length); 1405 (*hash->destroy)(hash_context, PR_TRUE); 1406 1407 PORT_Free(db); 1408 1409 /* Step 14 */ 1410 if (PORT_Memcmp(H_, &em[dbMaskLen], hash->length) != 0) { 1411 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1412 rv = SECFailure; 1413 } else { 1414 rv = SECSuccess; 1415 } 1416 1417 PORT_Free(H_); 1418 return rv; 1419 } 1420 1421 SECStatus 1422 RSA_SignPSS(RSAPrivateKey *key, 1423 HASH_HashType hashAlg, 1424 HASH_HashType maskHashAlg, 1425 const unsigned char *salt, 1426 unsigned int saltLength, 1427 unsigned char *output, 1428 unsigned int *outputLen, 1429 unsigned int maxOutputLen, 1430 const unsigned char *input, 1431 unsigned int inputLen) 1432 { 1433 SECStatus rv = SECSuccess; 1434 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1435 unsigned int modulusBits = rsa_modulusBits(&key->modulus); 1436 unsigned int emLen = modulusLen; 1437 unsigned char *pssEncoded, *em; 1438 1439 if (maxOutputLen < modulusLen) { 1440 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 1441 return SECFailure; 1442 } 1443 1444 if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) { 1445 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 1446 return SECFailure; 1447 } 1448 1449 pssEncoded = em = (unsigned char *)PORT_Alloc(modulusLen); 1450 if (pssEncoded == NULL) { 1451 PORT_SetError(SEC_ERROR_NO_MEMORY); 1452 return SECFailure; 1453 } 1454 1455 /* len(em) == ceil((modulusBits - 1) / 8). */ 1456 if (modulusBits % 8 == 1) { 1457 em[0] = 0; 1458 emLen--; 1459 em++; 1460 } 1461 rv = RSA_EMSAEncodePSS(em, emLen, modulusBits - 1, input, hashAlg, 1462 maskHashAlg, salt, saltLength); 1463 if (rv != SECSuccess) 1464 goto done; 1465 1466 // This sets error codes upon failure. 1467 rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded); 1468 *outputLen = modulusLen; 1469 1470 done: 1471 PORT_Free(pssEncoded); 1472 return rv; 1473 } 1474 1475 SECStatus 1476 RSA_CheckSignPSS(RSAPublicKey *key, 1477 HASH_HashType hashAlg, 1478 HASH_HashType maskHashAlg, 1479 unsigned int saltLength, 1480 const unsigned char *sig, 1481 unsigned int sigLen, 1482 const unsigned char *hash, 1483 unsigned int hashLen) 1484 { 1485 SECStatus rv; 1486 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1487 unsigned int modulusBits = rsa_modulusBits(&key->modulus); 1488 unsigned int emLen = modulusLen; 1489 unsigned char *buffer, *em; 1490 1491 if (sigLen != modulusLen) { 1492 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1493 return SECFailure; 1494 } 1495 1496 if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) { 1497 PORT_SetError(SEC_ERROR_INVALID_ALGORITHM); 1498 return SECFailure; 1499 } 1500 1501 buffer = em = (unsigned char *)PORT_Alloc(modulusLen); 1502 if (!buffer) { 1503 PORT_SetError(SEC_ERROR_NO_MEMORY); 1504 return SECFailure; 1505 } 1506 1507 rv = RSA_PublicKeyOp(key, buffer, sig); 1508 if (rv != SECSuccess) { 1509 PORT_Free(buffer); 1510 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1511 return SECFailure; 1512 } 1513 1514 /* len(em) == ceil((modulusBits - 1) / 8). */ 1515 if (modulusBits % 8 == 1) { 1516 emLen--; 1517 em++; 1518 } 1519 rv = emsa_pss_verify(hash, em, emLen, modulusBits - 1, hashAlg, 1520 maskHashAlg, saltLength); 1521 1522 PORT_Free(buffer); 1523 return rv; 1524 } 1525 1526 SECStatus 1527 RSA_Sign(RSAPrivateKey *key, 1528 unsigned char *output, 1529 unsigned int *outputLen, 1530 unsigned int maxOutputLen, 1531 const unsigned char *input, 1532 unsigned int inputLen) 1533 { 1534 SECStatus rv = SECFailure; 1535 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1536 SECItem formatted = { siBuffer, NULL, 0 }; 1537 SECItem unformatted = { siBuffer, (unsigned char *)input, inputLen }; 1538 1539 if (maxOutputLen < modulusLen) { 1540 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 1541 goto done; 1542 } 1543 1544 rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate, 1545 &unformatted); 1546 if (rv != SECSuccess) { 1547 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); 1548 goto done; 1549 } 1550 1551 // This sets error codes upon failure. 1552 rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data); 1553 *outputLen = modulusLen; 1554 1555 done: 1556 if (formatted.data != NULL) { 1557 PORT_ZFree(formatted.data, modulusLen); 1558 } 1559 return rv; 1560 } 1561 1562 SECStatus 1563 RSA_CheckSign(RSAPublicKey *key, 1564 const unsigned char *sig, 1565 unsigned int sigLen, 1566 const unsigned char *data, 1567 unsigned int dataLen) 1568 { 1569 SECStatus rv = SECFailure; 1570 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1571 unsigned int i; 1572 unsigned char *buffer = NULL; 1573 1574 if (sigLen != modulusLen) { 1575 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1576 goto done; 1577 } 1578 1579 /* 1580 * 0x00 || BT || Pad || 0x00 || ActualData 1581 * 1582 * The "3" below is the first octet + the second octet + the 0x00 1583 * octet that always comes just before the ActualData. 1584 */ 1585 if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) { 1586 PORT_SetError(SEC_ERROR_BAD_DATA); 1587 goto done; 1588 } 1589 1590 buffer = (unsigned char *)PORT_Alloc(modulusLen + 1); 1591 if (!buffer) { 1592 PORT_SetError(SEC_ERROR_NO_MEMORY); 1593 goto done; 1594 } 1595 1596 if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) { 1597 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1598 goto done; 1599 } 1600 1601 /* 1602 * check the padding that was used 1603 */ 1604 if (buffer[0] != RSA_BLOCK_FIRST_OCTET || 1605 buffer[1] != (unsigned char)RSA_BlockPrivate) { 1606 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1607 goto done; 1608 } 1609 for (i = 2; i < modulusLen - dataLen - 1; i++) { 1610 if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) { 1611 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1612 goto done; 1613 } 1614 } 1615 if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) { 1616 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1617 goto done; 1618 } 1619 1620 /* 1621 * make sure we get the same results 1622 */ 1623 if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) == 0) { 1624 rv = SECSuccess; 1625 } 1626 1627 done: 1628 if (buffer) { 1629 PORT_Free(buffer); 1630 } 1631 return rv; 1632 } 1633 1634 SECStatus 1635 RSA_CheckSignRecover(RSAPublicKey *key, 1636 unsigned char *output, 1637 unsigned int *outputLen, 1638 unsigned int maxOutputLen, 1639 const unsigned char *sig, 1640 unsigned int sigLen) 1641 { 1642 SECStatus rv = SECFailure; 1643 unsigned int modulusLen = rsa_modulusLen(&key->modulus); 1644 unsigned int i; 1645 unsigned char *buffer = NULL; 1646 unsigned int padLen; 1647 1648 if (sigLen != modulusLen) { 1649 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1650 goto done; 1651 } 1652 1653 buffer = (unsigned char *)PORT_Alloc(modulusLen + 1); 1654 if (!buffer) { 1655 PORT_SetError(SEC_ERROR_NO_MEMORY); 1656 goto done; 1657 } 1658 1659 if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) { 1660 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1661 goto done; 1662 } 1663 1664 *outputLen = 0; 1665 1666 /* 1667 * check the padding that was used 1668 */ 1669 if (buffer[0] != RSA_BLOCK_FIRST_OCTET || 1670 buffer[1] != (unsigned char)RSA_BlockPrivate) { 1671 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1672 goto done; 1673 } 1674 for (i = 2; i < modulusLen; i++) { 1675 if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) { 1676 *outputLen = modulusLen - i - 1; 1677 break; 1678 } 1679 if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) { 1680 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1681 goto done; 1682 } 1683 } 1684 padLen = i - 2; 1685 if (padLen < RSA_BLOCK_MIN_PAD_LEN) { 1686 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1687 goto done; 1688 } 1689 if (*outputLen == 0) { 1690 PORT_SetError(SEC_ERROR_BAD_SIGNATURE); 1691 goto done; 1692 } 1693 if (*outputLen > maxOutputLen) { 1694 PORT_SetError(SEC_ERROR_OUTPUT_LEN); 1695 goto done; 1696 } 1697 1698 PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen); 1699 rv = SECSuccess; 1700 1701 done: 1702 if (buffer) { 1703 PORT_Free(buffer); 1704 } 1705 return rv; 1706 }