transport.py (13277B)
1 # This Source Code Form is subject to the terms of the Mozilla Public 2 # License, v. 2.0. If a copy of the MPL was not distributed with this 3 # file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 5 import json 6 import socket 7 import sys 8 import time 9 from threading import RLock 10 11 12 class SocketTimeout: 13 def __init__(self, socket_ctx, timeout): 14 self.socket_ctx = socket_ctx 15 self.timeout = timeout 16 self.old_timeout = None 17 18 def __enter__(self): 19 self.old_timeout = self.socket_ctx.socket_timeout 20 self.socket_ctx.socket_timeout = self.timeout 21 22 def __exit__(self, *args, **kwargs): 23 self.socket_ctx.socket_timeout = self.old_timeout 24 25 26 class Message: 27 def __init__(self, msgid): 28 self.id = msgid 29 30 def __eq__(self, other): 31 return self.id == other.id 32 33 def __ne__(self, other): 34 return not self.__eq__(other) 35 36 def __hash__(self): 37 # pylint --py3k: W1641 38 return hash(self.id) 39 40 41 class Command(Message): 42 TYPE = 0 43 44 def __init__(self, msgid, name, params): 45 Message.__init__(self, msgid) 46 self.name = name 47 self.params = params 48 49 def __str__(self): 50 return f"<Command id={self.id}, name={self.name}, params={self.params}>" 51 52 def to_msg(self): 53 msg = [Command.TYPE, self.id, self.name, self.params] 54 return json.dumps(msg) 55 56 @staticmethod 57 def from_msg(data): 58 assert data[0] == Command.TYPE 59 cmd = Command(data[1], data[2], data[3]) 60 return cmd 61 62 63 class Response(Message): 64 TYPE = 1 65 66 def __init__(self, msgid, error, result): 67 Message.__init__(self, msgid) 68 self.error = error 69 self.result = result 70 71 def __str__(self): 72 return f"<Response id={self.id}, error={self.error}, result={self.result}>" 73 74 def to_msg(self): 75 msg = [Response.TYPE, self.id, self.error, self.result] 76 return json.dumps(msg) 77 78 @staticmethod 79 def from_msg(data): 80 assert data[0] == Response.TYPE 81 return Response(data[1], data[2], data[3]) 82 83 84 class SocketContext: 85 """Object that guards access to a socket via a lock. 86 87 The socket must be accessed using this object as a context manager; 88 access to the socket outside of a context will bypass the lock.""" 89 90 def __init__(self, host, port, timeout): 91 self.lock = RLock() 92 93 self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 94 self._sock.settimeout(timeout) 95 self._sock.connect((host, port)) 96 97 @property 98 def socket_timeout(self): 99 return self._sock.gettimeout() 100 101 @socket_timeout.setter 102 def socket_timeout(self, value): 103 self._sock.settimeout(value) 104 105 def __enter__(self): 106 self.lock.acquire() 107 return self._sock 108 109 def __exit__(self, *args, **kwargs): 110 self.lock.release() 111 112 113 class TcpTransport: 114 """Socket client that communciates with Marionette via TCP. 115 116 It speaks the protocol of the remote debugger in Gecko, in which 117 messages are always preceded by the message length and a colon, e.g.: 118 119 7:MESSAGE 120 121 On top of this protocol it uses a Marionette message format, that 122 depending on the protocol level offered by the remote server, varies. 123 Supported protocol levels are `min_protocol_level` and above. 124 """ 125 126 max_packet_length = 4096 127 min_protocol_level = 3 128 129 def __init__(self, host, port, socket_timeout=60.0): 130 """If `socket_timeout` is `0` or `0.0`, non-blocking socket mode 131 will be used. Setting it to `1` or `None` disables timeouts on 132 socket operations altogether. 133 """ 134 self._socket_context = None 135 136 self.host = host 137 self.port = port 138 self._socket_timeout = socket_timeout 139 140 self.protocol = self.min_protocol_level 141 self.application_type = None 142 self.last_id = 0 143 self.expected_response = None 144 145 @property 146 def socket_timeout(self): 147 return self._socket_timeout 148 149 @socket_timeout.setter 150 def socket_timeout(self, value): 151 self._socket_timeout = value 152 153 if self._socket_context is not None: 154 self._socket_context.socket_timeout = value 155 156 def _unmarshal(self, packet): 157 """Convert data from bytes to a Message subtype 158 159 Message format is [type, msg_id, body1, body2], where body1 and body2 depend 160 on the message type. 161 162 :param packet: Bytes received over the wire representing a complete message. 163 """ 164 msg = None 165 166 data = json.loads(packet) 167 msg_type = data[0] 168 169 if msg_type == Command.TYPE: 170 msg = Command.from_msg(data) 171 elif msg_type == Response.TYPE: 172 msg = Response.from_msg(data) 173 else: 174 raise ValueError(f"Invalid message body {packet!r}") 175 176 return msg 177 178 def receive(self, unmarshal=True): 179 """Wait for the next complete response from the remote. 180 181 Packet format is length-prefixed JSON: 182 183 packet = digit+ ":" body 184 digit = "0"-"9" 185 body = JSON text 186 187 :param unmarshal: Default is to deserialise the packet and 188 return a ``Message`` type. Setting this to false will return 189 the raw packet. 190 """ 191 # Initally we read 4 bytes. We don't support reading beyond the end of a message, and 192 # so assuming the JSON body has to be an array or object, the minimum possible message 193 # is 4 bytes: "2:{}". In practice the marionette format has some required fields so the 194 # message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in 195 # length in two reads, which is the common case. 196 with self._socket_context as sock: 197 recv_bytes = 4 198 199 length_prefix = b"" 200 201 body_length = -1 202 body_received = 0 203 body_parts = [] 204 205 now = time.time() 206 timeout_time = ( 207 now + self.socket_timeout if self.socket_timeout is not None else None 208 ) 209 210 while recv_bytes > 0: 211 if timeout_time is not None and time.time() > timeout_time: 212 raise socket.timeout( 213 f"Connection timed out after {self.socket_timeout}s" 214 ) 215 216 try: 217 chunk = sock.recv(recv_bytes) 218 except socket.timeout: 219 # Lets handle it with our own timeout check 220 continue 221 222 if not chunk: 223 raise OSError("No data received over socket") 224 225 body_part = None 226 if body_length > 0: 227 body_part = chunk 228 else: 229 parts = chunk.split(b":", 1) 230 length_prefix += parts[0] 231 232 # With > 10 decimal digits we aren't going to have a 32 bit number 233 if len(length_prefix) > 10: 234 raise ValueError(f"Invalid message length: {length_prefix!r}") 235 236 if len(parts) == 2: 237 # We found a : so we know the full length 238 err = None 239 try: 240 body_length = int(length_prefix) 241 except ValueError: 242 err = "expected an integer" 243 else: 244 if body_length <= 0: 245 err = "expected a positive integer" 246 elif body_length > 2**32 - 1: 247 err = "expected a 32 bit integer" 248 if err is not None: 249 raise ValueError( 250 f"Invalid message length: {err} got {length_prefix!r}" 251 ) 252 body_part = parts[1] 253 254 # If we didn't find a : yet we keep reading 4 bytes at a time until we do. 255 # We could increase this here to 7 bytes (since we can't have more than 10 256 # length bytes and a seperator byte), or just increase it to 257 # int(length_prefix) + 1 since that's the minimum total number of remaining 258 # bytes (if the : is in the next byte), but it's probably not worth optimising 259 # for large messages. 260 261 if body_part is not None: 262 body_received += len(body_part) 263 body_parts.append(body_part) 264 recv_bytes = body_length - body_received 265 266 body = b"".join(body_parts) 267 if unmarshal: 268 msg = self._unmarshal(body) 269 self.last_id = msg.id 270 271 # keep reading incoming responses until 272 # we receive the user's expected response 273 if isinstance(msg, Response) and msg != self.expected_response: 274 return self.receive(unmarshal) 275 276 return msg 277 return body 278 279 def connect(self): 280 """Connect to the server and process the hello message we expect 281 to receive in response. 282 283 Returns a tuple of the protocol level and the application type. 284 """ 285 try: 286 self._socket_context = SocketContext( 287 self.host, self.port, self._socket_timeout 288 ) 289 except Exception: 290 # Unset so that the next attempt to send will cause 291 # another connection attempt. 292 self._socket_context = None 293 raise 294 295 try: 296 with SocketTimeout(self._socket_context, 60.0): 297 # first packet is always a JSON Object 298 # which we can use to tell which protocol level we are at 299 raw = self.receive(unmarshal=False) 300 except socket.timeout: 301 exc_cls, exc, tb = sys.exc_info() 302 msg = "Connection attempt failed because no data has been received over the socket: {}" 303 raise exc_cls(msg.format(exc)).with_traceback(tb) 304 305 hello = json.loads(raw) 306 application_type = hello.get("applicationType") 307 protocol = hello.get("marionetteProtocol") 308 309 if application_type != "gecko": 310 raise ValueError(f"Application type '{application_type}' is not supported") 311 312 if not isinstance(protocol, int) or protocol < self.min_protocol_level: 313 msg = "Earliest supported protocol level is '{}' but got '{}'" 314 raise ValueError(msg.format(self.min_protocol_level, protocol)) 315 316 self.application_type = application_type 317 self.protocol = protocol 318 319 return (self.protocol, self.application_type) 320 321 def send(self, obj): 322 """Send message to the remote server. Allowed input is a 323 ``Message`` instance or a JSON serialisable object. 324 """ 325 if not self._socket_context: 326 self.connect() 327 328 if isinstance(obj, Message): 329 data = obj.to_msg() 330 if isinstance(obj, Command): 331 self.expected_response = obj 332 else: 333 data = json.dumps(obj) 334 data = data.encode() 335 payload = str(len(data)).encode() + b":" + data 336 337 with self._socket_context as sock: 338 totalsent = 0 339 while totalsent < len(payload): 340 sent = sock.send(payload[totalsent:]) 341 if sent == 0: 342 raise OSError( 343 f"Socket error after sending {totalsent} of {len(payload)} bytes" 344 ) 345 else: 346 totalsent += sent 347 348 def respond(self, obj): 349 """Send a response to a command. This can be an arbitrary JSON 350 serialisable object or an ``Exception``. 351 """ 352 res, err = None, None 353 if isinstance(obj, Exception): 354 err = obj 355 else: 356 res = obj 357 msg = Response(self.last_id, err, res) 358 self.send(msg) 359 return self.receive() 360 361 def request(self, name, params): 362 """Sends a message to the remote server and waits for a response 363 to come back. 364 """ 365 self.last_id = self.last_id + 1 366 cmd = Command(self.last_id, name, params) 367 self.send(cmd) 368 return self.receive() 369 370 def close(self): 371 """Close the socket. 372 373 First forces the socket to not send data anymore, and then explicitly 374 close it to free up its resources. 375 376 See: https://docs.python.org/2/howto/sockets.html#disconnecting 377 """ 378 if self._socket_context: 379 with self._socket_context as sock: 380 try: 381 sock.shutdown(socket.SHUT_RDWR) 382 except OSError as exc: 383 # If the socket is already closed, don't care about: 384 # Errno 57: Socket not connected 385 # Errno 107: Transport endpoint is not connected 386 if exc.errno not in (57, 107): 387 raise 388 389 if sock: 390 # Guard against unclean shutdown. 391 sock.close() 392 self._socket_context = None 393 394 def __del__(self): 395 self.close()