echo_client.py (25791B)
1 #!/usr/bin/env python 2 # 3 # Copyright 2011, Google Inc. 4 # All rights reserved. 5 # 6 # Redistribution and use in source and binary forms, with or without 7 # modification, are permitted provided that the following conditions are 8 # met: 9 # 10 # * Redistributions of source code must retain the above copyright 11 # notice, this list of conditions and the following disclaimer. 12 # * Redistributions in binary form must reproduce the above 13 # copyright notice, this list of conditions and the following disclaimer 14 # in the documentation and/or other materials provided with the 15 # distribution. 16 # * Neither the name of Google Inc. nor the names of its 17 # contributors may be used to endorse or promote products derived from 18 # this software without specific prior written permission. 19 # 20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 """Simple WebSocket client named echo_client just because of historical reason. 32 33 pywebsocket3 directory must be in PYTHONPATH. 34 35 Example Usage: 36 37 # server setup 38 % cd $pywebsocket3 39 % PYTHONPATH=$cwd/src python ./pywebsocket3/standalone.py -p 8880 \ 40 -d $cwd/src/example 41 42 # run client 43 % PYTHONPATH=$cwd/src python ./src/example/echo_client.py -p 8880 \ 44 -s localhost \ 45 -o http://localhost -r /echo -m test 46 """ 47 48 from __future__ import absolute_import 49 from __future__ import print_function 50 51 import argparse 52 import base64 53 import codecs 54 import logging 55 import os 56 import re 57 import socket 58 import ssl 59 import sys 60 from hashlib import sha1 61 62 import six 63 64 from pywebsocket3 import common, util 65 from pywebsocket3.extensions import ( 66 PerMessageDeflateExtensionProcessor, 67 _PerMessageDeflateFramer, 68 _parse_window_bits, 69 ) 70 from pywebsocket3.stream import Stream, StreamOptions 71 72 _TIMEOUT_SEC = 10 73 _UNDEFINED_PORT = -1 74 75 _UPGRADE_HEADER = 'Upgrade: websocket\r\n' 76 _CONNECTION_HEADER = 'Connection: Upgrade\r\n' 77 78 # Special message that tells the echo server to start closing handshake 79 _GOODBYE_MESSAGE = 'Goodbye' 80 81 _PROTOCOL_VERSION_HYBI13 = 'hybi13' 82 83 84 class ClientHandshakeError(Exception): 85 pass 86 87 88 def _build_method_line(resource): 89 return 'GET %s HTTP/1.1\r\n' % resource 90 91 92 def _origin_header(header, origin): 93 # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character, 94 # and the /origin/ value, converted to ASCII lowercase, to /fields/. 95 return '%s: %s\r\n' % (header, origin.lower()) 96 97 98 def _format_host_header(host, port, secure): 99 # 4.1 9. Let /hostport/ be an empty string. 100 # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to 101 # /hostport/ 102 hostport = host.lower() 103 # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/ 104 # is true, and /port/ is not 443, then append a U+003A COLON character 105 # (:) followed by the value of /port/, expressed as a base-ten integer, 106 # to /hostport/ 107 if ((not secure and port != common.DEFAULT_WEB_SOCKET_PORT) 108 or (secure and port != common.DEFAULT_WEB_SOCKET_SECURE_PORT)): 109 hostport += ':' + str(port) 110 # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE 111 # character, and /hostport/, to /fields/. 112 return '%s: %s\r\n' % (common.HOST_HEADER, hostport) 113 114 115 def _receive_bytes(socket, length): 116 recv_bytes = [] 117 remaining = length 118 while remaining > 0: 119 received_bytes = socket.recv(remaining) 120 if not received_bytes: 121 raise IOError( 122 'Connection closed before receiving requested length ' 123 '(requested %d bytes but received only %d bytes)' % 124 (length, length - remaining)) 125 recv_bytes.append(received_bytes) 126 remaining -= len(received_bytes) 127 return b''.join(recv_bytes) 128 129 130 def _get_mandatory_header(fields, name): 131 """Gets the value of the header specified by name from fields. 132 133 This function expects that there's only one header with the specified name 134 in fields. Otherwise, raises an ClientHandshakeError. 135 """ 136 137 values = fields.get(name.lower()) 138 if values is None or len(values) == 0: 139 raise ClientHandshakeError('%s header not found: %r' % (name, values)) 140 if len(values) > 1: 141 raise ClientHandshakeError('Multiple %s headers found: %r' % 142 (name, values)) 143 return values[0] 144 145 146 def _validate_mandatory_header(fields, 147 name, 148 expected_value, 149 case_sensitive=False): 150 """Gets and validates the value of the header specified by name from 151 fields. 152 153 If expected_value is specified, compares expected value and actual value 154 and raises an ClientHandshakeError on failure. You can specify case 155 sensitiveness in this comparison by case_sensitive parameter. This function 156 expects that there's only one header with the specified name in fields. 157 Otherwise, raises an ClientHandshakeError. 158 """ 159 160 value = _get_mandatory_header(fields, name) 161 162 if ((case_sensitive and value != expected_value) or 163 (not case_sensitive and value.lower() != expected_value.lower())): 164 raise ClientHandshakeError( 165 'Illegal value for header %s: %r (expected) vs %r (actual)' % 166 (name, expected_value, value)) 167 168 169 class _TLSSocket(object): 170 """Wrapper for a TLS connection.""" 171 def __init__(self, raw_socket): 172 self._logger = util.get_class_logger(self) 173 174 self._tls_socket = ssl.wrap_socket(raw_socket) 175 176 # Print cipher in use. Handshake is done on wrap_socket call. 177 self._logger.info("Cipher: %s", self._tls_socket.cipher()) 178 179 def send(self, data): 180 return self._tls_socket.write(data) 181 182 def sendall(self, data): 183 return self._tls_socket.sendall(data) 184 185 def recv(self, size=-1): 186 return self._tls_socket.read(size) 187 188 def close(self): 189 return self._tls_socket.close() 190 191 def getpeername(self): 192 return self._tls_socket.getpeername() 193 194 195 class ClientHandshakeBase(object): 196 """A base class for WebSocket opening handshake processors for each 197 protocol version. 198 """ 199 def __init__(self): 200 self._logger = util.get_class_logger(self) 201 202 def _read_fields(self): 203 # 4.1 32. let /fields/ be a list of name-value pairs, initially empty. 204 fields = {} 205 while True: # "Field" 206 # 4.1 33. let /name/ and /value/ be empty byte arrays 207 name = b'' 208 value = b'' 209 # 4.1 34. read /name/ 210 name = self._read_name() 211 if name is None: 212 break 213 # 4.1 35. read spaces 214 # TODO(tyoshino): Skip only one space as described in the spec. 215 ch = self._skip_spaces() 216 # 4.1 36. read /value/ 217 value = self._read_value(ch) 218 # 4.1 37. read a byte from the server 219 ch = _receive_bytes(self._socket, 1) 220 if ch != b'\n': # 0x0A 221 raise ClientHandshakeError( 222 'Expected LF but found %r while reading value %r for ' 223 'header %r' % (ch, value, name)) 224 self._logger.debug('Received %r header', name) 225 # 4.1 38. append an entry to the /fields/ list that has the name 226 # given by the string obtained by interpreting the /name/ byte 227 # array as a UTF-8 stream and the value given by the string 228 # obtained by interpreting the /value/ byte array as a UTF-8 byte 229 # stream. 230 fields.setdefault(name.decode('UTF-8'), 231 []).append(value.decode('UTF-8')) 232 # 4.1 39. return to the "Field" step above 233 return fields 234 235 def _read_name(self): 236 # 4.1 33. let /name/ be empty byte arrays 237 name = b'' 238 while True: 239 # 4.1 34. read a byte from the server 240 ch = _receive_bytes(self._socket, 1) 241 if ch == b'\r': # 0x0D 242 return None 243 elif ch == b'\n': # 0x0A 244 raise ClientHandshakeError( 245 'Unexpected LF when reading header name %r' % name) 246 elif ch == b':': # 0x3A 247 return name.lower() 248 else: 249 name += ch 250 251 def _skip_spaces(self): 252 # 4.1 35. read a byte from the server 253 while True: 254 ch = _receive_bytes(self._socket, 1) 255 if ch == b' ': # 0x20 256 continue 257 return ch 258 259 def _read_value(self, ch): 260 # 4.1 33. let /value/ be empty byte arrays 261 value = b'' 262 # 4.1 36. read a byte from server. 263 while True: 264 if ch == b'\r': # 0x0D 265 return value 266 elif ch == b'\n': # 0x0A 267 raise ClientHandshakeError( 268 'Unexpected LF when reading header value %r' % value) 269 else: 270 value += ch 271 ch = _receive_bytes(self._socket, 1) 272 273 274 def _get_permessage_deflate_framer(extension_response): 275 """Validate the response and return a framer object using the parameters in 276 the response. This method doesn't accept the server_.* parameters. 277 """ 278 279 client_max_window_bits = None 280 client_no_context_takeover = None 281 282 client_max_window_bits_name = ( 283 PerMessageDeflateExtensionProcessor._CLIENT_MAX_WINDOW_BITS_PARAM) 284 client_no_context_takeover_name = ( 285 PerMessageDeflateExtensionProcessor._CLIENT_NO_CONTEXT_TAKEOVER_PARAM) 286 287 # We didn't send any server_.* parameter. 288 # Handle those parameters as invalid if found in the response. 289 290 for param_name, param_value in extension_response.get_parameters(): 291 if param_name == client_max_window_bits_name: 292 if client_max_window_bits is not None: 293 raise ClientHandshakeError('Multiple %s found' % 294 client_max_window_bits_name) 295 296 parsed_value = _parse_window_bits(param_value) 297 if parsed_value is None: 298 raise ClientHandshakeError( 299 'Bad %s: %r' % (client_max_window_bits_name, param_value)) 300 client_max_window_bits = parsed_value 301 elif param_name == client_no_context_takeover_name: 302 if client_no_context_takeover is not None: 303 raise ClientHandshakeError('Multiple %s found' % 304 client_no_context_takeover_name) 305 306 if param_value is not None: 307 raise ClientHandshakeError( 308 'Bad %s: Has value %r' % 309 (client_no_context_takeover_name, param_value)) 310 client_no_context_takeover = True 311 312 if client_no_context_takeover is None: 313 client_no_context_takeover = False 314 315 return _PerMessageDeflateFramer(client_max_window_bits, 316 client_no_context_takeover) 317 318 319 class ClientHandshakeProcessor(ClientHandshakeBase): 320 """WebSocket opening handshake processor 321 """ 322 def __init__(self, socket, options): 323 super(ClientHandshakeProcessor, self).__init__() 324 325 self._socket = socket 326 self._options = options 327 328 self._logger = util.get_class_logger(self) 329 330 def handshake(self): 331 """Performs opening handshake on the specified socket. 332 333 Raises: 334 ClientHandshakeError: handshake failed. 335 """ 336 337 request_line = _build_method_line(self._options.resource) 338 self._logger.debug('Client\'s opening handshake Request-Line: %r', 339 request_line) 340 self._socket.sendall(request_line.encode('UTF-8')) 341 342 fields = [] 343 fields.append( 344 _format_host_header(self._options.server_host, 345 self._options.server_port, 346 self._options.use_tls)) 347 fields.append(_UPGRADE_HEADER) 348 fields.append(_CONNECTION_HEADER) 349 if self._options.origin is not None: 350 fields.append( 351 _origin_header(common.ORIGIN_HEADER, self._options.origin)) 352 353 original_key = os.urandom(16) 354 self._key = base64.b64encode(original_key) 355 self._logger.debug('%s: %r (%s)', common.SEC_WEBSOCKET_KEY_HEADER, 356 self._key, util.hexify(original_key)) 357 fields.append( 358 '%s: %s\r\n' % 359 (common.SEC_WEBSOCKET_KEY_HEADER, self._key.decode('UTF-8'))) 360 361 fields.append( 362 '%s: %d\r\n' % 363 (common.SEC_WEBSOCKET_VERSION_HEADER, common.VERSION_HYBI_LATEST)) 364 365 extensions_to_request = [] 366 367 if self._options.use_permessage_deflate: 368 extension = common.ExtensionParameter( 369 common.PERMESSAGE_DEFLATE_EXTENSION) 370 # Accept the client_max_window_bits extension parameter by default. 371 extension.add_parameter( 372 PerMessageDeflateExtensionProcessor. 373 _CLIENT_MAX_WINDOW_BITS_PARAM, None) 374 extensions_to_request.append(extension) 375 376 if len(extensions_to_request) != 0: 377 fields.append('%s: %s\r\n' % 378 (common.SEC_WEBSOCKET_EXTENSIONS_HEADER, 379 common.format_extensions(extensions_to_request))) 380 381 for field in fields: 382 self._socket.sendall(field.encode('UTF-8')) 383 384 self._socket.sendall(b'\r\n') 385 386 self._logger.debug('Sent client\'s opening handshake headers: %r', 387 fields) 388 self._logger.debug('Start reading Status-Line') 389 390 status_line = b'' 391 while True: 392 ch = _receive_bytes(self._socket, 1) 393 status_line += ch 394 if ch == b'\n': 395 break 396 397 m = re.match(b'HTTP/\\d+\.\\d+ (\\d\\d\\d) .*\r\n', status_line) 398 if m is None: 399 raise ClientHandshakeError('Wrong status line format: %r' % 400 status_line) 401 status_code = m.group(1) 402 if status_code != b'101': 403 self._logger.debug( 404 'Unexpected status code %s with following headers: %r', 405 status_code, self._read_fields()) 406 raise ClientHandshakeError( 407 'Expected HTTP status code 101 but found %r' % status_code) 408 409 self._logger.debug('Received valid Status-Line') 410 self._logger.debug('Start reading headers until we see an empty line') 411 412 fields = self._read_fields() 413 414 ch = _receive_bytes(self._socket, 1) 415 if ch != b'\n': # 0x0A 416 raise ClientHandshakeError( 417 'Expected LF but found %r while reading value %r for header ' 418 'name %r' % (ch, value, name)) 419 420 self._logger.debug('Received an empty line') 421 self._logger.debug('Server\'s opening handshake headers: %r', fields) 422 423 _validate_mandatory_header(fields, common.UPGRADE_HEADER, 424 common.WEBSOCKET_UPGRADE_TYPE, False) 425 426 _validate_mandatory_header(fields, common.CONNECTION_HEADER, 427 common.UPGRADE_CONNECTION_TYPE, False) 428 429 accept = _get_mandatory_header(fields, 430 common.SEC_WEBSOCKET_ACCEPT_HEADER) 431 432 # Validate 433 try: 434 binary_accept = base64.b64decode(accept) 435 except TypeError: 436 raise HandshakeError('Illegal value for header %s: %r' % 437 (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) 438 439 if len(binary_accept) != 20: 440 raise ClientHandshakeError( 441 'Decoded value of %s is not 20-byte long' % 442 common.SEC_WEBSOCKET_ACCEPT_HEADER) 443 444 self._logger.debug('Response for challenge : %r (%s)', accept, 445 util.hexify(binary_accept)) 446 447 binary_expected_accept = sha1(self._key + 448 common.WEBSOCKET_ACCEPT_UUID).digest() 449 expected_accept = base64.b64encode(binary_expected_accept) 450 451 self._logger.debug('Expected response for challenge: %r (%s)', 452 expected_accept, 453 util.hexify(binary_expected_accept)) 454 455 if accept != expected_accept.decode('UTF-8'): 456 raise ClientHandshakeError( 457 'Invalid %s header: %r (expected: %s)' % 458 (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept, expected_accept)) 459 460 permessage_deflate_accepted = False 461 462 extensions_header = fields.get( 463 common.SEC_WEBSOCKET_EXTENSIONS_HEADER.lower()) 464 accepted_extensions = [] 465 if extensions_header is not None and len(extensions_header) != 0: 466 accepted_extensions = common.parse_extensions(extensions_header[0]) 467 468 for extension in accepted_extensions: 469 extension_name = extension.name() 470 if (extension_name == common.PERMESSAGE_DEFLATE_EXTENSION 471 and self._options.use_permessage_deflate): 472 permessage_deflate_accepted = True 473 474 framer = _get_permessage_deflate_framer(extension) 475 framer.set_compress_outgoing_enabled(True) 476 self._options.use_permessage_deflate = framer 477 continue 478 479 raise ClientHandshakeError('Unexpected extension %r' % 480 extension_name) 481 482 if (self._options.use_permessage_deflate 483 and not permessage_deflate_accepted): 484 raise ClientHandshakeError( 485 'Requested %s, but the server rejected it' % 486 common.PERMESSAGE_DEFLATE_EXTENSION) 487 488 # TODO(tyoshino): Handle Sec-WebSocket-Protocol 489 # TODO(tyoshino): Handle Cookie, etc. 490 491 492 class ClientConnection(object): 493 """A wrapper for socket object to provide the mp_conn interface. 494 """ 495 def __init__(self, socket): 496 self._socket = socket 497 498 def write(self, data): 499 self._socket.sendall(data) 500 501 def read(self, n): 502 return self._socket.recv(n) 503 504 def get_remote_addr(self): 505 return self._socket.getpeername() 506 507 remote_addr = property(get_remote_addr) 508 509 510 class ClientRequest(object): 511 """A wrapper class just to make it able to pass a socket object to 512 functions that expect a mp_request object. 513 """ 514 def __init__(self, socket): 515 self._logger = util.get_class_logger(self) 516 517 self._socket = socket 518 self.connection = ClientConnection(socket) 519 self.ws_version = common.VERSION_HYBI_LATEST 520 521 522 class EchoClient(object): 523 """WebSocket echo client.""" 524 def __init__(self, options): 525 self._options = options 526 self._socket = None 527 528 self._logger = util.get_class_logger(self) 529 530 def run(self): 531 """Run the client. 532 533 Shake hands and then repeat sending message and receiving its echo. 534 """ 535 536 self._socket = socket.socket() 537 self._socket.settimeout(self._options.socket_timeout) 538 try: 539 self._socket.connect( 540 (self._options.server_host, self._options.server_port)) 541 if self._options.use_tls: 542 self._socket = _TLSSocket(self._socket) 543 544 self._handshake = ClientHandshakeProcessor(self._socket, 545 self._options) 546 547 self._handshake.handshake() 548 549 self._logger.info('Connection established') 550 551 request = ClientRequest(self._socket) 552 553 stream_option = StreamOptions() 554 stream_option.mask_send = True 555 stream_option.unmask_receive = False 556 557 if self._options.use_permessage_deflate is not False: 558 framer = self._options.use_permessage_deflate 559 framer.setup_stream_options(stream_option) 560 561 self._stream = Stream(request, stream_option) 562 563 for line in self._options.message.split(','): 564 self._stream.send_message(line) 565 if self._options.verbose: 566 print('Send: %s' % line) 567 try: 568 received = self._stream.receive_message() 569 570 if self._options.verbose: 571 print('Recv: %s' % received) 572 except Exception as e: 573 if self._options.verbose: 574 print('Error: %s' % e) 575 raise 576 577 self._do_closing_handshake() 578 finally: 579 self._socket.close() 580 581 def _do_closing_handshake(self): 582 """Perform closing handshake using the specified closing frame.""" 583 584 if self._options.message.split(',')[-1] == _GOODBYE_MESSAGE: 585 # requested server initiated closing handshake, so 586 # expecting closing handshake message from server. 587 self._logger.info('Wait for server-initiated closing handshake') 588 message = self._stream.receive_message() 589 if message is None: 590 print('Recv close') 591 print('Send ack') 592 self._logger.info('Received closing handshake and sent ack') 593 return 594 print('Send close') 595 self._stream.close_connection() 596 self._logger.info('Sent closing handshake') 597 print('Recv ack') 598 self._logger.info('Received ack') 599 600 601 def main(): 602 # Force Python 2 to use the locale encoding, even when the output is not a 603 # tty. This makes the behaviour the same as Python 3. The encoding won't 604 # necessarily support all unicode characters. This problem is particularly 605 # prevalent on Windows. 606 if six.PY2: 607 import locale 608 encoding = locale.getpreferredencoding() 609 sys.stdout = codecs.getwriter(encoding)(sys.stdout) 610 611 parser = argparse.ArgumentParser() 612 # We accept --command_line_flag style flags which is the same as Google 613 # gflags in addition to common --command-line-flag style flags. 614 parser.add_argument('-s', 615 '--server-host', 616 '--server_host', 617 dest='server_host', 618 type=six.text_type, 619 default='localhost', 620 help='server host') 621 parser.add_argument('-p', 622 '--server-port', 623 '--server_port', 624 dest='server_port', 625 type=int, 626 default=_UNDEFINED_PORT, 627 help='server port') 628 parser.add_argument('-o', 629 '--origin', 630 dest='origin', 631 type=six.text_type, 632 default=None, 633 help='origin') 634 parser.add_argument('-r', 635 '--resource', 636 dest='resource', 637 type=six.text_type, 638 default='/echo', 639 help='resource path') 640 parser.add_argument( 641 '-m', 642 '--message', 643 dest='message', 644 type=six.text_type, 645 default=u'Hello,<>', 646 help=('comma-separated messages to send. ' 647 '%s will force close the connection from server.' % 648 _GOODBYE_MESSAGE)) 649 parser.add_argument('-q', 650 '--quiet', 651 dest='verbose', 652 action='store_false', 653 default=True, 654 help='suppress messages') 655 parser.add_argument('-t', 656 '--tls', 657 dest='use_tls', 658 action='store_true', 659 default=False, 660 help='use TLS (wss://).') 661 parser.add_argument('-k', 662 '--socket-timeout', 663 '--socket_timeout', 664 dest='socket_timeout', 665 type=int, 666 default=_TIMEOUT_SEC, 667 help='Timeout(sec) for sockets') 668 parser.add_argument('--use-permessage-deflate', 669 '--use_permessage_deflate', 670 dest='use_permessage_deflate', 671 action='store_true', 672 default=False, 673 help='Use the permessage-deflate extension.') 674 parser.add_argument('--log-level', 675 '--log_level', 676 type=six.text_type, 677 dest='log_level', 678 default='warn', 679 choices=['debug', 'info', 'warn', 'error', 'critical'], 680 help='Log level.') 681 682 options = parser.parse_args() 683 684 logging.basicConfig(level=logging.getLevelName(options.log_level.upper())) 685 686 # Default port number depends on whether TLS is used. 687 if options.server_port == _UNDEFINED_PORT: 688 if options.use_tls: 689 options.server_port = common.DEFAULT_WEB_SOCKET_SECURE_PORT 690 else: 691 options.server_port = common.DEFAULT_WEB_SOCKET_PORT 692 693 EchoClient(options).run() 694 695 696 if __name__ == '__main__': 697 main() 698 699 # vi:sts=4 sw=4 et