tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

test_msgutil.py (37829B)


      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2012, Google Inc.
      4 # All rights reserved.
      5 #
      6 # Redistribution and use in source and binary forms, with or without
      7 # modification, are permitted provided that the following conditions are
      8 # met:
      9 #
     10 #     * Redistributions of source code must retain the above copyright
     11 # notice, this list of conditions and the following disclaimer.
     12 #     * Redistributions in binary form must reproduce the above
     13 # copyright notice, this list of conditions and the following disclaimer
     14 # in the documentation and/or other materials provided with the
     15 # distribution.
     16 #     * Neither the name of Google Inc. nor the names of its
     17 # contributors may be used to endorse or promote products derived from
     18 # this software without specific prior written permission.
     19 #
     20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 """Tests for msgutil module."""
     32 
     33 from __future__ import absolute_import
     34 from __future__ import print_function
     35 from __future__ import division
     36 
     37 import random
     38 import struct
     39 import unittest
     40 import zlib
     41 
     42 from six import iterbytes
     43 from six.moves import map
     44 from six.moves import range
     45 import six.moves.queue
     46 
     47 import set_sys_path  # Update sys.path to locate pywebsocket3 module.
     48 from pywebsocket3 import common, msgutil, util
     49 from pywebsocket3.extensions import PerMessageDeflateExtensionProcessor
     50 from pywebsocket3.stream import (
     51    InvalidUTF8Exception,
     52    Stream,
     53    StreamOptions,
     54 )
     55 from test import mock
     56 
     57 # We use one fixed nonce for testing instead of cryptographically secure PRNG.
     58 _MASKING_NONCE = b'ABCD'
     59 
     60 
     61 def _mask_hybi(frame):
     62    if isinstance(frame, six.text_type):
     63        Exception('masking does not accept Texts')
     64 
     65    frame_key = list(iterbytes(_MASKING_NONCE))
     66    frame_key_len = len(frame_key)
     67    result = bytearray(frame)
     68    count = 0
     69 
     70    for i in range(len(result)):
     71        result[i] ^= frame_key[count]
     72        count = (count + 1) % frame_key_len
     73 
     74    return _MASKING_NONCE + bytes(result)
     75 
     76 
     77 def _install_extension_processor(processor, request, stream_options):
     78    response = processor.get_extension_response()
     79    if response is not None:
     80        processor.setup_stream_options(stream_options)
     81        request.ws_extension_processors.append(processor)
     82 
     83 
     84 def _create_request_from_rawdata(read_data, permessage_deflate_request=None):
     85    req = mock.MockRequest(connection=mock.MockConn(read_data))
     86    req.ws_version = common.VERSION_HYBI_LATEST
     87    req.ws_extension_processors = []
     88 
     89    processor = None
     90    if permessage_deflate_request is not None:
     91        processor = PerMessageDeflateExtensionProcessor(
     92            permessage_deflate_request)
     93 
     94    stream_options = StreamOptions()
     95    if processor is not None:
     96        _install_extension_processor(processor, req, stream_options)
     97    req.ws_stream = Stream(req, stream_options)
     98 
     99    return req
    100 
    101 
    102 def _create_request(*frames):
    103    """Creates MockRequest using data given as frames.
    104 
    105    frames will be returned on calling request.connection.read() where request
    106    is MockRequest returned by this function.
    107    """
    108 
    109    read_data = []
    110    for (header, body) in frames:
    111        read_data.append(header + _mask_hybi(body))
    112 
    113    return _create_request_from_rawdata(b''.join(read_data))
    114 
    115 
    116 def _create_blocking_request():
    117    """Creates MockRequest.
    118 
    119    Data written to a MockRequest can be read out by calling
    120    request.connection.written_data().
    121    """
    122 
    123    req = mock.MockRequest(connection=mock.MockBlockingConn())
    124    req.ws_version = common.VERSION_HYBI_LATEST
    125    stream_options = StreamOptions()
    126    req.ws_stream = Stream(req, stream_options)
    127    return req
    128 
    129 
    130 class BasicMessageTest(unittest.TestCase):
    131    """Basic tests for Stream."""
    132    def test_send_message(self):
    133        request = _create_request()
    134        msgutil.send_message(request, 'Hello')
    135        self.assertEqual(b'\x81\x05Hello', request.connection.written_data())
    136 
    137        payload = 'a' * 125
    138        request = _create_request()
    139        msgutil.send_message(request, payload)
    140        self.assertEqual(b'\x81\x7d' + payload.encode('UTF-8'),
    141                         request.connection.written_data())
    142 
    143    def test_send_medium_message(self):
    144        payload = 'a' * 126
    145        request = _create_request()
    146        msgutil.send_message(request, payload)
    147        self.assertEqual(b'\x81\x7e\x00\x7e' + payload.encode('UTF-8'),
    148                         request.connection.written_data())
    149 
    150        payload = 'a' * ((1 << 16) - 1)
    151        request = _create_request()
    152        msgutil.send_message(request, payload)
    153        self.assertEqual(b'\x81\x7e\xff\xff' + payload.encode('UTF-8'),
    154                         request.connection.written_data())
    155 
    156    def test_send_large_message(self):
    157        payload = 'a' * (1 << 16)
    158        request = _create_request()
    159        msgutil.send_message(request, payload)
    160        self.assertEqual(
    161            b'\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' +
    162            payload.encode('UTF-8'), request.connection.written_data())
    163 
    164    def test_send_message_unicode(self):
    165        request = _create_request()
    166        msgutil.send_message(request, u'\u65e5')
    167        # U+65e5 is encoded as e6,97,a5 in UTF-8
    168        self.assertEqual(b'\x81\x03\xe6\x97\xa5',
    169                         request.connection.written_data())
    170 
    171    def test_send_message_fragments(self):
    172        request = _create_request()
    173        msgutil.send_message(request, 'Hello', False)
    174        msgutil.send_message(request, ' ', False)
    175        msgutil.send_message(request, 'World', False)
    176        msgutil.send_message(request, '!', True)
    177        self.assertEqual(b'\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!',
    178                         request.connection.written_data())
    179 
    180    def test_send_fragments_immediate_zero_termination(self):
    181        request = _create_request()
    182        msgutil.send_message(request, 'Hello World!', False)
    183        msgutil.send_message(request, '', True)
    184        self.assertEqual(b'\x01\x0cHello World!\x80\x00',
    185                         request.connection.written_data())
    186 
    187    def test_receive_message(self):
    188        request = _create_request((b'\x81\x85', b'Hello'),
    189                                  (b'\x81\x86', b'World!'))
    190        self.assertEqual('Hello', msgutil.receive_message(request))
    191        self.assertEqual('World!', msgutil.receive_message(request))
    192 
    193        payload = b'a' * 125
    194        request = _create_request((b'\x81\xfd', payload))
    195        self.assertEqual(payload.decode('UTF-8'),
    196                         msgutil.receive_message(request))
    197 
    198    def test_receive_medium_message(self):
    199        payload = b'a' * 126
    200        request = _create_request((b'\x81\xfe\x00\x7e', payload))
    201        self.assertEqual(payload.decode('UTF-8'),
    202                         msgutil.receive_message(request))
    203 
    204        payload = b'a' * ((1 << 16) - 1)
    205        request = _create_request((b'\x81\xfe\xff\xff', payload))
    206        self.assertEqual(payload.decode('UTF-8'),
    207                         msgutil.receive_message(request))
    208 
    209    def test_receive_large_message(self):
    210        payload = b'a' * (1 << 16)
    211        request = _create_request(
    212            (b'\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload))
    213        self.assertEqual(payload.decode('UTF-8'),
    214                         msgutil.receive_message(request))
    215 
    216    def test_receive_length_not_encoded_using_minimal_number_of_bytes(self):
    217        # Log warning on receiving bad payload length field that doesn't use
    218        # minimal number of bytes but continue processing.
    219 
    220        payload = b'a'
    221        # 1 byte can be represented without extended payload length field.
    222        request = _create_request(
    223            (b'\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload))
    224        self.assertEqual(payload.decode('UTF-8'),
    225                         msgutil.receive_message(request))
    226 
    227    def test_receive_message_unicode(self):
    228        request = _create_request((b'\x81\x83', b'\xe6\x9c\xac'))
    229        # U+672c is encoded as e6,9c,ac in UTF-8
    230        self.assertEqual(u'\u672c', msgutil.receive_message(request))
    231 
    232    def test_receive_message_erroneous_unicode(self):
    233        # \x80 and \x81 are invalid as UTF-8.
    234        request = _create_request((b'\x81\x82', b'\x80\x81'))
    235        # Invalid characters should raise InvalidUTF8Exception
    236        self.assertRaises(InvalidUTF8Exception, msgutil.receive_message,
    237                          request)
    238 
    239    def test_receive_fragments(self):
    240        request = _create_request((b'\x01\x85', b'Hello'), (b'\x00\x81', b' '),
    241                                  (b'\x00\x85', b'World'), (b'\x80\x81', b'!'))
    242        self.assertEqual('Hello World!', msgutil.receive_message(request))
    243 
    244    def test_receive_fragments_unicode(self):
    245        # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97.
    246        request = _create_request((b'\x01\x82', b'\xe6\xbc'),
    247                                  (b'\x00\x82', b'\xa2\xe5'),
    248                                  (b'\x80\x82', b'\xad\x97'))
    249        self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request))
    250 
    251    def test_receive_fragments_immediate_zero_termination(self):
    252        request = _create_request((b'\x01\x8c', b'Hello World!'),
    253                                  (b'\x80\x80', b''))
    254        self.assertEqual('Hello World!', msgutil.receive_message(request))
    255 
    256    def test_receive_fragments_duplicate_start(self):
    257        request = _create_request((b'\x01\x85', b'Hello'),
    258                                  (b'\x01\x85', b'World'))
    259        self.assertRaises(msgutil.InvalidFrameException,
    260                          msgutil.receive_message, request)
    261 
    262    def test_receive_fragments_intermediate_but_not_started(self):
    263        request = _create_request((b'\x00\x85', b'Hello'))
    264        self.assertRaises(msgutil.InvalidFrameException,
    265                          msgutil.receive_message, request)
    266 
    267    def test_receive_fragments_end_but_not_started(self):
    268        request = _create_request((b'\x80\x85', b'Hello'))
    269        self.assertRaises(msgutil.InvalidFrameException,
    270                          msgutil.receive_message, request)
    271 
    272    def test_receive_message_discard(self):
    273        request = _create_request(
    274            (b'\x8f\x86', b'IGNORE'), (b'\x81\x85', b'Hello'),
    275            (b'\x8f\x89', b'DISREGARD'), (b'\x81\x86', b'World!'))
    276        self.assertRaises(msgutil.UnsupportedFrameException,
    277                          msgutil.receive_message, request)
    278        self.assertEqual('Hello', msgutil.receive_message(request))
    279        self.assertRaises(msgutil.UnsupportedFrameException,
    280                          msgutil.receive_message, request)
    281        self.assertEqual('World!', msgutil.receive_message(request))
    282 
    283    def test_receive_close(self):
    284        request = _create_request(
    285            (b'\x88\x8a', struct.pack('!H', 1000) + b'Good bye'))
    286        self.assertEqual(None, msgutil.receive_message(request))
    287        self.assertEqual(1000, request.ws_close_code)
    288        self.assertEqual('Good bye', request.ws_close_reason)
    289 
    290    def test_send_longest_close(self):
    291        reason = 'a' * 123
    292        request = _create_request(
    293            (b'\x88\xfd', struct.pack('!H', common.STATUS_NORMAL_CLOSURE) +
    294             reason.encode('UTF-8')))
    295        request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE,
    296                                           reason)
    297        self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE)
    298        self.assertEqual(request.ws_close_reason, reason)
    299 
    300    def test_send_close_too_long(self):
    301        request = _create_request()
    302        self.assertRaises(msgutil.BadOperationException,
    303                          Stream.close_connection, request.ws_stream,
    304                          common.STATUS_NORMAL_CLOSURE, 'a' * 124)
    305 
    306    def test_send_close_inconsistent_code_and_reason(self):
    307        request = _create_request()
    308        # reason parameter must not be specified when code is None.
    309        self.assertRaises(msgutil.BadOperationException,
    310                          Stream.close_connection, request.ws_stream, None,
    311                          'a')
    312 
    313    def test_send_ping(self):
    314        request = _create_request()
    315        msgutil.send_ping(request, 'Hello World!')
    316        self.assertEqual(b'\x89\x0cHello World!',
    317                         request.connection.written_data())
    318 
    319    def test_send_longest_ping(self):
    320        request = _create_request()
    321        msgutil.send_ping(request, 'a' * 125)
    322        self.assertEqual(b'\x89\x7d' + b'a' * 125,
    323                         request.connection.written_data())
    324 
    325    def test_send_ping_too_long(self):
    326        request = _create_request()
    327        self.assertRaises(msgutil.BadOperationException, msgutil.send_ping,
    328                          request, 'a' * 126)
    329 
    330    def test_receive_ping(self):
    331        """Tests receiving a ping control frame."""
    332        def handler(request, message):
    333            request.called = True
    334 
    335        # Stream automatically respond to ping with pong without any action
    336        # by application layer.
    337        request = _create_request((b'\x89\x85', b'Hello'),
    338                                  (b'\x81\x85', b'World'))
    339        self.assertEqual('World', msgutil.receive_message(request))
    340        self.assertEqual(b'\x8a\x05Hello', request.connection.written_data())
    341 
    342        request = _create_request((b'\x89\x85', b'Hello'),
    343                                  (b'\x81\x85', b'World'))
    344        request.on_ping_handler = handler
    345        self.assertEqual('World', msgutil.receive_message(request))
    346        self.assertTrue(request.called)
    347 
    348    def test_receive_longest_ping(self):
    349        request = _create_request((b'\x89\xfd', b'a' * 125),
    350                                  (b'\x81\x85', b'World'))
    351        self.assertEqual('World', msgutil.receive_message(request))
    352        self.assertEqual(b'\x8a\x7d' + b'a' * 125,
    353                         request.connection.written_data())
    354 
    355    def test_receive_ping_too_long(self):
    356        request = _create_request((b'\x89\xfe\x00\x7e', b'a' * 126))
    357        self.assertRaises(msgutil.InvalidFrameException,
    358                          msgutil.receive_message, request)
    359 
    360    def test_receive_pong(self):
    361        """Tests receiving a pong control frame."""
    362        def handler(request, message):
    363            request.called = True
    364 
    365        request = _create_request((b'\x8a\x85', b'Hello'),
    366                                  (b'\x81\x85', b'World'))
    367        request.on_pong_handler = handler
    368        msgutil.send_ping(request, 'Hello')
    369        self.assertEqual(b'\x89\x05Hello', request.connection.written_data())
    370        # Valid pong is received, but receive_message won't return for it.
    371        self.assertEqual('World', msgutil.receive_message(request))
    372        # Check that nothing was written after receive_message call.
    373        self.assertEqual(b'\x89\x05Hello', request.connection.written_data())
    374 
    375        self.assertTrue(request.called)
    376 
    377    def test_receive_unsolicited_pong(self):
    378        # Unsolicited pong is allowed from HyBi 07.
    379        request = _create_request((b'\x8a\x85', b'Hello'),
    380                                  (b'\x81\x85', b'World'))
    381        msgutil.receive_message(request)
    382 
    383        request = _create_request((b'\x8a\x85', b'Hello'),
    384                                  (b'\x81\x85', b'World'))
    385        msgutil.send_ping(request, 'Jumbo')
    386        # Body mismatch.
    387        msgutil.receive_message(request)
    388 
    389    def test_ping_cannot_be_fragmented(self):
    390        request = _create_request((b'\x09\x85', b'Hello'))
    391        self.assertRaises(msgutil.InvalidFrameException,
    392                          msgutil.receive_message, request)
    393 
    394    def test_ping_with_too_long_payload(self):
    395        request = _create_request((b'\x89\xfe\x01\x00', b'a' * 256))
    396        self.assertRaises(msgutil.InvalidFrameException,
    397                          msgutil.receive_message, request)
    398 
    399 
    400 class PerMessageDeflateTest(unittest.TestCase):
    401    """Tests for permessage-deflate extension."""
    402    def test_response_parameters(self):
    403        extension = common.ExtensionParameter(
    404            common.PERMESSAGE_DEFLATE_EXTENSION)
    405        extension.add_parameter('server_no_context_takeover', None)
    406        processor = PerMessageDeflateExtensionProcessor(extension)
    407        response = processor.get_extension_response()
    408        self.assertTrue(response.has_parameter('server_no_context_takeover'))
    409        self.assertEqual(
    410            None, response.get_parameter_value('server_no_context_takeover'))
    411 
    412        extension = common.ExtensionParameter(
    413            common.PERMESSAGE_DEFLATE_EXTENSION)
    414        extension.add_parameter('client_max_window_bits', None)
    415        processor = PerMessageDeflateExtensionProcessor(extension)
    416 
    417        processor.set_client_max_window_bits(8)
    418        processor.set_client_no_context_takeover(True)
    419        response = processor.get_extension_response()
    420        self.assertEqual(
    421            '8', response.get_parameter_value('client_max_window_bits'))
    422        self.assertTrue(response.has_parameter('client_no_context_takeover'))
    423        self.assertEqual(
    424            None, response.get_parameter_value('client_no_context_takeover'))
    425 
    426    def test_send_message(self):
    427        extension = common.ExtensionParameter(
    428            common.PERMESSAGE_DEFLATE_EXTENSION)
    429        request = _create_request_from_rawdata(
    430            b'', permessage_deflate_request=extension)
    431        msgutil.send_message(request, 'Hello')
    432 
    433        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    434                                    -zlib.MAX_WBITS)
    435        compressed_hello = compress.compress(b'Hello')
    436        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    437        compressed_hello = compressed_hello[:-4]
    438        expected = b'\xc1%c' % len(compressed_hello)
    439        expected += compressed_hello
    440        self.assertEqual(expected, request.connection.written_data())
    441 
    442    def test_send_empty_message(self):
    443        """Test that an empty message is compressed correctly."""
    444 
    445        extension = common.ExtensionParameter(
    446            common.PERMESSAGE_DEFLATE_EXTENSION)
    447        request = _create_request_from_rawdata(
    448            b'', permessage_deflate_request=extension)
    449 
    450        msgutil.send_message(request, '')
    451 
    452        # Payload in binary: 0b00000000
    453        # From LSB,
    454        # - 1 bit of BFINAL (0)
    455        # - 2 bits of BTYPE (no compression)
    456        # - 5 bits of padding
    457        self.assertEqual(b'\xc1\x01\x00', request.connection.written_data())
    458 
    459    def test_send_message_with_null_character(self):
    460        """Test that a simple payload (one null) is framed correctly."""
    461 
    462        extension = common.ExtensionParameter(
    463            common.PERMESSAGE_DEFLATE_EXTENSION)
    464        request = _create_request_from_rawdata(
    465            b'', permessage_deflate_request=extension)
    466 
    467        msgutil.send_message(request, '\x00')
    468 
    469        # Payload in binary: 0b01100010 0b00000000 0b00000000
    470        # From LSB,
    471        # - 1 bit of BFINAL (0)
    472        # - 2 bits of BTYPE (01 that means fixed Huffman)
    473        # - 8 bits of the first code (00110000 that is the code for the literal
    474        #   alphabet 0x00)
    475        # - 7 bits of the second code (0000000 that is the code for the
    476        #   end-of-block)
    477        # - 1 bit of BFINAL (0)
    478        # - 2 bits of BTYPE (no compression)
    479        # - 2 bits of padding
    480        self.assertEqual(b'\xc1\x03\x62\x00\x00',
    481                         request.connection.written_data())
    482 
    483    def test_send_two_messages(self):
    484        extension = common.ExtensionParameter(
    485            common.PERMESSAGE_DEFLATE_EXTENSION)
    486        request = _create_request_from_rawdata(
    487            b'', permessage_deflate_request=extension)
    488        msgutil.send_message(request, 'Hello')
    489        msgutil.send_message(request, 'World')
    490 
    491        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    492                                    -zlib.MAX_WBITS)
    493 
    494        expected = b''
    495 
    496        compressed_hello = compress.compress(b'Hello')
    497        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    498        compressed_hello = compressed_hello[:-4]
    499        expected += b'\xc1%c' % len(compressed_hello)
    500        expected += compressed_hello
    501 
    502        compressed_world = compress.compress(b'World')
    503        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    504        compressed_world = compressed_world[:-4]
    505        expected += b'\xc1%c' % len(compressed_world)
    506        expected += compressed_world
    507 
    508        self.assertEqual(expected, request.connection.written_data())
    509 
    510    def test_send_message_fragmented(self):
    511        extension = common.ExtensionParameter(
    512            common.PERMESSAGE_DEFLATE_EXTENSION)
    513        request = _create_request_from_rawdata(
    514            b'', permessage_deflate_request=extension)
    515        msgutil.send_message(request, 'Hello', end=False)
    516        msgutil.send_message(request, 'Goodbye', end=False)
    517        msgutil.send_message(request, 'World')
    518 
    519        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    520                                    -zlib.MAX_WBITS)
    521        compressed_hello = compress.compress(b'Hello')
    522        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    523        expected = b'\x41%c' % len(compressed_hello)
    524        expected += compressed_hello
    525        compressed_goodbye = compress.compress(b'Goodbye')
    526        compressed_goodbye += compress.flush(zlib.Z_SYNC_FLUSH)
    527        expected += b'\x00%c' % len(compressed_goodbye)
    528        expected += compressed_goodbye
    529        compressed_world = compress.compress(b'World')
    530        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    531        compressed_world = compressed_world[:-4]
    532        expected += b'\x80%c' % len(compressed_world)
    533        expected += compressed_world
    534        self.assertEqual(expected, request.connection.written_data())
    535 
    536    def test_send_message_fragmented_empty_first_frame(self):
    537        extension = common.ExtensionParameter(
    538            common.PERMESSAGE_DEFLATE_EXTENSION)
    539        request = _create_request_from_rawdata(
    540            b'', permessage_deflate_request=extension)
    541        msgutil.send_message(request, '', end=False)
    542        msgutil.send_message(request, 'Hello')
    543 
    544        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    545                                    -zlib.MAX_WBITS)
    546        compressed_hello = compress.compress(b'')
    547        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    548        expected = b'\x41%c' % len(compressed_hello)
    549        expected += compressed_hello
    550        compressed_empty = compress.compress(b'Hello')
    551        compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH)
    552        compressed_empty = compressed_empty[:-4]
    553        expected += b'\x80%c' % len(compressed_empty)
    554        expected += compressed_empty
    555        self.assertEqual(expected, request.connection.written_data())
    556 
    557    def test_send_message_fragmented_empty_last_frame(self):
    558        extension = common.ExtensionParameter(
    559            common.PERMESSAGE_DEFLATE_EXTENSION)
    560        request = _create_request_from_rawdata(
    561            b'', permessage_deflate_request=extension)
    562        msgutil.send_message(request, 'Hello', end=False)
    563        msgutil.send_message(request, '')
    564 
    565        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    566                                    -zlib.MAX_WBITS)
    567        compressed_hello = compress.compress(b'Hello')
    568        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    569        expected = b'\x41%c' % len(compressed_hello)
    570        expected += compressed_hello
    571        compressed_empty = compress.compress(b'')
    572        compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH)
    573        compressed_empty = compressed_empty[:-4]
    574        expected += b'\x80%c' % len(compressed_empty)
    575        expected += compressed_empty
    576        self.assertEqual(expected, request.connection.written_data())
    577 
    578    def test_send_message_using_small_window(self):
    579        common_part = 'abcdefghijklmnopqrstuvwxyz'
    580        test_message = common_part + '-' * 30000 + common_part
    581 
    582        extension = common.ExtensionParameter(
    583            common.PERMESSAGE_DEFLATE_EXTENSION)
    584        extension.add_parameter('server_max_window_bits', '8')
    585        request = _create_request_from_rawdata(
    586            b'', permessage_deflate_request=extension)
    587        msgutil.send_message(request, test_message)
    588 
    589        expected_websocket_header_size = 2
    590        expected_websocket_payload_size = 91
    591 
    592        actual_frame = request.connection.written_data()
    593        self.assertEqual(
    594            expected_websocket_header_size + expected_websocket_payload_size,
    595            len(actual_frame))
    596        actual_header = actual_frame[0:expected_websocket_header_size]
    597        actual_payload = actual_frame[expected_websocket_header_size:]
    598 
    599        self.assertEqual(b'\xc1%c' % expected_websocket_payload_size,
    600                         actual_header)
    601        decompress = zlib.decompressobj(-8)
    602        decompressed_message = decompress.decompress(actual_payload +
    603                                                     b'\x00\x00\xff\xff')
    604        decompressed_message += decompress.flush()
    605        self.assertEqual(test_message, decompressed_message.decode('UTF-8'))
    606        self.assertEqual(0, len(decompress.unused_data))
    607        self.assertEqual(0, len(decompress.unconsumed_tail))
    608 
    609    def test_send_message_no_context_takeover_parameter(self):
    610        extension = common.ExtensionParameter(
    611            common.PERMESSAGE_DEFLATE_EXTENSION)
    612        extension.add_parameter('server_no_context_takeover', None)
    613        request = _create_request_from_rawdata(
    614            b'', permessage_deflate_request=extension)
    615        for i in range(3):
    616            msgutil.send_message(request, 'Hello', end=False)
    617            msgutil.send_message(request, 'Hello', end=True)
    618 
    619        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    620                                    -zlib.MAX_WBITS)
    621 
    622        first_hello = compress.compress(b'Hello')
    623        first_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    624        expected = b'\x41%c' % len(first_hello)
    625        expected += first_hello
    626        second_hello = compress.compress(b'Hello')
    627        second_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    628        second_hello = second_hello[:-4]
    629        expected += b'\x80%c' % len(second_hello)
    630        expected += second_hello
    631 
    632        self.assertEqual(expected + expected + expected,
    633                         request.connection.written_data())
    634 
    635    def test_send_message_fragmented_bfinal(self):
    636        extension = common.ExtensionParameter(
    637            common.PERMESSAGE_DEFLATE_EXTENSION)
    638        request = _create_request_from_rawdata(
    639            b'', permessage_deflate_request=extension)
    640        self.assertEqual(1, len(request.ws_extension_processors))
    641        request.ws_extension_processors[0].set_bfinal(True)
    642        msgutil.send_message(request, 'Hello', end=False)
    643        msgutil.send_message(request, 'World', end=True)
    644 
    645        expected = b''
    646 
    647        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    648                                    -zlib.MAX_WBITS)
    649        compressed_hello = compress.compress(b'Hello')
    650        compressed_hello += compress.flush(zlib.Z_FINISH)
    651        compressed_hello = compressed_hello + struct.pack('!B', 0)
    652        expected += b'\x41%c' % len(compressed_hello)
    653        expected += compressed_hello
    654 
    655        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    656                                    -zlib.MAX_WBITS)
    657        compressed_world = compress.compress(b'World')
    658        compressed_world += compress.flush(zlib.Z_FINISH)
    659        compressed_world = compressed_world + struct.pack('!B', 0)
    660        expected += b'\x80%c' % len(compressed_world)
    661        expected += compressed_world
    662 
    663        self.assertEqual(expected, request.connection.written_data())
    664 
    665    def test_receive_message_deflate(self):
    666        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    667                                    -zlib.MAX_WBITS)
    668 
    669        compressed_hello = compress.compress(b'Hello')
    670        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    671        compressed_hello = compressed_hello[:-4]
    672        data = b'\xc1%c' % (len(compressed_hello) | 0x80)
    673        data += _mask_hybi(compressed_hello)
    674 
    675        # Close frame
    676        data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye')
    677 
    678        extension = common.ExtensionParameter(
    679            common.PERMESSAGE_DEFLATE_EXTENSION)
    680        request = _create_request_from_rawdata(
    681            data, permessage_deflate_request=extension)
    682        self.assertEqual('Hello', msgutil.receive_message(request))
    683 
    684        self.assertEqual(None, msgutil.receive_message(request))
    685 
    686    def test_receive_message_random_section(self):
    687        """Test that a compressed message fragmented into lots of chunks is
    688        correctly received.
    689        """
    690 
    691        random.seed(a=0)
    692        payload = b''.join(
    693            [struct.pack('!B', random.randint(0, 255)) for i in range(1000)])
    694 
    695        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    696                                    -zlib.MAX_WBITS)
    697        compressed_payload = compress.compress(payload)
    698        compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH)
    699        compressed_payload = compressed_payload[:-4]
    700 
    701        # Fragment the compressed payload into lots of frames.
    702        bytes_chunked = 0
    703        data = b''
    704        frame_count = 0
    705 
    706        chunk_sizes = []
    707 
    708        while bytes_chunked < len(compressed_payload):
    709            # Make sure that
    710            # - the length of chunks are equal or less than 125 so that we can
    711            #   use 1 octet length header format for all frames.
    712            # - at least 10 chunks are created.
    713            chunk_size = random.randint(
    714                1,
    715                min(125,
    716                    len(compressed_payload) // 10,
    717                    len(compressed_payload) - bytes_chunked))
    718            chunk_sizes.append(chunk_size)
    719            chunk = compressed_payload[bytes_chunked:bytes_chunked +
    720                                       chunk_size]
    721            bytes_chunked += chunk_size
    722 
    723            first_octet = 0x00
    724            if len(data) == 0:
    725                first_octet = first_octet | 0x42
    726            if bytes_chunked == len(compressed_payload):
    727                first_octet = first_octet | 0x80
    728 
    729            data += b'%c%c' % (first_octet, chunk_size | 0x80)
    730            data += _mask_hybi(chunk)
    731 
    732            frame_count += 1
    733 
    734        self.assertTrue(len(chunk_sizes) > 10)
    735 
    736        # Close frame
    737        data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye')
    738 
    739        extension = common.ExtensionParameter(
    740            common.PERMESSAGE_DEFLATE_EXTENSION)
    741        request = _create_request_from_rawdata(
    742            data, permessage_deflate_request=extension)
    743        self.assertEqual(payload, msgutil.receive_message(request))
    744 
    745        self.assertEqual(None, msgutil.receive_message(request))
    746 
    747    def test_receive_two_messages(self):
    748        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    749                                    -zlib.MAX_WBITS)
    750 
    751        data = b''
    752 
    753        compressed_hello = compress.compress(b'HelloWebSocket')
    754        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    755        compressed_hello = compressed_hello[:-4]
    756        split_position = len(compressed_hello) // 2
    757        data += b'\x41%c' % (split_position | 0x80)
    758        data += _mask_hybi(compressed_hello[:split_position])
    759 
    760        data += b'\x80%c' % ((len(compressed_hello) - split_position) | 0x80)
    761        data += _mask_hybi(compressed_hello[split_position:])
    762 
    763        compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
    764                                    -zlib.MAX_WBITS)
    765 
    766        compressed_world = compress.compress(b'World')
    767        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    768        compressed_world = compressed_world[:-4]
    769        data += b'\xc1%c' % (len(compressed_world) | 0x80)
    770        data += _mask_hybi(compressed_world)
    771 
    772        # Close frame
    773        data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye')
    774 
    775        extension = common.ExtensionParameter(
    776            common.PERMESSAGE_DEFLATE_EXTENSION)
    777        request = _create_request_from_rawdata(
    778            data, permessage_deflate_request=extension)
    779        self.assertEqual('HelloWebSocket', msgutil.receive_message(request))
    780        self.assertEqual('World', msgutil.receive_message(request))
    781 
    782        self.assertEqual(None, msgutil.receive_message(request))
    783 
    784    def test_receive_message_mixed_btype(self):
    785        """Test that a message compressed using lots of DEFLATE blocks with
    786        various flush mode is correctly received.
    787        """
    788 
    789        random.seed(a=0)
    790        payload = b''.join(
    791            [struct.pack('!B', random.randint(0, 255)) for i in range(1000)])
    792 
    793        compress = None
    794 
    795        # Fragment the compressed payload into lots of frames.
    796        bytes_chunked = 0
    797        compressed_payload = b''
    798 
    799        chunk_sizes = []
    800        methods = []
    801        sync_used = False
    802        finish_used = False
    803 
    804        while bytes_chunked < len(payload):
    805            # Make sure at least 10 chunks are created.
    806            chunk_size = random.randint(1,
    807                                        min(100,
    808                                            len(payload) - bytes_chunked))
    809            chunk_sizes.append(chunk_size)
    810            chunk = payload[bytes_chunked:bytes_chunked + chunk_size]
    811 
    812            bytes_chunked += chunk_size
    813 
    814            if compress is None:
    815                compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
    816                                            zlib.DEFLATED, -zlib.MAX_WBITS)
    817 
    818            if bytes_chunked == len(payload):
    819                compressed_payload += compress.compress(chunk)
    820                compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH)
    821                compressed_payload = compressed_payload[:-4]
    822            else:
    823                method = random.randint(0, 1)
    824                methods.append(method)
    825                if method == 0:
    826                    compressed_payload += compress.compress(chunk)
    827                    compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH)
    828                    sync_used = True
    829                else:
    830                    compressed_payload += compress.compress(chunk)
    831                    compressed_payload += compress.flush(zlib.Z_FINISH)
    832                    compress = None
    833                    finish_used = True
    834 
    835        self.assertTrue(len(chunk_sizes) > 10)
    836        self.assertTrue(sync_used)
    837        self.assertTrue(finish_used)
    838 
    839        self.assertTrue(125 < len(compressed_payload))
    840        self.assertTrue(len(compressed_payload) < 65536)
    841        data = b'\xc2\xfe' + struct.pack('!H', len(compressed_payload))
    842        data += _mask_hybi(compressed_payload)
    843 
    844        # Close frame
    845        data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye')
    846 
    847        extension = common.ExtensionParameter(
    848            common.PERMESSAGE_DEFLATE_EXTENSION)
    849        request = _create_request_from_rawdata(
    850            data, permessage_deflate_request=extension)
    851        self.assertEqual(payload, msgutil.receive_message(request))
    852 
    853        self.assertEqual(None, msgutil.receive_message(request))
    854 
    855 
    856 class MessageReceiverTest(unittest.TestCase):
    857    """Tests the Stream class using MessageReceiver."""
    858    def test_queue(self):
    859        request = _create_blocking_request()
    860        receiver = msgutil.MessageReceiver(request)
    861 
    862        self.assertEqual(None, receiver.receive_nowait())
    863 
    864        request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!'))
    865        self.assertEqual('Hello!', receiver.receive())
    866 
    867    def test_onmessage(self):
    868        onmessage_queue = six.moves.queue.Queue()
    869 
    870        def onmessage_handler(message):
    871            onmessage_queue.put(message)
    872 
    873        request = _create_blocking_request()
    874        receiver = msgutil.MessageReceiver(request, onmessage_handler)
    875 
    876        request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!'))
    877        self.assertEqual('Hello!', onmessage_queue.get())
    878 
    879 
    880 class MessageSenderTest(unittest.TestCase):
    881    """Tests the Stream class using MessageSender."""
    882    def test_send(self):
    883        request = _create_blocking_request()
    884        sender = msgutil.MessageSender(request)
    885 
    886        sender.send('World')
    887        self.assertEqual(b'\x81\x05World', request.connection.written_data())
    888 
    889    def test_send_nowait(self):
    890        # Use a queue to check the bytes written by MessageSender.
    891        # request.connection.written_data() cannot be used here because
    892        # MessageSender runs in a separate thread.
    893        send_queue = six.moves.queue.Queue()
    894 
    895        def write(bytes):
    896            send_queue.put(bytes)
    897 
    898        request = _create_blocking_request()
    899        request.connection.write = write
    900 
    901        sender = msgutil.MessageSender(request)
    902 
    903        sender.send_nowait('Hello')
    904        sender.send_nowait('World')
    905        self.assertEqual(b'\x81\x05Hello', send_queue.get())
    906        self.assertEqual(b'\x81\x05World', send_queue.get())
    907 
    908 
    909 if __name__ == '__main__':
    910    unittest.main()
    911 
    912 # vi:sts=4 sw=4 et