tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

vendor.py (12833B)


      1 #!/usr/bin/python3
      2 # This Source Code Form is subject to the terms of the Mozilla Public
      3 # License, v. 2.0. If a copy of the MPL was not distributed with this
      4 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
      5 
      6 import io
      7 import os
      8 import requests
      9 import sys
     10 import tarfile
     11 
     12 
     13 def remove_include_guard(x: io.StringIO, guard: str) -> io.StringIO:
     14    out = io.StringIO()
     15    depth = 0
     16    inside_guard = False
     17    for line in x.readlines():
     18        tokens = line.split()
     19        if tokens and tokens[0] in ["#if", "#ifdef", "#ifndef"]:
     20            depth += 1
     21        if len(tokens) > 1 and tokens[0] == "#ifndef" and tokens[1] == guard:
     22            assert depth == 1, "error: nested include guard"
     23            inside_guard = True
     24            continue
     25        if len(tokens) > 1 and tokens[0] == "#define" and tokens[1] == guard:
     26            continue
     27        if tokens and tokens[0] == "#endif":
     28            depth -= 1
     29            if depth == 0 and inside_guard:
     30                inside_guard = False
     31                continue
     32        out.write(line)
     33    out.seek(0)
     34    return out
     35 
     36 
     37 def remove_includes(x: io.StringIO) -> io.StringIO:
     38    out = io.StringIO()
     39    for line in x.readlines():
     40        tokens = line.split()
     41        if tokens and tokens[0] == "#include":
     42            continue
     43        out.write(line)
     44    out.seek(0)
     45    return out
     46 
     47 
     48 # Take the else branch of any #ifdef KYBER90s ... #else ... #endif
     49 def remove_kyber90s(x: io.StringIO) -> io.StringIO:
     50    out = io.StringIO()
     51    states = ["before", "during-drop", "during-keep"]
     52    state = "before"
     53    current_depth = 0
     54    kyber90s_depth = None
     55    for line in x.readlines():
     56        tokens = line.split()
     57        if tokens and tokens[0] in ["#if", "#ifdef", "#ifndef"]:
     58            current_depth += 1
     59        if len(tokens) > 1 and tokens[0] == "#ifdef" and tokens[1] == "KYBER_90S":
     60            assert kyber90s_depth == None, "cannot handle nested #ifdef KYBER90S"
     61            kyber90s_depth = current_depth
     62            state = "during-drop"
     63            continue
     64        if len(tokens) > 1 and tokens[0] == "#ifndef" and tokens[1] == "KYBER_90S":
     65            assert kyber90s_depth == None, "cannot handle nested #ifndef KYBER90S"
     66            kyber90s_depth = current_depth
     67            state = "during-keep"
     68            continue
     69        if current_depth == kyber90s_depth and tokens:
     70            if tokens[0] == "#else":
     71                assert state != "before"
     72                state = "during-keep" if state == "during-drop" else "during-drop"
     73                continue
     74            if tokens[0] == "#elif":
     75                assert False, "cannot handle #elif branch of #ifdef KYBER90S"
     76            if tokens[0] == "#endif":
     77                assert state != "before"
     78                state = "before"
     79                kyber90s_depth = None
     80                current_depth -= 1
     81                continue
     82        if tokens and tokens[0] == "#endif":
     83            current_depth -= 1
     84        if state == "during-drop":
     85            continue
     86        out.write(line)
     87    out.seek(0)
     88    return out
     89 
     90 
     91 def add_static_to_fns(x: io.StringIO) -> io.StringIO:
     92    out = io.StringIO()
     93    depth = 0
     94    for line in x.readlines():
     95        tokens = line.split()
     96        # assumes return type starts on column 0
     97        if depth == 0 and any(
     98            line.startswith(typ) for typ in ["void", "uint32_t", "int16_t", "int"]
     99        ):
    100            out.write("static " + line)
    101        else:
    102            out.write(line)
    103        if "{" in line:
    104            depth += 1
    105        if "}" in line:
    106            depth -= 1
    107    out.seek(0)
    108    return out
    109 
    110 
    111 def file_block(x: io.StringIO, filename: str) -> io.StringIO:
    112    out = io.StringIO()
    113    out.write(f"\n/** begin: {filename} **/\n")
    114    out.write(x.read().strip())
    115    out.write(f"\n/** end: {filename} **/\n")
    116    out.seek(0)
    117    return out
    118 
    119 
    120 def test():
    121    assert 0 == len(remove_includes(io.StringIO("#include <stddef.h>")).read())
    122    assert 0 == len(remove_kyber90s(io.StringIO("#ifdef KYBER_90S\nx\n#endif")).read())
    123 
    124    test_remove_kyber90s_expect = "#ifdef OTHER\nx\n#else\nx\n#endif"
    125    test_remove_ifdef_kyber90s = f"""
    126 #ifdef KYBER_90S
    127 x
    128 {test_remove_kyber90s_expect}
    129 x
    130 #else
    131 {test_remove_kyber90s_expect}
    132 #endif
    133 """
    134    test_remove_ifdef_kyber90s_actual = (
    135        remove_kyber90s(io.StringIO(test_remove_ifdef_kyber90s)).read().strip()
    136    )
    137    assert (
    138        test_remove_kyber90s_expect == test_remove_ifdef_kyber90s_actual
    139    ), "remove_kyber90s unit test"
    140 
    141    test_remove_ifndef_kyber90s = f"""
    142 #ifndef KYBER_90S
    143 {test_remove_kyber90s_expect}
    144 #else
    145 x
    146 {test_remove_kyber90s_expect}
    147 x
    148 #endif
    149 """
    150    test_remove_ifndef_kyber90s_actual = (
    151        remove_kyber90s(io.StringIO(test_remove_ifndef_kyber90s)).read().strip()
    152    )
    153    assert (
    154        test_remove_kyber90s_expect == test_remove_ifndef_kyber90s_actual
    155    ), "remove_kyber90s unit test"
    156 
    157    test_add_static_to_fns = """\
    158 void fn() {
    159 int x[3] = {1,2,3};
    160 }"""
    161    assert (
    162        f"static {test_add_static_to_fns}"
    163        == add_static_to_fns(io.StringIO(test_add_static_to_fns)).read()
    164    )
    165 
    166    test_remove_include_guard = """\
    167 #ifndef TEST_H
    168 #define TEST_H
    169 #endif"""
    170 
    171    assert 0 == len(
    172        remove_include_guard(io.StringIO(test_remove_include_guard), "TEST_H").read()
    173    )
    174    assert (
    175        test_remove_include_guard
    176        == remove_include_guard(
    177            io.StringIO(test_remove_include_guard), "OTHER_H"
    178        ).read()
    179    )
    180 
    181 
    182 def is_hex(s: str) -> bool:
    183    try:
    184        int(s, 16)
    185    except ValueError:
    186        return False
    187    return True
    188 
    189 
    190 if __name__ == "__main__":
    191    test()
    192 
    193    repo = f"https://github.com/pq-crystals/kyber"
    194    out = "kyber-pqcrystals-ref.c"
    195    out_api = "kyber-pqcrystals-ref.h"
    196    out_orig = "kyber-pqcrystals-ref.c.orig"
    197 
    198    if len(sys.argv) == 2 and len(sys.argv[1]) >= 6 and is_hex(sys.argv[1]):
    199        commit = sys.argv[1]
    200        print(f"* using commit id {commit}")
    201    else:
    202        print(
    203            f"""\
    204 Usage: python3 {sys.argv[0]} [commit]
    205       where [commit] is an 8+ hex digit commit id from {repo}.
    206 """
    207        )
    208        sys.exit(1)
    209 
    210    short_commit = commit[:8]
    211    tarball_url = f"{repo}/tarball/{commit}"
    212    archive = f"kyber-{short_commit}.tar.gz"
    213 
    214    headers = [
    215        "params.h",
    216        "reduce.h",
    217        "ntt.h",
    218        "poly.h",
    219        "cbd.h",
    220        "polyvec.h",
    221        "indcpa.h",
    222        "fips202.h",
    223        "symmetric.h",
    224        "kem.h",
    225    ]
    226 
    227    sources = [
    228        "reduce.c",
    229        "cbd.c",
    230        "ntt.c",
    231        "poly.c",
    232        "polyvec.c",
    233        "indcpa.c",
    234        "fips202.c",
    235        "symmetric-shake.c",
    236        "kem.c",
    237    ]
    238 
    239    if not os.path.isfile(archive):
    240        print(f"* fetching {tarball_url}")
    241        req = requests.request(method="GET", url=tarball_url)
    242        if not req.ok:
    243            print(f"* failed: {req.reason}")
    244            sys.exit(1)
    245        with open(archive, "wb") as f:
    246            f.write(req.content)
    247 
    248    print(f"* extracting files from {archive}")
    249    with open(archive, "rb") as f:
    250        tarball = tarfile.open(mode="r:gz", fileobj=f)
    251 
    252        topdir = tarball.members[0].path
    253        assert (
    254            topdir == f"pq-crystals-kyber-{commit[:7]}"
    255        ), "tarball directory structure changed"
    256 
    257        # Write a single-file copy without modifications for easy diffing
    258        print(f"* writing unmodified files to {out_orig}")
    259        with open(out_orig, "w") as f:
    260            for filename in headers:
    261                x = tarball.extractfile(f"{topdir}/ref/{filename}")
    262                x = io.StringIO(x.read().decode("utf-8"))
    263                x = file_block(x, "ref/" + filename)
    264                f.write(x.read())
    265 
    266            for filename in sources:
    267                x = tarball.extractfile(f"{topdir}/ref/{filename}")
    268                x = io.StringIO(x.read().decode("utf-8"))
    269                x = file_block(x, "ref/" + filename)
    270                f.write(x.read())
    271 
    272        comment = io.StringIO()
    273        comment.write(
    274            f"""/*
    275 * SPDX-License-Identifier: Apache-2.0
    276 *
    277 * This file was generated from
    278 *   https://github.com/pq-crystals/kyber/commit/{short_commit}
    279 *
    280 * Files from that repository are listed here surrounded by
    281 * "* begin: [file] *" and "* end: [file] *" comments.
    282 *
    283 * The following changes have been made:
    284 *  - include guards have been removed,
    285 *  - include directives have been removed,
    286 *  - "#ifdef KYBER90S" blocks have been evaluated with "KYBER90S" undefined,
    287 *  - functions outside of kem.c have been made static.
    288 */
    289 """
    290        )
    291        for filename in ["LICENSE", "AUTHORS"]:
    292            comment.write(f"""\n/** begin: ref/{filename} **\n""")
    293            x = tarball.extractfile(f"{topdir}/{filename}")
    294            x = io.StringIO(x.read().decode("utf-8"))
    295            for line in x.readlines():
    296                comment.write(line)
    297            comment.write(f"""** end: ref/{filename} **/\n""")
    298        comment.seek(0)
    299 
    300        print(f"* writing modified files to {out}")
    301        with open(out, "w") as f:
    302            f.write(comment.read())
    303            f.write(
    304                """
    305 #include <assert.h>
    306 #include <stddef.h>
    307 #include <stdint.h>
    308 #include <string.h>
    309 
    310 #ifdef FREEBL_NO_DEPEND
    311 #include "stubs.h"
    312 #endif
    313 
    314 #include "secport.h"
    315 
    316 // We need to provide an implementation of randombytes to avoid an unused
    317 // function warning. We don't use the randomized API in freebl, so we'll make
    318 // calling randombytes an error.
    319 static void randombytes(uint8_t *out, size_t outlen) {
    320    // this memset is to avoid "maybe-uninitialized" warnings that gcc-11 issues
    321    // for the (unused) crypto_kem_keypair and crypto_kem_enc functions.
    322    memset(out, 0, outlen);
    323    assert(0);
    324 }
    325 
    326 /*************************************************
    327 * Name:        verify
    328 *
    329 * Description: Compare two arrays for equality in constant time.
    330 *
    331 * Arguments:   const uint8_t *a: pointer to first byte array
    332 *              const uint8_t *b: pointer to second byte array
    333 *              size_t len:       length of the byte arrays
    334 *
    335 * Returns 0 if the byte arrays are equal, 1 otherwise
    336 **************************************************/
    337 static int verify(const uint8_t *a, const uint8_t *b, size_t len) {
    338    return NSS_SecureMemcmp(a, b, len);
    339 }
    340 
    341 /*************************************************
    342 * Name:        cmov
    343 *
    344 * Description: Copy len bytes from x to r if b is 1;
    345 *              don't modify x if b is 0. Requires b to be in {0,1};
    346 *              assumes two's complement representation of negative integers.
    347 *              Runs in constant time.
    348 *
    349 * Arguments:   uint8_t *r:       pointer to output byte array
    350 *              const uint8_t *x: pointer to input byte array
    351 *              size_t len:       Amount of bytes to be copied
    352 *              uint8_t b:        Condition bit; has to be in {0,1}
    353 **************************************************/
    354 static void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b)
    355 {
    356    NSS_SecureSelect(r, r, x, len, b);
    357 }
    358 """
    359            )
    360            for filename in headers:
    361                x = tarball.extractfile(f"{topdir}/ref/{filename}")
    362                x = io.StringIO(x.read().decode("utf-8"))
    363                x = remove_include_guard(x, filename.upper().replace(".", "_"))
    364                x = remove_includes(x)
    365                x = remove_kyber90s(x)
    366                if filename not in ["kem.h", "fips202.h"]:
    367                    x = add_static_to_fns(x)
    368                x = file_block(x, "ref/" + filename)
    369                f.write(x.read())
    370 
    371            for filename in sources:
    372                x = tarball.extractfile(f"{topdir}/ref/{filename}")
    373                x = io.StringIO(x.read().decode("utf-8"))
    374                x = remove_includes(x)
    375                x = remove_kyber90s(x)
    376                if filename not in ["kem.c", "fips202.c"]:
    377                    x = add_static_to_fns(x)
    378                x = file_block(x, "ref/" + filename)
    379                f.write(x.read())
    380 
    381        print(f"* writing private header to {out_api}")
    382        with open(out_api, "w") as f:
    383            filename = "api.h"
    384            comment.seek(0)
    385            f.write(comment.read())
    386            f.write(
    387                """
    388 #ifndef KYBER_PQCRYSTALS_REF_H
    389 #define KYBER_PQCRYSTALS_REF_H
    390 """
    391            )
    392            x = tarball.extractfile(f"{topdir}/ref/{filename}")
    393            x = io.StringIO(x.read().decode("utf-8"))
    394            x = remove_include_guard(x, filename.upper().replace(".", "_"))
    395            x = file_block(x, "ref/" + filename)
    396            f.write(x.read())
    397            f.write(
    398                f"""
    399 #endif // KYBER_PQCRYSTALS_REF_H
    400 """
    401            )
    402        print(
    403            f"""* done!
    404 
    405 You should now:
    406    1) Check the output by running `diff {out_orig} {out}`
    407    2) Move {out} to lib/freebl/{out}
    408    3) Move {out_api} to lib/freebl/{out_api}
    409    4) Delete {out_orig} and {archive}.
    410 """
    411        )