commit 110d20790a79123ae8f8371b49ac138d6c35eeaf
parent 5fa015e8d0cd7f8fd0140e33a06141d5b8d887c6
Author: Nick Mathewson <nickm@torproject.org>
Date: Thu, 15 May 2025 10:00:21 -0400
Speed up polyval through pipelining.
This optimization helps because:
- We're not blocking the computation of each block on the computation of the
previous one, which leads to fewer pipeline stalls.
- We're deferring reduction until the end of handling a bunch of blocks.
Diffstat:
3 files changed, 146 insertions(+), 8 deletions(-)
diff --git a/src/ext/polyval/pclmul.c b/src/ext/polyval/pclmul.c
@@ -148,11 +148,85 @@ pclmulqdq11(__m128i x, __m128i y)
_mm_slli_epi64(x2, 57))); \
} while (0)
+#define PCLMUL_BLOCK_STRIDE 4
+struct expanded_key_pclmul {
+ // powers of h in reverse order.
+ // (in other words, contains
+ // h^PCLMUL_BLOCK_STRIDE .. H^2, H^1
+ __m128i k[PCLMUL_BLOCK_STRIDE];
+};
+
+BR_TARGET("ssse3,pclmul")
+static inline void
+expand_key_pclmul(const polyval_t *pv, struct expanded_key_pclmul *out)
+{
+ __m128i h1w, h1x;
+ __m128i lastw, lastx;
+ __m128i t0, t1, t2, t3;
+
+ h1w = PCLMUL_MEMBER(pv->key.h);
+ BK(h1w, h1x);
+ out->k[PCLMUL_BLOCK_STRIDE-1] = lastw = h1w;
+
+ for (int i = PCLMUL_BLOCK_STRIDE - 2; i >= 0; --i) {
+ BK(lastw, lastx);
+
+ t1 = pclmulqdq11(lastw, h1w);
+ t3 = pclmulqdq00(lastw, h1w);
+ t2 = pclmulqdq00(lastx, h1x);
+ t2 = _mm_xor_si128(t2, _mm_xor_si128(t1, t3));
+ t0 = _mm_shuffle_epi32(t1, 0x0E);
+ t1 = _mm_xor_si128(t1, _mm_shuffle_epi32(t2, 0x0E));
+ t2 = _mm_xor_si128(t2, _mm_shuffle_epi32(t3, 0x0E));
+ REDUCE_F128(t0, t1, t2, t3);
+ out->k[i] = lastw = _mm_unpacklo_epi64(t1, t0);
+ }
+}
+
+// Add PCLMUL_BLOCK_STRIDE * 16 bytes from input.
+BR_TARGET("ssse3,pclmul")
+static inline void
+pv_add_multiple_pclmul(polyval_t *pv,
+ const uint8_t *input,
+ const struct expanded_key_pclmul *expanded)
+{
+ __m128i t0, t1, t2, t3;
+
+ t1 = _mm_setzero_si128();
+ t2 = _mm_setzero_si128();
+ t3 = _mm_setzero_si128();
+
+ for (int i = 0; i < PCLMUL_BLOCK_STRIDE; ++i, input += 16) {
+ __m128i aw = _mm_loadu_si128((void *)(input));
+ __m128i ax;
+ __m128i hx;
+ if (i == 0) {
+ aw = _mm_xor_si128(aw, PCLMUL_MEMBER(pv->y));
+ }
+ BK(aw, ax);
+ BK(expanded->k[i], hx);
+ t1 = _mm_xor_si128(t1,
+ pclmulqdq11(aw, expanded->k[i]));
+ t3 = _mm_xor_si128(t3,
+ pclmulqdq00(aw, expanded->k[i]));
+ t2 = _mm_xor_si128(t2,
+ pclmulqdq00(ax, hx));
+ }
+
+ t2 = _mm_xor_si128(t2, _mm_xor_si128(t1, t3));
+ t0 = _mm_shuffle_epi32(t1, 0x0E);
+ t1 = _mm_xor_si128(t1, _mm_shuffle_epi32(t2, 0x0E));
+ t2 = _mm_xor_si128(t2, _mm_shuffle_epi32(t3, 0x0E));
+
+ REDUCE_F128(t0, t1, t2, t3);
+ PCLMUL_MEMBER(pv->y) = _mm_unpacklo_epi64(t1, t0);
+}
+
/* see bearssl_hash.h */
BR_TARGET("ssse3,pclmul")
-static
-void pv_mul_y_h_pclmul(polyval_t *pv)
+static inline void
+pv_mul_y_h_pclmul(polyval_t *pv)
{
__m128i yw, h1w, h1x;
diff --git a/src/ext/polyval/polyval.c b/src/ext/polyval/polyval.c
@@ -219,12 +219,30 @@ pv_xor_y_ctmul(polyval_t *pv, u128 val)
}
#endif
+struct expanded_key_none {};
+static inline void add_multiple_none(polyval_t *pv,
+ const uint8_t *input,
+ const struct expanded_key_none *expanded)
+{
+ (void) pv;
+ (void) input;
+ (void) expanded;
+}
+static inline void expand_key_none(const polyval_t *inp,
+ struct expanded_key_none *out)
+{
+ (void) inp;
+ (void) out;
+}
+
#define PV_DECLARE(prefix, \
st, \
u128_from_bytes, \
u128_to_bytes, \
pv_xor_y, \
- pv_mul_y_h) \
+ pv_mul_y_h, \
+ block_stride, \
+ expanded_key_tp, expand_fn, add_multiple_fn) \
st void \
prefix ## polyval_key_init(polyval_key_t *pvk, const uint8_t *key) \
{ \
@@ -252,6 +270,15 @@ pv_xor_y_ctmul(polyval_t *pv, u128 val)
st void \
prefix ## polyval_add_zpad(polyval_t *pv, const uint8_t *data, size_t n) \
{ \
+ if (n > block_stride * 16) { \
+ expanded_key_tp expanded_key; \
+ expand_fn(pv, &expanded_key); \
+ while (n > block_stride * 16) { \
+ add_multiple_fn(pv, data, &expanded_key); \
+ n -= block_stride*16; \
+ data += block_stride * 16; \
+ } \
+ } \
while (n > 16) { \
polyval_add_block(pv, data); \
data += 16; \
@@ -288,13 +315,21 @@ PV_DECLARE(pclmul_, static,
u128_from_bytes_pclmul,
u128_to_bytes_pclmul,
pv_xor_y_pclmul,
- pv_mul_y_h_pclmul)
+ pv_mul_y_h_pclmul,
+ PCLMUL_BLOCK_STRIDE,
+ struct expanded_key_pclmul,
+ expand_key_pclmul,
+ pv_add_multiple_pclmul)
PV_DECLARE(ctmul64_, static,
u128_from_bytes_ctmul64,
u128_to_bytes_ctmul64,
pv_xor_y_ctmul64,
- pv_mul_y_h_ctmul64)
+ pv_mul_y_h_ctmul64,
+ 0,
+ struct expanded_key_none,
+ expand_key_none,
+ add_multiple_none)
void
polyval_key_init(polyval_key_t *pv, const uint8_t *key)
@@ -358,20 +393,32 @@ PV_DECLARE(, ,
u128_from_bytes_pclmul,
u128_to_bytes_pclmul,
pv_xor_y_pclmul,
- pv_mul_y_h_pclmul)
+ pv_mul_y_h_pclmul,
+ PCLMUL_BLOCK_STRIDE,
+ struct expanded_key_pclmul,
+ expand_key_pclmul,
+ pv_add_multiple_pclmul)
#elif defined(PV_USE_CTMUL64)
PV_DECLARE(, ,
u128_from_bytes_ctmul64,
u128_to_bytes_ctmul64,
pv_xor_y_ctmul64,
- pv_mul_y_h_ctmul64)
+ pv_mul_y_h_ctmul64,
+ 0,
+ struct expanded_key_none,
+ expand_key_none,
+ add_multiple_none)
#elif defined(PV_USE_CTMUL)
PV_DECLARE(, , u128_from_bytes_ctmul,
u128_to_bytes_ctmul,
pv_xor_y_ctmul,
- pv_mul_y_h_ctmul)
+ pv_mul_y_h_ctmul,
+ 0,
+ struct expanded_key_none,
+ expand_key_none,
+ add_multiple_none)
#endif
#ifdef PV_USE_PCLMUL_DETECT
diff --git a/src/test/test_crypto.c b/src/test/test_crypto.c
@@ -3201,6 +3201,7 @@ test_crypto_polyval(void *arg)
uint8_t output[16];
uint8_t output2[16];
char *mem_op_hex_tmp=NULL;
+ uint8_t *longer = NULL;
// From RFC 8452
const char *key_hex = "25629347589242761d31f826ba4b757b";
@@ -3236,8 +3237,24 @@ test_crypto_polyval(void *arg)
polyval_get_tag(&pv, output2);
tt_mem_op(output, OP_EQ, output2, 16);
+ // Try a long input both ways, and make sure the answer is the same.
+ longer = tor_malloc_zero(4096);
+ crypto_rand((char *)longer, 4090); // leave zeros at the end.
+ polyval_reset(&pv);
+ polyval_add_zpad(&pv, longer, 4090);
+ polyval_get_tag(&pv, output);
+
+ polyval_reset(&pv);
+ const uint8_t *cp;
+ for (cp = longer; cp < longer + 4096; cp += 16) {
+ polyval_add_block(&pv, cp);
+ }
+ polyval_get_tag(&pv, output2);
+ tt_mem_op(output, OP_EQ, output2, 16);
+
done:
tor_free(mem_op_hex_tmp);
+ tor_free(longer);
}
static void