tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

iceserver.py (36096B)


      1 # vim: set ts=4 et sw=4 tw=80
      2 # This Source Code Form is subject to the terms of the Mozilla Public
      3 # License, v. 2.0. If a copy of the MPL was not distributed with this
      4 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
      5 
      6 import ipaddr
      7 import socket
      8 import hmac
      9 import hashlib
     10 import passlib.utils  # for saslprep
     11 import copy
     12 import random
     13 import operator
     14 import os
     15 import platform
     16 import six
     17 import string
     18 import time
     19 from functools import reduce
     20 from string import Template
     21 from twisted.internet import reactor, protocol
     22 from twisted.internet.task import LoopingCall
     23 from twisted.internet.address import IPv4Address
     24 from twisted.internet.address import IPv6Address
     25 
     26 MAGIC_COOKIE = 0x2112A442
     27 
     28 REQUEST = 0
     29 INDICATION = 1
     30 SUCCESS_RESPONSE = 2
     31 ERROR_RESPONSE = 3
     32 
     33 BINDING = 0x001
     34 ALLOCATE = 0x003
     35 REFRESH = 0x004
     36 SEND = 0x006
     37 DATA_MSG = 0x007
     38 CREATE_PERMISSION = 0x008
     39 CHANNEL_BIND = 0x009
     40 
     41 # STUN spec chose silly values for these
     42 STUN_IPV4 = 1
     43 STUN_IPV6 = 2
     44 
     45 MAPPED_ADDRESS = 0x0001
     46 USERNAME = 0x0006
     47 MESSAGE_INTEGRITY = 0x0008
     48 ERROR_CODE = 0x0009
     49 UNKNOWN_ATTRIBUTES = 0x000A
     50 LIFETIME = 0x000D
     51 DATA_ATTR = 0x0013
     52 XOR_PEER_ADDRESS = 0x0012
     53 REALM = 0x0014
     54 NONCE = 0x0015
     55 XOR_RELAYED_ADDRESS = 0x0016
     56 REQUESTED_TRANSPORT = 0x0019
     57 DONT_FRAGMENT = 0x001A
     58 XOR_MAPPED_ADDRESS = 0x0020
     59 SOFTWARE = 0x8022
     60 ALTERNATE_SERVER = 0x8023
     61 FINGERPRINT = 0x8028
     62 
     63 STUN_PORT = 3478
     64 STUNS_PORT = 5349
     65 
     66 TURN_REDIRECT_PORT = 3479
     67 TURNS_REDIRECT_PORT = 5350
     68 
     69 
     70 def unpack_uint(bytes_buf):
     71    result = 0
     72    for byte in bytes_buf:
     73        result = (result << 8) + byte
     74    return result
     75 
     76 
     77 def pack_uint(value, width):
     78    if value < 0:
     79        raise ValueError("Invalid value: {}".format(value))
     80    buf = bytearray([0] * width)
     81    for i in range(0, width):
     82        buf[i] = (value >> (8 * (width - i - 1))) & 0xFF
     83 
     84    return buf
     85 
     86 
     87 def unpack(bytes_buf, format_array):
     88    results = ()
     89    for width in format_array:
     90        results = results + (unpack_uint(bytes_buf[0:width]),)
     91        bytes_buf = bytes_buf[width:]
     92    return results
     93 
     94 
     95 def pack(values, format_array):
     96    if len(values) != len(format_array):
     97        raise ValueError()
     98    buf = bytearray()
     99    for i in range(0, len(values)):
    100        buf.extend(pack_uint(values[i], format_array[i]))
    101    return buf
    102 
    103 
    104 def bitwise_pack(source, dest, start_bit, num_bits):
    105    if num_bits <= 0 or num_bits > start_bit + 1:
    106        raise ValueError(
    107            "Invalid num_bits: {}, start_bit = {}".format(num_bits, start_bit)
    108        )
    109    last_bit = start_bit - num_bits + 1
    110    source = source >> last_bit
    111    dest = dest << num_bits
    112    mask = (1 << num_bits) - 1
    113    dest += source & mask
    114    return dest
    115 
    116 
    117 def to_ipaddress(protocol, host, port):
    118    if ":" not in host:
    119        return IPv4Address(protocol, host, port)
    120 
    121    return IPv6Address(protocol, host, port)
    122 
    123 
    124 class StunAttribute(object):
    125    """
    126    Represents a STUN attribute in a raw format, according to the following:
    127 
    128     0                   1                   2                   3
    129     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    130    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    131    |   StunAttribute.attr_type     |  Length (derived as needed)   |
    132    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    133    |           StunAttribute.data (variable length)             ....
    134    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    135    """
    136 
    137    __attr_header_fmt = [2, 2]
    138    __attr_header_size = reduce(operator.add, __attr_header_fmt)
    139 
    140    def __init__(self, attr_type=0, buf=bytearray()):
    141        self.attr_type = attr_type
    142        self.data = buf
    143 
    144    def build(self):
    145        buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
    146        buf.extend(self.data)
    147        # add padding if necessary
    148        if len(buf) % 4:
    149            buf.extend([0] * (4 - (len(buf) % 4)))
    150        return buf
    151 
    152    def parse(self, buf):
    153        if self.__attr_header_size > len(buf):
    154            raise Exception("truncated at attribute: incomplete header")
    155 
    156        self.attr_type, length = unpack(buf, self.__attr_header_fmt)
    157        length += self.__attr_header_size
    158 
    159        if length > len(buf):
    160            raise Exception("truncated at attribute: incomplete contents")
    161 
    162        self.data = buf[self.__attr_header_size : length]
    163 
    164        # verify padding
    165        while length % 4:
    166            if buf[length]:
    167                raise ValueError("Non-zero padding")
    168            length += 1
    169 
    170        return length
    171 
    172 
    173 class StunMessage(object):
    174    """
    175    Represents a STUN message. Contains a method, msg_class, cookie,
    176    transaction_id, and attributes (as an array of StunAttribute).
    177 
    178    Has various functions for getting/adding attributes.
    179    """
    180 
    181    def __init__(self):
    182        self.method = 0
    183        self.msg_class = 0
    184        self.cookie = MAGIC_COOKIE
    185        self.transaction_id = 0
    186        self.attributes = []
    187 
    188    #      0                   1                   2                   3
    189    #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    190    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    191    #     |0 0|M M M M M|C|M M M|C|M M M M|         Message Length        |
    192    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    193    #     |                         Magic Cookie                          |
    194    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    195    #     |                                                               |
    196    #     |                     Transaction ID (96 bits)                  |
    197    #     |                                                               |
    198    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    199    __header_fmt = [2, 2, 4, 12]
    200    __header_size = reduce(operator.add, __header_fmt)
    201 
    202    # Returns how many bytes were parsed if buf was large enough, or how many
    203    # bytes we would have needed if not. Throws if buf is malformed.
    204    def parse(self, buf):
    205        min_buf_size = self.__header_size
    206        if len(buf) < min_buf_size:
    207            return min_buf_size
    208 
    209        message_type, length, cookie, self.transaction_id = unpack(
    210            buf, self.__header_fmt
    211        )
    212        min_buf_size += length
    213        if len(buf) < min_buf_size:
    214            return min_buf_size
    215 
    216        # Avert your eyes...
    217        self.method = bitwise_pack(message_type, 0, 13, 5)
    218        self.msg_class = bitwise_pack(message_type, 0, 8, 1)
    219        self.method = bitwise_pack(message_type, self.method, 7, 3)
    220        self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
    221        self.method = bitwise_pack(message_type, self.method, 3, 4)
    222 
    223        if cookie != self.cookie:
    224            raise Exception("Invalid cookie: {}".format(cookie))
    225 
    226        buf = buf[self.__header_size : min_buf_size]
    227        while len(buf):
    228            attr = StunAttribute()
    229            length = attr.parse(buf)
    230            buf = buf[length:]
    231            self.attributes.append(attr)
    232 
    233        return min_buf_size
    234 
    235    # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
    236    def build(self, stop_after_attr_type=0):
    237        attrs = bytearray()
    238        for attr in self.attributes:
    239            attrs.extend(attr.build())
    240            if attr.attr_type == stop_after_attr_type:
    241                break
    242 
    243        message_type = bitwise_pack(self.method, 0, 11, 5)
    244        message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
    245        message_type = bitwise_pack(self.method, message_type, 6, 3)
    246        message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
    247        message_type = bitwise_pack(self.method, message_type, 3, 4)
    248 
    249        message = pack(
    250            (message_type, len(attrs), self.cookie, self.transaction_id),
    251            self.__header_fmt,
    252        )
    253        message.extend(attrs)
    254 
    255        return message
    256 
    257    def add_error_code(self, code, phrase=None):
    258        #      0                   1                   2                   3
    259        #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    260        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    261        #     |           Reserved, should be 0         |Class|     Number    |
    262        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    263        #     |      Reason Phrase (variable)                                ..
    264        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    265        error_code_fmt = [3, 1]
    266        error_code = pack((code // 100, code % 100), error_code_fmt)
    267        if phrase != None:
    268            error_code.extend(bytearray(phrase, "utf-8"))
    269        self.attributes.append(StunAttribute(ERROR_CODE, error_code))
    270 
    271    #     0                   1                   2                   3
    272    #     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    273    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    274    #    |x x x x x x x x|    Family     |         X-Port                |
    275    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    276    #    |                X-Address (Variable)
    277    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    278    __v4addr_fmt = [1, 1, 2, 4]
    279    __v6addr_fmt = [1, 1, 2, 16]
    280    __v4addr_size = reduce(operator.add, __v4addr_fmt)
    281    __v6addr_size = reduce(operator.add, __v6addr_fmt)
    282 
    283    def add_address(self, ip_address, version, port, attr_type):
    284        if version == STUN_IPV4:
    285            address = pack((0, STUN_IPV4, port, ip_address), self.__v4addr_fmt)
    286        elif version == STUN_IPV6:
    287            address = pack((0, STUN_IPV6, port, ip_address), self.__v6addr_fmt)
    288        else:
    289            raise ValueError("Invalid ip version: {}".format(version))
    290        self.attributes.append(StunAttribute(attr_type, address))
    291 
    292    def get_xaddr(self, ip_addr, version):
    293        if version == STUN_IPV4:
    294            return self.cookie ^ ip_addr
    295        elif version == STUN_IPV6:
    296            return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
    297        else:
    298            raise ValueError("Invalid family: {}".format(version))
    299 
    300    def get_xport(self, port):
    301        return (self.cookie >> 16) ^ port
    302 
    303    def add_xor_address(self, addr_port, attr_type):
    304        ip_address = ipaddr.IPAddress(addr_port.host)
    305        version = STUN_IPV6 if ip_address.version == 6 else STUN_IPV4
    306        xaddr = self.get_xaddr(int(ip_address), version)
    307        xport = self.get_xport(addr_port.port)
    308        self.add_address(xaddr, version, xport, attr_type)
    309 
    310    def add_data(self, buf):
    311        self.attributes.append(StunAttribute(DATA_ATTR, buf))
    312 
    313    def find(self, attr_type):
    314        for attr in self.attributes:
    315            if attr.attr_type == attr_type:
    316                return attr
    317        return None
    318 
    319    def get_xor_address(self, attr_type):
    320        addr_attr = self.find(attr_type)
    321        if not addr_attr:
    322            return None
    323 
    324        padding, family, xport, xaddr = unpack(addr_attr.data, self.__v4addr_fmt)
    325        addr_ctor = IPv4Address
    326        if family == STUN_IPV6:
    327            padding, family, xport, xaddr = unpack(addr_attr.data, self.__v6addr_fmt)
    328            addr_ctor = IPv6Address
    329        elif family != STUN_IPV4:
    330            raise ValueError("Invalid family: {}".format(family))
    331 
    332        return addr_ctor(
    333            "UDP",
    334            str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
    335            self.get_xport(xport),
    336        )
    337 
    338    def add_nonce(self, nonce):
    339        self.attributes.append(StunAttribute(NONCE, bytearray(nonce, "utf-8")))
    340 
    341    def add_realm(self, realm):
    342        self.attributes.append(StunAttribute(REALM, bytearray(realm, "utf-8")))
    343 
    344    def calculate_message_digest(self, username, realm, password):
    345        digest_buf = self.build(MESSAGE_INTEGRITY)
    346        # Trim off the MESSAGE-INTEGRITY attr
    347        digest_buf = digest_buf[: len(digest_buf) - 24]
    348        password = passlib.utils.saslprep(six.text_type(password))
    349        key_string = "{}:{}:{}".format(username, realm, password)
    350        md5 = hashlib.md5()
    351        md5.update(bytearray(key_string, "utf-8"))
    352        key = md5.digest()
    353        return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
    354 
    355    def add_lifetime(self, lifetime):
    356        self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
    357 
    358    def get_lifetime(self):
    359        lifetime_attr = self.find(LIFETIME)
    360        if not lifetime_attr:
    361            return None
    362        return unpack_uint(lifetime_attr.data[0:4])
    363 
    364    def get_username(self):
    365        username = self.find(USERNAME)
    366        if not username:
    367            return None
    368        return str(username.data)
    369 
    370    def add_message_integrity(self, username, realm, password):
    371        dummy_value = bytearray([0] * 20)
    372        self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
    373        digest = self.calculate_message_digest(username, realm, password)
    374        self.find(MESSAGE_INTEGRITY).data = digest
    375 
    376    def add_alternate_server(self, host, port):
    377        address = ipaddr.IPAddress(host)
    378        version = STUN_IPV6 if address.version == 6 else STUN_IPV4
    379        self.add_address(int(address), version, port, ALTERNATE_SERVER)
    380 
    381 
    382 class Allocation(protocol.DatagramProtocol):
    383    """
    384    Comprises the socket for a TURN allocation, a back-reference to the
    385    transport we will forward received traffic on, the allocator's address and
    386    username, the set of permissions for the allocation, and the allocation's
    387    expiry.
    388    """
    389 
    390    def __init__(self, other_transport_handler, allocator_address, username):
    391        self.permissions = set()  # str, int tuples
    392        # Handler to use when sending stuff that arrives on the allocation
    393        self.other_transport_handler = other_transport_handler
    394        self.allocator_address = allocator_address
    395        self.username = username
    396        self.expiry = time.time()
    397        self.port = reactor.listenUDP(0, self, interface=v4_address)
    398 
    399    def datagramReceived(self, data, address):
    400        host = address[0]
    401        port = address[1]
    402        if not host in self.permissions:
    403            print(
    404                "Dropping packet from {}:{}, no permission on allocation {}".format(
    405                    host, port, self.transport.getHost()
    406                )
    407            )
    408            return
    409 
    410        data_indication = StunMessage()
    411        data_indication.method = DATA_MSG
    412        data_indication.msg_class = INDICATION
    413        data_indication.transaction_id = random.getrandbits(96)
    414 
    415        # Only handles UDP allocations. Doubtful that we need more than this.
    416        data_indication.add_xor_address(
    417            to_ipaddress("UDP", host, port), XOR_PEER_ADDRESS
    418        )
    419        data_indication.add_data(data)
    420 
    421        self.other_transport_handler.write(
    422            data_indication.build(), self.allocator_address
    423        )
    424 
    425    def close(self):
    426        self.port.stopListening()
    427        self.port = None
    428 
    429 
    430 class StunHandler(object):
    431    """
    432    Frames and handles STUN messages. This is the core logic of the TURN
    433    server, along with Allocation.
    434    """
    435 
    436    def __init__(self, transport_handler):
    437        self.client_address = None
    438        self.data = bytearray()
    439        self.transport_handler = transport_handler
    440 
    441    def data_received(self, data, address):
    442        self.data += data
    443        while True:
    444            stun_message = StunMessage()
    445            parsed_len = stun_message.parse(self.data)
    446            if parsed_len > len(self.data):
    447                break
    448            self.data = self.data[parsed_len:]
    449 
    450            response = self.handle_stun(stun_message, address)
    451            if response:
    452                self.transport_handler.write(response, address)
    453 
    454    def handle_stun(self, stun_message, address):
    455        self.client_address = address
    456        if stun_message.msg_class == INDICATION:
    457            if stun_message.method == SEND:
    458                self.handle_send_indication(stun_message)
    459            else:
    460                print(
    461                    "Dropping unknown indication method: {}".format(stun_message.method)
    462                )
    463            return None
    464 
    465        if stun_message.msg_class != REQUEST:
    466            print("Dropping STUN response, method: {}".format(stun_message.method))
    467            return None
    468 
    469        if stun_message.method == BINDING:
    470            return self.make_success_response(stun_message).build()
    471        elif stun_message.method == ALLOCATE:
    472            return self.handle_allocation(stun_message).build()
    473        elif stun_message.method == REFRESH:
    474            return self.handle_refresh(stun_message).build()
    475        elif stun_message.method == CREATE_PERMISSION:
    476            return self.handle_permission(stun_message).build()
    477        else:
    478            return self.make_error_response(
    479                stun_message,
    480                400,
    481                ("Unsupported STUN request, method: {}".format(stun_message.method)),
    482            ).build()
    483 
    484    def get_allocation_tuple(self):
    485        return (
    486            self.client_address.host,
    487            self.client_address.port,
    488            self.transport_handler.transport.getHost().type,
    489            self.transport_handler.transport.getHost().host,
    490            self.transport_handler.transport.getHost().port,
    491        )
    492 
    493    def handle_allocation(self, request):
    494        allocate_response = self.check_long_term_auth(request)
    495        if allocate_response.msg_class == SUCCESS_RESPONSE:
    496            if self.get_allocation_tuple() in allocations:
    497                return self.make_error_response(
    498                    request,
    499                    437,
    500                    (
    501                        "Duplicate allocation request for tuple {}".format(
    502                            self.get_allocation_tuple()
    503                        )
    504                    ),
    505                )
    506 
    507            allocation = Allocation(
    508                self.transport_handler, self.client_address, request.get_username()
    509            )
    510 
    511            allocate_response.add_xor_address(
    512                allocation.transport.getHost(), XOR_RELAYED_ADDRESS
    513            )
    514 
    515            lifetime = request.get_lifetime()
    516            if lifetime == None:
    517                return self.make_error_response(
    518                    request, 400, "Missing lifetime attribute in allocation request"
    519                )
    520 
    521            lifetime = min(lifetime, 3600)
    522            allocate_response.add_lifetime(lifetime)
    523            allocation.expiry = time.time() + lifetime
    524 
    525            allocate_response.add_message_integrity(turn_user, turn_realm, turn_pass)
    526            allocations[self.get_allocation_tuple()] = allocation
    527        return allocate_response
    528 
    529    def handle_refresh(self, request):
    530        refresh_response = self.check_long_term_auth(request)
    531        if refresh_response.msg_class == SUCCESS_RESPONSE:
    532            try:
    533                allocation = allocations[self.get_allocation_tuple()]
    534            except KeyError:
    535                return self.make_error_response(
    536                    request,
    537                    437,
    538                    (
    539                        "Refresh request for non-existing allocation, tuple {}".format(
    540                            self.get_allocation_tuple()
    541                        )
    542                    ),
    543                )
    544 
    545            if allocation.username != request.get_username():
    546                return self.make_error_response(
    547                    request,
    548                    441,
    549                    (
    550                        "Refresh request with wrong user, exp {}, got {}".format(
    551                            allocation.username, request.get_username()
    552                        )
    553                    ),
    554                )
    555 
    556            lifetime = request.get_lifetime()
    557            if lifetime == None:
    558                return self.make_error_response(
    559                    request, 400, "Missing lifetime attribute in allocation request"
    560                )
    561 
    562            lifetime = min(lifetime, 3600)
    563            refresh_response.add_lifetime(lifetime)
    564            allocation.expiry = time.time() + lifetime
    565 
    566            refresh_response.add_message_integrity(turn_user, turn_realm, turn_pass)
    567        return refresh_response
    568 
    569    def handle_permission(self, request):
    570        permission_response = self.check_long_term_auth(request)
    571        if permission_response.msg_class == SUCCESS_RESPONSE:
    572            try:
    573                allocation = allocations[self.get_allocation_tuple()]
    574            except KeyError:
    575                return self.make_error_response(
    576                    request,
    577                    437,
    578                    (
    579                        "No such allocation for permission request, tuple {}".format(
    580                            self.get_allocation_tuple()
    581                        )
    582                    ),
    583                )
    584 
    585            if allocation.username != request.get_username():
    586                return self.make_error_response(
    587                    request,
    588                    441,
    589                    (
    590                        "Permission request with wrong user, exp {}, got {}".format(
    591                            allocation.username, request.get_username()
    592                        )
    593                    ),
    594                )
    595 
    596            # TODO: Handle multiple XOR-PEER-ADDRESS
    597            peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
    598            if not peer_address:
    599                return self.make_error_response(
    600                    request, 400, "Missing XOR-PEER-ADDRESS on permission request"
    601                )
    602 
    603            permission_response.add_message_integrity(turn_user, turn_realm, turn_pass)
    604            allocation.permissions.add(peer_address.host)
    605 
    606        return permission_response
    607 
    608    def handle_send_indication(self, indication):
    609        try:
    610            allocation = allocations[self.get_allocation_tuple()]
    611        except KeyError:
    612            print(
    613                "Dropping send indication; no allocation for tuple {}".format(
    614                    self.get_allocation_tuple()
    615                )
    616            )
    617            return
    618 
    619        peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
    620        if not peer_address:
    621            print("Dropping send indication, missing XOR-PEER-ADDRESS")
    622            return
    623 
    624        data_attr = indication.find(DATA_ATTR)
    625        if not data_attr:
    626            print("Dropping send indication, missing DATA")
    627            return
    628 
    629        if indication.find(DONT_FRAGMENT):
    630            print("Dropping send indication, DONT-FRAGMENT set")
    631            return
    632 
    633        if not peer_address.host in allocation.permissions:
    634            print(
    635                "Dropping send indication, no permission for {} on tuple {}".format(
    636                    peer_address.host, self.get_allocation_tuple()
    637                )
    638            )
    639            return
    640 
    641        allocation.transport.write(
    642            data_attr.data, (peer_address.host, peer_address.port)
    643        )
    644 
    645    def make_success_response(self, request):
    646        response = copy.deepcopy(request)
    647        response.attributes = []
    648        response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
    649        response.msg_class = SUCCESS_RESPONSE
    650        return response
    651 
    652    def make_error_response(self, request, code, reason=None):
    653        if reason:
    654            print("{}: rejecting with {}".format(reason, code))
    655        response = copy.deepcopy(request)
    656        response.attributes = []
    657        response.add_error_code(code, reason)
    658        response.msg_class = ERROR_RESPONSE
    659        return response
    660 
    661    def make_challenge_response(self, request, reason=None):
    662        response = self.make_error_response(request, 401, reason)
    663        # 65 means the hex encoding will need padding half the time
    664        response.add_nonce("{:x}".format(random.getrandbits(65)))
    665        response.add_realm(turn_realm)
    666        return response
    667 
    668    def check_long_term_auth(self, request):
    669        message_integrity = request.find(MESSAGE_INTEGRITY)
    670        if not message_integrity:
    671            return self.make_challenge_response(request)
    672 
    673        username = request.find(USERNAME)
    674        realm = request.find(REALM)
    675        nonce = request.find(NONCE)
    676        if not username or not realm or not nonce:
    677            return self.make_error_response(
    678                request, 400, "Missing either USERNAME, NONCE, or REALM"
    679            )
    680 
    681        if username.data.decode("utf-8") != turn_user:
    682            return self.make_challenge_response(
    683                request, "Wrong user {}, exp {}".format(username.data, turn_user)
    684            )
    685 
    686        expected_message_digest = request.calculate_message_digest(
    687            turn_user, turn_realm, turn_pass
    688        )
    689        if message_integrity.data != expected_message_digest:
    690            return self.make_challenge_response(request, "Incorrect message disgest")
    691 
    692        return self.make_success_response(request)
    693 
    694 
    695 class StunRedirectHandler(StunHandler):
    696    """
    697    Frames and handles STUN messages by redirecting to the "real" server port.
    698    Performs the redirect with auth, so does a 401 to unauthed requests.
    699    Can be used to test port-based redirect handling.
    700    """
    701 
    702    def __init__(self, transport_handler):
    703        super(StunRedirectHandler, self).__init__(transport_handler)
    704 
    705    def handle_stun(self, stun_message, address):
    706        self.client_address = address
    707        if stun_message.msg_class == REQUEST:
    708            challenge_response = self.check_long_term_auth(stun_message)
    709 
    710            if challenge_response.msg_class == SUCCESS_RESPONSE:
    711                return self.make_redirect_response(stun_message).build()
    712 
    713            return challenge_response.build()
    714 
    715    def make_redirect_response(self, request):
    716        response = self.make_error_response(request, 300, "Try alternate")
    717        port = STUN_PORT
    718        if self.transport_handler.transport.getHost().port == TURNS_REDIRECT_PORT:
    719            port = STUNS_PORT
    720 
    721        response.add_alternate_server(
    722            self.transport_handler.transport.getHost().host, port
    723        )
    724 
    725        response.add_message_integrity(turn_user, turn_realm, turn_pass)
    726        return response
    727 
    728 
    729 class UdpStunHandler(protocol.DatagramProtocol):
    730    """
    731    Represents a UDP listen port for TURN.
    732    """
    733 
    734    def datagramReceived(self, data, address):
    735        stun_handler = StunHandler(self)
    736        stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
    737 
    738    def write(self, data, address):
    739        self.transport.write(bytes(data), (address.host, address.port))
    740 
    741 
    742 class UdpStunRedirectHandler(protocol.DatagramProtocol):
    743    """
    744    Represents a UDP listen port for TURN that will redirect.
    745    """
    746 
    747    def datagramReceived(self, data, address):
    748        stun_handler = StunRedirectHandler(self)
    749        stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
    750 
    751    def write(self, data, address):
    752        self.transport.write(bytes(data), (address.host, address.port))
    753 
    754 
    755 class TcpStunHandlerFactory(protocol.Factory):
    756    """
    757    Represents a TCP listen port for TURN.
    758    """
    759 
    760    def buildProtocol(self, addr):
    761        return TcpStunHandler(addr)
    762 
    763 
    764 class TcpStunHandler(protocol.Protocol):
    765    """
    766    Represents a connected TCP port for TURN.
    767    """
    768 
    769    def __init__(self, addr):
    770        self.address = addr
    771        self.stun_handler = None
    772 
    773    def dataReceived(self, data):
    774        # This needs to persist, since it handles framing
    775        if not self.stun_handler:
    776            self.stun_handler = StunHandler(self)
    777        self.stun_handler.data_received(data, self.address)
    778 
    779    def connectionLost(self, reason):
    780        print("Lost connection from {}".format(self.address))
    781        # Destroy allocations that this connection made
    782        keys_to_delete = []
    783        for key, allocation in allocations.items():
    784            if allocation.other_transport_handler == self:
    785                print("Closing allocation due to dropped connection: {}".format(key))
    786                keys_to_delete.append(key)
    787                allocation.close()
    788 
    789        for key in keys_to_delete:
    790            del allocations[key]
    791 
    792    def write(self, data, address):
    793        self.transport.write(bytes(data))
    794 
    795 
    796 class TcpStunRedirectHandlerFactory(protocol.Factory):
    797    """
    798    Represents a TCP listen port for TURN that will redirect.
    799    """
    800 
    801    def buildProtocol(self, addr):
    802        return TcpStunRedirectHandler(addr)
    803 
    804 
    805 class TcpStunRedirectHandler(protocol.DatagramProtocol):
    806    def __init__(self, addr):
    807        self.address = addr
    808        self.stun_handler = None
    809 
    810    def dataReceived(self, data):
    811        # This needs to persist, since it handles framing. Framing matters here
    812        # because we do a round of auth before redirecting.
    813        if not self.stun_handler:
    814            self.stun_handler = StunRedirectHandler(self)
    815        self.stun_handler.data_received(data, self.address)
    816 
    817    def write(self, data, address):
    818        self.transport.write(bytes(data))
    819 
    820    def connectionLost(self, reason):
    821        print("Lost connection from {}".format(self.address))
    822 
    823 
    824 def get_default_route(family):
    825    dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
    826    if family is socket.AF_INET:
    827        dummy_socket.connect(("8.8.8.8", 53))
    828    else:
    829        dummy_socket.connect(("2001:4860:4860::8888", 53))
    830 
    831    default_route = dummy_socket.getsockname()[0]
    832    dummy_socket.close()
    833    return default_route
    834 
    835 
    836 turn_user = "foo"
    837 turn_pass = "bar"
    838 turn_realm = "mozilla.invalid"
    839 allocations = {}
    840 v4_address = get_default_route(socket.AF_INET)
    841 try:
    842    v6_address = get_default_route(socket.AF_INET6)
    843 except:
    844    v6_address = ""
    845 
    846 
    847 def prune_allocations():
    848    now = time.time()
    849    keys_to_delete = []
    850    for key, allocation in allocations.items():
    851        if allocation.expiry < now:
    852            print("Allocation expired: {}".format(key))
    853            keys_to_delete.append(key)
    854            allocation.close()
    855 
    856    for key in keys_to_delete:
    857        del allocations[key]
    858 
    859 
    860 CERT_FILE = "selfsigned.crt"
    861 KEY_FILE = "private.key"
    862 
    863 
    864 def create_self_signed_cert(name):
    865    # pyOpenSSL used to have some wrappers to help with this, but those have
    866    # been deprecated, and they have instructed users to use stuff from
    867    # cryptography.hazmat directly. This strikes me as a bad idea, but here we
    868    # go...
    869 
    870    from cryptography.hazmat.primitives.asymmetric import rsa
    871    from cryptography import x509
    872    from cryptography.x509.oid import NameOID
    873    from cryptography.hazmat.primitives import hashes
    874    import datetime
    875    from cryptography.hazmat.primitives import serialization
    876 
    877    # Not ideal, but in order to avoid generating certs with duplicate serial
    878    # numbers, we don't regenerate if there's one there already. If we wanted
    879    # to regenerate, we'd need to load the cert if it was there, determine its
    880    # serial number, and then make a new cert with a higher serial number.
    881    if os.path.isfile(CERT_FILE) and os.path.isfile(KEY_FILE):
    882        return
    883 
    884    # Key size does not need to be big, this is a self-signed cert for testing,
    885    # but I'm going to use something common to avoid warnings that might come
    886    # up in the future.
    887    # Why 65537? Because the documentation says so, citing a document written
    888    # by Colin Percival in 2009. Will this ever be out of date? Is it out of
    889    # date already? Who knows!
    890    key = rsa.generate_private_key(key_size=2048, public_exponent=65537)
    891 
    892    subject = x509.Name(
    893        [
    894            x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
    895            x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "TX"),
    896            x509.NameAttribute(NameOID.LOCALITY_NAME, "Dallas"),
    897            x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Mozilla test iceserver"),
    898            x509.NameAttribute(NameOID.COMMON_NAME, name),
    899        ]
    900    )
    901 
    902    # create a self-signed cert
    903    cert = (
    904        x509.CertificateBuilder()
    905        .subject_name(subject)
    906        .issuer_name(subject)
    907        .serial_number(1000)
    908        .not_valid_before(datetime.datetime.now(datetime.timezone.utc))
    909        .not_valid_after(
    910            datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365)
    911        )
    912        .public_key(key.public_key())
    913        .add_extension(
    914            x509.SubjectAlternativeName([x509.DNSName(name)]),
    915            critical=False,
    916        )
    917        .sign(key, hashes.SHA256())
    918    )
    919 
    920    open(CERT_FILE, "wb").write(cert.public_bytes(encoding=serialization.Encoding.PEM))
    921    open(KEY_FILE, "wb").write(
    922        key.private_bytes(
    923            encoding=serialization.Encoding.PEM,
    924            format=serialization.PrivateFormat.PKCS8,
    925            encryption_algorithm=serialization.NoEncryption(),
    926        )
    927    )
    928 
    929 
    930 if __name__ == "__main__":
    931    random.seed()
    932 
    933    if platform.system() == "Windows":
    934        # Windows is finicky about allowing real interfaces to talk to loopback.
    935        interface_4 = v4_address
    936        interface_6 = v6_address
    937        hostname = socket.gethostname()
    938    else:
    939        # Our linux builders do not have a hostname that resolves to the real
    940        # interface.
    941        interface_4 = "127.0.0.1"
    942        interface_6 = "::1"
    943        hostname = "localhost"
    944 
    945    reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_4)
    946    reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_4)
    947 
    948    reactor.listenUDP(
    949        TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_4
    950    )
    951    reactor.listenTCP(
    952        TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_4
    953    )
    954 
    955    try:
    956        reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_6)
    957        reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_6)
    958 
    959        reactor.listenUDP(
    960            TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_6
    961        )
    962        reactor.listenTCP(
    963            TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_6
    964        )
    965    except:
    966        pass
    967 
    968    try:
    969        from twisted.internet import ssl
    970        from OpenSSL import SSL
    971 
    972        create_self_signed_cert(hostname)
    973        tls_context_factory = ssl.DefaultOpenSSLContextFactory(
    974            KEY_FILE, CERT_FILE, SSL.TLSv1_2_METHOD
    975        )
    976        reactor.listenSSL(
    977            STUNS_PORT,
    978            TcpStunHandlerFactory(),
    979            tls_context_factory,
    980            interface=interface_4,
    981        )
    982 
    983        try:
    984            reactor.listenSSL(
    985                STUNS_PORT,
    986                TcpStunHandlerFactory(),
    987                tls_context_factory,
    988                interface=interface_6,
    989            )
    990 
    991            reactor.listenSSL(
    992                TURNS_REDIRECT_PORT,
    993                TcpStunRedirectHandlerFactory(),
    994                tls_context_factory,
    995                interface=interface_6,
    996            )
    997        except:
    998            pass
    999 
   1000        f = open(CERT_FILE, "r")
   1001        lines = f.readlines()
   1002        lines.pop(0)  # Remove BEGIN CERTIFICATE
   1003        lines.pop()  # Remove END CERTIFICATE
   1004        # pylint --py3k: W1636 W1649
   1005        lines = list(map(str.strip, lines))
   1006        certbase64 = "".join(lines)  # pylint --py3k: W1649
   1007 
   1008        turns_url = ', "turns:' + hostname + '"'
   1009        cert_prop = ', "cert":"' + certbase64 + '"'
   1010    except:
   1011        turns_url = ""
   1012        cert_prop = ""
   1013        pass
   1014 
   1015    allocation_pruner = LoopingCall(prune_allocations)
   1016    allocation_pruner.start(1)
   1017 
   1018    template = Template(
   1019        '[\
   1020 {"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
   1021 {"username":"$user","credential":"$pwd","turn_redirect_port":"$TURN_REDIRECT_PORT","turns_redirect_port":"$TURNS_REDIRECT_PORT","urls": \
   1022 ["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
   1023 $cert_prop}]'  # Hack to make it easier to override cert checks
   1024    )
   1025 
   1026    print(
   1027        template.substitute(
   1028            user=turn_user,
   1029            pwd=turn_pass,
   1030            hostname=hostname,
   1031            turns_url=turns_url,
   1032            cert_prop=cert_prop,
   1033            TURN_REDIRECT_PORT=TURN_REDIRECT_PORT,
   1034            TURNS_REDIRECT_PORT=TURNS_REDIRECT_PORT,
   1035        )
   1036    )
   1037 
   1038    reactor.run()