jwt_helper.py (3785B)
1 import json 2 import base64 3 from cryptography.hazmat.primitives import serialization 4 from cryptography.hazmat.primitives import hashes 5 from cryptography.hazmat.primitives.asymmetric import rsa, padding 6 7 # This method decodes the JWT and verifies the signature. If a key is provided, 8 # that will be used for signature verification. Otherwise, the key sent within 9 # the JWT payload will be used instead. 10 # This returns a tuple of (decoded_header, decoded_payload, verify_succeeded). 11 def decode_jwt(token, key=None): 12 try: 13 # Decode the header and payload. 14 header, payload, signature = token.split('.') 15 decoded_header = decode_base64_json(header) 16 decoded_payload = decode_base64_json(payload) 17 18 # If decoding failed, return nothing. 19 if not decoded_header or not decoded_payload: 20 return None, None, False 21 22 # If there is a key passed in (for refresh), use that for checking the signature below. 23 # Otherwise (for registration), use the key sent within the JWT to check the signature. 24 if key == None: 25 key = decoded_header.get('jwk') 26 public_key = serialization.load_pem_public_key(jwk_to_pem(key)) 27 # Verifying the signature will throw an exception if it fails. 28 verify_rs256_signature(header, payload, signature, public_key) 29 return decoded_header, decoded_payload, True 30 except Exception: 31 return None, None, False 32 33 def jwk_to_pem(jwk_data): 34 jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data 35 key_type = jwk.get("kty") 36 37 if key_type != "RSA": 38 raise ValueError(f"Unsupported key type: {key_type}") 39 40 n = int.from_bytes(decode_base64url(jwk["n"]), 'big') 41 e = int.from_bytes(decode_base64url(jwk["e"]), 'big') 42 public_key = rsa.RSAPublicNumbers(e, n).public_key() 43 pem_public_key = public_key.public_bytes( 44 encoding=serialization.Encoding.PEM, 45 format=serialization.PublicFormat.SubjectPublicKeyInfo 46 ) 47 return pem_public_key 48 49 def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key): 50 message = (f'{encoded_header}.{encoded_payload}').encode('utf-8') 51 signature_bytes = decode_base64(signature) 52 # This will throw an exception if verification fails. 53 public_key.verify( 54 signature_bytes, 55 message, 56 padding.PKCS1v15(), 57 hashes.SHA256() 58 ) 59 60 def add_base64_padding(encoded_data): 61 remainder = len(encoded_data) % 4 62 if remainder > 0: 63 encoded_data += '=' * (4 - remainder) 64 return encoded_data 65 66 def decode_base64url(encoded_data): 67 encoded_data = add_base64_padding(encoded_data) 68 encoded_data = encoded_data.replace("-", "+").replace("_", "/") 69 return base64.b64decode(encoded_data) 70 71 def decode_base64(encoded_data): 72 encoded_data = add_base64_padding(encoded_data) 73 return base64.urlsafe_b64decode(encoded_data) 74 75 def decode_base64_json(encoded_data): 76 return json.loads(decode_base64(encoded_data)) 77 78 def thumbprint_for_jwk(jwk): 79 filtered_jwk = None 80 if jwk['kty'] == 'RSA': 81 filtered_jwk = dict() 82 filtered_jwk['kty'] = jwk['kty'] 83 filtered_jwk['n'] = jwk['n'] 84 filtered_jwk['e'] = jwk['e'] 85 elif jwk['kty'] == 'EC': 86 filtered_jwk = dict() 87 filtered_jwk['kty'] = jwk['kty'] 88 filtered_jwk['crv'] = jwk['crv'] 89 filtered_jwk['x'] = jwk['x'] 90 filtered_jwk['y'] = jwk['y'] 91 else: 92 return None 93 94 serialized_jwk = json.dumps(filtered_jwk, sort_keys=True, separators=(',',':')) 95 96 digest = hashes.Hash(hashes.SHA256()) 97 digest.update(serialized_jwk.encode("utf-8")) 98 99 thumbprint_base64 = base64.b64encode(digest.finalize(), altchars=b"-_").rstrip(b"=") 100 return thumbprint_base64.decode('ascii')