tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

transport.py (13277B)


      1 # This Source Code Form is subject to the terms of the Mozilla Public
      2 # License, v. 2.0. If a copy of the MPL was not distributed with this
      3 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
      4 
      5 import json
      6 import socket
      7 import sys
      8 import time
      9 from threading import RLock
     10 
     11 
     12 class SocketTimeout:
     13    def __init__(self, socket_ctx, timeout):
     14        self.socket_ctx = socket_ctx
     15        self.timeout = timeout
     16        self.old_timeout = None
     17 
     18    def __enter__(self):
     19        self.old_timeout = self.socket_ctx.socket_timeout
     20        self.socket_ctx.socket_timeout = self.timeout
     21 
     22    def __exit__(self, *args, **kwargs):
     23        self.socket_ctx.socket_timeout = self.old_timeout
     24 
     25 
     26 class Message:
     27    def __init__(self, msgid):
     28        self.id = msgid
     29 
     30    def __eq__(self, other):
     31        return self.id == other.id
     32 
     33    def __ne__(self, other):
     34        return not self.__eq__(other)
     35 
     36    def __hash__(self):
     37        # pylint --py3k: W1641
     38        return hash(self.id)
     39 
     40 
     41 class Command(Message):
     42    TYPE = 0
     43 
     44    def __init__(self, msgid, name, params):
     45        Message.__init__(self, msgid)
     46        self.name = name
     47        self.params = params
     48 
     49    def __str__(self):
     50        return f"<Command id={self.id}, name={self.name}, params={self.params}>"
     51 
     52    def to_msg(self):
     53        msg = [Command.TYPE, self.id, self.name, self.params]
     54        return json.dumps(msg)
     55 
     56    @staticmethod
     57    def from_msg(data):
     58        assert data[0] == Command.TYPE
     59        cmd = Command(data[1], data[2], data[3])
     60        return cmd
     61 
     62 
     63 class Response(Message):
     64    TYPE = 1
     65 
     66    def __init__(self, msgid, error, result):
     67        Message.__init__(self, msgid)
     68        self.error = error
     69        self.result = result
     70 
     71    def __str__(self):
     72        return f"<Response id={self.id}, error={self.error}, result={self.result}>"
     73 
     74    def to_msg(self):
     75        msg = [Response.TYPE, self.id, self.error, self.result]
     76        return json.dumps(msg)
     77 
     78    @staticmethod
     79    def from_msg(data):
     80        assert data[0] == Response.TYPE
     81        return Response(data[1], data[2], data[3])
     82 
     83 
     84 class SocketContext:
     85    """Object that guards access to a socket via a lock.
     86 
     87    The socket must be accessed using this object as a context manager;
     88    access to the socket outside of a context will bypass the lock."""
     89 
     90    def __init__(self, host, port, timeout):
     91        self.lock = RLock()
     92 
     93        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     94        self._sock.settimeout(timeout)
     95        self._sock.connect((host, port))
     96 
     97    @property
     98    def socket_timeout(self):
     99        return self._sock.gettimeout()
    100 
    101    @socket_timeout.setter
    102    def socket_timeout(self, value):
    103        self._sock.settimeout(value)
    104 
    105    def __enter__(self):
    106        self.lock.acquire()
    107        return self._sock
    108 
    109    def __exit__(self, *args, **kwargs):
    110        self.lock.release()
    111 
    112 
    113 class TcpTransport:
    114    """Socket client that communciates with Marionette via TCP.
    115 
    116    It speaks the protocol of the remote debugger in Gecko, in which
    117    messages are always preceded by the message length and a colon, e.g.:
    118 
    119        7:MESSAGE
    120 
    121    On top of this protocol it uses a Marionette message format, that
    122    depending on the protocol level offered by the remote server, varies.
    123    Supported protocol levels are `min_protocol_level` and above.
    124    """
    125 
    126    max_packet_length = 4096
    127    min_protocol_level = 3
    128 
    129    def __init__(self, host, port, socket_timeout=60.0):
    130        """If `socket_timeout` is `0` or `0.0`, non-blocking socket mode
    131        will be used.  Setting it to `1` or `None` disables timeouts on
    132        socket operations altogether.
    133        """
    134        self._socket_context = None
    135 
    136        self.host = host
    137        self.port = port
    138        self._socket_timeout = socket_timeout
    139 
    140        self.protocol = self.min_protocol_level
    141        self.application_type = None
    142        self.last_id = 0
    143        self.expected_response = None
    144 
    145    @property
    146    def socket_timeout(self):
    147        return self._socket_timeout
    148 
    149    @socket_timeout.setter
    150    def socket_timeout(self, value):
    151        self._socket_timeout = value
    152 
    153        if self._socket_context is not None:
    154            self._socket_context.socket_timeout = value
    155 
    156    def _unmarshal(self, packet):
    157        """Convert data from bytes to a Message subtype
    158 
    159        Message format is [type, msg_id, body1, body2], where body1 and body2 depend
    160        on the message type.
    161 
    162        :param packet: Bytes received over the wire representing a complete message.
    163        """
    164        msg = None
    165 
    166        data = json.loads(packet)
    167        msg_type = data[0]
    168 
    169        if msg_type == Command.TYPE:
    170            msg = Command.from_msg(data)
    171        elif msg_type == Response.TYPE:
    172            msg = Response.from_msg(data)
    173        else:
    174            raise ValueError(f"Invalid message body {packet!r}")
    175 
    176        return msg
    177 
    178    def receive(self, unmarshal=True):
    179        """Wait for the next complete response from the remote.
    180 
    181        Packet format is length-prefixed JSON:
    182 
    183          packet = digit+ ":" body
    184          digit = "0"-"9"
    185          body = JSON text
    186 
    187        :param unmarshal: Default is to deserialise the packet and
    188            return a ``Message`` type.  Setting this to false will return
    189            the raw packet.
    190        """
    191        # Initally we read 4 bytes. We don't support reading beyond the end of a message, and
    192        # so assuming the JSON body has to be an array or object, the minimum possible message
    193        # is 4 bytes: "2:{}". In practice the marionette format has some required fields so the
    194        # message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in
    195        # length in two reads, which is the common case.
    196        with self._socket_context as sock:
    197            recv_bytes = 4
    198 
    199            length_prefix = b""
    200 
    201            body_length = -1
    202            body_received = 0
    203            body_parts = []
    204 
    205            now = time.time()
    206            timeout_time = (
    207                now + self.socket_timeout if self.socket_timeout is not None else None
    208            )
    209 
    210            while recv_bytes > 0:
    211                if timeout_time is not None and time.time() > timeout_time:
    212                    raise socket.timeout(
    213                        f"Connection timed out after {self.socket_timeout}s"
    214                    )
    215 
    216                try:
    217                    chunk = sock.recv(recv_bytes)
    218                except socket.timeout:
    219                    # Lets handle it with our own timeout check
    220                    continue
    221 
    222                if not chunk:
    223                    raise OSError("No data received over socket")
    224 
    225                body_part = None
    226                if body_length > 0:
    227                    body_part = chunk
    228                else:
    229                    parts = chunk.split(b":", 1)
    230                    length_prefix += parts[0]
    231 
    232                    # With > 10 decimal digits we aren't going to have a 32 bit number
    233                    if len(length_prefix) > 10:
    234                        raise ValueError(f"Invalid message length: {length_prefix!r}")
    235 
    236                    if len(parts) == 2:
    237                        # We found a : so we know the full length
    238                        err = None
    239                        try:
    240                            body_length = int(length_prefix)
    241                        except ValueError:
    242                            err = "expected an integer"
    243                        else:
    244                            if body_length <= 0:
    245                                err = "expected a positive integer"
    246                            elif body_length > 2**32 - 1:
    247                                err = "expected a 32 bit integer"
    248                        if err is not None:
    249                            raise ValueError(
    250                                f"Invalid message length: {err} got {length_prefix!r}"
    251                            )
    252                        body_part = parts[1]
    253 
    254                    # If we didn't find a : yet we keep reading 4 bytes at a time until we do.
    255                    # We could increase this here to 7 bytes (since we can't have more than 10
    256                    # length bytes and a seperator byte), or just increase it to
    257                    # int(length_prefix) + 1 since that's the minimum total number of remaining
    258                    # bytes (if the : is in the next byte), but it's probably not worth optimising
    259                    # for large messages.
    260 
    261                if body_part is not None:
    262                    body_received += len(body_part)
    263                    body_parts.append(body_part)
    264                    recv_bytes = body_length - body_received
    265 
    266            body = b"".join(body_parts)
    267            if unmarshal:
    268                msg = self._unmarshal(body)
    269                self.last_id = msg.id
    270 
    271                # keep reading incoming responses until
    272                # we receive the user's expected response
    273                if isinstance(msg, Response) and msg != self.expected_response:
    274                    return self.receive(unmarshal)
    275 
    276                return msg
    277            return body
    278 
    279    def connect(self):
    280        """Connect to the server and process the hello message we expect
    281        to receive in response.
    282 
    283        Returns a tuple of the protocol level and the application type.
    284        """
    285        try:
    286            self._socket_context = SocketContext(
    287                self.host, self.port, self._socket_timeout
    288            )
    289        except Exception:
    290            # Unset so that the next attempt to send will cause
    291            # another connection attempt.
    292            self._socket_context = None
    293            raise
    294 
    295        try:
    296            with SocketTimeout(self._socket_context, 60.0):
    297                # first packet is always a JSON Object
    298                # which we can use to tell which protocol level we are at
    299                raw = self.receive(unmarshal=False)
    300        except socket.timeout:
    301            exc_cls, exc, tb = sys.exc_info()
    302            msg = "Connection attempt failed because no data has been received over the socket: {}"
    303            raise exc_cls(msg.format(exc)).with_traceback(tb)
    304 
    305        hello = json.loads(raw)
    306        application_type = hello.get("applicationType")
    307        protocol = hello.get("marionetteProtocol")
    308 
    309        if application_type != "gecko":
    310            raise ValueError(f"Application type '{application_type}' is not supported")
    311 
    312        if not isinstance(protocol, int) or protocol < self.min_protocol_level:
    313            msg = "Earliest supported protocol level is '{}' but got '{}'"
    314            raise ValueError(msg.format(self.min_protocol_level, protocol))
    315 
    316        self.application_type = application_type
    317        self.protocol = protocol
    318 
    319        return (self.protocol, self.application_type)
    320 
    321    def send(self, obj):
    322        """Send message to the remote server.  Allowed input is a
    323        ``Message`` instance or a JSON serialisable object.
    324        """
    325        if not self._socket_context:
    326            self.connect()
    327 
    328        if isinstance(obj, Message):
    329            data = obj.to_msg()
    330            if isinstance(obj, Command):
    331                self.expected_response = obj
    332        else:
    333            data = json.dumps(obj)
    334        data = data.encode()
    335        payload = str(len(data)).encode() + b":" + data
    336 
    337        with self._socket_context as sock:
    338            totalsent = 0
    339            while totalsent < len(payload):
    340                sent = sock.send(payload[totalsent:])
    341                if sent == 0:
    342                    raise OSError(
    343                        f"Socket error after sending {totalsent} of {len(payload)} bytes"
    344                    )
    345                else:
    346                    totalsent += sent
    347 
    348    def respond(self, obj):
    349        """Send a response to a command.  This can be an arbitrary JSON
    350        serialisable object or an ``Exception``.
    351        """
    352        res, err = None, None
    353        if isinstance(obj, Exception):
    354            err = obj
    355        else:
    356            res = obj
    357        msg = Response(self.last_id, err, res)
    358        self.send(msg)
    359        return self.receive()
    360 
    361    def request(self, name, params):
    362        """Sends a message to the remote server and waits for a response
    363        to come back.
    364        """
    365        self.last_id = self.last_id + 1
    366        cmd = Command(self.last_id, name, params)
    367        self.send(cmd)
    368        return self.receive()
    369 
    370    def close(self):
    371        """Close the socket.
    372 
    373        First forces the socket to not send data anymore, and then explicitly
    374        close it to free up its resources.
    375 
    376        See: https://docs.python.org/2/howto/sockets.html#disconnecting
    377        """
    378        if self._socket_context:
    379            with self._socket_context as sock:
    380                try:
    381                    sock.shutdown(socket.SHUT_RDWR)
    382                except OSError as exc:
    383                    # If the socket is already closed, don't care about:
    384                    #   Errno  57: Socket not connected
    385                    #   Errno 107: Transport endpoint is not connected
    386                    if exc.errno not in (57, 107):
    387                        raise
    388 
    389                if sock:
    390                    # Guard against unclean shutdown.
    391                    sock.close()
    392                    self._socket_context = None
    393 
    394    def __del__(self):
    395        self.close()