nss_countbytes.c (7601B)
1 /* Copyright 2018-2021, The Tor Project Inc. */ 2 /* See LICENSE for licensing information */ 3 4 /** 5 * \file nss_countbytes.c 6 * \brief A PRFileDesc layer to let us count the number of bytes 7 * bytes actually written on a PRFileDesc. 8 **/ 9 10 #include "orconfig.h" 11 12 #include "lib/log/util_bug.h" 13 #include "lib/malloc/malloc.h" 14 #include "lib/tls/nss_countbytes.h" 15 16 #include <stdlib.h> 17 #include <string.h> 18 19 #include <prio.h> 20 21 /** Boolean: have we initialized this module */ 22 static bool countbytes_initialized = false; 23 24 /** Integer to identity this layer. */ 25 static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER; 26 27 /** Table of methods for this layer.*/ 28 static PRIOMethods countbytes_methods; 29 30 /** Default close function provided by NSPR. We use this to help 31 * implement our own close function.*/ 32 static PRStatus(*default_close_fn)(PRFileDesc *fd); 33 34 static PRStatus countbytes_close_fn(PRFileDesc *fd); 35 static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount); 36 static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf, 37 PRInt32 amount); 38 static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, 39 PRInt32 size, PRIntervalTime timeout); 40 static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf, 41 PRInt32 amount, PRIntn flags, 42 PRIntervalTime timeout); 43 static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, 44 PRIntn flags, PRIntervalTime timeout); 45 46 /** Private fields for the byte-counter layer. We cast this to and from 47 * PRFilePrivate*, which is supposed to be allowed. */ 48 typedef struct tor_nss_bytecounts_t { 49 uint64_t n_read; 50 uint64_t n_written; 51 } tor_nss_bytecounts_t; 52 53 /** 54 * Initialize this module, if it is not already initialized. 55 **/ 56 void 57 tor_nss_countbytes_init(void) 58 { 59 if (countbytes_initialized) 60 return; 61 62 countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer"); 63 tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER); 64 65 memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods)); 66 67 default_close_fn = countbytes_methods.close; 68 countbytes_methods.close = countbytes_close_fn; 69 countbytes_methods.read = countbytes_read_fn; 70 countbytes_methods.write = countbytes_write_fn; 71 countbytes_methods.writev = countbytes_writev_fn; 72 countbytes_methods.send = countbytes_send_fn; 73 countbytes_methods.recv = countbytes_recv_fn; 74 /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think 75 * NSS won't be using them for TLS connections. */ 76 77 countbytes_initialized = true; 78 } 79 80 /** 81 * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that 82 * the IO layer is in fact a layer created by this module. 83 */ 84 static tor_nss_bytecounts_t * 85 get_counts(PRFileDesc *fd) 86 { 87 tor_assert(fd->identity == countbytes_layer_id); 88 return (tor_nss_bytecounts_t*) fd->secret; 89 } 90 91 /** Helper: increment the read-count of an fd by n. */ 92 #define INC_READ(fd, n) STMT_BEGIN \ 93 get_counts(fd)->n_read += (n); \ 94 STMT_END 95 96 /** Helper: increment the write-count of an fd by n. */ 97 #define INC_WRITTEN(fd, n) STMT_BEGIN \ 98 get_counts(fd)->n_written += (n); \ 99 STMT_END 100 101 /** Implementation for PR_Close: frees the 'secret' field, then passes control 102 * to the default close function */ 103 static PRStatus 104 countbytes_close_fn(PRFileDesc *fd) 105 { 106 tor_assert(fd); 107 108 tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret; 109 tor_free(counts); 110 fd->secret = NULL; 111 112 return default_close_fn(fd); 113 } 114 115 /** Implementation for PR_Read: Calls the lower-level read function, 116 * and records what it said. */ 117 static PRInt32 118 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount) 119 { 120 tor_assert(fd); 121 tor_assert(fd->lower); 122 123 PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount); 124 if (result > 0) 125 INC_READ(fd, result); 126 return result; 127 } 128 /** Implementation for PR_Write: Calls the lower-level write function, 129 * and records what it said. */ 130 static PRInt32 131 countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount) 132 { 133 tor_assert(fd); 134 tor_assert(fd->lower); 135 136 PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount); 137 if (result > 0) 138 INC_WRITTEN(fd, result); 139 return result; 140 } 141 /** Implementation for PR_Writev: Calls the lower-level writev function, 142 * and records what it said. */ 143 static PRInt32 144 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, 145 PRInt32 size, PRIntervalTime timeout) 146 { 147 tor_assert(fd); 148 tor_assert(fd->lower); 149 150 PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout); 151 if (result > 0) 152 INC_WRITTEN(fd, result); 153 return result; 154 } 155 /** Implementation for PR_Send: Calls the lower-level send function, 156 * and records what it said. */ 157 static PRInt32 158 countbytes_send_fn(PRFileDesc *fd, const void *buf, 159 PRInt32 amount, PRIntn flags, PRIntervalTime timeout) 160 { 161 tor_assert(fd); 162 tor_assert(fd->lower); 163 164 PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags, 165 timeout); 166 if (result > 0) 167 INC_WRITTEN(fd, result); 168 return result; 169 } 170 /** Implementation for PR_Recv: Calls the lower-level recv function, 171 * and records what it said. */ 172 static PRInt32 173 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, 174 PRIntn flags, PRIntervalTime timeout) 175 { 176 tor_assert(fd); 177 tor_assert(fd->lower); 178 179 PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags, 180 timeout); 181 if (result > 0) 182 INC_READ(fd, result); 183 return result; 184 } 185 186 /** 187 * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the 188 * total number of bytes read and written. Return the new PRFileDesc. 189 * 190 * This function takes ownership of its input. 191 */ 192 PRFileDesc * 193 tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack) 194 { 195 if (BUG(! countbytes_initialized)) { 196 tor_nss_countbytes_init(); 197 } 198 199 tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts)); 200 201 PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id, 202 &countbytes_methods); 203 tor_assert(newfd); 204 newfd->secret = (PRFilePrivate *)bytecounts; 205 206 /* This does some complicated messing around with the headers of these 207 objects; see the NSPR documentation for more. The upshot is that 208 after PushIOLayer, "stack" will be the head of the stack. 209 */ 210 PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd); 211 tor_assert(status == PR_SUCCESS); 212 213 return stack; 214 } 215 216 /** 217 * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(), 218 * or another PRFileDesc wrapping that PRFileDesc, set the provided 219 * pointers to the number of bytes read and written on the descriptor since 220 * it was created. 221 * 222 * Return 0 on success, -1 on failure. 223 */ 224 int 225 tor_get_prfiledesc_byte_counts(PRFileDesc *fd, 226 uint64_t *n_read_out, 227 uint64_t *n_written_out) 228 { 229 if (BUG(! countbytes_initialized)) { 230 tor_nss_countbytes_init(); 231 } 232 233 tor_assert(fd); 234 PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id); 235 if (BUG(bclayer == NULL)) 236 return -1; 237 238 tor_nss_bytecounts_t *counts = get_counts(bclayer); 239 240 *n_read_out = counts->n_read; 241 *n_written_out = counts->n_written; 242 243 return 0; 244 }