test_msgutil.py (37829B)
1 #!/usr/bin/env python 2 # 3 # Copyright 2012, Google Inc. 4 # All rights reserved. 5 # 6 # Redistribution and use in source and binary forms, with or without 7 # modification, are permitted provided that the following conditions are 8 # met: 9 # 10 # * Redistributions of source code must retain the above copyright 11 # notice, this list of conditions and the following disclaimer. 12 # * Redistributions in binary form must reproduce the above 13 # copyright notice, this list of conditions and the following disclaimer 14 # in the documentation and/or other materials provided with the 15 # distribution. 16 # * Neither the name of Google Inc. nor the names of its 17 # contributors may be used to endorse or promote products derived from 18 # this software without specific prior written permission. 19 # 20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 """Tests for msgutil module.""" 32 33 from __future__ import absolute_import 34 from __future__ import print_function 35 from __future__ import division 36 37 import random 38 import struct 39 import unittest 40 import zlib 41 42 from six import iterbytes 43 from six.moves import map 44 from six.moves import range 45 import six.moves.queue 46 47 import set_sys_path # Update sys.path to locate pywebsocket3 module. 48 from pywebsocket3 import common, msgutil, util 49 from pywebsocket3.extensions import PerMessageDeflateExtensionProcessor 50 from pywebsocket3.stream import ( 51 InvalidUTF8Exception, 52 Stream, 53 StreamOptions, 54 ) 55 from test import mock 56 57 # We use one fixed nonce for testing instead of cryptographically secure PRNG. 58 _MASKING_NONCE = b'ABCD' 59 60 61 def _mask_hybi(frame): 62 if isinstance(frame, six.text_type): 63 Exception('masking does not accept Texts') 64 65 frame_key = list(iterbytes(_MASKING_NONCE)) 66 frame_key_len = len(frame_key) 67 result = bytearray(frame) 68 count = 0 69 70 for i in range(len(result)): 71 result[i] ^= frame_key[count] 72 count = (count + 1) % frame_key_len 73 74 return _MASKING_NONCE + bytes(result) 75 76 77 def _install_extension_processor(processor, request, stream_options): 78 response = processor.get_extension_response() 79 if response is not None: 80 processor.setup_stream_options(stream_options) 81 request.ws_extension_processors.append(processor) 82 83 84 def _create_request_from_rawdata(read_data, permessage_deflate_request=None): 85 req = mock.MockRequest(connection=mock.MockConn(read_data)) 86 req.ws_version = common.VERSION_HYBI_LATEST 87 req.ws_extension_processors = [] 88 89 processor = None 90 if permessage_deflate_request is not None: 91 processor = PerMessageDeflateExtensionProcessor( 92 permessage_deflate_request) 93 94 stream_options = StreamOptions() 95 if processor is not None: 96 _install_extension_processor(processor, req, stream_options) 97 req.ws_stream = Stream(req, stream_options) 98 99 return req 100 101 102 def _create_request(*frames): 103 """Creates MockRequest using data given as frames. 104 105 frames will be returned on calling request.connection.read() where request 106 is MockRequest returned by this function. 107 """ 108 109 read_data = [] 110 for (header, body) in frames: 111 read_data.append(header + _mask_hybi(body)) 112 113 return _create_request_from_rawdata(b''.join(read_data)) 114 115 116 def _create_blocking_request(): 117 """Creates MockRequest. 118 119 Data written to a MockRequest can be read out by calling 120 request.connection.written_data(). 121 """ 122 123 req = mock.MockRequest(connection=mock.MockBlockingConn()) 124 req.ws_version = common.VERSION_HYBI_LATEST 125 stream_options = StreamOptions() 126 req.ws_stream = Stream(req, stream_options) 127 return req 128 129 130 class BasicMessageTest(unittest.TestCase): 131 """Basic tests for Stream.""" 132 def test_send_message(self): 133 request = _create_request() 134 msgutil.send_message(request, 'Hello') 135 self.assertEqual(b'\x81\x05Hello', request.connection.written_data()) 136 137 payload = 'a' * 125 138 request = _create_request() 139 msgutil.send_message(request, payload) 140 self.assertEqual(b'\x81\x7d' + payload.encode('UTF-8'), 141 request.connection.written_data()) 142 143 def test_send_medium_message(self): 144 payload = 'a' * 126 145 request = _create_request() 146 msgutil.send_message(request, payload) 147 self.assertEqual(b'\x81\x7e\x00\x7e' + payload.encode('UTF-8'), 148 request.connection.written_data()) 149 150 payload = 'a' * ((1 << 16) - 1) 151 request = _create_request() 152 msgutil.send_message(request, payload) 153 self.assertEqual(b'\x81\x7e\xff\xff' + payload.encode('UTF-8'), 154 request.connection.written_data()) 155 156 def test_send_large_message(self): 157 payload = 'a' * (1 << 16) 158 request = _create_request() 159 msgutil.send_message(request, payload) 160 self.assertEqual( 161 b'\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 162 payload.encode('UTF-8'), request.connection.written_data()) 163 164 def test_send_message_unicode(self): 165 request = _create_request() 166 msgutil.send_message(request, u'\u65e5') 167 # U+65e5 is encoded as e6,97,a5 in UTF-8 168 self.assertEqual(b'\x81\x03\xe6\x97\xa5', 169 request.connection.written_data()) 170 171 def test_send_message_fragments(self): 172 request = _create_request() 173 msgutil.send_message(request, 'Hello', False) 174 msgutil.send_message(request, ' ', False) 175 msgutil.send_message(request, 'World', False) 176 msgutil.send_message(request, '!', True) 177 self.assertEqual(b'\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!', 178 request.connection.written_data()) 179 180 def test_send_fragments_immediate_zero_termination(self): 181 request = _create_request() 182 msgutil.send_message(request, 'Hello World!', False) 183 msgutil.send_message(request, '', True) 184 self.assertEqual(b'\x01\x0cHello World!\x80\x00', 185 request.connection.written_data()) 186 187 def test_receive_message(self): 188 request = _create_request((b'\x81\x85', b'Hello'), 189 (b'\x81\x86', b'World!')) 190 self.assertEqual('Hello', msgutil.receive_message(request)) 191 self.assertEqual('World!', msgutil.receive_message(request)) 192 193 payload = b'a' * 125 194 request = _create_request((b'\x81\xfd', payload)) 195 self.assertEqual(payload.decode('UTF-8'), 196 msgutil.receive_message(request)) 197 198 def test_receive_medium_message(self): 199 payload = b'a' * 126 200 request = _create_request((b'\x81\xfe\x00\x7e', payload)) 201 self.assertEqual(payload.decode('UTF-8'), 202 msgutil.receive_message(request)) 203 204 payload = b'a' * ((1 << 16) - 1) 205 request = _create_request((b'\x81\xfe\xff\xff', payload)) 206 self.assertEqual(payload.decode('UTF-8'), 207 msgutil.receive_message(request)) 208 209 def test_receive_large_message(self): 210 payload = b'a' * (1 << 16) 211 request = _create_request( 212 (b'\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload)) 213 self.assertEqual(payload.decode('UTF-8'), 214 msgutil.receive_message(request)) 215 216 def test_receive_length_not_encoded_using_minimal_number_of_bytes(self): 217 # Log warning on receiving bad payload length field that doesn't use 218 # minimal number of bytes but continue processing. 219 220 payload = b'a' 221 # 1 byte can be represented without extended payload length field. 222 request = _create_request( 223 (b'\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload)) 224 self.assertEqual(payload.decode('UTF-8'), 225 msgutil.receive_message(request)) 226 227 def test_receive_message_unicode(self): 228 request = _create_request((b'\x81\x83', b'\xe6\x9c\xac')) 229 # U+672c is encoded as e6,9c,ac in UTF-8 230 self.assertEqual(u'\u672c', msgutil.receive_message(request)) 231 232 def test_receive_message_erroneous_unicode(self): 233 # \x80 and \x81 are invalid as UTF-8. 234 request = _create_request((b'\x81\x82', b'\x80\x81')) 235 # Invalid characters should raise InvalidUTF8Exception 236 self.assertRaises(InvalidUTF8Exception, msgutil.receive_message, 237 request) 238 239 def test_receive_fragments(self): 240 request = _create_request((b'\x01\x85', b'Hello'), (b'\x00\x81', b' '), 241 (b'\x00\x85', b'World'), (b'\x80\x81', b'!')) 242 self.assertEqual('Hello World!', msgutil.receive_message(request)) 243 244 def test_receive_fragments_unicode(self): 245 # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97. 246 request = _create_request((b'\x01\x82', b'\xe6\xbc'), 247 (b'\x00\x82', b'\xa2\xe5'), 248 (b'\x80\x82', b'\xad\x97')) 249 self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request)) 250 251 def test_receive_fragments_immediate_zero_termination(self): 252 request = _create_request((b'\x01\x8c', b'Hello World!'), 253 (b'\x80\x80', b'')) 254 self.assertEqual('Hello World!', msgutil.receive_message(request)) 255 256 def test_receive_fragments_duplicate_start(self): 257 request = _create_request((b'\x01\x85', b'Hello'), 258 (b'\x01\x85', b'World')) 259 self.assertRaises(msgutil.InvalidFrameException, 260 msgutil.receive_message, request) 261 262 def test_receive_fragments_intermediate_but_not_started(self): 263 request = _create_request((b'\x00\x85', b'Hello')) 264 self.assertRaises(msgutil.InvalidFrameException, 265 msgutil.receive_message, request) 266 267 def test_receive_fragments_end_but_not_started(self): 268 request = _create_request((b'\x80\x85', b'Hello')) 269 self.assertRaises(msgutil.InvalidFrameException, 270 msgutil.receive_message, request) 271 272 def test_receive_message_discard(self): 273 request = _create_request( 274 (b'\x8f\x86', b'IGNORE'), (b'\x81\x85', b'Hello'), 275 (b'\x8f\x89', b'DISREGARD'), (b'\x81\x86', b'World!')) 276 self.assertRaises(msgutil.UnsupportedFrameException, 277 msgutil.receive_message, request) 278 self.assertEqual('Hello', msgutil.receive_message(request)) 279 self.assertRaises(msgutil.UnsupportedFrameException, 280 msgutil.receive_message, request) 281 self.assertEqual('World!', msgutil.receive_message(request)) 282 283 def test_receive_close(self): 284 request = _create_request( 285 (b'\x88\x8a', struct.pack('!H', 1000) + b'Good bye')) 286 self.assertEqual(None, msgutil.receive_message(request)) 287 self.assertEqual(1000, request.ws_close_code) 288 self.assertEqual('Good bye', request.ws_close_reason) 289 290 def test_send_longest_close(self): 291 reason = 'a' * 123 292 request = _create_request( 293 (b'\x88\xfd', struct.pack('!H', common.STATUS_NORMAL_CLOSURE) + 294 reason.encode('UTF-8'))) 295 request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE, 296 reason) 297 self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE) 298 self.assertEqual(request.ws_close_reason, reason) 299 300 def test_send_close_too_long(self): 301 request = _create_request() 302 self.assertRaises(msgutil.BadOperationException, 303 Stream.close_connection, request.ws_stream, 304 common.STATUS_NORMAL_CLOSURE, 'a' * 124) 305 306 def test_send_close_inconsistent_code_and_reason(self): 307 request = _create_request() 308 # reason parameter must not be specified when code is None. 309 self.assertRaises(msgutil.BadOperationException, 310 Stream.close_connection, request.ws_stream, None, 311 'a') 312 313 def test_send_ping(self): 314 request = _create_request() 315 msgutil.send_ping(request, 'Hello World!') 316 self.assertEqual(b'\x89\x0cHello World!', 317 request.connection.written_data()) 318 319 def test_send_longest_ping(self): 320 request = _create_request() 321 msgutil.send_ping(request, 'a' * 125) 322 self.assertEqual(b'\x89\x7d' + b'a' * 125, 323 request.connection.written_data()) 324 325 def test_send_ping_too_long(self): 326 request = _create_request() 327 self.assertRaises(msgutil.BadOperationException, msgutil.send_ping, 328 request, 'a' * 126) 329 330 def test_receive_ping(self): 331 """Tests receiving a ping control frame.""" 332 def handler(request, message): 333 request.called = True 334 335 # Stream automatically respond to ping with pong without any action 336 # by application layer. 337 request = _create_request((b'\x89\x85', b'Hello'), 338 (b'\x81\x85', b'World')) 339 self.assertEqual('World', msgutil.receive_message(request)) 340 self.assertEqual(b'\x8a\x05Hello', request.connection.written_data()) 341 342 request = _create_request((b'\x89\x85', b'Hello'), 343 (b'\x81\x85', b'World')) 344 request.on_ping_handler = handler 345 self.assertEqual('World', msgutil.receive_message(request)) 346 self.assertTrue(request.called) 347 348 def test_receive_longest_ping(self): 349 request = _create_request((b'\x89\xfd', b'a' * 125), 350 (b'\x81\x85', b'World')) 351 self.assertEqual('World', msgutil.receive_message(request)) 352 self.assertEqual(b'\x8a\x7d' + b'a' * 125, 353 request.connection.written_data()) 354 355 def test_receive_ping_too_long(self): 356 request = _create_request((b'\x89\xfe\x00\x7e', b'a' * 126)) 357 self.assertRaises(msgutil.InvalidFrameException, 358 msgutil.receive_message, request) 359 360 def test_receive_pong(self): 361 """Tests receiving a pong control frame.""" 362 def handler(request, message): 363 request.called = True 364 365 request = _create_request((b'\x8a\x85', b'Hello'), 366 (b'\x81\x85', b'World')) 367 request.on_pong_handler = handler 368 msgutil.send_ping(request, 'Hello') 369 self.assertEqual(b'\x89\x05Hello', request.connection.written_data()) 370 # Valid pong is received, but receive_message won't return for it. 371 self.assertEqual('World', msgutil.receive_message(request)) 372 # Check that nothing was written after receive_message call. 373 self.assertEqual(b'\x89\x05Hello', request.connection.written_data()) 374 375 self.assertTrue(request.called) 376 377 def test_receive_unsolicited_pong(self): 378 # Unsolicited pong is allowed from HyBi 07. 379 request = _create_request((b'\x8a\x85', b'Hello'), 380 (b'\x81\x85', b'World')) 381 msgutil.receive_message(request) 382 383 request = _create_request((b'\x8a\x85', b'Hello'), 384 (b'\x81\x85', b'World')) 385 msgutil.send_ping(request, 'Jumbo') 386 # Body mismatch. 387 msgutil.receive_message(request) 388 389 def test_ping_cannot_be_fragmented(self): 390 request = _create_request((b'\x09\x85', b'Hello')) 391 self.assertRaises(msgutil.InvalidFrameException, 392 msgutil.receive_message, request) 393 394 def test_ping_with_too_long_payload(self): 395 request = _create_request((b'\x89\xfe\x01\x00', b'a' * 256)) 396 self.assertRaises(msgutil.InvalidFrameException, 397 msgutil.receive_message, request) 398 399 400 class PerMessageDeflateTest(unittest.TestCase): 401 """Tests for permessage-deflate extension.""" 402 def test_response_parameters(self): 403 extension = common.ExtensionParameter( 404 common.PERMESSAGE_DEFLATE_EXTENSION) 405 extension.add_parameter('server_no_context_takeover', None) 406 processor = PerMessageDeflateExtensionProcessor(extension) 407 response = processor.get_extension_response() 408 self.assertTrue(response.has_parameter('server_no_context_takeover')) 409 self.assertEqual( 410 None, response.get_parameter_value('server_no_context_takeover')) 411 412 extension = common.ExtensionParameter( 413 common.PERMESSAGE_DEFLATE_EXTENSION) 414 extension.add_parameter('client_max_window_bits', None) 415 processor = PerMessageDeflateExtensionProcessor(extension) 416 417 processor.set_client_max_window_bits(8) 418 processor.set_client_no_context_takeover(True) 419 response = processor.get_extension_response() 420 self.assertEqual( 421 '8', response.get_parameter_value('client_max_window_bits')) 422 self.assertTrue(response.has_parameter('client_no_context_takeover')) 423 self.assertEqual( 424 None, response.get_parameter_value('client_no_context_takeover')) 425 426 def test_send_message(self): 427 extension = common.ExtensionParameter( 428 common.PERMESSAGE_DEFLATE_EXTENSION) 429 request = _create_request_from_rawdata( 430 b'', permessage_deflate_request=extension) 431 msgutil.send_message(request, 'Hello') 432 433 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 434 -zlib.MAX_WBITS) 435 compressed_hello = compress.compress(b'Hello') 436 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 437 compressed_hello = compressed_hello[:-4] 438 expected = b'\xc1%c' % len(compressed_hello) 439 expected += compressed_hello 440 self.assertEqual(expected, request.connection.written_data()) 441 442 def test_send_empty_message(self): 443 """Test that an empty message is compressed correctly.""" 444 445 extension = common.ExtensionParameter( 446 common.PERMESSAGE_DEFLATE_EXTENSION) 447 request = _create_request_from_rawdata( 448 b'', permessage_deflate_request=extension) 449 450 msgutil.send_message(request, '') 451 452 # Payload in binary: 0b00000000 453 # From LSB, 454 # - 1 bit of BFINAL (0) 455 # - 2 bits of BTYPE (no compression) 456 # - 5 bits of padding 457 self.assertEqual(b'\xc1\x01\x00', request.connection.written_data()) 458 459 def test_send_message_with_null_character(self): 460 """Test that a simple payload (one null) is framed correctly.""" 461 462 extension = common.ExtensionParameter( 463 common.PERMESSAGE_DEFLATE_EXTENSION) 464 request = _create_request_from_rawdata( 465 b'', permessage_deflate_request=extension) 466 467 msgutil.send_message(request, '\x00') 468 469 # Payload in binary: 0b01100010 0b00000000 0b00000000 470 # From LSB, 471 # - 1 bit of BFINAL (0) 472 # - 2 bits of BTYPE (01 that means fixed Huffman) 473 # - 8 bits of the first code (00110000 that is the code for the literal 474 # alphabet 0x00) 475 # - 7 bits of the second code (0000000 that is the code for the 476 # end-of-block) 477 # - 1 bit of BFINAL (0) 478 # - 2 bits of BTYPE (no compression) 479 # - 2 bits of padding 480 self.assertEqual(b'\xc1\x03\x62\x00\x00', 481 request.connection.written_data()) 482 483 def test_send_two_messages(self): 484 extension = common.ExtensionParameter( 485 common.PERMESSAGE_DEFLATE_EXTENSION) 486 request = _create_request_from_rawdata( 487 b'', permessage_deflate_request=extension) 488 msgutil.send_message(request, 'Hello') 489 msgutil.send_message(request, 'World') 490 491 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 492 -zlib.MAX_WBITS) 493 494 expected = b'' 495 496 compressed_hello = compress.compress(b'Hello') 497 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 498 compressed_hello = compressed_hello[:-4] 499 expected += b'\xc1%c' % len(compressed_hello) 500 expected += compressed_hello 501 502 compressed_world = compress.compress(b'World') 503 compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) 504 compressed_world = compressed_world[:-4] 505 expected += b'\xc1%c' % len(compressed_world) 506 expected += compressed_world 507 508 self.assertEqual(expected, request.connection.written_data()) 509 510 def test_send_message_fragmented(self): 511 extension = common.ExtensionParameter( 512 common.PERMESSAGE_DEFLATE_EXTENSION) 513 request = _create_request_from_rawdata( 514 b'', permessage_deflate_request=extension) 515 msgutil.send_message(request, 'Hello', end=False) 516 msgutil.send_message(request, 'Goodbye', end=False) 517 msgutil.send_message(request, 'World') 518 519 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 520 -zlib.MAX_WBITS) 521 compressed_hello = compress.compress(b'Hello') 522 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 523 expected = b'\x41%c' % len(compressed_hello) 524 expected += compressed_hello 525 compressed_goodbye = compress.compress(b'Goodbye') 526 compressed_goodbye += compress.flush(zlib.Z_SYNC_FLUSH) 527 expected += b'\x00%c' % len(compressed_goodbye) 528 expected += compressed_goodbye 529 compressed_world = compress.compress(b'World') 530 compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) 531 compressed_world = compressed_world[:-4] 532 expected += b'\x80%c' % len(compressed_world) 533 expected += compressed_world 534 self.assertEqual(expected, request.connection.written_data()) 535 536 def test_send_message_fragmented_empty_first_frame(self): 537 extension = common.ExtensionParameter( 538 common.PERMESSAGE_DEFLATE_EXTENSION) 539 request = _create_request_from_rawdata( 540 b'', permessage_deflate_request=extension) 541 msgutil.send_message(request, '', end=False) 542 msgutil.send_message(request, 'Hello') 543 544 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 545 -zlib.MAX_WBITS) 546 compressed_hello = compress.compress(b'') 547 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 548 expected = b'\x41%c' % len(compressed_hello) 549 expected += compressed_hello 550 compressed_empty = compress.compress(b'Hello') 551 compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH) 552 compressed_empty = compressed_empty[:-4] 553 expected += b'\x80%c' % len(compressed_empty) 554 expected += compressed_empty 555 self.assertEqual(expected, request.connection.written_data()) 556 557 def test_send_message_fragmented_empty_last_frame(self): 558 extension = common.ExtensionParameter( 559 common.PERMESSAGE_DEFLATE_EXTENSION) 560 request = _create_request_from_rawdata( 561 b'', permessage_deflate_request=extension) 562 msgutil.send_message(request, 'Hello', end=False) 563 msgutil.send_message(request, '') 564 565 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 566 -zlib.MAX_WBITS) 567 compressed_hello = compress.compress(b'Hello') 568 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 569 expected = b'\x41%c' % len(compressed_hello) 570 expected += compressed_hello 571 compressed_empty = compress.compress(b'') 572 compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH) 573 compressed_empty = compressed_empty[:-4] 574 expected += b'\x80%c' % len(compressed_empty) 575 expected += compressed_empty 576 self.assertEqual(expected, request.connection.written_data()) 577 578 def test_send_message_using_small_window(self): 579 common_part = 'abcdefghijklmnopqrstuvwxyz' 580 test_message = common_part + '-' * 30000 + common_part 581 582 extension = common.ExtensionParameter( 583 common.PERMESSAGE_DEFLATE_EXTENSION) 584 extension.add_parameter('server_max_window_bits', '8') 585 request = _create_request_from_rawdata( 586 b'', permessage_deflate_request=extension) 587 msgutil.send_message(request, test_message) 588 589 expected_websocket_header_size = 2 590 expected_websocket_payload_size = 91 591 592 actual_frame = request.connection.written_data() 593 self.assertEqual( 594 expected_websocket_header_size + expected_websocket_payload_size, 595 len(actual_frame)) 596 actual_header = actual_frame[0:expected_websocket_header_size] 597 actual_payload = actual_frame[expected_websocket_header_size:] 598 599 self.assertEqual(b'\xc1%c' % expected_websocket_payload_size, 600 actual_header) 601 decompress = zlib.decompressobj(-8) 602 decompressed_message = decompress.decompress(actual_payload + 603 b'\x00\x00\xff\xff') 604 decompressed_message += decompress.flush() 605 self.assertEqual(test_message, decompressed_message.decode('UTF-8')) 606 self.assertEqual(0, len(decompress.unused_data)) 607 self.assertEqual(0, len(decompress.unconsumed_tail)) 608 609 def test_send_message_no_context_takeover_parameter(self): 610 extension = common.ExtensionParameter( 611 common.PERMESSAGE_DEFLATE_EXTENSION) 612 extension.add_parameter('server_no_context_takeover', None) 613 request = _create_request_from_rawdata( 614 b'', permessage_deflate_request=extension) 615 for i in range(3): 616 msgutil.send_message(request, 'Hello', end=False) 617 msgutil.send_message(request, 'Hello', end=True) 618 619 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 620 -zlib.MAX_WBITS) 621 622 first_hello = compress.compress(b'Hello') 623 first_hello += compress.flush(zlib.Z_SYNC_FLUSH) 624 expected = b'\x41%c' % len(first_hello) 625 expected += first_hello 626 second_hello = compress.compress(b'Hello') 627 second_hello += compress.flush(zlib.Z_SYNC_FLUSH) 628 second_hello = second_hello[:-4] 629 expected += b'\x80%c' % len(second_hello) 630 expected += second_hello 631 632 self.assertEqual(expected + expected + expected, 633 request.connection.written_data()) 634 635 def test_send_message_fragmented_bfinal(self): 636 extension = common.ExtensionParameter( 637 common.PERMESSAGE_DEFLATE_EXTENSION) 638 request = _create_request_from_rawdata( 639 b'', permessage_deflate_request=extension) 640 self.assertEqual(1, len(request.ws_extension_processors)) 641 request.ws_extension_processors[0].set_bfinal(True) 642 msgutil.send_message(request, 'Hello', end=False) 643 msgutil.send_message(request, 'World', end=True) 644 645 expected = b'' 646 647 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 648 -zlib.MAX_WBITS) 649 compressed_hello = compress.compress(b'Hello') 650 compressed_hello += compress.flush(zlib.Z_FINISH) 651 compressed_hello = compressed_hello + struct.pack('!B', 0) 652 expected += b'\x41%c' % len(compressed_hello) 653 expected += compressed_hello 654 655 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 656 -zlib.MAX_WBITS) 657 compressed_world = compress.compress(b'World') 658 compressed_world += compress.flush(zlib.Z_FINISH) 659 compressed_world = compressed_world + struct.pack('!B', 0) 660 expected += b'\x80%c' % len(compressed_world) 661 expected += compressed_world 662 663 self.assertEqual(expected, request.connection.written_data()) 664 665 def test_receive_message_deflate(self): 666 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 667 -zlib.MAX_WBITS) 668 669 compressed_hello = compress.compress(b'Hello') 670 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 671 compressed_hello = compressed_hello[:-4] 672 data = b'\xc1%c' % (len(compressed_hello) | 0x80) 673 data += _mask_hybi(compressed_hello) 674 675 # Close frame 676 data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') 677 678 extension = common.ExtensionParameter( 679 common.PERMESSAGE_DEFLATE_EXTENSION) 680 request = _create_request_from_rawdata( 681 data, permessage_deflate_request=extension) 682 self.assertEqual('Hello', msgutil.receive_message(request)) 683 684 self.assertEqual(None, msgutil.receive_message(request)) 685 686 def test_receive_message_random_section(self): 687 """Test that a compressed message fragmented into lots of chunks is 688 correctly received. 689 """ 690 691 random.seed(a=0) 692 payload = b''.join( 693 [struct.pack('!B', random.randint(0, 255)) for i in range(1000)]) 694 695 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 696 -zlib.MAX_WBITS) 697 compressed_payload = compress.compress(payload) 698 compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) 699 compressed_payload = compressed_payload[:-4] 700 701 # Fragment the compressed payload into lots of frames. 702 bytes_chunked = 0 703 data = b'' 704 frame_count = 0 705 706 chunk_sizes = [] 707 708 while bytes_chunked < len(compressed_payload): 709 # Make sure that 710 # - the length of chunks are equal or less than 125 so that we can 711 # use 1 octet length header format for all frames. 712 # - at least 10 chunks are created. 713 chunk_size = random.randint( 714 1, 715 min(125, 716 len(compressed_payload) // 10, 717 len(compressed_payload) - bytes_chunked)) 718 chunk_sizes.append(chunk_size) 719 chunk = compressed_payload[bytes_chunked:bytes_chunked + 720 chunk_size] 721 bytes_chunked += chunk_size 722 723 first_octet = 0x00 724 if len(data) == 0: 725 first_octet = first_octet | 0x42 726 if bytes_chunked == len(compressed_payload): 727 first_octet = first_octet | 0x80 728 729 data += b'%c%c' % (first_octet, chunk_size | 0x80) 730 data += _mask_hybi(chunk) 731 732 frame_count += 1 733 734 self.assertTrue(len(chunk_sizes) > 10) 735 736 # Close frame 737 data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') 738 739 extension = common.ExtensionParameter( 740 common.PERMESSAGE_DEFLATE_EXTENSION) 741 request = _create_request_from_rawdata( 742 data, permessage_deflate_request=extension) 743 self.assertEqual(payload, msgutil.receive_message(request)) 744 745 self.assertEqual(None, msgutil.receive_message(request)) 746 747 def test_receive_two_messages(self): 748 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 749 -zlib.MAX_WBITS) 750 751 data = b'' 752 753 compressed_hello = compress.compress(b'HelloWebSocket') 754 compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) 755 compressed_hello = compressed_hello[:-4] 756 split_position = len(compressed_hello) // 2 757 data += b'\x41%c' % (split_position | 0x80) 758 data += _mask_hybi(compressed_hello[:split_position]) 759 760 data += b'\x80%c' % ((len(compressed_hello) - split_position) | 0x80) 761 data += _mask_hybi(compressed_hello[split_position:]) 762 763 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, 764 -zlib.MAX_WBITS) 765 766 compressed_world = compress.compress(b'World') 767 compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) 768 compressed_world = compressed_world[:-4] 769 data += b'\xc1%c' % (len(compressed_world) | 0x80) 770 data += _mask_hybi(compressed_world) 771 772 # Close frame 773 data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') 774 775 extension = common.ExtensionParameter( 776 common.PERMESSAGE_DEFLATE_EXTENSION) 777 request = _create_request_from_rawdata( 778 data, permessage_deflate_request=extension) 779 self.assertEqual('HelloWebSocket', msgutil.receive_message(request)) 780 self.assertEqual('World', msgutil.receive_message(request)) 781 782 self.assertEqual(None, msgutil.receive_message(request)) 783 784 def test_receive_message_mixed_btype(self): 785 """Test that a message compressed using lots of DEFLATE blocks with 786 various flush mode is correctly received. 787 """ 788 789 random.seed(a=0) 790 payload = b''.join( 791 [struct.pack('!B', random.randint(0, 255)) for i in range(1000)]) 792 793 compress = None 794 795 # Fragment the compressed payload into lots of frames. 796 bytes_chunked = 0 797 compressed_payload = b'' 798 799 chunk_sizes = [] 800 methods = [] 801 sync_used = False 802 finish_used = False 803 804 while bytes_chunked < len(payload): 805 # Make sure at least 10 chunks are created. 806 chunk_size = random.randint(1, 807 min(100, 808 len(payload) - bytes_chunked)) 809 chunk_sizes.append(chunk_size) 810 chunk = payload[bytes_chunked:bytes_chunked + chunk_size] 811 812 bytes_chunked += chunk_size 813 814 if compress is None: 815 compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, 816 zlib.DEFLATED, -zlib.MAX_WBITS) 817 818 if bytes_chunked == len(payload): 819 compressed_payload += compress.compress(chunk) 820 compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) 821 compressed_payload = compressed_payload[:-4] 822 else: 823 method = random.randint(0, 1) 824 methods.append(method) 825 if method == 0: 826 compressed_payload += compress.compress(chunk) 827 compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) 828 sync_used = True 829 else: 830 compressed_payload += compress.compress(chunk) 831 compressed_payload += compress.flush(zlib.Z_FINISH) 832 compress = None 833 finish_used = True 834 835 self.assertTrue(len(chunk_sizes) > 10) 836 self.assertTrue(sync_used) 837 self.assertTrue(finish_used) 838 839 self.assertTrue(125 < len(compressed_payload)) 840 self.assertTrue(len(compressed_payload) < 65536) 841 data = b'\xc2\xfe' + struct.pack('!H', len(compressed_payload)) 842 data += _mask_hybi(compressed_payload) 843 844 # Close frame 845 data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') 846 847 extension = common.ExtensionParameter( 848 common.PERMESSAGE_DEFLATE_EXTENSION) 849 request = _create_request_from_rawdata( 850 data, permessage_deflate_request=extension) 851 self.assertEqual(payload, msgutil.receive_message(request)) 852 853 self.assertEqual(None, msgutil.receive_message(request)) 854 855 856 class MessageReceiverTest(unittest.TestCase): 857 """Tests the Stream class using MessageReceiver.""" 858 def test_queue(self): 859 request = _create_blocking_request() 860 receiver = msgutil.MessageReceiver(request) 861 862 self.assertEqual(None, receiver.receive_nowait()) 863 864 request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!')) 865 self.assertEqual('Hello!', receiver.receive()) 866 867 def test_onmessage(self): 868 onmessage_queue = six.moves.queue.Queue() 869 870 def onmessage_handler(message): 871 onmessage_queue.put(message) 872 873 request = _create_blocking_request() 874 receiver = msgutil.MessageReceiver(request, onmessage_handler) 875 876 request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!')) 877 self.assertEqual('Hello!', onmessage_queue.get()) 878 879 880 class MessageSenderTest(unittest.TestCase): 881 """Tests the Stream class using MessageSender.""" 882 def test_send(self): 883 request = _create_blocking_request() 884 sender = msgutil.MessageSender(request) 885 886 sender.send('World') 887 self.assertEqual(b'\x81\x05World', request.connection.written_data()) 888 889 def test_send_nowait(self): 890 # Use a queue to check the bytes written by MessageSender. 891 # request.connection.written_data() cannot be used here because 892 # MessageSender runs in a separate thread. 893 send_queue = six.moves.queue.Queue() 894 895 def write(bytes): 896 send_queue.put(bytes) 897 898 request = _create_blocking_request() 899 request.connection.write = write 900 901 sender = msgutil.MessageSender(request) 902 903 sender.send_nowait('Hello') 904 sender.send_nowait('World') 905 self.assertEqual(b'\x81\x05Hello', send_queue.get()) 906 self.assertEqual(b'\x81\x05World', send_queue.get()) 907 908 909 if __name__ == '__main__': 910 unittest.main() 911 912 # vi:sts=4 sw=4 et