tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

jwt_helper.py (3785B)


      1 import json
      2 import base64
      3 from cryptography.hazmat.primitives import serialization
      4 from cryptography.hazmat.primitives import hashes
      5 from cryptography.hazmat.primitives.asymmetric import rsa, padding
      6 
      7 # This method decodes the JWT and verifies the signature. If a key is provided,
      8 # that will be used for signature verification. Otherwise, the key sent within
      9 # the JWT payload will be used instead.
     10 # This returns a tuple of (decoded_header, decoded_payload, verify_succeeded).
     11 def decode_jwt(token, key=None):
     12    try:
     13        # Decode the header and payload.
     14        header, payload, signature = token.split('.')
     15        decoded_header = decode_base64_json(header)
     16        decoded_payload = decode_base64_json(payload)
     17 
     18        # If decoding failed, return nothing.
     19        if not decoded_header or not decoded_payload:
     20            return None, None, False
     21 
     22        # If there is a key passed in (for refresh), use that for checking the signature below.
     23        # Otherwise (for registration), use the key sent within the JWT to check the signature.
     24        if key == None:
     25            key = decoded_header.get('jwk')
     26        public_key = serialization.load_pem_public_key(jwk_to_pem(key))
     27        # Verifying the signature will throw an exception if it fails.
     28        verify_rs256_signature(header, payload, signature, public_key)
     29        return decoded_header, decoded_payload, True
     30    except Exception:
     31        return None, None, False
     32 
     33 def jwk_to_pem(jwk_data):
     34    jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data
     35    key_type = jwk.get("kty")
     36 
     37    if key_type != "RSA":
     38        raise ValueError(f"Unsupported key type: {key_type}")
     39 
     40    n = int.from_bytes(decode_base64url(jwk["n"]), 'big')
     41    e = int.from_bytes(decode_base64url(jwk["e"]), 'big')
     42    public_key = rsa.RSAPublicNumbers(e, n).public_key()
     43    pem_public_key = public_key.public_bytes(
     44        encoding=serialization.Encoding.PEM,
     45        format=serialization.PublicFormat.SubjectPublicKeyInfo
     46    )
     47    return pem_public_key
     48 
     49 def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key):
     50    message = (f'{encoded_header}.{encoded_payload}').encode('utf-8')
     51    signature_bytes = decode_base64(signature)
     52    # This will throw an exception if verification fails.
     53    public_key.verify(
     54        signature_bytes,
     55        message,
     56        padding.PKCS1v15(),
     57        hashes.SHA256()
     58    )
     59 
     60 def add_base64_padding(encoded_data):
     61    remainder = len(encoded_data) % 4
     62    if remainder > 0:
     63        encoded_data += '=' * (4 - remainder)
     64    return encoded_data
     65 
     66 def decode_base64url(encoded_data):
     67    encoded_data = add_base64_padding(encoded_data)
     68    encoded_data = encoded_data.replace("-", "+").replace("_", "/")
     69    return base64.b64decode(encoded_data)
     70 
     71 def decode_base64(encoded_data):
     72    encoded_data = add_base64_padding(encoded_data)
     73    return base64.urlsafe_b64decode(encoded_data)
     74 
     75 def decode_base64_json(encoded_data):
     76    return json.loads(decode_base64(encoded_data))
     77 
     78 def thumbprint_for_jwk(jwk):
     79    filtered_jwk = None
     80    if jwk['kty'] == 'RSA':
     81        filtered_jwk = dict()
     82        filtered_jwk['kty'] = jwk['kty']
     83        filtered_jwk['n'] = jwk['n']
     84        filtered_jwk['e'] = jwk['e']
     85    elif jwk['kty'] == 'EC':
     86        filtered_jwk = dict()
     87        filtered_jwk['kty'] = jwk['kty']
     88        filtered_jwk['crv'] = jwk['crv']
     89        filtered_jwk['x'] = jwk['x']
     90        filtered_jwk['y'] = jwk['y']
     91    else:
     92        return None
     93 
     94    serialized_jwk = json.dumps(filtered_jwk, sort_keys=True, separators=(',',':'))
     95 
     96    digest = hashes.Hash(hashes.SHA256())
     97    digest.update(serialized_jwk.encode("utf-8"))
     98 
     99    thumbprint_base64 = base64.b64encode(digest.finalize(), altchars=b"-_").rstrip(b"=")
    100    return thumbprint_base64.decode('ascii')