tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

refmvs.c (36817B)


      1 /*
      2 * Copyright © 2020, VideoLAN and dav1d authors
      3 * Copyright © 2020, Two Orioles, LLC
      4 * All rights reserved.
      5 *
      6 * Redistribution and use in source and binary forms, with or without
      7 * modification, are permitted provided that the following conditions are met:
      8 *
      9 * 1. Redistributions of source code must retain the above copyright notice, this
     10 *    list of conditions and the following disclaimer.
     11 *
     12 * 2. Redistributions in binary form must reproduce the above copyright notice,
     13 *    this list of conditions and the following disclaimer in the documentation
     14 *    and/or other materials provided with the distribution.
     15 *
     16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
     17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
     18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
     19 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
     20 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
     21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
     22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
     23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
     25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     26 */
     27 
     28 #include "config.h"
     29 
     30 #include <limits.h>
     31 #include <stdlib.h>
     32 
     33 #include "dav1d/common.h"
     34 
     35 #include "common/intops.h"
     36 
     37 #include "src/env.h"
     38 #include "src/mem.h"
     39 #include "src/refmvs.h"
     40 
     41 static void add_spatial_candidate(refmvs_candidate *const mvstack, int *const cnt,
     42                                  const int weight, const refmvs_block *const b,
     43                                  const union refmvs_refpair ref, const mv gmv[2],
     44                                  int *const have_newmv_match,
     45                                  int *const have_refmv_match)
     46 {
     47    if (b->mv.mv[0].n == INVALID_MV) return; // intra block, no intrabc
     48 
     49    if (ref.ref[1] == -1) {
     50        for (int n = 0; n < 2; n++) {
     51            if (b->ref.ref[n] == ref.ref[0]) {
     52                const mv cand_mv = ((b->mf & 1) && gmv[0].n != INVALID_MV) ?
     53                                   gmv[0] : b->mv.mv[n];
     54 
     55                *have_refmv_match = 1;
     56                *have_newmv_match |= b->mf >> 1;
     57 
     58                const int last = *cnt;
     59                for (int m = 0; m < last; m++)
     60                    if (mvstack[m].mv.mv[0].n == cand_mv.n) {
     61                        mvstack[m].weight += weight;
     62                        return;
     63                    }
     64 
     65                if (last < 8) {
     66                    mvstack[last].mv.mv[0] = cand_mv;
     67                    mvstack[last].weight = weight;
     68                    *cnt = last + 1;
     69                }
     70                return;
     71            }
     72        }
     73    } else if (b->ref.pair == ref.pair) {
     74        const refmvs_mvpair cand_mv = { .mv = {
     75            [0] = ((b->mf & 1) && gmv[0].n != INVALID_MV) ? gmv[0] : b->mv.mv[0],
     76            [1] = ((b->mf & 1) && gmv[1].n != INVALID_MV) ? gmv[1] : b->mv.mv[1],
     77        }};
     78 
     79        *have_refmv_match = 1;
     80        *have_newmv_match |= b->mf >> 1;
     81 
     82        const int last = *cnt;
     83        for (int n = 0; n < last; n++)
     84            if (mvstack[n].mv.n == cand_mv.n) {
     85                mvstack[n].weight += weight;
     86                return;
     87            }
     88 
     89        if (last < 8) {
     90            mvstack[last].mv = cand_mv;
     91            mvstack[last].weight = weight;
     92            *cnt = last + 1;
     93        }
     94    }
     95 }
     96 
     97 static int scan_row(refmvs_candidate *const mvstack, int *const cnt,
     98                    const union refmvs_refpair ref, const mv gmv[2],
     99                    const refmvs_block *b, const int bw4, const int w4,
    100                    const int max_rows, const int step,
    101                    int *const have_newmv_match, int *const have_refmv_match)
    102 {
    103    const refmvs_block *cand_b = b;
    104    const enum BlockSize first_cand_bs = cand_b->bs;
    105    const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
    106    int cand_bw4 = first_cand_b_dim[0];
    107    int len = imax(step, imin(bw4, cand_bw4));
    108 
    109    if (bw4 <= cand_bw4) {
    110        // FIXME weight can be higher for odd blocks (bx4 & 1), but then the
    111        // position of the first block has to be odd already, i.e. not just
    112        // for row_offset=-3/-5
    113        // FIXME why can this not be cand_bw4?
    114        const int weight = bw4 == 1 ? 2 :
    115                           imax(2, imin(2 * max_rows, first_cand_b_dim[1]));
    116        add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
    117                              have_newmv_match, have_refmv_match);
    118        return weight >> 1;
    119    }
    120 
    121    for (int x = 0;;) {
    122        // FIXME if we overhang above, we could fill a bitmask so we don't have
    123        // to repeat the add_spatial_candidate() for the next row, but just increase
    124        // the weight here
    125        add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
    126                              have_newmv_match, have_refmv_match);
    127        x += len;
    128        if (x >= w4) return 1;
    129        cand_b = &b[x];
    130        cand_bw4 = dav1d_block_dimensions[cand_b->bs][0];
    131        assert(cand_bw4 < bw4);
    132        len = imax(step, cand_bw4);
    133    }
    134 }
    135 
    136 static int scan_col(refmvs_candidate *const mvstack, int *const cnt,
    137                    const union refmvs_refpair ref, const mv gmv[2],
    138                    /*const*/ refmvs_block *const *b, const int bh4, const int h4,
    139                    const int bx4, const int max_cols, const int step,
    140                    int *const have_newmv_match, int *const have_refmv_match)
    141 {
    142    const refmvs_block *cand_b = &b[0][bx4];
    143    const enum BlockSize first_cand_bs = cand_b->bs;
    144    const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
    145    int cand_bh4 = first_cand_b_dim[1];
    146    int len = imax(step, imin(bh4, cand_bh4));
    147 
    148    if (bh4 <= cand_bh4) {
    149        // FIXME weight can be higher for odd blocks (by4 & 1), but then the
    150        // position of the first block has to be odd already, i.e. not just
    151        // for col_offset=-3/-5
    152        // FIXME why can this not be cand_bh4?
    153        const int weight = bh4 == 1 ? 2 :
    154                           imax(2, imin(2 * max_cols, first_cand_b_dim[0]));
    155        add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
    156                            have_newmv_match, have_refmv_match);
    157        return weight >> 1;
    158    }
    159 
    160    for (int y = 0;;) {
    161        // FIXME if we overhang above, we could fill a bitmask so we don't have
    162        // to repeat the add_spatial_candidate() for the next row, but just increase
    163        // the weight here
    164        add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
    165                              have_newmv_match, have_refmv_match);
    166        y += len;
    167        if (y >= h4) return 1;
    168        cand_b = &b[y][bx4];
    169        cand_bh4 = dav1d_block_dimensions[cand_b->bs][1];
    170        assert(cand_bh4 < bh4);
    171        len = imax(step, cand_bh4);
    172    }
    173 }
    174 
    175 static inline union mv mv_projection(const union mv mv, const int num, const int den) {
    176    static const uint16_t div_mult[32] = {
    177           0, 16384, 8192, 5461, 4096, 3276, 2730, 2340,
    178        2048,  1820, 1638, 1489, 1365, 1260, 1170, 1092,
    179        1024,   963,  910,  862,  819,  780,  744,  712,
    180         682,   655,  630,  606,  585,  564,  546,  528
    181    };
    182    assert(den > 0 && den < 32);
    183    assert(num > -32 && num < 32);
    184    const int frac = num * div_mult[den];
    185    const int y = mv.y * frac, x = mv.x * frac;
    186    // Round and clip according to AV1 spec section 7.9.3
    187    return (union mv) { // 0x3fff == (1 << 14) - 1
    188        .y = iclip((y + 8192 + (y >> 31)) >> 14, -0x3fff, 0x3fff),
    189        .x = iclip((x + 8192 + (x >> 31)) >> 14, -0x3fff, 0x3fff)
    190    };
    191 }
    192 
    193 static void add_temporal_candidate(const refmvs_frame *const rf,
    194                                   refmvs_candidate *const mvstack, int *const cnt,
    195                                   const refmvs_temporal_block *const rb,
    196                                   const union refmvs_refpair ref, int *const globalmv_ctx,
    197                                   const union mv gmv[])
    198 {
    199    if (rb->mv.n == INVALID_MV) return;
    200 
    201    union mv mv = mv_projection(rb->mv, rf->pocdiff[ref.ref[0] - 1], rb->ref);
    202    fix_mv_precision(rf->frm_hdr, &mv);
    203 
    204    const int last = *cnt;
    205    if (ref.ref[1] == -1) {
    206        if (globalmv_ctx)
    207            *globalmv_ctx = (abs(mv.x - gmv[0].x) | abs(mv.y - gmv[0].y)) >= 16;
    208 
    209        for (int n = 0; n < last; n++)
    210            if (mvstack[n].mv.mv[0].n == mv.n) {
    211                mvstack[n].weight += 2;
    212                return;
    213            }
    214        if (last < 8) {
    215            mvstack[last].mv.mv[0] = mv;
    216            mvstack[last].weight = 2;
    217            *cnt = last + 1;
    218        }
    219    } else {
    220        refmvs_mvpair mvp = { .mv = {
    221            [0] = mv,
    222            [1] = mv_projection(rb->mv, rf->pocdiff[ref.ref[1] - 1], rb->ref),
    223        }};
    224        fix_mv_precision(rf->frm_hdr, &mvp.mv[1]);
    225 
    226        for (int n = 0; n < last; n++)
    227            if (mvstack[n].mv.n == mvp.n) {
    228                mvstack[n].weight += 2;
    229                return;
    230            }
    231        if (last < 8) {
    232            mvstack[last].mv = mvp;
    233            mvstack[last].weight = 2;
    234            *cnt = last + 1;
    235        }
    236    }
    237 }
    238 
    239 static void add_compound_extended_candidate(refmvs_candidate *const same,
    240                                            int *const same_count,
    241                                            const refmvs_block *const cand_b,
    242                                            const int sign0, const int sign1,
    243                                            const union refmvs_refpair ref,
    244                                            const uint8_t *const sign_bias)
    245 {
    246    refmvs_candidate *const diff = &same[2];
    247    int *const diff_count = &same_count[2];
    248 
    249    for (int n = 0; n < 2; n++) {
    250        const int cand_ref = cand_b->ref.ref[n];
    251 
    252        if (cand_ref <= 0) break;
    253 
    254        mv cand_mv = cand_b->mv.mv[n];
    255        if (cand_ref == ref.ref[0]) {
    256            if (same_count[0] < 2)
    257                same[same_count[0]++].mv.mv[0] = cand_mv;
    258            if (diff_count[1] < 2) {
    259                if (sign1 ^ sign_bias[cand_ref - 1]) {
    260                    cand_mv.y = -cand_mv.y;
    261                    cand_mv.x = -cand_mv.x;
    262                }
    263                diff[diff_count[1]++].mv.mv[1] = cand_mv;
    264            }
    265        } else if (cand_ref == ref.ref[1]) {
    266            if (same_count[1] < 2)
    267                same[same_count[1]++].mv.mv[1] = cand_mv;
    268            if (diff_count[0] < 2) {
    269                if (sign0 ^ sign_bias[cand_ref - 1]) {
    270                    cand_mv.y = -cand_mv.y;
    271                    cand_mv.x = -cand_mv.x;
    272                }
    273                diff[diff_count[0]++].mv.mv[0] = cand_mv;
    274            }
    275        } else {
    276            mv i_cand_mv = (union mv) {
    277                .x = -cand_mv.x,
    278                .y = -cand_mv.y
    279            };
    280 
    281            if (diff_count[0] < 2) {
    282                diff[diff_count[0]++].mv.mv[0] =
    283                    sign0 ^ sign_bias[cand_ref - 1] ?
    284                    i_cand_mv : cand_mv;
    285            }
    286 
    287            if (diff_count[1] < 2) {
    288                diff[diff_count[1]++].mv.mv[1] =
    289                    sign1 ^ sign_bias[cand_ref - 1] ?
    290                    i_cand_mv : cand_mv;
    291            }
    292        }
    293    }
    294 }
    295 
    296 static void add_single_extended_candidate(refmvs_candidate mvstack[8], int *const cnt,
    297                                          const refmvs_block *const cand_b,
    298                                          const int sign, const uint8_t *const sign_bias)
    299 {
    300    for (int n = 0; n < 2; n++) {
    301        const int cand_ref = cand_b->ref.ref[n];
    302 
    303        if (cand_ref <= 0) break;
    304        // we need to continue even if cand_ref == ref.ref[0], since
    305        // the candidate could have been added as a globalmv variant,
    306        // which changes the value
    307        // FIXME if scan_{row,col}() returned a mask for the nearest
    308        // edge, we could skip the appropriate ones here
    309 
    310        mv cand_mv = cand_b->mv.mv[n];
    311        if (sign ^ sign_bias[cand_ref - 1]) {
    312            cand_mv.y = -cand_mv.y;
    313            cand_mv.x = -cand_mv.x;
    314        }
    315 
    316        int m;
    317        const int last = *cnt;
    318        for (m = 0; m < last; m++)
    319            if (cand_mv.n == mvstack[m].mv.mv[0].n)
    320                break;
    321        if (m == last) {
    322            mvstack[m].mv.mv[0] = cand_mv;
    323            mvstack[m].weight = 2; // "minimal"
    324            *cnt = last + 1;
    325        }
    326    }
    327 }
    328 
    329 /*
    330 * refmvs_frame allocates memory for one sbrow (32 blocks high, whole frame
    331 * wide) of 4x4-resolution refmvs_block entries for spatial MV referencing.
    332 * mvrefs_tile[] keeps a list of 35 (32 + 3 above) pointers into this memory,
    333 * and each sbrow, the bottom entries (y=27/29/31) are exchanged with the top
    334 * (-5/-3/-1) pointers by calling dav1d_refmvs_tile_sbrow_init() at the start
    335 * of each tile/sbrow.
    336 *
    337 * For temporal MV referencing, we call dav1d_refmvs_save_tmvs() at the end of
    338 * each tile/sbrow (when tile column threading is enabled), or at the start of
    339 * each interleaved sbrow (i.e. once for all tile columns together, when tile
    340 * column threading is disabled). This will copy the 4x4-resolution spatial MVs
    341 * into 8x8-resolution refmvs_temporal_block structures. Then, for subsequent
    342 * frames, at the start of each tile/sbrow (when tile column threading is
    343 * enabled) or at the start of each interleaved sbrow (when tile column
    344 * threading is disabled), we call load_tmvs(), which will project the MVs to
    345 * their respective position in the current frame.
    346 */
    347 
    348 void dav1d_refmvs_find(const refmvs_tile *const rt,
    349                       refmvs_candidate mvstack[8], int *const cnt,
    350                       int *const ctx,
    351                       const union refmvs_refpair ref, const enum BlockSize bs,
    352                       const enum EdgeFlags edge_flags,
    353                       const int by4, const int bx4)
    354 {
    355    const refmvs_frame *const rf = rt->rf;
    356    const uint8_t *const b_dim = dav1d_block_dimensions[bs];
    357    const int bw4 = b_dim[0], w4 = imin(imin(bw4, 16), rt->tile_col.end - bx4);
    358    const int bh4 = b_dim[1], h4 = imin(imin(bh4, 16), rt->tile_row.end - by4);
    359    mv gmv[2], tgmv[2];
    360 
    361    *cnt = 0;
    362    assert(ref.ref[0] >=  0 && ref.ref[0] <= 8 &&
    363           ref.ref[1] >= -1 && ref.ref[1] <= 8);
    364    if (ref.ref[0] > 0) {
    365        tgmv[0] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[0] - 1],
    366                             bx4, by4, bw4, bh4, rf->frm_hdr);
    367        gmv[0] = rf->frm_hdr->gmv[ref.ref[0] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
    368                 tgmv[0] : (mv) { .n = INVALID_MV };
    369    } else {
    370        tgmv[0] = (mv) { .n = 0 };
    371        gmv[0] = (mv) { .n = INVALID_MV };
    372    }
    373    if (ref.ref[1] > 0) {
    374        tgmv[1] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[1] - 1],
    375                             bx4, by4, bw4, bh4, rf->frm_hdr);
    376        gmv[1] = rf->frm_hdr->gmv[ref.ref[1] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
    377                 tgmv[1] : (mv) { .n = INVALID_MV };
    378    }
    379 
    380    // top
    381    int have_newmv = 0, have_col_mvs = 0, have_row_mvs = 0;
    382    unsigned max_rows = 0, n_rows = ~0;
    383    const refmvs_block *b_top;
    384    if (by4 > rt->tile_row.start) {
    385        max_rows = imin((by4 - rt->tile_row.start + 1) >> 1, 2 + (bh4 > 1));
    386        b_top = &rt->r[(by4 & 31) - 1 + 5][bx4];
    387        n_rows = scan_row(mvstack, cnt, ref, gmv, b_top,
    388                          bw4, w4, max_rows, bw4 >= 16 ? 4 : 1,
    389                          &have_newmv, &have_row_mvs);
    390    }
    391 
    392    // left
    393    unsigned max_cols = 0, n_cols = ~0U;
    394    refmvs_block *const *b_left;
    395    if (bx4 > rt->tile_col.start) {
    396        max_cols = imin((bx4 - rt->tile_col.start + 1) >> 1, 2 + (bw4 > 1));
    397        b_left = &rt->r[(by4 & 31) + 5];
    398        n_cols = scan_col(mvstack, cnt, ref, gmv, b_left,
    399                          bh4, h4, bx4 - 1, max_cols, bh4 >= 16 ? 4 : 1,
    400                          &have_newmv, &have_col_mvs);
    401    }
    402 
    403    // top/right
    404    if (n_rows != ~0U && edge_flags & EDGE_I444_TOP_HAS_RIGHT &&
    405        imax(bw4, bh4) <= 16 && bw4 + bx4 < rt->tile_col.end)
    406    {
    407        add_spatial_candidate(mvstack, cnt, 4, &b_top[bw4], ref, gmv,
    408                              &have_newmv, &have_row_mvs);
    409    }
    410 
    411    const int nearest_match = have_col_mvs + have_row_mvs;
    412    const int nearest_cnt = *cnt;
    413    for (int n = 0; n < nearest_cnt; n++)
    414        mvstack[n].weight += 640;
    415 
    416    // temporal
    417    int globalmv_ctx = rf->frm_hdr->use_ref_frame_mvs;
    418    if (rf->use_ref_frame_mvs) {
    419        const ptrdiff_t stride = rf->rp_stride;
    420        const int by8 = by4 >> 1, bx8 = bx4 >> 1;
    421        const refmvs_temporal_block *const rbi = &rt->rp_proj[(by8 & 15) * stride + bx8];
    422        const refmvs_temporal_block *rb = rbi;
    423        const int step_h = bw4 >= 16 ? 2 : 1, step_v = bh4 >= 16 ? 2 : 1;
    424        const int w8 = imin((w4 + 1) >> 1, 8), h8 = imin((h4 + 1) >> 1, 8);
    425        for (int y = 0; y < h8; y += step_v) {
    426            for (int x = 0; x < w8; x+= step_h) {
    427                add_temporal_candidate(rf, mvstack, cnt, &rb[x], ref,
    428                                       !(x | y) ? &globalmv_ctx : NULL, tgmv);
    429            }
    430            rb += stride * step_v;
    431        }
    432        if (imin(bw4, bh4) >= 2 && imax(bw4, bh4) < 16) {
    433            const int bh8 = bh4 >> 1, bw8 = bw4 >> 1;
    434            rb = &rbi[bh8 * stride];
    435            const int has_bottom = by8 + bh8 < imin(rt->tile_row.end >> 1,
    436                                                    (by8 & ~7) + 8);
    437            if (has_bottom && bx8 - 1 >= imax(rt->tile_col.start >> 1, bx8 & ~7)) {
    438                add_temporal_candidate(rf, mvstack, cnt, &rb[-1], ref,
    439                                       NULL, NULL);
    440            }
    441            if (bx8 + bw8 < imin(rt->tile_col.end >> 1, (bx8 & ~7) + 8)) {
    442                if (has_bottom) {
    443                    add_temporal_candidate(rf, mvstack, cnt, &rb[bw8], ref,
    444                                           NULL, NULL);
    445                }
    446                if (by8 + bh8 - 1 < imin(rt->tile_row.end >> 1, (by8 & ~7) + 8)) {
    447                    add_temporal_candidate(rf, mvstack, cnt, &rb[bw8 - stride],
    448                                           ref, NULL, NULL);
    449                }
    450            }
    451        }
    452    }
    453    assert(*cnt <= 8);
    454 
    455    // top/left (which, confusingly, is part of "secondary" references)
    456    int have_dummy_newmv_match;
    457    if ((n_rows | n_cols) != ~0U) {
    458        add_spatial_candidate(mvstack, cnt, 4, &b_top[-1], ref, gmv,
    459                              &have_dummy_newmv_match, &have_row_mvs);
    460    }
    461 
    462    // "secondary" (non-direct neighbour) top & left edges
    463    // what is different about secondary is that everything is now in 8x8 resolution
    464    for (int n = 2; n <= 3; n++) {
    465        if ((unsigned) n > n_rows && (unsigned) n <= max_rows) {
    466            n_rows += scan_row(mvstack, cnt, ref, gmv,
    467                               &rt->r[(((by4 & 31) - 2 * n + 1) | 1) + 5][bx4 | 1],
    468                               bw4, w4, 1 + max_rows - n, bw4 >= 16 ? 4 : 2,
    469                               &have_dummy_newmv_match, &have_row_mvs);
    470        }
    471 
    472        if ((unsigned) n > n_cols && (unsigned) n <= max_cols) {
    473            n_cols += scan_col(mvstack, cnt, ref, gmv, &rt->r[((by4 & 31) | 1) + 5],
    474                               bh4, h4, (bx4 - n * 2 + 1) | 1,
    475                               1 + max_cols - n, bh4 >= 16 ? 4 : 2,
    476                               &have_dummy_newmv_match, &have_col_mvs);
    477        }
    478    }
    479    assert(*cnt <= 8);
    480 
    481    const int ref_match_count = have_col_mvs + have_row_mvs;
    482 
    483    // context build-up
    484    int refmv_ctx, newmv_ctx;
    485    switch (nearest_match) {
    486    case 0:
    487        refmv_ctx = imin(2, ref_match_count);
    488        newmv_ctx = ref_match_count > 0;
    489        break;
    490    case 1:
    491        refmv_ctx = imin(ref_match_count * 3, 4);
    492        newmv_ctx = 3 - have_newmv;
    493        break;
    494    case 2:
    495        refmv_ctx = 5;
    496        newmv_ctx = 5 - have_newmv;
    497        break;
    498    }
    499 
    500    // sorting (nearest, then "secondary")
    501    int len = nearest_cnt;
    502    while (len) {
    503        int last = 0;
    504        for (int n = 1; n < len; n++) {
    505            if (mvstack[n - 1].weight < mvstack[n].weight) {
    506 #define EXCHANGE(a, b) do { refmvs_candidate tmp = a; a = b; b = tmp; } while (0)
    507                EXCHANGE(mvstack[n - 1], mvstack[n]);
    508                last = n;
    509            }
    510        }
    511        len = last;
    512    }
    513    len = *cnt;
    514    while (len > nearest_cnt) {
    515        int last = nearest_cnt;
    516        for (int n = nearest_cnt + 1; n < len; n++) {
    517            if (mvstack[n - 1].weight < mvstack[n].weight) {
    518                EXCHANGE(mvstack[n - 1], mvstack[n]);
    519 #undef EXCHANGE
    520                last = n;
    521            }
    522        }
    523        len = last;
    524    }
    525 
    526    if (ref.ref[1] > 0) {
    527        if (*cnt < 2) {
    528            const int sign0 = rf->sign_bias[ref.ref[0] - 1];
    529            const int sign1 = rf->sign_bias[ref.ref[1] - 1];
    530            const int sz4 = imin(w4, h4);
    531            refmvs_candidate *const same = &mvstack[*cnt];
    532            int same_count[4] = { 0 };
    533 
    534            // non-self references in top
    535            if (n_rows != ~0U) for (int x = 0; x < sz4;) {
    536                const refmvs_block *const cand_b = &b_top[x];
    537                add_compound_extended_candidate(same, same_count, cand_b,
    538                                                sign0, sign1, ref, rf->sign_bias);
    539                x += dav1d_block_dimensions[cand_b->bs][0];
    540            }
    541 
    542            // non-self references in left
    543            if (n_cols != ~0U) for (int y = 0; y < sz4;) {
    544                const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
    545                add_compound_extended_candidate(same, same_count, cand_b,
    546                                                sign0, sign1, ref, rf->sign_bias);
    547                y += dav1d_block_dimensions[cand_b->bs][1];
    548            }
    549 
    550            refmvs_candidate *const diff = &same[2];
    551            const int *const diff_count = &same_count[2];
    552 
    553            // merge together
    554            for (int n = 0; n < 2; n++) {
    555                int m = same_count[n];
    556 
    557                if (m >= 2) continue;
    558 
    559                const int l = diff_count[n];
    560                if (l) {
    561                    same[m].mv.mv[n] = diff[0].mv.mv[n];
    562                    if (++m == 2) continue;
    563                    if (l == 2) {
    564                        same[1].mv.mv[n] = diff[1].mv.mv[n];
    565                        continue;
    566                    }
    567                }
    568                do {
    569                    same[m].mv.mv[n] = tgmv[n];
    570                } while (++m < 2);
    571            }
    572 
    573            // if the first extended was the same as the non-extended one,
    574            // then replace it with the second extended one
    575            int n = *cnt;
    576            if (n == 1 && mvstack[0].mv.n == same[0].mv.n)
    577                mvstack[1].mv = mvstack[2].mv;
    578            do {
    579                mvstack[n].weight = 2;
    580            } while (++n < 2);
    581            *cnt = 2;
    582        }
    583 
    584        // clamping
    585        const int left = -(bx4 + bw4 + 4) * 4 * 8;
    586        const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
    587        const int top = -(by4 + bh4 + 4) * 4 * 8;
    588        const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
    589 
    590        const int n_refmvs = *cnt;
    591        int n = 0;
    592        do {
    593            mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
    594            mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
    595            mvstack[n].mv.mv[1].x = iclip(mvstack[n].mv.mv[1].x, left, right);
    596            mvstack[n].mv.mv[1].y = iclip(mvstack[n].mv.mv[1].y, top, bottom);
    597        } while (++n < n_refmvs);
    598 
    599        switch (refmv_ctx >> 1) {
    600        case 0:
    601            *ctx = imin(newmv_ctx, 1);
    602            break;
    603        case 1:
    604            *ctx = 1 + imin(newmv_ctx, 3);
    605            break;
    606        case 2:
    607            *ctx = iclip(3 + newmv_ctx, 4, 7);
    608            break;
    609        }
    610 
    611        return;
    612    } else if (*cnt < 2 && ref.ref[0] > 0) {
    613        const int sign = rf->sign_bias[ref.ref[0] - 1];
    614        const int sz4 = imin(w4, h4);
    615 
    616        // non-self references in top
    617        if (n_rows != ~0U) for (int x = 0; x < sz4 && *cnt < 2;) {
    618            const refmvs_block *const cand_b = &b_top[x];
    619            add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
    620            x += dav1d_block_dimensions[cand_b->bs][0];
    621        }
    622 
    623        // non-self references in left
    624        if (n_cols != ~0U) for (int y = 0; y < sz4 && *cnt < 2;) {
    625            const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
    626            add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
    627            y += dav1d_block_dimensions[cand_b->bs][1];
    628        }
    629    }
    630    assert(*cnt <= 8);
    631 
    632    // clamping
    633    int n_refmvs = *cnt;
    634    if (n_refmvs) {
    635        const int left = -(bx4 + bw4 + 4) * 4 * 8;
    636        const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
    637        const int top = -(by4 + bh4 + 4) * 4 * 8;
    638        const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
    639 
    640        int n = 0;
    641        do {
    642            mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
    643            mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
    644        } while (++n < n_refmvs);
    645    }
    646 
    647    for (int n = *cnt; n < 2; n++)
    648        mvstack[n].mv.mv[0] = tgmv[0];
    649 
    650    *ctx = (refmv_ctx << 4) | (globalmv_ctx << 3) | newmv_ctx;
    651 }
    652 
    653 void dav1d_refmvs_tile_sbrow_init(refmvs_tile *const rt, const refmvs_frame *const rf,
    654                                  const int tile_col_start4, const int tile_col_end4,
    655                                  const int tile_row_start4, const int tile_row_end4,
    656                                  const int sby, int tile_row_idx, const int pass)
    657 {
    658    if (rf->n_tile_threads == 1) tile_row_idx = 0;
    659    rt->rp_proj = &rf->rp_proj[16 * rf->rp_stride * tile_row_idx];
    660    const ptrdiff_t r_stride = rf->rp_stride * 2;
    661    const ptrdiff_t pass_off = (rf->n_frame_threads > 1 && pass == 2) ?
    662        35 * 2 * rf->n_blocks : 0;
    663    refmvs_block *r = &rf->r[35 * r_stride * tile_row_idx + pass_off];
    664    const int sbsz = rf->sbsz;
    665    const int off = (sbsz * sby) & 16;
    666    for (int i = 0; i < sbsz; i++, r += r_stride)
    667        rt->r[off + 5 + i] = r;
    668    rt->r[off + 0] = r;
    669    r += r_stride;
    670    rt->r[off + 1] = NULL;
    671    rt->r[off + 2] = r;
    672    r += r_stride;
    673    rt->r[off + 3] = NULL;
    674    rt->r[off + 4] = r;
    675    if (sby & 1) {
    676 #define EXCHANGE(a, b) do { void *const tmp = a; a = b; b = tmp; } while (0)
    677        EXCHANGE(rt->r[off + 0], rt->r[off + sbsz + 0]);
    678        EXCHANGE(rt->r[off + 2], rt->r[off + sbsz + 2]);
    679        EXCHANGE(rt->r[off + 4], rt->r[off + sbsz + 4]);
    680 #undef EXCHANGE
    681    }
    682 
    683    rt->rf = rf;
    684    rt->tile_row.start = tile_row_start4;
    685    rt->tile_row.end = imin(tile_row_end4, rf->ih4);
    686    rt->tile_col.start = tile_col_start4;
    687    rt->tile_col.end = imin(tile_col_end4, rf->iw4);
    688 }
    689 
    690 static void load_tmvs_c(const refmvs_frame *const rf, int tile_row_idx,
    691                        const int col_start8, const int col_end8,
    692                        const int row_start8, int row_end8)
    693 {
    694    if (rf->n_tile_threads == 1) tile_row_idx = 0;
    695    assert(row_start8 >= 0);
    696    assert((unsigned) (row_end8 - row_start8) <= 16U);
    697    row_end8 = imin(row_end8, rf->ih8);
    698    const int col_start8i = imax(col_start8 - 8, 0);
    699    const int col_end8i = imin(col_end8 + 8, rf->iw8);
    700 
    701    const ptrdiff_t stride = rf->rp_stride;
    702    refmvs_temporal_block *rp_proj =
    703        &rf->rp_proj[16 * stride * tile_row_idx + (row_start8 & 15) * stride];
    704    for (int y = row_start8; y < row_end8; y++) {
    705        for (int x = col_start8; x < col_end8; x++)
    706            rp_proj[x].mv.n = INVALID_MV;
    707        rp_proj += stride;
    708    }
    709 
    710    rp_proj = &rf->rp_proj[16 * stride * tile_row_idx];
    711    for (int n = 0; n < rf->n_mfmvs; n++) {
    712        const int ref2cur = rf->mfmv_ref2cur[n];
    713        if (ref2cur == INVALID_REF2CUR) continue;
    714 
    715        const int ref = rf->mfmv_ref[n];
    716        const int ref_sign = ref - 4;
    717        const refmvs_temporal_block *r = &rf->rp_ref[ref][row_start8 * stride];
    718        for (int y = row_start8; y < row_end8; y++) {
    719            const int y_sb_align = y & ~7;
    720            const int y_proj_start = imax(y_sb_align, row_start8);
    721            const int y_proj_end = imin(y_sb_align + 8, row_end8);
    722            for (int x = col_start8i; x < col_end8i; x++) {
    723                const refmvs_temporal_block *rb = &r[x];
    724                const int b_ref = rb->ref;
    725                if (!b_ref) continue;
    726                const int ref2ref = rf->mfmv_ref2ref[n][b_ref - 1];
    727                if (!ref2ref) continue;
    728                const mv b_mv = rb->mv;
    729                const mv offset = mv_projection(b_mv, ref2cur, ref2ref);
    730                int pos_x = x + apply_sign(abs(offset.x) >> 6,
    731                                           offset.x ^ ref_sign);
    732                const int pos_y = y + apply_sign(abs(offset.y) >> 6,
    733                                                 offset.y ^ ref_sign);
    734                if (pos_y >= y_proj_start && pos_y < y_proj_end) {
    735                    const ptrdiff_t pos = (pos_y & 15) * stride;
    736                    for (;;) {
    737                        const int x_sb_align = x & ~7;
    738                        if (pos_x >= imax(x_sb_align - 8, col_start8) &&
    739                            pos_x < imin(x_sb_align + 16, col_end8))
    740                        {
    741                            rp_proj[pos + pos_x].mv = rb->mv;
    742                            rp_proj[pos + pos_x].ref = ref2ref;
    743                        }
    744                        if (++x >= col_end8i) break;
    745                        rb++;
    746                        if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
    747                        pos_x++;
    748                    }
    749                } else {
    750                    for (;;) {
    751                        if (++x >= col_end8i) break;
    752                        rb++;
    753                        if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
    754                    }
    755                }
    756                x--;
    757            }
    758            r += stride;
    759        }
    760    }
    761 }
    762 
    763 static void save_tmvs_c(refmvs_temporal_block *rp, const ptrdiff_t stride,
    764                        refmvs_block *const *const rr,
    765                        const uint8_t *const ref_sign,
    766                        const int col_end8, const int row_end8,
    767                        const int col_start8, const int row_start8)
    768 {
    769    for (int y = row_start8; y < row_end8; y++) {
    770        const refmvs_block *const b = rr[(y & 15) * 2];
    771 
    772        for (int x = col_start8; x < col_end8;) {
    773            const refmvs_block *const cand_b = &b[x * 2 + 1];
    774            const int bw8 = (dav1d_block_dimensions[cand_b->bs][0] + 1) >> 1;
    775 
    776            if (cand_b->ref.ref[1] > 0 && ref_sign[cand_b->ref.ref[1] - 1] &&
    777                (abs(cand_b->mv.mv[1].y) | abs(cand_b->mv.mv[1].x)) < 4096)
    778            {
    779                for (int n = 0; n < bw8; n++, x++)
    780                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[1],
    781                                                      .ref = cand_b->ref.ref[1] };
    782            } else if (cand_b->ref.ref[0] > 0 && ref_sign[cand_b->ref.ref[0] - 1] &&
    783                       (abs(cand_b->mv.mv[0].y) | abs(cand_b->mv.mv[0].x)) < 4096)
    784            {
    785                for (int n = 0; n < bw8; n++, x++)
    786                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[0],
    787                                                      .ref = cand_b->ref.ref[0] };
    788            } else {
    789                for (int n = 0; n < bw8; n++, x++) {
    790                    rp[x].mv.n = 0;
    791                    rp[x].ref = 0; // "invalid"
    792                }
    793            }
    794        }
    795        rp += stride;
    796    }
    797 }
    798 
    799 int dav1d_refmvs_init_frame(refmvs_frame *const rf,
    800                            const Dav1dSequenceHeader *const seq_hdr,
    801                            const Dav1dFrameHeader *const frm_hdr,
    802                            const uint8_t ref_poc[7],
    803                            refmvs_temporal_block *const rp,
    804                            const uint8_t ref_ref_poc[7][7],
    805                            /*const*/ refmvs_temporal_block *const rp_ref[7],
    806                            const int n_tile_threads, const int n_frame_threads)
    807 {
    808    const int rp_stride = ((frm_hdr->width[0] + 127) & ~127) >> 3;
    809    const int n_tile_rows = n_tile_threads > 1 ? frm_hdr->tiling.rows : 1;
    810    const int n_blocks = rp_stride * n_tile_rows;
    811 
    812    rf->sbsz = 16 << seq_hdr->sb128;
    813    rf->frm_hdr = frm_hdr;
    814    rf->iw8 = (frm_hdr->width[0] + 7) >> 3;
    815    rf->ih8 = (frm_hdr->height + 7) >> 3;
    816    rf->iw4 = rf->iw8 << 1;
    817    rf->ih4 = rf->ih8 << 1;
    818    rf->rp = rp;
    819    rf->rp_stride = rp_stride;
    820    rf->n_tile_threads = n_tile_threads;
    821    rf->n_frame_threads = n_frame_threads;
    822 
    823    if (n_blocks != rf->n_blocks) {
    824        const size_t r_sz = sizeof(*rf->r) * 35 * 2 * n_blocks * (1 + (n_frame_threads > 1));
    825        const size_t rp_proj_sz = sizeof(*rf->rp_proj) * 16 * n_blocks;
    826        /* Note that sizeof(*rf->r) == 12, but it's accessed using 16-byte unaligned
    827         * loads in save_tmvs() asm which can overread 4 bytes into rp_proj. */
    828        dav1d_free_aligned(rf->r);
    829        rf->r = dav1d_alloc_aligned(ALLOC_REFMVS, r_sz + rp_proj_sz, 64);
    830        if (!rf->r) {
    831            rf->n_blocks = 0;
    832            return DAV1D_ERR(ENOMEM);
    833        }
    834 
    835        rf->rp_proj = (refmvs_temporal_block*)((uintptr_t)rf->r + r_sz);
    836        rf->n_blocks = n_blocks;
    837    }
    838 
    839    const int poc = frm_hdr->frame_offset;
    840    for (int i = 0; i < 7; i++) {
    841        const int poc_diff = get_poc_diff(seq_hdr->order_hint_n_bits,
    842                                          ref_poc[i], poc);
    843        rf->sign_bias[i] = poc_diff > 0;
    844        rf->mfmv_sign[i] = poc_diff < 0;
    845        rf->pocdiff[i] = iclip(get_poc_diff(seq_hdr->order_hint_n_bits,
    846                                            poc, ref_poc[i]), -31, 31);
    847    }
    848 
    849    // temporal MV setup
    850    rf->n_mfmvs = 0;
    851    rf->rp_ref = rp_ref;
    852    if (frm_hdr->use_ref_frame_mvs && seq_hdr->order_hint_n_bits) {
    853        int total = 2;
    854        if (rp_ref[0] && ref_ref_poc[0][6] != ref_poc[3] /* alt-of-last != gold */) {
    855            rf->mfmv_ref[rf->n_mfmvs++] = 0; // last
    856            total = 3;
    857        }
    858        if (rp_ref[4] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[4],
    859                                      frm_hdr->frame_offset) > 0)
    860        {
    861            rf->mfmv_ref[rf->n_mfmvs++] = 4; // bwd
    862        }
    863        if (rp_ref[5] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[5],
    864                                      frm_hdr->frame_offset) > 0)
    865        {
    866            rf->mfmv_ref[rf->n_mfmvs++] = 5; // altref2
    867        }
    868        if (rf->n_mfmvs < total && rp_ref[6] &&
    869            get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[6],
    870                         frm_hdr->frame_offset) > 0)
    871        {
    872            rf->mfmv_ref[rf->n_mfmvs++] = 6; // altref
    873        }
    874        if (rf->n_mfmvs < total && rp_ref[1])
    875            rf->mfmv_ref[rf->n_mfmvs++] = 1; // last2
    876 
    877        for (int n = 0; n < rf->n_mfmvs; n++) {
    878            const int rpoc = ref_poc[rf->mfmv_ref[n]];
    879            const int diff1 = get_poc_diff(seq_hdr->order_hint_n_bits,
    880                                           rpoc, frm_hdr->frame_offset);
    881            if (abs(diff1) > 31) {
    882                rf->mfmv_ref2cur[n] = INVALID_REF2CUR;
    883            } else {
    884                rf->mfmv_ref2cur[n] = rf->mfmv_ref[n] < 4 ? -diff1 : diff1;
    885                for (int m = 0; m < 7; m++) {
    886                    const int rrpoc = ref_ref_poc[rf->mfmv_ref[n]][m];
    887                    const int diff2 = get_poc_diff(seq_hdr->order_hint_n_bits,
    888                                                   rpoc, rrpoc);
    889                    // unsigned comparison also catches the < 0 case
    890                    rf->mfmv_ref2ref[n][m] = (unsigned) diff2 > 31U ? 0 : diff2;
    891                }
    892            }
    893        }
    894    }
    895    rf->use_ref_frame_mvs = rf->n_mfmvs > 0;
    896 
    897    return 0;
    898 }
    899 
    900 static void splat_mv_c(refmvs_block **rr, const refmvs_block *const rmv,
    901                       const int bx4, const int bw4, int bh4)
    902 {
    903    do {
    904        refmvs_block *const r = *rr++ + bx4;
    905        for (int x = 0; x < bw4; x++)
    906            r[x] = *rmv;
    907    } while (--bh4);
    908 }
    909 
    910 #if HAVE_ASM
    911 #if ARCH_AARCH64 || ARCH_ARM
    912 #include "src/arm/refmvs.h"
    913 #elif ARCH_LOONGARCH64
    914 #include "src/loongarch/refmvs.h"
    915 #elif ARCH_X86
    916 #include "src/x86/refmvs.h"
    917 #endif
    918 #endif
    919 
    920 COLD void dav1d_refmvs_dsp_init(Dav1dRefmvsDSPContext *const c)
    921 {
    922    c->load_tmvs = load_tmvs_c;
    923    c->save_tmvs = save_tmvs_c;
    924    c->splat_mv = splat_mv_c;
    925 
    926 #if HAVE_ASM
    927 #if ARCH_AARCH64 || ARCH_ARM
    928    refmvs_dsp_init_arm(c);
    929 #elif ARCH_LOONGARCH64
    930    refmvs_dsp_init_loongarch(c);
    931 #elif ARCH_X86
    932    refmvs_dsp_init_x86(c);
    933 #endif
    934 #endif
    935 }