highbd_subtract_sse2.c (11678B)
1 /* 2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <assert.h> 13 #include <emmintrin.h> 14 #include <stddef.h> 15 16 #include "config/aom_config.h" 17 #include "config/aom_dsp_rtcd.h" 18 19 typedef void (*SubtractWxHFuncType)(int16_t *diff, ptrdiff_t diff_stride, 20 const uint16_t *src, ptrdiff_t src_stride, 21 const uint16_t *pred, 22 ptrdiff_t pred_stride); 23 24 static void subtract_4x4(int16_t *diff, ptrdiff_t diff_stride, 25 const uint16_t *src, ptrdiff_t src_stride, 26 const uint16_t *pred, ptrdiff_t pred_stride) { 27 __m128i u0, u1, u2, u3; 28 __m128i v0, v1, v2, v3; 29 __m128i x0, x1, x2, x3; 30 int64_t *store_diff = (int64_t *)(diff + 0 * diff_stride); 31 32 u0 = _mm_loadl_epi64((__m128i const *)(src + 0 * src_stride)); 33 u1 = _mm_loadl_epi64((__m128i const *)(src + 1 * src_stride)); 34 u2 = _mm_loadl_epi64((__m128i const *)(src + 2 * src_stride)); 35 u3 = _mm_loadl_epi64((__m128i const *)(src + 3 * src_stride)); 36 37 v0 = _mm_loadl_epi64((__m128i const *)(pred + 0 * pred_stride)); 38 v1 = _mm_loadl_epi64((__m128i const *)(pred + 1 * pred_stride)); 39 v2 = _mm_loadl_epi64((__m128i const *)(pred + 2 * pred_stride)); 40 v3 = _mm_loadl_epi64((__m128i const *)(pred + 3 * pred_stride)); 41 42 x0 = _mm_sub_epi16(u0, v0); 43 x1 = _mm_sub_epi16(u1, v1); 44 x2 = _mm_sub_epi16(u2, v2); 45 x3 = _mm_sub_epi16(u3, v3); 46 47 _mm_storel_epi64((__m128i *)store_diff, x0); 48 store_diff = (int64_t *)(diff + 1 * diff_stride); 49 _mm_storel_epi64((__m128i *)store_diff, x1); 50 store_diff = (int64_t *)(diff + 2 * diff_stride); 51 _mm_storel_epi64((__m128i *)store_diff, x2); 52 store_diff = (int64_t *)(diff + 3 * diff_stride); 53 _mm_storel_epi64((__m128i *)store_diff, x3); 54 } 55 56 static void subtract_4x8(int16_t *diff, ptrdiff_t diff_stride, 57 const uint16_t *src, ptrdiff_t src_stride, 58 const uint16_t *pred, ptrdiff_t pred_stride) { 59 __m128i u0, u1, u2, u3, u4, u5, u6, u7; 60 __m128i v0, v1, v2, v3, v4, v5, v6, v7; 61 __m128i x0, x1, x2, x3, x4, x5, x6, x7; 62 int64_t *store_diff = (int64_t *)(diff + 0 * diff_stride); 63 64 u0 = _mm_loadl_epi64((__m128i const *)(src + 0 * src_stride)); 65 u1 = _mm_loadl_epi64((__m128i const *)(src + 1 * src_stride)); 66 u2 = _mm_loadl_epi64((__m128i const *)(src + 2 * src_stride)); 67 u3 = _mm_loadl_epi64((__m128i const *)(src + 3 * src_stride)); 68 u4 = _mm_loadl_epi64((__m128i const *)(src + 4 * src_stride)); 69 u5 = _mm_loadl_epi64((__m128i const *)(src + 5 * src_stride)); 70 u6 = _mm_loadl_epi64((__m128i const *)(src + 6 * src_stride)); 71 u7 = _mm_loadl_epi64((__m128i const *)(src + 7 * src_stride)); 72 73 v0 = _mm_loadl_epi64((__m128i const *)(pred + 0 * pred_stride)); 74 v1 = _mm_loadl_epi64((__m128i const *)(pred + 1 * pred_stride)); 75 v2 = _mm_loadl_epi64((__m128i const *)(pred + 2 * pred_stride)); 76 v3 = _mm_loadl_epi64((__m128i const *)(pred + 3 * pred_stride)); 77 v4 = _mm_loadl_epi64((__m128i const *)(pred + 4 * pred_stride)); 78 v5 = _mm_loadl_epi64((__m128i const *)(pred + 5 * pred_stride)); 79 v6 = _mm_loadl_epi64((__m128i const *)(pred + 6 * pred_stride)); 80 v7 = _mm_loadl_epi64((__m128i const *)(pred + 7 * pred_stride)); 81 82 x0 = _mm_sub_epi16(u0, v0); 83 x1 = _mm_sub_epi16(u1, v1); 84 x2 = _mm_sub_epi16(u2, v2); 85 x3 = _mm_sub_epi16(u3, v3); 86 x4 = _mm_sub_epi16(u4, v4); 87 x5 = _mm_sub_epi16(u5, v5); 88 x6 = _mm_sub_epi16(u6, v6); 89 x7 = _mm_sub_epi16(u7, v7); 90 91 _mm_storel_epi64((__m128i *)store_diff, x0); 92 store_diff = (int64_t *)(diff + 1 * diff_stride); 93 _mm_storel_epi64((__m128i *)store_diff, x1); 94 store_diff = (int64_t *)(diff + 2 * diff_stride); 95 _mm_storel_epi64((__m128i *)store_diff, x2); 96 store_diff = (int64_t *)(diff + 3 * diff_stride); 97 _mm_storel_epi64((__m128i *)store_diff, x3); 98 store_diff = (int64_t *)(diff + 4 * diff_stride); 99 _mm_storel_epi64((__m128i *)store_diff, x4); 100 store_diff = (int64_t *)(diff + 5 * diff_stride); 101 _mm_storel_epi64((__m128i *)store_diff, x5); 102 store_diff = (int64_t *)(diff + 6 * diff_stride); 103 _mm_storel_epi64((__m128i *)store_diff, x6); 104 store_diff = (int64_t *)(diff + 7 * diff_stride); 105 _mm_storel_epi64((__m128i *)store_diff, x7); 106 } 107 108 static void subtract_8x4(int16_t *diff, ptrdiff_t diff_stride, 109 const uint16_t *src, ptrdiff_t src_stride, 110 const uint16_t *pred, ptrdiff_t pred_stride) { 111 __m128i u0, u1, u2, u3; 112 __m128i v0, v1, v2, v3; 113 __m128i x0, x1, x2, x3; 114 115 u0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride)); 116 u1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride)); 117 u2 = _mm_loadu_si128((__m128i const *)(src + 2 * src_stride)); 118 u3 = _mm_loadu_si128((__m128i const *)(src + 3 * src_stride)); 119 120 v0 = _mm_loadu_si128((__m128i const *)(pred + 0 * pred_stride)); 121 v1 = _mm_loadu_si128((__m128i const *)(pred + 1 * pred_stride)); 122 v2 = _mm_loadu_si128((__m128i const *)(pred + 2 * pred_stride)); 123 v3 = _mm_loadu_si128((__m128i const *)(pred + 3 * pred_stride)); 124 125 x0 = _mm_sub_epi16(u0, v0); 126 x1 = _mm_sub_epi16(u1, v1); 127 x2 = _mm_sub_epi16(u2, v2); 128 x3 = _mm_sub_epi16(u3, v3); 129 130 _mm_storeu_si128((__m128i *)(diff + 0 * diff_stride), x0); 131 _mm_storeu_si128((__m128i *)(diff + 1 * diff_stride), x1); 132 _mm_storeu_si128((__m128i *)(diff + 2 * diff_stride), x2); 133 _mm_storeu_si128((__m128i *)(diff + 3 * diff_stride), x3); 134 } 135 136 static void subtract_8x8(int16_t *diff, ptrdiff_t diff_stride, 137 const uint16_t *src, ptrdiff_t src_stride, 138 const uint16_t *pred, ptrdiff_t pred_stride) { 139 __m128i u0, u1, u2, u3, u4, u5, u6, u7; 140 __m128i v0, v1, v2, v3, v4, v5, v6, v7; 141 __m128i x0, x1, x2, x3, x4, x5, x6, x7; 142 143 u0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride)); 144 u1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride)); 145 u2 = _mm_loadu_si128((__m128i const *)(src + 2 * src_stride)); 146 u3 = _mm_loadu_si128((__m128i const *)(src + 3 * src_stride)); 147 u4 = _mm_loadu_si128((__m128i const *)(src + 4 * src_stride)); 148 u5 = _mm_loadu_si128((__m128i const *)(src + 5 * src_stride)); 149 u6 = _mm_loadu_si128((__m128i const *)(src + 6 * src_stride)); 150 u7 = _mm_loadu_si128((__m128i const *)(src + 7 * src_stride)); 151 152 v0 = _mm_loadu_si128((__m128i const *)(pred + 0 * pred_stride)); 153 v1 = _mm_loadu_si128((__m128i const *)(pred + 1 * pred_stride)); 154 v2 = _mm_loadu_si128((__m128i const *)(pred + 2 * pred_stride)); 155 v3 = _mm_loadu_si128((__m128i const *)(pred + 3 * pred_stride)); 156 v4 = _mm_loadu_si128((__m128i const *)(pred + 4 * pred_stride)); 157 v5 = _mm_loadu_si128((__m128i const *)(pred + 5 * pred_stride)); 158 v6 = _mm_loadu_si128((__m128i const *)(pred + 6 * pred_stride)); 159 v7 = _mm_loadu_si128((__m128i const *)(pred + 7 * pred_stride)); 160 161 x0 = _mm_sub_epi16(u0, v0); 162 x1 = _mm_sub_epi16(u1, v1); 163 x2 = _mm_sub_epi16(u2, v2); 164 x3 = _mm_sub_epi16(u3, v3); 165 x4 = _mm_sub_epi16(u4, v4); 166 x5 = _mm_sub_epi16(u5, v5); 167 x6 = _mm_sub_epi16(u6, v6); 168 x7 = _mm_sub_epi16(u7, v7); 169 170 _mm_storeu_si128((__m128i *)(diff + 0 * diff_stride), x0); 171 _mm_storeu_si128((__m128i *)(diff + 1 * diff_stride), x1); 172 _mm_storeu_si128((__m128i *)(diff + 2 * diff_stride), x2); 173 _mm_storeu_si128((__m128i *)(diff + 3 * diff_stride), x3); 174 _mm_storeu_si128((__m128i *)(diff + 4 * diff_stride), x4); 175 _mm_storeu_si128((__m128i *)(diff + 5 * diff_stride), x5); 176 _mm_storeu_si128((__m128i *)(diff + 6 * diff_stride), x6); 177 _mm_storeu_si128((__m128i *)(diff + 7 * diff_stride), x7); 178 } 179 180 #define STACK_V(h, fun) \ 181 do { \ 182 fun(diff, diff_stride, src, src_stride, pred, pred_stride); \ 183 fun(diff + diff_stride * h, diff_stride, src + src_stride * h, src_stride, \ 184 pred + pred_stride * h, pred_stride); \ 185 } while (0) 186 187 #define STACK_H(w, fun) \ 188 do { \ 189 fun(diff, diff_stride, src, src_stride, pred, pred_stride); \ 190 fun(diff + w, diff_stride, src + w, src_stride, pred + w, pred_stride); \ 191 } while (0) 192 193 #define SUBTRACT_FUN(size) \ 194 static void subtract_##size(int16_t *diff, ptrdiff_t diff_stride, \ 195 const uint16_t *src, ptrdiff_t src_stride, \ 196 const uint16_t *pred, ptrdiff_t pred_stride) 197 198 SUBTRACT_FUN(8x16) { STACK_V(8, subtract_8x8); } 199 SUBTRACT_FUN(16x8) { STACK_H(8, subtract_8x8); } 200 SUBTRACT_FUN(16x16) { STACK_V(8, subtract_16x8); } 201 SUBTRACT_FUN(16x32) { STACK_V(16, subtract_16x16); } 202 SUBTRACT_FUN(32x16) { STACK_H(16, subtract_16x16); } 203 SUBTRACT_FUN(32x32) { STACK_V(16, subtract_32x16); } 204 SUBTRACT_FUN(32x64) { STACK_V(32, subtract_32x32); } 205 SUBTRACT_FUN(64x32) { STACK_H(32, subtract_32x32); } 206 SUBTRACT_FUN(64x64) { STACK_V(32, subtract_64x32); } 207 SUBTRACT_FUN(64x128) { STACK_V(64, subtract_64x64); } 208 SUBTRACT_FUN(128x64) { STACK_H(64, subtract_64x64); } 209 SUBTRACT_FUN(128x128) { STACK_V(64, subtract_128x64); } 210 SUBTRACT_FUN(4x16) { STACK_V(8, subtract_4x8); } 211 SUBTRACT_FUN(16x4) { STACK_H(8, subtract_8x4); } 212 SUBTRACT_FUN(8x32) { STACK_V(16, subtract_8x16); } 213 SUBTRACT_FUN(32x8) { STACK_H(16, subtract_16x8); } 214 SUBTRACT_FUN(16x64) { STACK_V(32, subtract_16x32); } 215 SUBTRACT_FUN(64x16) { STACK_H(32, subtract_32x16); } 216 217 static SubtractWxHFuncType getSubtractFunc(int rows, int cols) { 218 if (rows == 4) { 219 if (cols == 4) return subtract_4x4; 220 if (cols == 8) return subtract_8x4; 221 if (cols == 16) return subtract_16x4; 222 } 223 if (rows == 8) { 224 if (cols == 4) return subtract_4x8; 225 if (cols == 8) return subtract_8x8; 226 if (cols == 16) return subtract_16x8; 227 if (cols == 32) return subtract_32x8; 228 } 229 if (rows == 16) { 230 if (cols == 4) return subtract_4x16; 231 if (cols == 8) return subtract_8x16; 232 if (cols == 16) return subtract_16x16; 233 if (cols == 32) return subtract_32x16; 234 if (cols == 64) return subtract_64x16; 235 } 236 if (rows == 32) { 237 if (cols == 8) return subtract_8x32; 238 if (cols == 16) return subtract_16x32; 239 if (cols == 32) return subtract_32x32; 240 if (cols == 64) return subtract_64x32; 241 } 242 if (rows == 64) { 243 if (cols == 16) return subtract_16x64; 244 if (cols == 32) return subtract_32x64; 245 if (cols == 64) return subtract_64x64; 246 if (cols == 128) return subtract_128x64; 247 } 248 if (rows == 128) { 249 if (cols == 64) return subtract_64x128; 250 if (cols == 128) return subtract_128x128; 251 } 252 assert(0); 253 return NULL; 254 } 255 256 void aom_highbd_subtract_block_sse2(int rows, int cols, int16_t *diff, 257 ptrdiff_t diff_stride, const uint8_t *src8, 258 ptrdiff_t src_stride, const uint8_t *pred8, 259 ptrdiff_t pred_stride) { 260 uint16_t *src = CONVERT_TO_SHORTPTR(src8); 261 uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); 262 SubtractWxHFuncType func; 263 264 func = getSubtractFunc(rows, cols); 265 func(diff, diff_stride, src, src_stride, pred, pred_stride); 266 }