tor

The Tor anonymity network
git clone https://git.dasho.dev/tor.git
Log | Files | Refs | README | LICENSE

nss_countbytes.c (7601B)


      1 /* Copyright 2018-2021, The Tor Project Inc. */
      2 /* See LICENSE for licensing information */
      3 
      4 /**
      5 * \file nss_countbytes.c
      6 * \brief A PRFileDesc layer to let us count the number of bytes
      7 *        bytes actually written on a PRFileDesc.
      8 **/
      9 
     10 #include "orconfig.h"
     11 
     12 #include "lib/log/util_bug.h"
     13 #include "lib/malloc/malloc.h"
     14 #include "lib/tls/nss_countbytes.h"
     15 
     16 #include <stdlib.h>
     17 #include <string.h>
     18 
     19 #include <prio.h>
     20 
     21 /** Boolean: have we initialized this module */
     22 static bool countbytes_initialized = false;
     23 
     24 /** Integer to identity this layer. */
     25 static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER;
     26 
     27 /** Table of methods for this layer.*/
     28 static PRIOMethods countbytes_methods;
     29 
     30 /** Default close function provided by NSPR.  We use this to help
     31 *  implement our own close function.*/
     32 static PRStatus(*default_close_fn)(PRFileDesc *fd);
     33 
     34 static PRStatus countbytes_close_fn(PRFileDesc *fd);
     35 static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount);
     36 static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf,
     37                                   PRInt32 amount);
     38 static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
     39                                    PRInt32 size, PRIntervalTime timeout);
     40 static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf,
     41                                  PRInt32 amount, PRIntn flags,
     42                                  PRIntervalTime timeout);
     43 static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
     44                                  PRIntn flags, PRIntervalTime timeout);
     45 
     46 /** Private fields for the byte-counter layer.  We cast this to and from
     47 * PRFilePrivate*, which is supposed to be allowed. */
     48 typedef struct tor_nss_bytecounts_t {
     49  uint64_t n_read;
     50  uint64_t n_written;
     51 } tor_nss_bytecounts_t;
     52 
     53 /**
     54 * Initialize this module, if it is not already initialized.
     55 **/
     56 void
     57 tor_nss_countbytes_init(void)
     58 {
     59  if (countbytes_initialized)
     60    return;
     61 
     62  countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer");
     63  tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER);
     64 
     65  memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods));
     66 
     67  default_close_fn = countbytes_methods.close;
     68  countbytes_methods.close = countbytes_close_fn;
     69  countbytes_methods.read = countbytes_read_fn;
     70  countbytes_methods.write = countbytes_write_fn;
     71  countbytes_methods.writev = countbytes_writev_fn;
     72  countbytes_methods.send = countbytes_send_fn;
     73  countbytes_methods.recv = countbytes_recv_fn;
     74  /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
     75   * NSS won't be using them for TLS connections. */
     76 
     77  countbytes_initialized = true;
     78 }
     79 
     80 /**
     81 * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
     82 * the IO layer is in fact a layer created by this module.
     83 */
     84 static tor_nss_bytecounts_t *
     85 get_counts(PRFileDesc *fd)
     86 {
     87  tor_assert(fd->identity == countbytes_layer_id);
     88  return (tor_nss_bytecounts_t*) fd->secret;
     89 }
     90 
     91 /** Helper: increment the read-count of an fd by n. */
     92 #define INC_READ(fd, n) STMT_BEGIN                      \
     93    get_counts(fd)->n_read += (n);                      \
     94  STMT_END
     95 
     96 /** Helper: increment the write-count of an fd by n. */
     97 #define INC_WRITTEN(fd, n) STMT_BEGIN                      \
     98    get_counts(fd)->n_written += (n);                      \
     99  STMT_END
    100 
    101 /** Implementation for PR_Close: frees the 'secret' field, then passes control
    102 * to the default close function */
    103 static PRStatus
    104 countbytes_close_fn(PRFileDesc *fd)
    105 {
    106  tor_assert(fd);
    107 
    108  tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret;
    109  tor_free(counts);
    110  fd->secret = NULL;
    111 
    112  return default_close_fn(fd);
    113 }
    114 
    115 /** Implementation for PR_Read: Calls the lower-level read function,
    116 * and records what it said. */
    117 static PRInt32
    118 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount)
    119 {
    120  tor_assert(fd);
    121  tor_assert(fd->lower);
    122 
    123  PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount);
    124  if (result > 0)
    125    INC_READ(fd, result);
    126  return result;
    127 }
    128 /** Implementation for PR_Write: Calls the lower-level write function,
    129 * and records what it said. */
    130 static PRInt32
    131 countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount)
    132 {
    133  tor_assert(fd);
    134  tor_assert(fd->lower);
    135 
    136  PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount);
    137  if (result > 0)
    138    INC_WRITTEN(fd, result);
    139  return result;
    140 }
    141 /** Implementation for PR_Writev: Calls the lower-level writev function,
    142 * and records what it said. */
    143 static PRInt32
    144 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
    145                     PRInt32 size, PRIntervalTime timeout)
    146 {
    147  tor_assert(fd);
    148  tor_assert(fd->lower);
    149 
    150  PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout);
    151  if (result > 0)
    152    INC_WRITTEN(fd, result);
    153  return result;
    154 }
    155 /** Implementation for PR_Send: Calls the lower-level send function,
    156 * and records what it said. */
    157 static PRInt32
    158 countbytes_send_fn(PRFileDesc *fd, const void *buf,
    159                   PRInt32 amount, PRIntn flags, PRIntervalTime timeout)
    160 {
    161  tor_assert(fd);
    162  tor_assert(fd->lower);
    163 
    164  PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags,
    165                                              timeout);
    166  if (result > 0)
    167    INC_WRITTEN(fd, result);
    168  return result;
    169 }
    170 /** Implementation for PR_Recv: Calls the lower-level recv function,
    171 * and records what it said. */
    172 static PRInt32
    173 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
    174                                  PRIntn flags, PRIntervalTime timeout)
    175 {
    176  tor_assert(fd);
    177  tor_assert(fd->lower);
    178 
    179  PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags,
    180                                              timeout);
    181  if (result > 0)
    182    INC_READ(fd, result);
    183  return result;
    184 }
    185 
    186 /**
    187 * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
    188 * total number of bytes read and written.  Return the new PRFileDesc.
    189 *
    190 * This function takes ownership of its input.
    191 */
    192 PRFileDesc *
    193 tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack)
    194 {
    195  if (BUG(! countbytes_initialized)) {
    196    tor_nss_countbytes_init();
    197  }
    198 
    199  tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts));
    200 
    201  PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id,
    202                                           &countbytes_methods);
    203  tor_assert(newfd);
    204  newfd->secret = (PRFilePrivate *)bytecounts;
    205 
    206  /* This does some complicated messing around with the headers of these
    207     objects; see the NSPR documentation for more. The upshot is that
    208     after PushIOLayer, "stack" will be the head of the stack.
    209  */
    210  PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd);
    211  tor_assert(status == PR_SUCCESS);
    212 
    213  return stack;
    214 }
    215 
    216 /**
    217 * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
    218 * or another PRFileDesc wrapping that PRFileDesc, set the provided
    219 * pointers to the number of bytes read and written on the descriptor since
    220 * it was created.
    221 *
    222 * Return 0 on success, -1 on failure.
    223 */
    224 int
    225 tor_get_prfiledesc_byte_counts(PRFileDesc *fd,
    226                               uint64_t *n_read_out,
    227                               uint64_t *n_written_out)
    228 {
    229  if (BUG(! countbytes_initialized)) {
    230    tor_nss_countbytes_init();
    231  }
    232 
    233  tor_assert(fd);
    234  PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id);
    235  if (BUG(bclayer == NULL))
    236    return -1;
    237 
    238  tor_nss_bytecounts_t *counts = get_counts(bclayer);
    239 
    240  *n_read_out = counts->n_read;
    241  *n_written_out = counts->n_written;
    242 
    243  return 0;
    244 }