lib.rs (5299B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ 4 5 mod builtins; 6 mod cert_storage; 7 mod tls; 8 9 use log; 10 use nserror::{nsresult, NS_ERROR_INVALID_ARG, NS_ERROR_UNEXPECTED, NS_OK}; 11 use std::fmt::Display; 12 use std::io::Write; 13 use tls::{CertificateMessage, CompressedCertEntry, UncompressedCertEntry}; 14 15 #[no_mangle] 16 pub extern "C" fn certs_are_available() -> bool { 17 let Some(hashes) = builtins::get_needed_hashes() else { 18 return false; 19 }; 20 match cert_storage::has_all_certs_by_hash(hashes) { 21 Ok(result) => { 22 log::debug!("certs_are_available {}", result); 23 return result; 24 } 25 Err(e) => { 26 log::warn!("certs_are_available failed: {:?}", e); 27 return false; 28 } 29 } 30 } 31 32 #[no_mangle] 33 pub extern "C" fn decompress( 34 input: *const u8, 35 input_len: usize, 36 output: *mut u8, 37 output_len: usize, 38 used_len: *mut usize, 39 ) -> nsresult { 40 if input.is_null() || output.is_null() || used_len.is_null() { 41 return NS_ERROR_INVALID_ARG; 42 } 43 let input_slice = unsafe { std::slice::from_raw_parts(input, input_len) }; 44 45 let mut output = unsafe { 46 std::ptr::write_bytes(output, 0, output_len); 47 std::slice::from_raw_parts_mut(output, output_len) 48 }; 49 50 let size = match pass_1_decompression(input_slice, output_len, &mut output) { 51 Ok(size) => size, 52 Err(e) => { 53 log::error!("Error during pass 1 decompression: {}", e); 54 return NS_ERROR_UNEXPECTED; 55 } 56 }; 57 58 unsafe { 59 *used_len = size; 60 } 61 62 log::debug!( 63 "successfully decompressed {} input bytes to {} output_bytes ", 64 input_len, 65 size, 66 ); 67 NS_OK 68 } 69 70 fn pass_1_mapping(mut entry: UncompressedCertEntry) -> Result<CompressedCertEntry, AbridgedError> { 71 let id_or_cert = &entry.data; 72 let Ok(id) = TryInto::<&[u8; 3]>::try_into(id_or_cert.as_slice()) else { 73 // Avoid doing a lookup when we know its not an identifier. 74 log::trace!("Passing through directly {:#02X?}", entry.data); 75 return Ok(entry); 76 }; 77 78 log::trace!("Abridged Certs Identifier: {:#02X?}", id); 79 let Some(hash) = builtins::id_to_hash(id) else { 80 return Err(AbridgedError::UnknownIdentifier(id_or_cert.to_vec())); 81 }; 82 83 match cert_storage::get_cert_from_hash(&hash) { 84 Ok(cert_bytes) => entry.data = cert_bytes.to_vec(), 85 Err(err) => { 86 return Err(AbridgedError::UnableToLoadCertByHash(hash.to_vec(), err)); 87 } 88 }; 89 Ok(entry) 90 } 91 fn pass_1_decompression( 92 input: &[u8], 93 expected_len: usize, 94 output: &mut impl Write, 95 ) -> Result<usize, AbridgedError> { 96 let (mut cert_msg, tail) = CertificateMessage::read_from_bytes(&input)?; 97 98 if !tail.is_empty() { 99 // Trailing data on a certificate message is forbidden by Abridged Certs spec 100 return Err(AbridgedError::ParsingInvalidCertificateMessage); 101 } 102 103 let mut new_entries = Vec::with_capacity(cert_msg.certificate_entries.len()); 104 105 // We keep a running tally of the decompressed message's size. 106 let mut current_size = cert_msg.get_size(); 107 108 for entry in cert_msg.certificate_entries { 109 current_size -= entry.get_size(); 110 let mapping = pass_1_mapping(entry)?; 111 current_size += mapping.get_size(); 112 new_entries.push(mapping); 113 if current_size > expected_len { 114 return Err(AbridgedError::LargerThanExpectedSize(expected_len)); 115 } 116 } 117 cert_msg.certificate_entries = new_entries; 118 assert_eq!(current_size, cert_msg.get_size()); 119 cert_msg.write_to_bytes(output)?; 120 Ok(cert_msg.get_size()) 121 } 122 123 #[derive(Debug)] 124 pub enum AbridgedError { 125 UnknownError, 126 ParsingInvalidTLSVec, 127 ParsingInvalidCertificateMessage, 128 InvalidOperation, 129 LargerThanExpectedSize(usize), 130 ReadingError(std::io::Error), 131 WritingError(std::io::Error), 132 UnableToLoadCertByHash(Vec<u8>, nsresult), 133 UnknownIdentifier(Vec<u8>), 134 } 135 136 impl Display for AbridgedError { 137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 138 match self { 139 AbridgedError::UnknownError => write!(f, "Unknown Error"), 140 AbridgedError::ParsingInvalidTLSVec => write!(f, "ParsingInvalidTLSVec"), 141 AbridgedError::ParsingInvalidCertificateMessage => { 142 write!(f, "ParsingInvalidCertificateMessage") 143 } 144 AbridgedError::InvalidOperation => write!(f, "InvalidOperation"), 145 AbridgedError::LargerThanExpectedSize(size) => { 146 write!(f, "Larger Than Expected Sizes {}", size) 147 } 148 AbridgedError::ReadingError(x) => write!(f, "Writing Error {}", x), 149 AbridgedError::WritingError(x) => write!(f, "Writing Error {}", x), 150 AbridgedError::UnableToLoadCertByHash(hash, error) => write!( 151 f, 152 "Unable to Load Cert for Hash {:#02X?}. Error: {}", 153 hash, error 154 ), 155 AbridgedError::UnknownIdentifier(id) => { 156 write!(f, "Unrecognized Identifier {:#02X?}", id) 157 } 158 } 159 } 160 }