websocket_server.py (11783B)
1 # Copyright 2020, Google Inc. 2 # All rights reserved. 3 # 4 # Redistribution and use in source and binary forms, with or without 5 # modification, are permitted provided that the following conditions are 6 # met: 7 # 8 # * Redistributions of source code must retain the above copyright 9 # notice, this list of conditions and the following disclaimer. 10 # * Redistributions in binary form must reproduce the above 11 # copyright notice, this list of conditions and the following disclaimer 12 # in the documentation and/or other materials provided with the 13 # distribution. 14 # * Neither the name of Google Inc. nor the names of its 15 # contributors may be used to endorse or promote products derived from 16 # this software without specific prior written permission. 17 # 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 """Standalone WebsocketServer 30 31 This file deals with the main module of standalone server. Although it is fine 32 to import this file directly to use WebSocketServer, it is strongly recommended 33 to use standalone.py, since it is intended to act as a skeleton of this module. 34 """ 35 36 from __future__ import absolute_import 37 from six.moves import BaseHTTPServer 38 from six.moves import socketserver 39 import logging 40 import re 41 import select 42 import socket 43 import ssl 44 import threading 45 import traceback 46 47 from mod_pywebsocket import dispatch 48 from mod_pywebsocket import util 49 from mod_pywebsocket.request_handler import WebSocketRequestHandler 50 51 52 def _alias_handlers(dispatcher, websock_handlers_map_file): 53 """Set aliases specified in websock_handler_map_file in dispatcher. 54 55 Args: 56 dispatcher: dispatch.Dispatcher instance 57 websock_handler_map_file: alias map file 58 """ 59 60 with open(websock_handlers_map_file) as f: 61 for line in f: 62 if line[0] == '#' or line.isspace(): 63 continue 64 m = re.match(r'(\S+)\s+(\S+)$', line) 65 if not m: 66 logging.warning('Wrong format in map file:' + line) 67 continue 68 try: 69 dispatcher.add_resource_path_alias(m.group(1), m.group(2)) 70 except dispatch.DispatchException as e: 71 logging.error(str(e)) 72 73 74 class WebSocketServer(socketserver.ThreadingMixIn, BaseHTTPServer.HTTPServer): 75 """HTTPServer specialized for WebSocket.""" 76 77 # Overrides SocketServer.ThreadingMixIn.daemon_threads 78 daemon_threads = True 79 # Overrides BaseHTTPServer.HTTPServer.allow_reuse_address 80 allow_reuse_address = True 81 82 def __init__(self, options): 83 """Override SocketServer.TCPServer.__init__ to set SSL enabled 84 socket object to self.socket before server_bind and server_activate, 85 if necessary. 86 """ 87 88 # Fall back to None for embedders that don't know about the 89 # handler_encoding option. 90 handler_encoding = getattr(options, "handler_encoding", None) 91 92 # Share a Dispatcher among request handlers to save time for 93 # instantiation. Dispatcher can be shared because it is thread-safe. 94 options.dispatcher = dispatch.Dispatcher( 95 options.websock_handlers, options.scan_dir, 96 options.allow_handlers_outside_root_dir, handler_encoding) 97 if options.websock_handlers_map_file: 98 _alias_handlers(options.dispatcher, 99 options.websock_handlers_map_file) 100 warnings = options.dispatcher.source_warnings() 101 if warnings: 102 for warning in warnings: 103 logging.warning('Warning in source loading: %s' % warning) 104 105 self._logger = util.get_class_logger(self) 106 107 self.request_queue_size = options.request_queue_size 108 self.__ws_is_shut_down = threading.Event() 109 self.__ws_serving = False 110 111 socketserver.BaseServer.__init__(self, 112 (options.server_host, options.port), 113 WebSocketRequestHandler) 114 115 # Expose the options object to allow handler objects access it. We name 116 # it with websocket_ prefix to avoid conflict. 117 self.websocket_server_options = options 118 119 self._create_sockets() 120 self.server_bind() 121 self.server_activate() 122 123 def _create_sockets(self): 124 self.server_name, self.server_port = self.server_address 125 self._sockets = [] 126 if not self.server_name: 127 # On platforms that doesn't support IPv6, the first bind fails. 128 # On platforms that supports IPv6 129 # - If it binds both IPv4 and IPv6 on call with AF_INET6, the 130 # first bind succeeds and the second fails (we'll see 'Address 131 # already in use' error). 132 # - If it binds only IPv6 on call with AF_INET6, both call are 133 # expected to succeed to listen both protocol. 134 addrinfo_array = [(socket.AF_INET6, socket.SOCK_STREAM, '', '', 135 ''), 136 (socket.AF_INET, socket.SOCK_STREAM, '', '', '')] 137 else: 138 addrinfo_array = socket.getaddrinfo(self.server_name, 139 self.server_port, 140 socket.AF_UNSPEC, 141 socket.SOCK_STREAM, 142 socket.IPPROTO_TCP) 143 for addrinfo in addrinfo_array: 144 self._logger.info('Create socket on: %r', addrinfo) 145 family, socktype, proto, canonname, sockaddr = addrinfo 146 try: 147 socket_ = socket.socket(family, socktype) 148 except Exception as e: 149 self._logger.info('Skip by failure: %r', e) 150 continue 151 server_options = self.websocket_server_options 152 if server_options.use_tls: 153 if server_options.tls_client_auth: 154 if server_options.tls_client_cert_optional: 155 client_cert_ = ssl.CERT_OPTIONAL 156 else: 157 client_cert_ = ssl.CERT_REQUIRED 158 else: 159 client_cert_ = ssl.CERT_NONE 160 ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 161 ssl_context.load_cert_chain(keyfile=server_options.private_key, certfile=server_options.certificate) 162 ssl_context.load_verify_locations(cafile=server_options.tls_client_ca) 163 ssl_context.verify_mode = client_cert_ 164 socket_ = ssl_context.wrap_socket(socket_) 165 self._sockets.append((socket_, addrinfo)) 166 167 def server_bind(self): 168 """Override SocketServer.TCPServer.server_bind to enable multiple 169 sockets bind. 170 """ 171 172 failed_sockets = [] 173 174 for socketinfo in self._sockets: 175 socket_, addrinfo = socketinfo 176 self._logger.info('Bind on: %r', addrinfo) 177 if self.allow_reuse_address: 178 socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 179 try: 180 socket_.bind(self.server_address) 181 except Exception as e: 182 self._logger.info('Skip by failure: %r', e) 183 socket_.close() 184 failed_sockets.append(socketinfo) 185 if self.server_address[1] == 0: 186 # The operating system assigns the actual port number for port 187 # number 0. This case, the second and later sockets should use 188 # the same port number. Also self.server_port is rewritten 189 # because it is exported, and will be used by external code. 190 self.server_address = (self.server_name, 191 socket_.getsockname()[1]) 192 self.server_port = self.server_address[1] 193 self._logger.info('Port %r is assigned', self.server_port) 194 195 for socketinfo in failed_sockets: 196 self._sockets.remove(socketinfo) 197 198 def server_activate(self): 199 """Override SocketServer.TCPServer.server_activate to enable multiple 200 sockets listen. 201 """ 202 203 failed_sockets = [] 204 205 for socketinfo in self._sockets: 206 socket_, addrinfo = socketinfo 207 self._logger.info('Listen on: %r', addrinfo) 208 try: 209 socket_.listen(self.request_queue_size) 210 except Exception as e: 211 self._logger.info('Skip by failure: %r', e) 212 socket_.close() 213 failed_sockets.append(socketinfo) 214 215 for socketinfo in failed_sockets: 216 self._sockets.remove(socketinfo) 217 218 if len(self._sockets) == 0: 219 self._logger.critical( 220 'No sockets activated. Use info log level to see the reason.') 221 222 def server_close(self): 223 """Override SocketServer.TCPServer.server_close to enable multiple 224 sockets close. 225 """ 226 227 for socketinfo in self._sockets: 228 socket_, addrinfo = socketinfo 229 self._logger.info('Close on: %r', addrinfo) 230 socket_.close() 231 232 def fileno(self): 233 """Override SocketServer.TCPServer.fileno.""" 234 235 self._logger.critical('Not supported: fileno') 236 return self._sockets[0][0].fileno() 237 238 def handle_error(self, request, client_address): 239 """Override SocketServer.handle_error.""" 240 241 self._logger.error('Exception in processing request from: %r\n%s', 242 client_address, traceback.format_exc()) 243 # Note: client_address is a tuple. 244 245 def get_request(self): 246 """Override TCPServer.get_request.""" 247 248 accepted_socket, client_address = self.socket.accept() 249 250 server_options = self.websocket_server_options 251 if server_options.use_tls: 252 # Print cipher in use. Handshake is done on accept. 253 self._logger.debug('Cipher: %s', accepted_socket.cipher()) 254 self._logger.debug('Client cert: %r', 255 accepted_socket.getpeercert()) 256 257 return accepted_socket, client_address 258 259 def serve_forever(self, poll_interval=0.5): 260 """Override SocketServer.BaseServer.serve_forever.""" 261 262 self.__ws_serving = True 263 self.__ws_is_shut_down.clear() 264 handle_request = self.handle_request 265 if hasattr(self, '_handle_request_noblock'): 266 handle_request = self._handle_request_noblock 267 else: 268 self._logger.warning('Fallback to blocking request handler') 269 try: 270 while self.__ws_serving: 271 r, w, e = select.select( 272 [socket_[0] for socket_ in self._sockets], [], [], 273 poll_interval) 274 for socket_ in r: 275 self.socket = socket_ 276 handle_request() 277 self.socket = None 278 finally: 279 self.__ws_is_shut_down.set() 280 281 def shutdown(self): 282 """Override SocketServer.BaseServer.shutdown.""" 283 284 self.__ws_serving = False 285 self.__ws_is_shut_down.wait() 286 287 288 # vi:sts=4 sw=4 et