tor

The Tor anonymity network
git clone https://git.dasho.dev/tor.git
Log | Files | Refs | README | LICENSE

compress_zstd.c (16385B)


      1 /* Copyright (c) 2004, Roger Dingledine.
      2 * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
      3 * Copyright (c) 2007-2021, The Tor Project, Inc. */
      4 /* See LICENSE for licensing information */
      5 
      6 /**
      7 * \file compress_zstd.c
      8 * \brief Compression backend for Zstandard.
      9 *
     10 * This module should never be invoked directly. Use the compress module
     11 * instead.
     12 **/
     13 
     14 #include "orconfig.h"
     15 
     16 #include "lib/log/log.h"
     17 #include "lib/log/util_bug.h"
     18 #include "lib/compress/compress.h"
     19 #include "lib/compress/compress_zstd.h"
     20 #include "lib/string/printf.h"
     21 #include "lib/thread/threads.h"
     22 
     23 #ifdef ENABLE_ZSTD_ADVANCED_APIS
     24 /* This is a lie, but we make sure it doesn't get us in trouble by wrapping
     25 * all invocations of zstd's static-only functions in a check to make sure
     26 * that the compile-time version matches the run-time version. */
     27 #define ZSTD_STATIC_LINKING_ONLY
     28 #endif /* defined(ENABLE_ZSTD_ADVANCED_APIS) */
     29 
     30 #ifdef HAVE_ZSTD
     31 #ifdef HAVE_CFLAG_WUNUSED_CONST_VARIABLE
     32 DISABLE_GCC_WARNING("-Wunused-const-variable")
     33 #endif
     34 #include <zstd.h>
     35 #ifdef HAVE_CFLAG_WUNUSED_CONST_VARIABLE
     36 ENABLE_GCC_WARNING("-Wunused-const-variable")
     37 #endif
     38 #endif /* defined(HAVE_ZSTD) */
     39 
     40 /** Total number of bytes allocated for Zstandard state. */
     41 static atomic_counter_t total_zstd_allocation;
     42 
     43 #ifdef HAVE_ZSTD
     44 /** Given <b>level</b> return the memory level. */
     45 static int
     46 memory_level(compression_level_t level)
     47 {
     48  switch (level) {
     49    default:
     50    case BEST_COMPRESSION:
     51    case HIGH_COMPRESSION: return 9;
     52    case MEDIUM_COMPRESSION: return 3;
     53    case LOW_COMPRESSION: return 1;
     54  }
     55 }
     56 #endif /* defined(HAVE_ZSTD) */
     57 
     58 /** Return 1 if Zstandard compression is supported; otherwise 0. */
     59 int
     60 tor_zstd_method_supported(void)
     61 {
     62 #ifdef HAVE_ZSTD
     63  return 1;
     64 #else
     65  return 0;
     66 #endif
     67 }
     68 
     69 #ifdef HAVE_ZSTD
     70 /** Format a zstd version number as a string in <b>buf</b>. */
     71 static void
     72 tor_zstd_format_version(char *buf, size_t buflen, unsigned version_number)
     73 {
     74  tor_snprintf(buf, buflen,
     75               "%u.%u.%u",
     76               version_number / 10000 % 100,
     77               version_number / 100 % 100,
     78               version_number % 100);
     79 }
     80 #endif /* defined(HAVE_ZSTD) */
     81 
     82 #define VERSION_STR_MAX_LEN 16 /* more than enough space for 99.99.99 */
     83 
     84 /** Return a string representation of the version of the currently running
     85 * version of libzstd. Returns NULL if Zstandard is unsupported. */
     86 const char *
     87 tor_zstd_get_version_str(void)
     88 {
     89 #ifdef HAVE_ZSTD
     90  static char version_str[VERSION_STR_MAX_LEN];
     91 
     92  tor_zstd_format_version(version_str, sizeof(version_str),
     93                          ZSTD_versionNumber());
     94 
     95  return version_str;
     96 #else /* !defined(HAVE_ZSTD) */
     97  return NULL;
     98 #endif /* defined(HAVE_ZSTD) */
     99 }
    100 
    101 /** Return a string representation of the version of the version of libzstd
    102 * used at compilation time. Returns NULL if Zstandard is unsupported. */
    103 const char *
    104 tor_zstd_get_header_version_str(void)
    105 {
    106 #ifdef HAVE_ZSTD
    107  return ZSTD_VERSION_STRING;
    108 #else
    109  return NULL;
    110 #endif
    111 }
    112 
    113 #ifdef TOR_UNIT_TESTS
    114 static int static_apis_disable_for_testing = 0;
    115 #endif
    116 
    117 /** Return true iff we can use the "static-only" APIs. */
    118 int
    119 tor_zstd_can_use_static_apis(void)
    120 {
    121 #if defined(ZSTD_STATIC_LINKING_ONLY) && defined(HAVE_ZSTD)
    122 #ifdef TOR_UNIT_TESTS
    123  if (static_apis_disable_for_testing) {
    124    return 0;
    125  }
    126 #endif
    127  return (ZSTD_VERSION_NUMBER == ZSTD_versionNumber());
    128 #else /* !(defined(ZSTD_STATIC_LINKING_ONLY) && defined(HAVE_ZSTD)) */
    129  return 0;
    130 #endif /* defined(ZSTD_STATIC_LINKING_ONLY) && defined(HAVE_ZSTD) */
    131 }
    132 
    133 /** Internal Zstandard state for incremental compression/decompression.
    134 * The body of this struct is not exposed. */
    135 struct tor_zstd_compress_state_t {
    136 #ifdef HAVE_ZSTD
    137  union {
    138    /** Compression stream. Used when <b>compress</b> is true. */
    139    ZSTD_CStream *compress_stream;
    140    /** Decompression stream. Used when <b>compress</b> is false. */
    141    ZSTD_DStream *decompress_stream;
    142  } u; /**< Zstandard stream objects. */
    143 #endif /* defined(HAVE_ZSTD) */
    144 
    145  int compress; /**< True if we are compressing; false if we are inflating */
    146  int have_called_end; /**< True if we are compressing and we've called
    147                        * ZSTD_endStream */
    148 
    149  /** Number of bytes read so far.  Used to detect compression bombs. */
    150  size_t input_so_far;
    151  /** Number of bytes written so far.  Used to detect compression bombs. */
    152  size_t output_so_far;
    153 
    154  /** Approximate number of bytes allocated for this object. */
    155  size_t allocation;
    156 };
    157 
    158 #ifdef HAVE_ZSTD
    159 /** Return an approximate number of bytes stored in memory to hold the
    160 * Zstandard compression/decompression state. This is a fake estimate
    161 * based on inspecting the zstd source: tor_zstd_state_size_precalc() is
    162 * more accurate when it's allowed to use "static-only" functions */
    163 static size_t
    164 tor_zstd_state_size_precalc_fake(int compress, int preset)
    165 {
    166  tor_assert(preset > 0);
    167 
    168  size_t memory_usage = sizeof(tor_zstd_compress_state_t);
    169 
    170  // The Zstandard library provides a number of functions that would be useful
    171  // here, but they are, unfortunately, still considered experimental and are
    172  // thus only available in libzstd if we link against the library statically.
    173  //
    174  // The code in this function tries to approximate the calculations without
    175  // being able to use the following:
    176  //
    177  // - We do not have access to neither the internal members of ZSTD_CStream
    178  //   and ZSTD_DStream and their internal context objects.
    179  //
    180  // - We cannot use ZSTD_sizeof_CStream() and ZSTD_sizeof_DStream() since they
    181  //   are unexposed.
    182  //
    183  // In the future it might be useful to check if libzstd have started
    184  // providing these functions in a stable manner and simplify this function.
    185  if (compress) {
    186    // We try to approximate the ZSTD_sizeof_CStream(ZSTD_CStream *stream)
    187    // function here. This function uses the following fields to make its
    188    // estimate:
    189 
    190    // - sizeof(ZSTD_CStream): Around 192 bytes on a 64-bit machine:
    191    memory_usage += 192;
    192 
    193    // - ZSTD_sizeof_CCtx(stream->cctx): This function requires access to
    194    // variables that are not exposed via the public API. We use a _very_
    195    // simplified function to calculate the estimated amount of bytes used in
    196    // this struct.
    197    // memory_usage += (preset - 0.5) * 1024 * 1024;
    198    memory_usage += (preset * 1024 * 1024) - (512 * 1024);
    199    // - ZSTD_sizeof_CDict(stream->cdictLocal): Unused in Tor: 0 bytes.
    200    // - stream->outBuffSize: 128 KB:
    201    memory_usage += 128 * 1024;
    202    // - stream->inBuffSize: 2048 KB:
    203    memory_usage += 2048 * 1024;
    204  } else {
    205    // We try to approximate the ZSTD_sizeof_DStream(ZSTD_DStream *stream)
    206    // function here. This function uses the following fields to make its
    207    // estimate:
    208 
    209    // - sizeof(ZSTD_DStream): Around 208 bytes on a 64-bit machine:
    210    memory_usage += 208;
    211    // - ZSTD_sizeof_DCtx(stream->dctx): Around 150 KB.
    212    memory_usage += 150 * 1024;
    213 
    214    // - ZSTD_sizeof_DDict(stream->ddictLocal): Unused in Tor: 0 bytes.
    215    // - stream->inBuffSize: 0 KB.
    216    // - stream->outBuffSize: 0 KB.
    217  }
    218 
    219  return memory_usage;
    220 }
    221 
    222 /** Return an approximate number of bytes stored in memory to hold the
    223 * Zstandard compression/decompression state. */
    224 static size_t
    225 tor_zstd_state_size_precalc(int compress, int preset)
    226 {
    227 #ifdef ZSTD_STATIC_LINKING_ONLY
    228  if (tor_zstd_can_use_static_apis()) {
    229    if (compress) {
    230 #ifdef HAVE_ZSTD_ESTIMATECSTREAMSIZE
    231      return ZSTD_estimateCStreamSize(preset);
    232 #endif
    233    } else {
    234 #ifdef HAVE_ZSTD_ESTIMATEDCTXSIZE
    235      /* Could use DStream, but that takes a windowSize. */
    236      return ZSTD_estimateDCtxSize();
    237 #endif
    238    }
    239  }
    240 #endif /* defined(ZSTD_STATIC_LINKING_ONLY) */
    241  return tor_zstd_state_size_precalc_fake(compress, preset);
    242 }
    243 #endif /* defined(HAVE_ZSTD) */
    244 
    245 /** Construct and return a tor_zstd_compress_state_t object using
    246 * <b>method</b>. If <b>compress</b>, it's for compression; otherwise it's for
    247 * decompression. */
    248 tor_zstd_compress_state_t *
    249 tor_zstd_compress_new(int compress,
    250                      compress_method_t method,
    251                      compression_level_t level)
    252 {
    253  tor_assert(method == ZSTD_METHOD);
    254 
    255 #ifdef HAVE_ZSTD
    256  const int preset = memory_level(level);
    257  tor_zstd_compress_state_t *result;
    258  size_t retval;
    259 
    260  result = tor_malloc_zero(sizeof(tor_zstd_compress_state_t));
    261  result->compress = compress;
    262  result->allocation = tor_zstd_state_size_precalc(compress, preset);
    263 
    264  if (compress) {
    265    result->u.compress_stream = ZSTD_createCStream();
    266 
    267    if (result->u.compress_stream == NULL) {
    268      // LCOV_EXCL_START
    269      log_warn(LD_GENERAL, "Error while creating Zstandard compression "
    270               "stream");
    271      goto err;
    272      // LCOV_EXCL_STOP
    273    }
    274 
    275    retval = ZSTD_initCStream(result->u.compress_stream, preset);
    276 
    277    if (ZSTD_isError(retval)) {
    278      // LCOV_EXCL_START
    279      log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
    280               ZSTD_getErrorName(retval));
    281      goto err;
    282      // LCOV_EXCL_STOP
    283    }
    284  } else {
    285    result->u.decompress_stream = ZSTD_createDStream();
    286 
    287    if (result->u.decompress_stream == NULL) {
    288      // LCOV_EXCL_START
    289      log_warn(LD_GENERAL, "Error while creating Zstandard decompression "
    290               "stream");
    291      goto err;
    292      // LCOV_EXCL_STOP
    293    }
    294 
    295    retval = ZSTD_initDStream(result->u.decompress_stream);
    296 
    297    if (ZSTD_isError(retval)) {
    298      // LCOV_EXCL_START
    299      log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
    300               ZSTD_getErrorName(retval));
    301      goto err;
    302      // LCOV_EXCL_STOP
    303    }
    304  }
    305 
    306  atomic_counter_add(&total_zstd_allocation, result->allocation);
    307  return result;
    308 
    309 err:
    310  // LCOV_EXCL_START
    311  if (compress) {
    312    ZSTD_freeCStream(result->u.compress_stream);
    313  } else {
    314    ZSTD_freeDStream(result->u.decompress_stream);
    315  }
    316 
    317  tor_free(result);
    318  return NULL;
    319  // LCOV_EXCL_STOP
    320 #else /* !defined(HAVE_ZSTD) */
    321  (void)compress;
    322  (void)method;
    323  (void)level;
    324 
    325  return NULL;
    326 #endif /* defined(HAVE_ZSTD) */
    327 }
    328 
    329 /** Compress/decompress some bytes using <b>state</b>.  Read up to
    330 * *<b>in_len</b> bytes from *<b>in</b>, and write up to *<b>out_len</b> bytes
    331 * to *<b>out</b>, adjusting the values as we go.  If <b>finish</b> is true,
    332 * we've reached the end of the input.
    333 *
    334 * Return TOR_COMPRESS_DONE if we've finished the entire
    335 * compression/decompression.
    336 * Return TOR_COMPRESS_OK if we're processed everything from the input.
    337 * Return TOR_COMPRESS_BUFFER_FULL if we're out of space on <b>out</b>.
    338 * Return TOR_COMPRESS_ERROR if the stream is corrupt.
    339 */
    340 tor_compress_output_t
    341 tor_zstd_compress_process(tor_zstd_compress_state_t *state,
    342                          char **out, size_t *out_len,
    343                          const char **in, size_t *in_len,
    344                          int finish)
    345 {
    346 #ifdef HAVE_ZSTD
    347  size_t retval;
    348 
    349  tor_assert(state != NULL);
    350  tor_assert(*in_len <= UINT_MAX);
    351  tor_assert(*out_len <= UINT_MAX);
    352 
    353  ZSTD_inBuffer input = { *in, *in_len, 0 };
    354  ZSTD_outBuffer output = { *out, *out_len, 0 };
    355 
    356  if (BUG(finish == 0 && state->have_called_end)) {
    357    finish = 1;
    358  }
    359 
    360  if (state->compress) {
    361    if (! state->have_called_end)
    362      retval = ZSTD_compressStream(state->u.compress_stream,
    363                                   &output, &input);
    364    else
    365      retval = 0;
    366  } else {
    367    retval = ZSTD_decompressStream(state->u.decompress_stream,
    368                                   &output, &input);
    369  }
    370 
    371  if (ZSTD_isError(retval)) {
    372    log_warn(LD_GENERAL, "Zstandard %s didn't finish: %s.",
    373             state->compress ? "compression" : "decompression",
    374             ZSTD_getErrorName(retval));
    375    return TOR_COMPRESS_ERROR;
    376  }
    377 
    378  state->input_so_far += input.pos;
    379  state->output_so_far += output.pos;
    380 
    381  *out = (char *)output.dst + output.pos;
    382  *out_len = output.size - output.pos;
    383  *in = (char *)input.src + input.pos;
    384  *in_len = input.size - input.pos;
    385 
    386  if (! state->compress &&
    387      tor_compress_is_compression_bomb(state->input_so_far,
    388                                       state->output_so_far)) {
    389    log_warn(LD_DIR, "Possible compression bomb; abandoning stream.");
    390    return TOR_COMPRESS_ERROR;
    391  }
    392 
    393  if (state->compress && !state->have_called_end) {
    394    retval = ZSTD_flushStream(state->u.compress_stream, &output);
    395 
    396    *out = (char *)output.dst + output.pos;
    397    *out_len = output.size - output.pos;
    398 
    399    if (ZSTD_isError(retval)) {
    400      log_warn(LD_GENERAL, "Zstandard compression unable to flush: %s.",
    401               ZSTD_getErrorName(retval));
    402      return TOR_COMPRESS_ERROR;
    403    }
    404 
    405    // ZSTD_flushStream returns 0 if the frame is done, or >0 if it
    406    // is incomplete.
    407    if (retval > 0) {
    408      return TOR_COMPRESS_BUFFER_FULL;
    409    }
    410  }
    411 
    412  if (!finish) {
    413    // The caller says we're not done with the input, so no need to write an
    414    // epilogue.
    415    return TOR_COMPRESS_OK;
    416  } else if (state->compress) {
    417    if (*in_len) {
    418      // We say that we're not done with the input, so we can't write an
    419      // epilogue.
    420      return TOR_COMPRESS_OK;
    421    }
    422 
    423    retval = ZSTD_endStream(state->u.compress_stream, &output);
    424    state->have_called_end = 1;
    425    *out = (char *)output.dst + output.pos;
    426    *out_len = output.size - output.pos;
    427 
    428    if (ZSTD_isError(retval)) {
    429      log_warn(LD_GENERAL, "Zstandard compression unable to write "
    430               "epilogue: %s.",
    431               ZSTD_getErrorName(retval));
    432      return TOR_COMPRESS_ERROR;
    433    }
    434 
    435    // endStream returns the number of bytes that is needed to write the
    436    // epilogue.
    437    if (retval > 0)
    438      return TOR_COMPRESS_BUFFER_FULL;
    439 
    440    return TOR_COMPRESS_DONE;
    441  } else /* if (!state->compress) */ {
    442    // ZSTD_decompressStream returns 0 if the frame is done, or >0 if it
    443    // is incomplete.
    444    // We check this above.
    445    tor_assert_nonfatal(!ZSTD_isError(retval));
    446    // Start a new frame if this frame is done
    447    if (retval == 0)
    448      return TOR_COMPRESS_DONE;
    449    // Don't check out_len, it might have some space left if the next output
    450    // chunk is larger than the remaining space
    451    else if (*in_len > 0)
    452      return  TOR_COMPRESS_BUFFER_FULL;
    453    else
    454      return TOR_COMPRESS_OK;
    455  }
    456 
    457 #else /* !defined(HAVE_ZSTD) */
    458  (void)state;
    459  (void)out;
    460  (void)out_len;
    461  (void)in;
    462  (void)in_len;
    463  (void)finish;
    464 
    465  return TOR_COMPRESS_ERROR;
    466 #endif /* defined(HAVE_ZSTD) */
    467 }
    468 
    469 /** Deallocate <b>state</b>. */
    470 void
    471 tor_zstd_compress_free_(tor_zstd_compress_state_t *state)
    472 {
    473  if (state == NULL)
    474    return;
    475 
    476  atomic_counter_sub(&total_zstd_allocation, state->allocation);
    477 
    478 #ifdef HAVE_ZSTD
    479  if (state->compress) {
    480    ZSTD_freeCStream(state->u.compress_stream);
    481  } else {
    482    ZSTD_freeDStream(state->u.decompress_stream);
    483  }
    484 #endif /* defined(HAVE_ZSTD) */
    485 
    486  tor_free(state);
    487 }
    488 
    489 /** Return the approximate number of bytes allocated for <b>state</b>. */
    490 size_t
    491 tor_zstd_compress_state_size(const tor_zstd_compress_state_t *state)
    492 {
    493  tor_assert(state != NULL);
    494  return state->allocation;
    495 }
    496 
    497 /** Return the approximate number of bytes allocated for all Zstandard
    498 * states. */
    499 size_t
    500 tor_zstd_get_total_allocation(void)
    501 {
    502  return atomic_counter_get(&total_zstd_allocation);
    503 }
    504 
    505 /** Initialize the zstd module */
    506 void
    507 tor_zstd_init(void)
    508 {
    509  atomic_counter_init(&total_zstd_allocation);
    510 }
    511 
    512 /** Warn if the header and library versions don't match. */
    513 void
    514 tor_zstd_warn_if_version_mismatched(void)
    515 {
    516 #if defined(HAVE_ZSTD) && defined(ENABLE_ZSTD_ADVANCED_APIS)
    517  if (! tor_zstd_can_use_static_apis()) {
    518    char header_version[VERSION_STR_MAX_LEN];
    519    char runtime_version[VERSION_STR_MAX_LEN];
    520    tor_zstd_format_version(header_version, sizeof(header_version),
    521                            ZSTD_VERSION_NUMBER);
    522    tor_zstd_format_version(runtime_version, sizeof(runtime_version),
    523                            ZSTD_versionNumber());
    524 
    525    log_info(LD_GENERAL,
    526             "Tor was compiled with zstd %s, but is running with zstd %s. "
    527             "For ABI compatibility reasons, we'll avoid using advanced zstd "
    528             "functionality.",
    529             header_version, runtime_version);
    530  }
    531 #endif /* defined(HAVE_ZSTD) && defined(ENABLE_ZSTD_ADVANCED_APIS) */
    532 }
    533 
    534 #ifdef TOR_UNIT_TESTS
    535 /** Testing only: disable usage of static-only APIs, so we can make sure that
    536 * we still work without them. */
    537 void
    538 tor_zstd_set_static_apis_disabled_for_testing(int disabled)
    539 {
    540  static_apis_disable_for_testing = disabled;
    541 }
    542 #endif /* defined(TOR_UNIT_TESTS) */