ntor_v3_ref.py (7597B)
1 #!/usr/bin/python 2 3 import binascii 4 import hashlib 5 import os 6 import struct 7 8 import donna25519 9 from Crypto.Cipher import AES 10 from Crypto.Util import Counter 11 12 # Define basic wrappers. 13 14 DIGEST_LEN = 32 15 ENC_KEY_LEN = 32 16 PUB_KEY_LEN = 32 17 SEC_KEY_LEN = 32 18 IDENTITY_LEN = 32 19 20 def sha3_256(s): 21 d = hashlib.sha3_256(s).digest() 22 assert len(d) == DIGEST_LEN 23 return d 24 25 def shake_256(s): 26 # Note: In reality, you wouldn't want to generate more bytes than needed. 27 MAX_KEY_BYTES = 1024 28 return hashlib.shake_256(s).digest(MAX_KEY_BYTES) 29 30 def curve25519(pk, sk): 31 assert len(pk) == PUB_KEY_LEN 32 assert len(sk) == SEC_KEY_LEN 33 private = donna25519.PrivateKey.load(sk) 34 public = donna25519.PublicKey(pk) 35 return private.do_exchange(public) 36 37 def keygen(): 38 private = donna25519.PrivateKey() 39 public = private.get_public() 40 return (private.private, public.public) 41 42 def aes256_ctr(k, s): 43 assert len(k) == ENC_KEY_LEN 44 cipher = AES.new(k, AES.MODE_CTR, counter=Counter.new(128, initial_value=0)) 45 return cipher.encrypt(s) 46 47 # Byte-oriented helper. We use this for decoding keystreams and messages. 48 49 class ByteSeq: 50 def __init__(self, data): 51 self.data = data 52 53 def take(self, n): 54 assert n <= len(self.data) 55 result = self.data[:n] 56 self.data = self.data[n:] 57 return result 58 59 def exhausted(self): 60 return len(self.data) == 0 61 62 def remaining(self): 63 return len(self.data) 64 65 # Low-level functions 66 67 MAC_KEY_LEN = 32 68 MAC_LEN = DIGEST_LEN 69 70 hash_func = sha3_256 71 72 def encapsulate(s): 73 """encapsulate `s` with a length prefix. 74 75 We use this whenever we need to avoid message ambiguities in 76 cryptographic inputs. 77 """ 78 assert len(s) <= 0xffffffff 79 header = b"\0\0\0\0" + struct.pack("!L", len(s)) 80 assert len(header) == 8 81 return header + s 82 83 def h(s, tweak): 84 return hash_func(encapsulate(tweak) + s) 85 86 def mac(s, key, tweak): 87 return hash_func(encapsulate(tweak) + encapsulate(key) + s) 88 89 def kdf(s, tweak): 90 data = shake_256(encapsulate(tweak) + s) 91 return ByteSeq(data) 92 93 def enc(s, k): 94 return aes256_ctr(k, s) 95 96 # Tweaked wrappers 97 98 PROTOID = b"ntor3-curve25519-sha3_256-1" 99 T_KDF_PHASE1 = PROTOID + b":kdf_phase1" 100 T_MAC_PHASE1 = PROTOID + b":msg_mac" 101 T_KDF_FINAL = PROTOID + b":kdf_final" 102 T_KEY_SEED = PROTOID + b":key_seed" 103 T_VERIFY = PROTOID + b":verify" 104 T_AUTH = PROTOID + b":auth_final" 105 106 def kdf_phase1(s): 107 return kdf(s, T_KDF_PHASE1) 108 109 def kdf_final(s): 110 return kdf(s, T_KDF_FINAL) 111 112 def mac_phase1(s, key): 113 return mac(s, key, T_MAC_PHASE1) 114 115 def h_key_seed(s): 116 return h(s, T_KEY_SEED) 117 118 def h_verify(s): 119 return h(s, T_VERIFY) 120 121 def h_auth(s): 122 return h(s, T_AUTH) 123 124 # Handshake. 125 126 def client_phase1(msg, verification, B, ID): 127 assert len(B) == PUB_KEY_LEN 128 assert len(ID) == IDENTITY_LEN 129 130 (x,X) = keygen() 131 p(["x", "X"], locals()) 132 p(["msg", "verification"], locals()) 133 Bx = curve25519(B, x) 134 secret_input_phase1 = Bx + ID + X + B + PROTOID + encapsulate(verification) 135 136 phase1_keys = kdf_phase1(secret_input_phase1) 137 enc_key = phase1_keys.take(ENC_KEY_LEN) 138 mac_key = phase1_keys.take(MAC_KEY_LEN) 139 p(["enc_key", "mac_key"], locals()) 140 141 msg_0 = ID + B + X + enc(msg, enc_key) 142 mac = mac_phase1(msg_0, mac_key) 143 p(["mac"], locals()) 144 145 client_handshake = msg_0 + mac 146 state = dict(x=x, X=X, B=B, ID=ID, Bx=Bx, mac=mac, verification=verification) 147 148 p(["client_handshake"], locals()) 149 150 return (client_handshake, state) 151 152 # server. 153 154 class Reject(Exception): 155 pass 156 157 def server_part1(cmsg, verification, b, B, ID): 158 assert len(B) == PUB_KEY_LEN 159 assert len(ID) == IDENTITY_LEN 160 assert len(b) == SEC_KEY_LEN 161 162 if len(cmsg) < (IDENTITY_LEN + PUB_KEY_LEN * 2 + MAC_LEN): 163 raise Reject() 164 165 mac_covered_portion = cmsg[0:-MAC_LEN] 166 cmsg = ByteSeq(cmsg) 167 cmsg_id = cmsg.take(IDENTITY_LEN) 168 cmsg_B = cmsg.take(PUB_KEY_LEN) 169 cmsg_X = cmsg.take(PUB_KEY_LEN) 170 cmsg_msg = cmsg.take(cmsg.remaining() - MAC_LEN) 171 cmsg_mac = cmsg.take(MAC_LEN) 172 173 assert cmsg.exhausted() 174 175 # XXXX for real purposes, you would use constant-time checks here 176 if cmsg_id != ID or cmsg_B != B: 177 raise Reject() 178 179 Xb = curve25519(cmsg_X, b) 180 secret_input_phase1 = Xb + ID + cmsg_X + B + PROTOID + encapsulate(verification) 181 182 phase1_keys = kdf_phase1(secret_input_phase1) 183 enc_key = phase1_keys.take(ENC_KEY_LEN) 184 mac_key = phase1_keys.take(MAC_KEY_LEN) 185 186 mac_received = mac_phase1(mac_covered_portion, mac_key) 187 if mac_received != cmsg_mac: 188 raise Reject() 189 190 client_msg = enc(cmsg_msg, enc_key) 191 state = dict( 192 b=b, 193 B=B, 194 X=cmsg_X, 195 mac_received=mac_received, 196 Xb=Xb, 197 ID=ID, 198 verification=verification) 199 200 return (client_msg, state) 201 202 def server_part2(state, server_msg): 203 X = state['X'] 204 Xb = state['Xb'] 205 B = state['B'] 206 b = state['b'] 207 ID = state['ID'] 208 mac_received = state['mac_received'] 209 verification = state['verification'] 210 211 p(["server_msg"], locals()) 212 213 (y,Y) = keygen() 214 p(["y", "Y"], locals()) 215 Xy = curve25519(X, y) 216 217 secret_input = Xy + Xb + ID + B + X + Y + PROTOID + encapsulate(verification) 218 key_seed = h_key_seed(secret_input) 219 verify = h_verify(secret_input) 220 p(["key_seed", "verify"], locals()) 221 222 keys = kdf_final(key_seed) 223 server_enc_key = keys.take(ENC_KEY_LEN) 224 p(["server_enc_key"], locals()) 225 226 smsg_msg = enc(server_msg, server_enc_key) 227 228 auth_input = verify + ID + B + Y + X + mac_received + encapsulate(smsg_msg) + PROTOID + b"Server" 229 230 auth = h_auth(auth_input) 231 server_handshake = Y + auth + smsg_msg 232 p(["auth", "server_handshake"], locals()) 233 234 return (server_handshake, keys) 235 236 def client_phase2(state, smsg): 237 x = state['x'] 238 X = state['X'] 239 B = state['B'] 240 ID = state['ID'] 241 Bx = state['Bx'] 242 mac_sent = state['mac'] 243 verification = state['verification'] 244 245 if len(smsg) < PUB_KEY_LEN + DIGEST_LEN: 246 raise Reject() 247 248 smsg = ByteSeq(smsg) 249 Y = smsg.take(PUB_KEY_LEN) 250 auth_received = smsg.take(DIGEST_LEN) 251 server_msg = smsg.take(smsg.remaining()) 252 253 Yx = curve25519(Y,x) 254 255 secret_input = Yx + Bx + ID + B + X + Y + PROTOID + encapsulate(verification) 256 key_seed = h_key_seed(secret_input) 257 verify = h_verify(secret_input) 258 259 auth_input = verify + ID + B + Y + X + mac_sent + encapsulate(server_msg) + PROTOID + b"Server" 260 261 auth = h_auth(auth_input) 262 if auth != auth_received: 263 raise Reject() 264 265 keys = kdf_final(key_seed) 266 enc_key = keys.take(ENC_KEY_LEN) 267 268 server_msg_decrypted = enc(server_msg, enc_key) 269 270 return (keys, server_msg_decrypted) 271 272 def p(varnames, localvars): 273 for v in varnames: 274 label = v 275 val = localvars[label] 276 print('{} = "{}"'.format(label, binascii.b2a_hex(val).decode("ascii"))) 277 278 def test(): 279 (b,B) = keygen() 280 ID = os.urandom(IDENTITY_LEN) 281 282 p(["b", "B", "ID"], locals()) 283 284 print("# ============") 285 (c_handshake, c_state) = client_phase1(b"hello world", b"xyzzy", B, ID) 286 287 print("# ============") 288 289 (c_msg_got, s_state) = server_part1(c_handshake, b"xyzzy", b, B, ID) 290 291 #print(repr(c_msg_got)) 292 293 (s_handshake, s_keys) = server_part2(s_state, b"Hola Mundo") 294 295 print("# ============") 296 297 (c_keys, s_msg_got) = client_phase2(c_state, s_handshake) 298 299 #print(repr(s_msg_got)) 300 301 c_keys_256 = c_keys.take(256) 302 p(["c_keys_256"], locals()) 303 304 assert (c_keys_256 == s_keys.take(256)) 305 306 307 if __name__ == '__main__': 308 test()