base.py (14591B)
1 # Copyright 2012, Google Inc. 2 # All rights reserved. 3 # 4 # Redistribution and use in source and binary forms, with or without 5 # modification, are permitted provided that the following conditions are 6 # met: 7 # 8 # * Redistributions of source code must retain the above copyright 9 # notice, this list of conditions and the following disclaimer. 10 # * Redistributions in binary form must reproduce the above 11 # copyright notice, this list of conditions and the following disclaimer 12 # in the documentation and/or other materials provided with the 13 # distribution. 14 # * Neither the name of Google Inc. nor the names of its 15 # contributors may be used to endorse or promote products derived from 16 # this software without specific prior written permission. 17 # 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 """Common functions and exceptions used by WebSocket opening handshake 30 processors. 31 """ 32 33 from __future__ import absolute_import 34 35 from mod_pywebsocket import common 36 from mod_pywebsocket import http_header_util 37 from mod_pywebsocket.extensions import get_extension_processor 38 from mod_pywebsocket.stream import StreamOptions 39 from mod_pywebsocket.stream import Stream 40 from mod_pywebsocket import util 41 42 from six.moves import map 43 from six.moves import range 44 45 # Defining aliases for values used frequently. 46 _VERSION_LATEST = common.VERSION_HYBI_LATEST 47 _VERSION_LATEST_STRING = str(_VERSION_LATEST) 48 _SUPPORTED_VERSIONS = [ 49 _VERSION_LATEST, 50 ] 51 52 53 class AbortedByUserException(Exception): 54 """Exception for aborting a connection intentionally. 55 56 If this exception is raised in do_extra_handshake handler, the connection 57 will be abandoned. No other WebSocket or HTTP(S) handler will be invoked. 58 59 If this exception is raised in transfer_data_handler, the connection will 60 be closed without closing handshake. No other WebSocket or HTTP(S) handler 61 will be invoked. 62 """ 63 64 pass 65 66 67 class HandshakeException(Exception): 68 """This exception will be raised when an error occurred while processing 69 WebSocket initial handshake. 70 """ 71 def __init__(self, name, status=None): 72 super(HandshakeException, self).__init__(name) 73 self.status = status 74 75 76 class VersionException(Exception): 77 """This exception will be raised when a version of client request does not 78 match with version the server supports. 79 """ 80 def __init__(self, name, supported_versions=''): 81 """Construct an instance. 82 83 Args: 84 supported_version: a str object to show supported hybi versions. 85 (e.g. '13') 86 """ 87 super(VersionException, self).__init__(name) 88 self.supported_versions = supported_versions 89 90 91 def get_default_port(is_secure): 92 if is_secure: 93 return common.DEFAULT_WEB_SOCKET_SECURE_PORT 94 else: 95 return common.DEFAULT_WEB_SOCKET_PORT 96 97 98 def validate_subprotocol(subprotocol): 99 """Validate a value in the Sec-WebSocket-Protocol field. 100 101 See the Section 4.1., 4.2.2., and 4.3. of RFC 6455. 102 """ 103 104 if not subprotocol: 105 raise HandshakeException('Invalid subprotocol name: empty') 106 107 # Parameter should be encoded HTTP token. 108 state = http_header_util.ParsingState(subprotocol) 109 token = http_header_util.consume_token(state) 110 rest = http_header_util.peek(state) 111 # If |rest| is not None, |subprotocol| is not one token or invalid. If 112 # |rest| is None, |token| must not be None because |subprotocol| is 113 # concatenation of |token| and |rest| and is not None. 114 if rest is not None: 115 raise HandshakeException('Invalid non-token string in subprotocol ' 116 'name: %r' % rest) 117 118 119 def parse_host_header(request): 120 fields = request.headers_in[common.HOST_HEADER].split(':', 1) 121 if len(fields) == 1: 122 return fields[0], get_default_port(request.is_https()) 123 try: 124 return fields[0], int(fields[1]) 125 except ValueError as e: 126 raise HandshakeException('Invalid port number format: %r' % e) 127 128 129 def get_mandatory_header(request, key): 130 value = request.headers_in.get(key) 131 if value is None: 132 raise HandshakeException('Header %s is not defined' % key) 133 return value 134 135 136 def validate_mandatory_header(request, key, expected_value, fail_status=None): 137 value = get_mandatory_header(request, key) 138 139 if value.lower() != expected_value.lower(): 140 raise HandshakeException( 141 'Expected %r for header %s but found %r (case-insensitive)' % 142 (expected_value, key, value), 143 status=fail_status) 144 145 146 def parse_token_list(data): 147 """Parses a header value which follows 1#token and returns parsed elements 148 as a list of strings. 149 150 Leading LWSes must be trimmed. 151 """ 152 153 state = http_header_util.ParsingState(data) 154 155 token_list = [] 156 157 while True: 158 token = http_header_util.consume_token(state) 159 if token is not None: 160 token_list.append(token) 161 162 http_header_util.consume_lwses(state) 163 164 if http_header_util.peek(state) is None: 165 break 166 167 if not http_header_util.consume_string(state, ','): 168 raise HandshakeException('Expected a comma but found %r' % 169 http_header_util.peek(state)) 170 171 http_header_util.consume_lwses(state) 172 173 if len(token_list) == 0: 174 raise HandshakeException('No valid token found') 175 176 return token_list 177 178 179 class HandshakerBase(object): 180 def __init__(self, request, dispatcher): 181 self._logger = util.get_class_logger(self) 182 self._request = request 183 self._dispatcher = dispatcher 184 185 """ subclasses must implement the five following methods """ 186 187 def _protocol_rfc(self): 188 """ Return the name of the RFC that the handshake class is implementing. 189 """ 190 191 raise AssertionError("subclasses should implement this method") 192 193 def _transform_header(self, header): 194 """ 195 :param header: header name 196 197 transform the header name if needed. For example, HTTP/2 subclass will 198 return the name of the header in lower case. 199 """ 200 201 raise AssertionError("subclasses should implement this method") 202 203 def _validate_request(self): 204 """ validate that all the mandatory fields are set """ 205 206 raise AssertionError("subclasses should implement this method") 207 208 def _set_accept(self): 209 """ Computes accept value based on Sec-WebSocket-Accept if needed. """ 210 211 raise AssertionError("subclasses should implement this method") 212 213 def _send_handshake(self): 214 """ Prepare and send the response after it has been parsed and processed. 215 """ 216 217 raise AssertionError("subclasses should implement this method") 218 219 def do_handshake(self): 220 self._request.ws_close_code = None 221 self._request.ws_close_reason = None 222 223 # Parsing. 224 self._validate_request() 225 self._request.ws_resource = self._request.uri 226 self._request.ws_version = self._check_version() 227 228 try: 229 self._get_origin() 230 self._set_protocol() 231 self._parse_extensions() 232 233 self._set_accept() 234 235 self._logger.debug('Protocol version is ' + self._protocol_rfc()) 236 237 # Setup extension processors. 238 self._request.ws_extension_processors = self._get_extension_processors_requested( 239 ) 240 241 # List of extra headers. The extra handshake handler may add header 242 # data as name/value pairs to this list and pywebsocket appends 243 # them to the WebSocket handshake. 244 self._request.extra_headers = [] 245 246 # Extra handshake handler may modify/remove processors. 247 self._dispatcher.do_extra_handshake(self._request) 248 249 stream_options = StreamOptions() 250 self._process_extensions(stream_options) 251 252 self._request.ws_stream = Stream(self._request, stream_options) 253 254 if self._request.ws_requested_protocols is not None: 255 if self._request.ws_protocol is None: 256 raise HandshakeException( 257 'do_extra_handshake must choose one subprotocol from ' 258 'ws_requested_protocols and set it to ws_protocol') 259 validate_subprotocol(self._request.ws_protocol) 260 261 self._logger.debug('Subprotocol accepted: %r', 262 self._request.ws_protocol) 263 else: 264 if self._request.ws_protocol is not None: 265 raise HandshakeException( 266 'ws_protocol must be None when the client didn\'t ' 267 'request any subprotocol') 268 269 self._send_handshake() 270 except HandshakeException as e: 271 if not e.status: 272 # Fallback to 400 bad request by default. 273 e.status = common.HTTP_STATUS_BAD_REQUEST 274 raise e 275 276 def _check_version(self): 277 sec_websocket_version_header = self._transform_header( 278 common.SEC_WEBSOCKET_VERSION_HEADER) 279 version = get_mandatory_header(self._request, 280 sec_websocket_version_header) 281 if version == _VERSION_LATEST_STRING: 282 return _VERSION_LATEST 283 284 if version.find(',') >= 0: 285 raise HandshakeException( 286 'Multiple versions (%r) are not allowed for header %s' % 287 (version, sec_websocket_version_header), 288 status=common.HTTP_STATUS_BAD_REQUEST) 289 raise VersionException('Unsupported version %r for header %s' % 290 (version, sec_websocket_version_header), 291 supported_versions=', '.join( 292 map(str, _SUPPORTED_VERSIONS))) 293 294 def _get_origin(self): 295 origin_header = self._transform_header(common.ORIGIN_HEADER) 296 origin = self._request.headers_in.get(origin_header) 297 if origin is None: 298 self._logger.debug('Client request does not have origin header') 299 self._request.ws_origin = origin 300 301 def _set_protocol(self): 302 self._request.ws_protocol = None 303 # MOZILLA 304 self._request.sts = None 305 # /MOZILLA 306 307 sec_websocket_protocol_header = self._transform_header( 308 common.SEC_WEBSOCKET_PROTOCOL_HEADER) 309 protocol_header = self._request.headers_in.get( 310 sec_websocket_protocol_header) 311 312 if protocol_header is None: 313 self._request.ws_requested_protocols = None 314 return 315 316 self._request.ws_requested_protocols = parse_token_list( 317 protocol_header) 318 self._logger.debug('Subprotocols requested: %r', 319 self._request.ws_requested_protocols) 320 321 def _parse_extensions(self): 322 sec_websocket_extensions_header = self._transform_header( 323 common.SEC_WEBSOCKET_EXTENSIONS_HEADER) 324 extensions_header = self._request.headers_in.get( 325 sec_websocket_extensions_header) 326 if not extensions_header: 327 self._request.ws_requested_extensions = None 328 return 329 330 try: 331 self._request.ws_requested_extensions = common.parse_extensions( 332 extensions_header) 333 except common.ExtensionParsingException as e: 334 raise HandshakeException( 335 'Failed to parse sec-websocket-extensions header: %r' % e) 336 337 self._logger.debug( 338 'Extensions requested: %r', 339 list( 340 map(common.ExtensionParameter.name, 341 self._request.ws_requested_extensions))) 342 343 def _get_extension_processors_requested(self): 344 processors = [] 345 if self._request.ws_requested_extensions is not None: 346 for extension_request in self._request.ws_requested_extensions: 347 processor = get_extension_processor(extension_request) 348 # Unknown extension requests are just ignored. 349 if processor is not None: 350 processors.append(processor) 351 return processors 352 353 def _process_extensions(self, stream_options): 354 processors = [ 355 processor for processor in self._request.ws_extension_processors 356 if processor is not None 357 ] 358 359 # Ask each processor if there are extensions on the request which 360 # cannot co-exist. When processor decided other processors cannot 361 # co-exist with it, the processor marks them (or itself) as 362 # "inactive". The first extension processor has the right to 363 # make the final call. 364 for processor in reversed(processors): 365 if processor.is_active(): 366 processor.check_consistency_with_other_processors(processors) 367 processors = [ 368 processor for processor in processors if processor.is_active() 369 ] 370 371 accepted_extensions = [] 372 373 for index, processor in enumerate(processors): 374 if not processor.is_active(): 375 continue 376 377 extension_response = processor.get_extension_response() 378 if extension_response is None: 379 # Rejected. 380 continue 381 382 accepted_extensions.append(extension_response) 383 384 processor.setup_stream_options(stream_options) 385 386 # Inactivate all of the following compression extensions. 387 for j in range(index + 1, len(processors)): 388 processors[j].set_active(False) 389 390 if len(accepted_extensions) > 0: 391 self._request.ws_extensions = accepted_extensions 392 self._logger.debug( 393 'Extensions accepted: %r', 394 list(map(common.ExtensionParameter.name, accepted_extensions))) 395 else: 396 self._request.ws_extensions = None 397 398 399 # vi:sts=4 sw=4 et