tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

disflow_neon.c (11367B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "aom_dsp/flow_estimation/disflow.h"
     13 
     14 #include <arm_neon.h>
     15 #include <math.h>
     16 
     17 #include "aom_dsp/arm/mem_neon.h"
     18 #include "aom_dsp/arm/sum_neon.h"
     19 #include "aom_dsp/flow_estimation/arm/disflow_neon.h"
     20 #include "config/aom_config.h"
     21 #include "config/aom_dsp_rtcd.h"
     22 
     23 // Compare two regions of width x height pixels, one rooted at position
     24 // (x, y) in src and the other at (x + u, y + v) in ref.
     25 // This function returns the sum of squared pixel differences between
     26 // the two regions.
     27 static inline void compute_flow_error(const uint8_t *src, const uint8_t *ref,
     28                                      int width, int height, int stride, int x,
     29                                      int y, double u, double v, int16_t *dt) {
     30  // Split offset into integer and fractional parts, and compute cubic
     31  // interpolation kernels
     32  const int u_int = (int)floor(u);
     33  const int v_int = (int)floor(v);
     34  const double u_frac = u - floor(u);
     35  const double v_frac = v - floor(v);
     36 
     37  int h_kernel[4];
     38  int v_kernel[4];
     39  get_cubic_kernel_int(u_frac, h_kernel);
     40  get_cubic_kernel_int(v_frac, v_kernel);
     41 
     42  int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
     43 
     44  // Clamp coordinates so that all pixels we fetch will remain within the
     45  // allocated border region, but allow them to go far enough out that
     46  // the border pixels' values do not change.
     47  // Since we are calculating an 8x8 block, the bottom-right pixel
     48  // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
     49  // interpolation has 4 taps, meaning that the output of pixel
     50  // (x_w, y_w) depends on the pixels in the range
     51  // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
     52  //
     53  // Thus the most extreme coordinates which will be fetched are
     54  // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
     55  const int x0 = clamp(x + u_int, -9, width);
     56  const int y0 = clamp(y + v_int, -9, height);
     57 
     58  // Horizontal convolution.
     59  const uint8_t *ref_start = ref + (y0 - 1) * stride + (x0 - 1);
     60  int16x4_t h_filter = vmovn_s32(vld1q_s32(h_kernel));
     61 
     62  for (int i = 0; i < DISFLOW_PATCH_SIZE + 3; ++i) {
     63    uint8x16_t r = vld1q_u8(ref_start + i * stride);
     64    uint16x8_t r0 = vmovl_u8(vget_low_u8(r));
     65    uint16x8_t r1 = vmovl_u8(vget_high_u8(r));
     66 
     67    int16x8_t s0 = vreinterpretq_s16_u16(r0);
     68    int16x8_t s1 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 1));
     69    int16x8_t s2 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 2));
     70    int16x8_t s3 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 3));
     71 
     72    int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(s0), h_filter, 0);
     73    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s1), h_filter, 1);
     74    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), h_filter, 2);
     75    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), h_filter, 3);
     76 
     77    int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(s0), h_filter, 0);
     78    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s1), h_filter, 1);
     79    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), h_filter, 2);
     80    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), h_filter, 3);
     81 
     82    // 6 is the maximum allowable number of extra bits which will avoid
     83    // the intermediate values overflowing an int16_t. The most extreme
     84    // intermediate value occurs when:
     85    // * The input pixels are [0, 255, 255, 0]
     86    // * u_frac = 0.5
     87    // In this case, the un-scaled output is 255 * 1.125 = 286.875.
     88    // As an integer with 6 fractional bits, that is 18360, which fits
     89    // in an int16_t. But with 7 fractional bits it would be 36720,
     90    // which is too large.
     91 
     92    int16x8_t sum = vcombine_s16(vrshrn_n_s32(sum_lo, DISFLOW_INTERP_BITS - 6),
     93                                 vrshrn_n_s32(sum_hi, DISFLOW_INTERP_BITS - 6));
     94    vst1q_s16(tmp_ + i * DISFLOW_PATCH_SIZE, sum);
     95  }
     96 
     97  // Vertical convolution.
     98  int16x4_t v_filter = vmovn_s32(vld1q_s32(v_kernel));
     99  int16_t *tmp_start = tmp_ + DISFLOW_PATCH_SIZE;
    100 
    101  for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
    102    int16x8_t t0 = vld1q_s16(tmp_start + (i - 1) * DISFLOW_PATCH_SIZE);
    103    int16x8_t t1 = vld1q_s16(tmp_start + i * DISFLOW_PATCH_SIZE);
    104    int16x8_t t2 = vld1q_s16(tmp_start + (i + 1) * DISFLOW_PATCH_SIZE);
    105    int16x8_t t3 = vld1q_s16(tmp_start + (i + 2) * DISFLOW_PATCH_SIZE);
    106 
    107    int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(t0), v_filter, 0);
    108    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t1), v_filter, 1);
    109    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t2), v_filter, 2);
    110    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t3), v_filter, 3);
    111 
    112    int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(t0), v_filter, 0);
    113    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t1), v_filter, 1);
    114    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t2), v_filter, 2);
    115    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t3), v_filter, 3);
    116 
    117    uint8x8_t s = vld1_u8(src + (i + y) * stride + x);
    118    int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, 3));
    119 
    120    // This time, we have to round off the 6 extra bits which were kept
    121    // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
    122    // of precision to match the scale of the dx and dy arrays.
    123    sum_lo = vrshrq_n_s32(sum_lo,
    124                          DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
    125    sum_hi = vrshrq_n_s32(sum_hi,
    126                          DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
    127    int32x4_t err_lo = vsubw_s16(sum_lo, vget_low_s16(s_s16));
    128    int32x4_t err_hi = vsubw_s16(sum_hi, vget_high_s16(s_s16));
    129    vst1q_s16(dt + i * DISFLOW_PATCH_SIZE,
    130              vcombine_s16(vmovn_s32(err_lo), vmovn_s32(err_hi)));
    131  }
    132 }
    133 
    134 // Computes the components of the system of equations used to solve for
    135 // a flow vector.
    136 //
    137 // The flow equations are a least-squares system, derived as follows:
    138 //
    139 // For each pixel in the patch, we calculate the current error `dt`,
    140 // and the x and y gradients `dx` and `dy` of the source patch.
    141 // This means that, to first order, the squared error for this pixel is
    142 //
    143 //    (dt + u * dx + v * dy)^2
    144 //
    145 // where (u, v) are the incremental changes to the flow vector.
    146 //
    147 // We then want to find the values of u and v which minimize the sum
    148 // of the squared error across all pixels. Conveniently, this fits exactly
    149 // into the form of a least squares problem, with one equation
    150 //
    151 //   u * dx + v * dy = -dt
    152 //
    153 // for each pixel.
    154 //
    155 // Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE,
    156 // and absorbing the - sign elsewhere, this results in the least squares system
    157 //
    158 //   M = |sum(dx * dx)  sum(dx * dy)|
    159 //       |sum(dx * dy)  sum(dy * dy)|
    160 //
    161 //   b = |sum(dx * dt)|
    162 //       |sum(dy * dt)|
    163 static inline void compute_flow_matrix(const int16_t *dx, int dx_stride,
    164                                       const int16_t *dy, int dy_stride,
    165                                       double *M_inv) {
    166  int32x4_t sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
    167                       vdupq_n_s32(0) };
    168 
    169  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
    170    int16x8_t x = vld1q_s16(dx + i * dx_stride);
    171    int16x8_t y = vld1q_s16(dy + i * dy_stride);
    172    sum[0] = vmlal_s16(sum[0], vget_low_s16(x), vget_low_s16(x));
    173    sum[0] = vmlal_s16(sum[0], vget_high_s16(x), vget_high_s16(x));
    174 
    175    sum[1] = vmlal_s16(sum[1], vget_low_s16(x), vget_low_s16(y));
    176    sum[1] = vmlal_s16(sum[1], vget_high_s16(x), vget_high_s16(y));
    177 
    178    sum[3] = vmlal_s16(sum[3], vget_low_s16(y), vget_low_s16(y));
    179    sum[3] = vmlal_s16(sum[3], vget_high_s16(y), vget_high_s16(y));
    180  }
    181  sum[2] = sum[1];
    182 
    183  int32x4_t res = horizontal_add_4d_s32x4(sum);
    184 
    185  // Apply regularization
    186  // We follow the standard regularization method of adding `k * I` before
    187  // inverting. This ensures that the matrix will be invertible.
    188  //
    189  // Setting the regularization strength k to 1 seems to work well here, as
    190  // typical values coming from the other equations are very large (1e5 to
    191  // 1e6, with an upper limit of around 6e7, at the time of writing).
    192  // It also preserves the property that all matrix values are whole numbers,
    193  // which is convenient for integerized SIMD implementation.
    194 
    195  double M0 = (double)vgetq_lane_s32(res, 0) + 1;
    196  double M1 = (double)vgetq_lane_s32(res, 1);
    197  double M2 = (double)vgetq_lane_s32(res, 2);
    198  double M3 = (double)vgetq_lane_s32(res, 3) + 1;
    199 
    200  // Invert matrix M.
    201  double det = (M0 * M3) - (M1 * M2);
    202  assert(det >= 1);
    203  const double det_inv = 1 / det;
    204 
    205  M_inv[0] = M3 * det_inv;
    206  M_inv[1] = -M1 * det_inv;
    207  M_inv[2] = -M2 * det_inv;
    208  M_inv[3] = M0 * det_inv;
    209 }
    210 
    211 static inline void compute_flow_vector(const int16_t *dx, int dx_stride,
    212                                       const int16_t *dy, int dy_stride,
    213                                       const int16_t *dt, int dt_stride,
    214                                       int *b) {
    215  int32x4_t b_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
    216 
    217  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
    218    int16x8_t dx16 = vld1q_s16(dx + i * dx_stride);
    219    int16x8_t dy16 = vld1q_s16(dy + i * dy_stride);
    220    int16x8_t dt16 = vld1q_s16(dt + i * dt_stride);
    221 
    222    b_s32[0] = vmlal_s16(b_s32[0], vget_low_s16(dx16), vget_low_s16(dt16));
    223    b_s32[0] = vmlal_s16(b_s32[0], vget_high_s16(dx16), vget_high_s16(dt16));
    224 
    225    b_s32[1] = vmlal_s16(b_s32[1], vget_low_s16(dy16), vget_low_s16(dt16));
    226    b_s32[1] = vmlal_s16(b_s32[1], vget_high_s16(dy16), vget_high_s16(dt16));
    227  }
    228 
    229  int32x4_t b_red = horizontal_add_2d_s32(b_s32[0], b_s32[1]);
    230  vst1_s32(b, add_pairwise_s32x4(b_red));
    231 }
    232 
    233 void aom_compute_flow_at_point_neon(const uint8_t *src, const uint8_t *ref,
    234                                    int x, int y, int width, int height,
    235                                    int stride, double *u, double *v) {
    236  double M_inv[4];
    237  int b[2];
    238  int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
    239  int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
    240  int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
    241 
    242  // Compute gradients within this patch
    243  const uint8_t *src_patch = &src[y * stride + x];
    244  sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE);
    245  sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE);
    246 
    247  compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M_inv);
    248 
    249  for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
    250    compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
    251    compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
    252                        DISFLOW_PATCH_SIZE, b);
    253 
    254    // Solve flow equations to find a better estimate for the flow vector
    255    // at this point
    256    const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
    257    const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
    258    *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
    259    *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
    260 
    261    if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
    262      // Stop iteration when we're close to convergence
    263      break;
    264    }
    265  }
    266 }