vendor.py (12833B)
1 #!/usr/bin/python3 2 # This Source Code Form is subject to the terms of the Mozilla Public 3 # License, v. 2.0. If a copy of the MPL was not distributed with this 4 # file, You can obtain one at http://mozilla.org/MPL/2.0/. 5 6 import io 7 import os 8 import requests 9 import sys 10 import tarfile 11 12 13 def remove_include_guard(x: io.StringIO, guard: str) -> io.StringIO: 14 out = io.StringIO() 15 depth = 0 16 inside_guard = False 17 for line in x.readlines(): 18 tokens = line.split() 19 if tokens and tokens[0] in ["#if", "#ifdef", "#ifndef"]: 20 depth += 1 21 if len(tokens) > 1 and tokens[0] == "#ifndef" and tokens[1] == guard: 22 assert depth == 1, "error: nested include guard" 23 inside_guard = True 24 continue 25 if len(tokens) > 1 and tokens[0] == "#define" and tokens[1] == guard: 26 continue 27 if tokens and tokens[0] == "#endif": 28 depth -= 1 29 if depth == 0 and inside_guard: 30 inside_guard = False 31 continue 32 out.write(line) 33 out.seek(0) 34 return out 35 36 37 def remove_includes(x: io.StringIO) -> io.StringIO: 38 out = io.StringIO() 39 for line in x.readlines(): 40 tokens = line.split() 41 if tokens and tokens[0] == "#include": 42 continue 43 out.write(line) 44 out.seek(0) 45 return out 46 47 48 # Take the else branch of any #ifdef KYBER90s ... #else ... #endif 49 def remove_kyber90s(x: io.StringIO) -> io.StringIO: 50 out = io.StringIO() 51 states = ["before", "during-drop", "during-keep"] 52 state = "before" 53 current_depth = 0 54 kyber90s_depth = None 55 for line in x.readlines(): 56 tokens = line.split() 57 if tokens and tokens[0] in ["#if", "#ifdef", "#ifndef"]: 58 current_depth += 1 59 if len(tokens) > 1 and tokens[0] == "#ifdef" and tokens[1] == "KYBER_90S": 60 assert kyber90s_depth == None, "cannot handle nested #ifdef KYBER90S" 61 kyber90s_depth = current_depth 62 state = "during-drop" 63 continue 64 if len(tokens) > 1 and tokens[0] == "#ifndef" and tokens[1] == "KYBER_90S": 65 assert kyber90s_depth == None, "cannot handle nested #ifndef KYBER90S" 66 kyber90s_depth = current_depth 67 state = "during-keep" 68 continue 69 if current_depth == kyber90s_depth and tokens: 70 if tokens[0] == "#else": 71 assert state != "before" 72 state = "during-keep" if state == "during-drop" else "during-drop" 73 continue 74 if tokens[0] == "#elif": 75 assert False, "cannot handle #elif branch of #ifdef KYBER90S" 76 if tokens[0] == "#endif": 77 assert state != "before" 78 state = "before" 79 kyber90s_depth = None 80 current_depth -= 1 81 continue 82 if tokens and tokens[0] == "#endif": 83 current_depth -= 1 84 if state == "during-drop": 85 continue 86 out.write(line) 87 out.seek(0) 88 return out 89 90 91 def add_static_to_fns(x: io.StringIO) -> io.StringIO: 92 out = io.StringIO() 93 depth = 0 94 for line in x.readlines(): 95 tokens = line.split() 96 # assumes return type starts on column 0 97 if depth == 0 and any( 98 line.startswith(typ) for typ in ["void", "uint32_t", "int16_t", "int"] 99 ): 100 out.write("static " + line) 101 else: 102 out.write(line) 103 if "{" in line: 104 depth += 1 105 if "}" in line: 106 depth -= 1 107 out.seek(0) 108 return out 109 110 111 def file_block(x: io.StringIO, filename: str) -> io.StringIO: 112 out = io.StringIO() 113 out.write(f"\n/** begin: {filename} **/\n") 114 out.write(x.read().strip()) 115 out.write(f"\n/** end: {filename} **/\n") 116 out.seek(0) 117 return out 118 119 120 def test(): 121 assert 0 == len(remove_includes(io.StringIO("#include <stddef.h>")).read()) 122 assert 0 == len(remove_kyber90s(io.StringIO("#ifdef KYBER_90S\nx\n#endif")).read()) 123 124 test_remove_kyber90s_expect = "#ifdef OTHER\nx\n#else\nx\n#endif" 125 test_remove_ifdef_kyber90s = f""" 126 #ifdef KYBER_90S 127 x 128 {test_remove_kyber90s_expect} 129 x 130 #else 131 {test_remove_kyber90s_expect} 132 #endif 133 """ 134 test_remove_ifdef_kyber90s_actual = ( 135 remove_kyber90s(io.StringIO(test_remove_ifdef_kyber90s)).read().strip() 136 ) 137 assert ( 138 test_remove_kyber90s_expect == test_remove_ifdef_kyber90s_actual 139 ), "remove_kyber90s unit test" 140 141 test_remove_ifndef_kyber90s = f""" 142 #ifndef KYBER_90S 143 {test_remove_kyber90s_expect} 144 #else 145 x 146 {test_remove_kyber90s_expect} 147 x 148 #endif 149 """ 150 test_remove_ifndef_kyber90s_actual = ( 151 remove_kyber90s(io.StringIO(test_remove_ifndef_kyber90s)).read().strip() 152 ) 153 assert ( 154 test_remove_kyber90s_expect == test_remove_ifndef_kyber90s_actual 155 ), "remove_kyber90s unit test" 156 157 test_add_static_to_fns = """\ 158 void fn() { 159 int x[3] = {1,2,3}; 160 }""" 161 assert ( 162 f"static {test_add_static_to_fns}" 163 == add_static_to_fns(io.StringIO(test_add_static_to_fns)).read() 164 ) 165 166 test_remove_include_guard = """\ 167 #ifndef TEST_H 168 #define TEST_H 169 #endif""" 170 171 assert 0 == len( 172 remove_include_guard(io.StringIO(test_remove_include_guard), "TEST_H").read() 173 ) 174 assert ( 175 test_remove_include_guard 176 == remove_include_guard( 177 io.StringIO(test_remove_include_guard), "OTHER_H" 178 ).read() 179 ) 180 181 182 def is_hex(s: str) -> bool: 183 try: 184 int(s, 16) 185 except ValueError: 186 return False 187 return True 188 189 190 if __name__ == "__main__": 191 test() 192 193 repo = f"https://github.com/pq-crystals/kyber" 194 out = "kyber-pqcrystals-ref.c" 195 out_api = "kyber-pqcrystals-ref.h" 196 out_orig = "kyber-pqcrystals-ref.c.orig" 197 198 if len(sys.argv) == 2 and len(sys.argv[1]) >= 6 and is_hex(sys.argv[1]): 199 commit = sys.argv[1] 200 print(f"* using commit id {commit}") 201 else: 202 print( 203 f"""\ 204 Usage: python3 {sys.argv[0]} [commit] 205 where [commit] is an 8+ hex digit commit id from {repo}. 206 """ 207 ) 208 sys.exit(1) 209 210 short_commit = commit[:8] 211 tarball_url = f"{repo}/tarball/{commit}" 212 archive = f"kyber-{short_commit}.tar.gz" 213 214 headers = [ 215 "params.h", 216 "reduce.h", 217 "ntt.h", 218 "poly.h", 219 "cbd.h", 220 "polyvec.h", 221 "indcpa.h", 222 "fips202.h", 223 "symmetric.h", 224 "kem.h", 225 ] 226 227 sources = [ 228 "reduce.c", 229 "cbd.c", 230 "ntt.c", 231 "poly.c", 232 "polyvec.c", 233 "indcpa.c", 234 "fips202.c", 235 "symmetric-shake.c", 236 "kem.c", 237 ] 238 239 if not os.path.isfile(archive): 240 print(f"* fetching {tarball_url}") 241 req = requests.request(method="GET", url=tarball_url) 242 if not req.ok: 243 print(f"* failed: {req.reason}") 244 sys.exit(1) 245 with open(archive, "wb") as f: 246 f.write(req.content) 247 248 print(f"* extracting files from {archive}") 249 with open(archive, "rb") as f: 250 tarball = tarfile.open(mode="r:gz", fileobj=f) 251 252 topdir = tarball.members[0].path 253 assert ( 254 topdir == f"pq-crystals-kyber-{commit[:7]}" 255 ), "tarball directory structure changed" 256 257 # Write a single-file copy without modifications for easy diffing 258 print(f"* writing unmodified files to {out_orig}") 259 with open(out_orig, "w") as f: 260 for filename in headers: 261 x = tarball.extractfile(f"{topdir}/ref/{filename}") 262 x = io.StringIO(x.read().decode("utf-8")) 263 x = file_block(x, "ref/" + filename) 264 f.write(x.read()) 265 266 for filename in sources: 267 x = tarball.extractfile(f"{topdir}/ref/{filename}") 268 x = io.StringIO(x.read().decode("utf-8")) 269 x = file_block(x, "ref/" + filename) 270 f.write(x.read()) 271 272 comment = io.StringIO() 273 comment.write( 274 f"""/* 275 * SPDX-License-Identifier: Apache-2.0 276 * 277 * This file was generated from 278 * https://github.com/pq-crystals/kyber/commit/{short_commit} 279 * 280 * Files from that repository are listed here surrounded by 281 * "* begin: [file] *" and "* end: [file] *" comments. 282 * 283 * The following changes have been made: 284 * - include guards have been removed, 285 * - include directives have been removed, 286 * - "#ifdef KYBER90S" blocks have been evaluated with "KYBER90S" undefined, 287 * - functions outside of kem.c have been made static. 288 */ 289 """ 290 ) 291 for filename in ["LICENSE", "AUTHORS"]: 292 comment.write(f"""\n/** begin: ref/{filename} **\n""") 293 x = tarball.extractfile(f"{topdir}/{filename}") 294 x = io.StringIO(x.read().decode("utf-8")) 295 for line in x.readlines(): 296 comment.write(line) 297 comment.write(f"""** end: ref/{filename} **/\n""") 298 comment.seek(0) 299 300 print(f"* writing modified files to {out}") 301 with open(out, "w") as f: 302 f.write(comment.read()) 303 f.write( 304 """ 305 #include <assert.h> 306 #include <stddef.h> 307 #include <stdint.h> 308 #include <string.h> 309 310 #ifdef FREEBL_NO_DEPEND 311 #include "stubs.h" 312 #endif 313 314 #include "secport.h" 315 316 // We need to provide an implementation of randombytes to avoid an unused 317 // function warning. We don't use the randomized API in freebl, so we'll make 318 // calling randombytes an error. 319 static void randombytes(uint8_t *out, size_t outlen) { 320 // this memset is to avoid "maybe-uninitialized" warnings that gcc-11 issues 321 // for the (unused) crypto_kem_keypair and crypto_kem_enc functions. 322 memset(out, 0, outlen); 323 assert(0); 324 } 325 326 /************************************************* 327 * Name: verify 328 * 329 * Description: Compare two arrays for equality in constant time. 330 * 331 * Arguments: const uint8_t *a: pointer to first byte array 332 * const uint8_t *b: pointer to second byte array 333 * size_t len: length of the byte arrays 334 * 335 * Returns 0 if the byte arrays are equal, 1 otherwise 336 **************************************************/ 337 static int verify(const uint8_t *a, const uint8_t *b, size_t len) { 338 return NSS_SecureMemcmp(a, b, len); 339 } 340 341 /************************************************* 342 * Name: cmov 343 * 344 * Description: Copy len bytes from x to r if b is 1; 345 * don't modify x if b is 0. Requires b to be in {0,1}; 346 * assumes two's complement representation of negative integers. 347 * Runs in constant time. 348 * 349 * Arguments: uint8_t *r: pointer to output byte array 350 * const uint8_t *x: pointer to input byte array 351 * size_t len: Amount of bytes to be copied 352 * uint8_t b: Condition bit; has to be in {0,1} 353 **************************************************/ 354 static void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) 355 { 356 NSS_SecureSelect(r, r, x, len, b); 357 } 358 """ 359 ) 360 for filename in headers: 361 x = tarball.extractfile(f"{topdir}/ref/{filename}") 362 x = io.StringIO(x.read().decode("utf-8")) 363 x = remove_include_guard(x, filename.upper().replace(".", "_")) 364 x = remove_includes(x) 365 x = remove_kyber90s(x) 366 if filename not in ["kem.h", "fips202.h"]: 367 x = add_static_to_fns(x) 368 x = file_block(x, "ref/" + filename) 369 f.write(x.read()) 370 371 for filename in sources: 372 x = tarball.extractfile(f"{topdir}/ref/{filename}") 373 x = io.StringIO(x.read().decode("utf-8")) 374 x = remove_includes(x) 375 x = remove_kyber90s(x) 376 if filename not in ["kem.c", "fips202.c"]: 377 x = add_static_to_fns(x) 378 x = file_block(x, "ref/" + filename) 379 f.write(x.read()) 380 381 print(f"* writing private header to {out_api}") 382 with open(out_api, "w") as f: 383 filename = "api.h" 384 comment.seek(0) 385 f.write(comment.read()) 386 f.write( 387 """ 388 #ifndef KYBER_PQCRYSTALS_REF_H 389 #define KYBER_PQCRYSTALS_REF_H 390 """ 391 ) 392 x = tarball.extractfile(f"{topdir}/ref/{filename}") 393 x = io.StringIO(x.read().decode("utf-8")) 394 x = remove_include_guard(x, filename.upper().replace(".", "_")) 395 x = file_block(x, "ref/" + filename) 396 f.write(x.read()) 397 f.write( 398 f""" 399 #endif // KYBER_PQCRYSTALS_REF_H 400 """ 401 ) 402 print( 403 f"""* done! 404 405 You should now: 406 1) Check the output by running `diff {out_orig} {out}` 407 2) Move {out} to lib/freebl/{out} 408 3) Move {out_api} to lib/freebl/{out_api} 409 4) Delete {out_orig} and {archive}. 410 """ 411 )