tls_protect.cc (4625B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "tls_protect.h" 8 #include "sslproto.h" 9 #include "tls_filter.h" 10 11 namespace nss_test { 12 13 static uint64_t FirstSeqno(bool dtls, uint16_t epoc) { 14 if (dtls) { 15 return static_cast<uint64_t>(epoc) << 48; 16 } 17 return 0; 18 } 19 20 TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc) 21 : dtls_(dtls), 22 epoch_(epoc), 23 in_seqno_(FirstSeqno(dtls, epoc)), 24 out_seqno_(FirstSeqno(dtls, epoc)) {} 25 26 bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, 27 PK11SymKey* secret) { 28 SSLAeadContext* aead_ctx; 29 SSLProtocolVariant variant = 30 dtls_ ? ssl_variant_datagram : ssl_variant_stream; 31 SECStatus rv = 32 SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, 33 variant, secret, "", 0, // Use the default labels. 34 &aead_ctx); 35 if (rv != SECSuccess) { 36 return false; 37 } 38 aead_.reset(aead_ctx); 39 40 SSLMaskingContext* mask_ctx; 41 const char kHkdfPurposeSn[] = "sn"; 42 rv = SSL_CreateVariantMaskingContext( 43 SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret, 44 kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx); 45 if (rv != SECSuccess) { 46 return false; 47 } 48 mask_.reset(mask_ctx); 49 return true; 50 } 51 52 bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header, 53 const DataBuffer& ciphertext, 54 DataBuffer* plaintext, 55 TlsRecordHeader* out_header) { 56 if (!aead_ || !out_header) { 57 return false; 58 } 59 *out_header = header; 60 61 // Make space. 62 plaintext->Allocate(ciphertext.len()); 63 64 unsigned int len; 65 uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_; 66 SECStatus rv; 67 68 if (header.is_dtls13_ciphertext()) { 69 if (!mask_ || !out_header) { 70 return false; 71 } 72 PORT_Assert(ciphertext.len() >= 16); 73 DataBuffer mask(2); 74 rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(), 75 mask.data(), mask.len()); 76 if (rv != SECSuccess) { 77 return false; 78 } 79 80 if (!out_header->MaskSequenceNumber(mask)) { 81 return false; 82 } 83 seqno = out_header->sequence_number(); 84 } 85 86 if (header.is_dtls() && (header.version() >= SSL_LIBRARY_VERSION_TLS_1_3)) { 87 // Removing the epoch (16 first bits) 88 seqno = seqno & 0xffffffffffff; 89 } 90 91 auto header_bytes = out_header->header(); 92 rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(), 93 header_bytes.len(), ciphertext.data(), ciphertext.len(), 94 plaintext->data(), &len, plaintext->len()); 95 if (rv != SECSuccess) { 96 return false; 97 } 98 99 RecordUnprotected(seqno); 100 plaintext->Truncate(static_cast<size_t>(len)); 101 102 return true; 103 } 104 105 bool TlsCipherSpec::Protect(const TlsRecordHeader& header, 106 const DataBuffer& plaintext, DataBuffer* ciphertext, 107 TlsRecordHeader* out_header) { 108 if (!aead_ || !out_header) { 109 return false; 110 } 111 112 *out_header = header; 113 114 // Make a padded buffer. 115 ciphertext->Allocate(plaintext.len() + 116 32); // Room for any plausible auth tag 117 unsigned int len; 118 119 DataBuffer header_bytes; 120 (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16); 121 uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_; 122 123 if (header.is_dtls() && (header.version() >= SSL_LIBRARY_VERSION_TLS_1_3)) { 124 // Removing the epoch (16 first bits) 125 seqno = seqno & 0xffffffffffff; 126 } 127 128 SECStatus rv = 129 SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(), 130 header_bytes.len(), plaintext.data(), plaintext.len(), 131 ciphertext->data(), &len, ciphertext->len()); 132 if (rv != SECSuccess) { 133 return false; 134 } 135 136 if (header.is_dtls13_ciphertext()) { 137 if (!mask_ || !out_header) { 138 return false; 139 } 140 PORT_Assert(ciphertext->len() >= 16); 141 DataBuffer mask(2); 142 rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(), 143 mask.data(), mask.len()); 144 if (rv != SECSuccess) { 145 return false; 146 } 147 if (!out_header->MaskSequenceNumber(mask)) { 148 return false; 149 } 150 } 151 152 RecordProtected(); 153 ciphertext->Truncate(len); 154 155 return true; 156 } 157 158 } // namespace nss_test