tor

The Tor anonymity network
git clone https://git.dasho.dev/tor.git
Log | Files | Refs | README | LICENSE

ntor_v3_ref.py (7597B)


      1 #!/usr/bin/python
      2 
      3 import binascii
      4 import hashlib
      5 import os
      6 import struct
      7 
      8 import donna25519
      9 from Crypto.Cipher import AES
     10 from Crypto.Util import Counter
     11 
     12 # Define basic wrappers.
     13 
     14 DIGEST_LEN = 32
     15 ENC_KEY_LEN = 32
     16 PUB_KEY_LEN = 32
     17 SEC_KEY_LEN = 32
     18 IDENTITY_LEN = 32
     19 
     20 def sha3_256(s):
     21    d = hashlib.sha3_256(s).digest()
     22    assert len(d) == DIGEST_LEN
     23    return d
     24 
     25 def shake_256(s):
     26    # Note: In reality, you wouldn't want to generate more bytes than needed.
     27    MAX_KEY_BYTES = 1024
     28    return hashlib.shake_256(s).digest(MAX_KEY_BYTES)
     29 
     30 def curve25519(pk, sk):
     31    assert len(pk) == PUB_KEY_LEN
     32    assert len(sk) == SEC_KEY_LEN
     33    private = donna25519.PrivateKey.load(sk)
     34    public = donna25519.PublicKey(pk)
     35    return private.do_exchange(public)
     36 
     37 def keygen():
     38    private = donna25519.PrivateKey()
     39    public = private.get_public()
     40    return (private.private, public.public)
     41 
     42 def aes256_ctr(k, s):
     43    assert len(k) == ENC_KEY_LEN
     44    cipher = AES.new(k, AES.MODE_CTR, counter=Counter.new(128, initial_value=0))
     45    return cipher.encrypt(s)
     46 
     47 # Byte-oriented helper.  We use this for decoding keystreams and messages.
     48 
     49 class ByteSeq:
     50    def __init__(self, data):
     51        self.data = data
     52 
     53    def take(self, n):
     54        assert n <= len(self.data)
     55        result = self.data[:n]
     56        self.data = self.data[n:]
     57        return result
     58 
     59    def exhausted(self):
     60        return len(self.data) == 0
     61 
     62    def remaining(self):
     63        return len(self.data)
     64 
     65 # Low-level functions
     66 
     67 MAC_KEY_LEN = 32
     68 MAC_LEN = DIGEST_LEN
     69 
     70 hash_func = sha3_256
     71 
     72 def encapsulate(s):
     73    """encapsulate `s` with a length prefix.
     74 
     75       We use this whenever we need to avoid message ambiguities in
     76       cryptographic inputs.
     77    """
     78    assert len(s) <= 0xffffffff
     79    header = b"\0\0\0\0" + struct.pack("!L", len(s))
     80    assert len(header) == 8
     81    return header + s
     82 
     83 def h(s, tweak):
     84    return hash_func(encapsulate(tweak) + s)
     85 
     86 def mac(s, key, tweak):
     87    return hash_func(encapsulate(tweak) + encapsulate(key) + s)
     88 
     89 def kdf(s, tweak):
     90    data = shake_256(encapsulate(tweak) + s)
     91    return ByteSeq(data)
     92 
     93 def enc(s, k):
     94    return aes256_ctr(k, s)
     95 
     96 # Tweaked wrappers
     97 
     98 PROTOID = b"ntor3-curve25519-sha3_256-1"
     99 T_KDF_PHASE1 = PROTOID + b":kdf_phase1"
    100 T_MAC_PHASE1 = PROTOID + b":msg_mac"
    101 T_KDF_FINAL = PROTOID + b":kdf_final"
    102 T_KEY_SEED = PROTOID + b":key_seed"
    103 T_VERIFY = PROTOID + b":verify"
    104 T_AUTH = PROTOID + b":auth_final"
    105 
    106 def kdf_phase1(s):
    107    return kdf(s, T_KDF_PHASE1)
    108 
    109 def kdf_final(s):
    110    return kdf(s, T_KDF_FINAL)
    111 
    112 def mac_phase1(s, key):
    113    return mac(s, key, T_MAC_PHASE1)
    114 
    115 def h_key_seed(s):
    116    return h(s, T_KEY_SEED)
    117 
    118 def h_verify(s):
    119    return h(s, T_VERIFY)
    120 
    121 def h_auth(s):
    122    return h(s, T_AUTH)
    123 
    124 # Handshake.
    125 
    126 def client_phase1(msg, verification, B, ID):
    127    assert len(B) == PUB_KEY_LEN
    128    assert len(ID) == IDENTITY_LEN
    129 
    130    (x,X) = keygen()
    131    p(["x", "X"], locals())
    132    p(["msg", "verification"], locals())
    133    Bx = curve25519(B, x)
    134    secret_input_phase1 = Bx + ID + X + B + PROTOID + encapsulate(verification)
    135 
    136    phase1_keys = kdf_phase1(secret_input_phase1)
    137    enc_key = phase1_keys.take(ENC_KEY_LEN)
    138    mac_key = phase1_keys.take(MAC_KEY_LEN)
    139    p(["enc_key", "mac_key"], locals())
    140 
    141    msg_0 = ID + B + X + enc(msg, enc_key)
    142    mac = mac_phase1(msg_0, mac_key)
    143    p(["mac"], locals())
    144 
    145    client_handshake = msg_0 + mac
    146    state = dict(x=x, X=X, B=B, ID=ID, Bx=Bx, mac=mac, verification=verification)
    147 
    148    p(["client_handshake"], locals())
    149 
    150    return (client_handshake, state)
    151 
    152 # server.
    153 
    154 class Reject(Exception):
    155    pass
    156 
    157 def server_part1(cmsg, verification, b, B, ID):
    158    assert len(B) == PUB_KEY_LEN
    159    assert len(ID) == IDENTITY_LEN
    160    assert len(b) == SEC_KEY_LEN
    161 
    162    if len(cmsg) < (IDENTITY_LEN + PUB_KEY_LEN * 2 + MAC_LEN):
    163        raise Reject()
    164 
    165    mac_covered_portion = cmsg[0:-MAC_LEN]
    166    cmsg = ByteSeq(cmsg)
    167    cmsg_id = cmsg.take(IDENTITY_LEN)
    168    cmsg_B = cmsg.take(PUB_KEY_LEN)
    169    cmsg_X = cmsg.take(PUB_KEY_LEN)
    170    cmsg_msg = cmsg.take(cmsg.remaining() - MAC_LEN)
    171    cmsg_mac = cmsg.take(MAC_LEN)
    172 
    173    assert cmsg.exhausted()
    174 
    175    # XXXX for real purposes, you would use constant-time checks here
    176    if cmsg_id != ID or cmsg_B != B:
    177        raise Reject()
    178 
    179    Xb = curve25519(cmsg_X, b)
    180    secret_input_phase1 = Xb + ID + cmsg_X + B + PROTOID + encapsulate(verification)
    181 
    182    phase1_keys = kdf_phase1(secret_input_phase1)
    183    enc_key = phase1_keys.take(ENC_KEY_LEN)
    184    mac_key = phase1_keys.take(MAC_KEY_LEN)
    185 
    186    mac_received = mac_phase1(mac_covered_portion, mac_key)
    187    if mac_received != cmsg_mac:
    188        raise Reject()
    189 
    190    client_msg = enc(cmsg_msg, enc_key)
    191    state = dict(
    192        b=b,
    193        B=B,
    194        X=cmsg_X,
    195        mac_received=mac_received,
    196        Xb=Xb,
    197        ID=ID,
    198        verification=verification)
    199 
    200    return (client_msg, state)
    201 
    202 def server_part2(state, server_msg):
    203    X = state['X']
    204    Xb = state['Xb']
    205    B = state['B']
    206    b = state['b']
    207    ID = state['ID']
    208    mac_received = state['mac_received']
    209    verification = state['verification']
    210 
    211    p(["server_msg"], locals())
    212    
    213    (y,Y) = keygen()
    214    p(["y", "Y"], locals())
    215    Xy = curve25519(X, y)
    216 
    217    secret_input = Xy + Xb + ID + B + X + Y + PROTOID + encapsulate(verification)
    218    key_seed = h_key_seed(secret_input)
    219    verify = h_verify(secret_input)
    220    p(["key_seed", "verify"], locals())
    221 
    222    keys = kdf_final(key_seed)
    223    server_enc_key = keys.take(ENC_KEY_LEN)
    224    p(["server_enc_key"], locals())
    225 
    226    smsg_msg = enc(server_msg, server_enc_key)
    227 
    228    auth_input = verify + ID + B + Y + X + mac_received + encapsulate(smsg_msg) + PROTOID + b"Server"
    229 
    230    auth = h_auth(auth_input)
    231    server_handshake = Y + auth + smsg_msg
    232    p(["auth", "server_handshake"], locals())
    233 
    234    return (server_handshake, keys)
    235 
    236 def client_phase2(state, smsg):
    237    x = state['x']
    238    X = state['X']
    239    B = state['B']
    240    ID = state['ID']
    241    Bx = state['Bx']
    242    mac_sent = state['mac']
    243    verification = state['verification']
    244 
    245    if len(smsg) < PUB_KEY_LEN + DIGEST_LEN:
    246        raise Reject()
    247 
    248    smsg = ByteSeq(smsg)
    249    Y = smsg.take(PUB_KEY_LEN)
    250    auth_received = smsg.take(DIGEST_LEN)
    251    server_msg = smsg.take(smsg.remaining())
    252 
    253    Yx = curve25519(Y,x)
    254 
    255    secret_input = Yx + Bx + ID + B + X + Y + PROTOID + encapsulate(verification)
    256    key_seed = h_key_seed(secret_input)
    257    verify = h_verify(secret_input)
    258 
    259    auth_input = verify + ID + B + Y + X + mac_sent + encapsulate(server_msg) + PROTOID + b"Server"
    260 
    261    auth = h_auth(auth_input)
    262    if auth != auth_received:
    263        raise Reject()
    264 
    265    keys = kdf_final(key_seed)
    266    enc_key = keys.take(ENC_KEY_LEN)
    267 
    268    server_msg_decrypted = enc(server_msg, enc_key)
    269 
    270    return (keys, server_msg_decrypted)
    271 
    272 def p(varnames, localvars):
    273    for v in varnames:
    274        label = v
    275        val = localvars[label]
    276        print('{} = "{}"'.format(label, binascii.b2a_hex(val).decode("ascii")))
    277 
    278 def test():
    279    (b,B) = keygen()
    280    ID = os.urandom(IDENTITY_LEN)
    281 
    282    p(["b", "B", "ID"], locals())
    283 
    284    print("# ============")
    285    (c_handshake, c_state) = client_phase1(b"hello world", b"xyzzy", B, ID)
    286 
    287    print("# ============")
    288 
    289    (c_msg_got, s_state) = server_part1(c_handshake, b"xyzzy", b, B, ID)
    290 
    291    #print(repr(c_msg_got))
    292 
    293    (s_handshake, s_keys) = server_part2(s_state, b"Hola Mundo")
    294 
    295    print("# ============")
    296 
    297    (c_keys, s_msg_got) = client_phase2(c_state, s_handshake)
    298 
    299    #print(repr(s_msg_got))
    300 
    301    c_keys_256 = c_keys.take(256)
    302    p(["c_keys_256"], locals())
    303 
    304    assert (c_keys_256 == s_keys.take(256))
    305 
    306 
    307 if __name__ == '__main__':
    308    test()