BaseWebSocketChannel.cpp (11681B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set sw=2 ts=8 et tw=80 : */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this 5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "WebSocketLog.h" 8 #include "BaseWebSocketChannel.h" 9 #include "mozilla/dom/Document.h" 10 #include "MainThreadUtils.h" 11 #include "nsContentUtils.h" 12 #include "nsIClassifiedChannel.h" 13 #include "nsILoadGroup.h" 14 #include "nsINode.h" 15 #include "nsIInterfaceRequestor.h" 16 #include "nsProxyRelease.h" 17 #include "nsStandardURL.h" 18 #include "LoadInfo.h" 19 #include "mozilla/dom/ContentChild.h" 20 #include "nsITransportProvider.h" 21 22 using mozilla::dom::ContentChild; 23 24 namespace mozilla { 25 namespace net { 26 27 LazyLogModule webSocketLog("nsWebSocket"); 28 static uint64_t gNextWebSocketID = 0; 29 30 // We use only 53 bits for the WebSocket serial ID so that it can be converted 31 // to and from a JS value without loss of precision. The upper bits of the 32 // WebSocket serial ID hold the process ID. The lower bits identify the 33 // WebSocket. 34 static const uint64_t kWebSocketIDTotalBits = 53; 35 static const uint64_t kWebSocketIDProcessBits = 22; 36 static const uint64_t kWebSocketIDWebSocketBits = 37 kWebSocketIDTotalBits - kWebSocketIDProcessBits; 38 39 BaseWebSocketChannel::BaseWebSocketChannel() 40 : mWasOpened(0), 41 mClientSetPingInterval(0), 42 mClientSetPingTimeout(0), 43 mEncrypted(false), 44 mPingForced(false), 45 mIsServerSide(false), 46 mPingInterval(0), 47 mPingResponseTimeout(10000), 48 mHttpChannelId(0) { 49 // Generation of a unique serial ID. 50 uint64_t processID = 0; 51 if (XRE_IsContentProcess()) { 52 ContentChild* cc = ContentChild::GetSingleton(); 53 processID = cc->GetID(); 54 } 55 56 uint64_t processBits = 57 processID & ((uint64_t(1) << kWebSocketIDProcessBits) - 1); 58 59 // Make sure no actual webSocket ends up with mWebSocketID == 0 but less then 60 // what the kWebSocketIDProcessBits allows. 61 if (++gNextWebSocketID >= (uint64_t(1) << kWebSocketIDWebSocketBits)) { 62 gNextWebSocketID = 1; 63 } 64 65 uint64_t webSocketBits = 66 gNextWebSocketID & ((uint64_t(1) << kWebSocketIDWebSocketBits) - 1); 67 mSerial = (processBits << kWebSocketIDWebSocketBits) | webSocketBits; 68 } 69 70 BaseWebSocketChannel::~BaseWebSocketChannel() { 71 NS_ReleaseOnMainThread("BaseWebSocketChannel::mLoadGroup", 72 mLoadGroup.forget()); 73 NS_ReleaseOnMainThread("BaseWebSocketChannel::mLoadInfo", mLoadInfo.forget()); 74 nsCOMPtr<nsISerialEventTarget> target; 75 { 76 auto lock = mTargetThread.Lock(); 77 target.swap(*lock); 78 } 79 NS_ReleaseOnMainThread("BaseWebSocketChannel::mTargetThread", 80 target.forget()); 81 } 82 83 //----------------------------------------------------------------------------- 84 // BaseWebSocketChannel::nsIWebSocketChannel 85 //----------------------------------------------------------------------------- 86 87 NS_IMETHODIMP 88 BaseWebSocketChannel::GetOriginalURI(nsIURI** aOriginalURI) { 89 LOG(("BaseWebSocketChannel::GetOriginalURI() %p\n", this)); 90 91 if (!mOriginalURI) return NS_ERROR_NOT_INITIALIZED; 92 *aOriginalURI = do_AddRef(mOriginalURI).take(); 93 return NS_OK; 94 } 95 96 NS_IMETHODIMP 97 BaseWebSocketChannel::GetURI(nsIURI** aURI) { 98 LOG(("BaseWebSocketChannel::GetURI() %p\n", this)); 99 100 if (!mOriginalURI) return NS_ERROR_NOT_INITIALIZED; 101 if (mURI) { 102 *aURI = do_AddRef(mURI).take(); 103 } else { 104 *aURI = do_AddRef(mOriginalURI).take(); 105 } 106 return NS_OK; 107 } 108 109 NS_IMETHODIMP 110 BaseWebSocketChannel::GetNotificationCallbacks( 111 nsIInterfaceRequestor** aNotificationCallbacks) { 112 LOG(("BaseWebSocketChannel::GetNotificationCallbacks() %p\n", this)); 113 *aNotificationCallbacks = do_AddRef(mCallbacks).take(); 114 return NS_OK; 115 } 116 117 NS_IMETHODIMP 118 BaseWebSocketChannel::SetNotificationCallbacks( 119 nsIInterfaceRequestor* aNotificationCallbacks) { 120 LOG(("BaseWebSocketChannel::SetNotificationCallbacks() %p\n", this)); 121 mCallbacks = aNotificationCallbacks; 122 return NS_OK; 123 } 124 125 NS_IMETHODIMP 126 BaseWebSocketChannel::GetLoadGroup(nsILoadGroup** aLoadGroup) { 127 LOG(("BaseWebSocketChannel::GetLoadGroup() %p\n", this)); 128 *aLoadGroup = do_AddRef(mLoadGroup).take(); 129 return NS_OK; 130 } 131 132 NS_IMETHODIMP 133 BaseWebSocketChannel::SetLoadGroup(nsILoadGroup* aLoadGroup) { 134 LOG(("BaseWebSocketChannel::SetLoadGroup() %p\n", this)); 135 mLoadGroup = aLoadGroup; 136 return NS_OK; 137 } 138 139 NS_IMETHODIMP 140 BaseWebSocketChannel::SetLoadInfo(nsILoadInfo* aLoadInfo) { 141 MOZ_RELEASE_ASSERT(aLoadInfo, "loadinfo can't be null"); 142 mLoadInfo = aLoadInfo; 143 return NS_OK; 144 } 145 146 NS_IMETHODIMP 147 BaseWebSocketChannel::GetLoadInfo(nsILoadInfo** aLoadInfo) { 148 *aLoadInfo = do_AddRef(mLoadInfo).take(); 149 return NS_OK; 150 } 151 152 NS_IMETHODIMP 153 BaseWebSocketChannel::GetExtensions(nsACString& aExtensions) { 154 LOG(("BaseWebSocketChannel::GetExtensions() %p\n", this)); 155 aExtensions = mNegotiatedExtensions; 156 return NS_OK; 157 } 158 159 NS_IMETHODIMP 160 BaseWebSocketChannel::GetProtocol(nsACString& aProtocol) { 161 LOG(("BaseWebSocketChannel::GetProtocol() %p\n", this)); 162 aProtocol = mProtocol; 163 return NS_OK; 164 } 165 166 NS_IMETHODIMP 167 BaseWebSocketChannel::SetProtocol(const nsACString& aProtocol) { 168 LOG(("BaseWebSocketChannel::SetProtocol() %p\n", this)); 169 mProtocol = aProtocol; /* the sub protocol */ 170 return NS_OK; 171 } 172 173 NS_IMETHODIMP 174 BaseWebSocketChannel::GetPingInterval(uint32_t* aSeconds) { 175 // stored in ms but should only have second resolution 176 MOZ_ASSERT(!(mPingInterval % 1000)); 177 178 *aSeconds = mPingInterval / 1000; 179 return NS_OK; 180 } 181 182 NS_IMETHODIMP 183 BaseWebSocketChannel::SetPingInterval(uint32_t aSeconds) { 184 MOZ_ASSERT(NS_IsMainThread()); 185 186 if (mWasOpened) { 187 return NS_ERROR_IN_PROGRESS; 188 } 189 190 mPingInterval = aSeconds * 1000; 191 mClientSetPingInterval = 1; 192 193 return NS_OK; 194 } 195 196 NS_IMETHODIMP 197 BaseWebSocketChannel::GetPingTimeout(uint32_t* aSeconds) { 198 // stored in ms but should only have second resolution 199 MOZ_ASSERT(!(mPingResponseTimeout % 1000)); 200 201 *aSeconds = mPingResponseTimeout / 1000; 202 return NS_OK; 203 } 204 205 NS_IMETHODIMP 206 BaseWebSocketChannel::SetPingTimeout(uint32_t aSeconds) { 207 MOZ_ASSERT(NS_IsMainThread()); 208 209 if (mWasOpened) { 210 return NS_ERROR_IN_PROGRESS; 211 } 212 213 mPingResponseTimeout = aSeconds * 1000; 214 mClientSetPingTimeout = 1; 215 216 return NS_OK; 217 } 218 219 NS_IMETHODIMP 220 BaseWebSocketChannel::InitLoadInfoNative( 221 nsINode* aLoadingNode, nsIPrincipal* aLoadingPrincipal, 222 nsIPrincipal* aTriggeringPrincipal, 223 nsICookieJarSettings* aCookieJarSettings, uint32_t aSecurityFlags, 224 nsContentPolicyType aContentPolicyType, uint32_t aSandboxFlags) { 225 mLoadInfo = MOZ_TRY(LoadInfo::Create( 226 aLoadingPrincipal, aTriggeringPrincipal, aLoadingNode, aSecurityFlags, 227 aContentPolicyType, Maybe<mozilla::dom::ClientInfo>(), 228 Maybe<mozilla::dom::ServiceWorkerDescriptor>(), aSandboxFlags)); 229 if (aCookieJarSettings) { 230 mLoadInfo->SetCookieJarSettings(aCookieJarSettings); 231 } 232 233 if (aLoadingNode) { 234 RefPtr<dom::Document> doc = aLoadingNode->OwnerDoc(); 235 if (doc) { 236 ClassificationFlags flags = doc->GetScriptTrackingFlags(); 237 mLoadInfo->SetTriggeringFirstPartyClassificationFlags( 238 flags.firstPartyFlags); 239 mLoadInfo->SetTriggeringThirdPartyClassificationFlags( 240 flags.thirdPartyFlags); 241 } 242 } 243 return NS_OK; 244 } 245 246 NS_IMETHODIMP 247 BaseWebSocketChannel::InitLoadInfo(nsINode* aLoadingNode, 248 nsIPrincipal* aLoadingPrincipal, 249 nsIPrincipal* aTriggeringPrincipal, 250 uint32_t aSecurityFlags, 251 nsContentPolicyType aContentPolicyType) { 252 return InitLoadInfoNative(aLoadingNode, aLoadingPrincipal, 253 aTriggeringPrincipal, nullptr, aSecurityFlags, 254 aContentPolicyType, 0); 255 } 256 257 NS_IMETHODIMP 258 BaseWebSocketChannel::GetSerial(uint32_t* aSerial) { 259 if (!aSerial) { 260 return NS_ERROR_FAILURE; 261 } 262 263 *aSerial = mSerial; 264 return NS_OK; 265 } 266 267 NS_IMETHODIMP 268 BaseWebSocketChannel::SetSerial(uint32_t aSerial) { 269 mSerial = aSerial; 270 return NS_OK; 271 } 272 273 NS_IMETHODIMP 274 BaseWebSocketChannel::SetServerParameters( 275 nsITransportProvider* aProvider, const nsACString& aNegotiatedExtensions) { 276 MOZ_ASSERT(aProvider); 277 mServerTransportProvider = aProvider; 278 mNegotiatedExtensions = aNegotiatedExtensions; 279 mIsServerSide = true; 280 return NS_OK; 281 } 282 283 NS_IMETHODIMP 284 BaseWebSocketChannel::GetHttpChannelId(uint64_t* aHttpChannelId) { 285 *aHttpChannelId = mHttpChannelId; 286 return NS_OK; 287 } 288 289 //----------------------------------------------------------------------------- 290 // BaseWebSocketChannel::nsIProtocolHandler 291 //----------------------------------------------------------------------------- 292 293 NS_IMETHODIMP 294 BaseWebSocketChannel::GetScheme(nsACString& aScheme) { 295 LOG(("BaseWebSocketChannel::GetScheme() %p\n", this)); 296 297 if (mEncrypted) { 298 aScheme.AssignLiteral("wss"); 299 } else { 300 aScheme.AssignLiteral("ws"); 301 } 302 return NS_OK; 303 } 304 305 NS_IMETHODIMP 306 BaseWebSocketChannel::NewChannel(nsIURI* aURI, nsILoadInfo* aLoadInfo, 307 nsIChannel** outChannel) { 308 LOG(("BaseWebSocketChannel::NewChannel() %p\n", this)); 309 return NS_ERROR_NOT_IMPLEMENTED; 310 } 311 312 NS_IMETHODIMP 313 BaseWebSocketChannel::AllowPort(int32_t port, const char* scheme, 314 bool* _retval) { 315 LOG(("BaseWebSocketChannel::AllowPort() %p\n", this)); 316 317 // do not override any blacklisted ports 318 *_retval = false; 319 return NS_OK; 320 } 321 322 //----------------------------------------------------------------------------- 323 // BaseWebSocketChannel::nsIThreadRetargetableRequest 324 //----------------------------------------------------------------------------- 325 326 NS_IMETHODIMP 327 BaseWebSocketChannel::RetargetDeliveryTo(nsISerialEventTarget* aTargetThread) { 328 MOZ_ASSERT(NS_IsMainThread()); 329 MOZ_ASSERT(aTargetThread); 330 MOZ_ASSERT(!mWasOpened, "Should not be called after AsyncOpen!"); 331 MOZ_ASSERT(aTargetThread); 332 333 auto lock = mTargetThread.Lock(); 334 MOZ_ASSERT(!lock.ref(), 335 "Delivery target should be set once, before AsyncOpen"); 336 lock.ref() = aTargetThread; 337 return NS_OK; 338 } 339 340 NS_IMETHODIMP 341 BaseWebSocketChannel::GetDeliveryTarget(nsISerialEventTarget** aTargetThread) { 342 MOZ_ASSERT(NS_IsMainThread()); 343 344 nsCOMPtr<nsISerialEventTarget> target = GetTargetThread(); 345 if (!target) { 346 target = GetCurrentSerialEventTarget(); 347 } 348 target.forget(aTargetThread); 349 return NS_OK; 350 } 351 352 already_AddRefed<nsISerialEventTarget> BaseWebSocketChannel::GetTargetThread() { 353 nsCOMPtr<nsISerialEventTarget> target; 354 auto lock = mTargetThread.Lock(); 355 target = *lock; 356 return target.forget(); 357 } 358 359 bool BaseWebSocketChannel::IsOnTargetThread() { 360 nsCOMPtr<nsISerialEventTarget> target = GetTargetThread(); 361 if (!target) { 362 MOZ_ASSERT(false); 363 return false; 364 } 365 366 bool isOnTargetThread = false; 367 nsresult rv = target->IsOnCurrentThread(&isOnTargetThread); 368 MOZ_ASSERT(NS_SUCCEEDED(rv)); 369 return NS_FAILED(rv) ? false : isOnTargetThread; 370 } 371 372 BaseWebSocketChannel::ListenerAndContextContainer::ListenerAndContextContainer( 373 nsIWebSocketListener* aListener, nsISupports* aContext) 374 : mListener(aListener), mContext(aContext) { 375 MOZ_ASSERT(NS_IsMainThread()); 376 MOZ_ASSERT(mListener); 377 } 378 379 BaseWebSocketChannel::ListenerAndContextContainer:: 380 ~ListenerAndContextContainer() { 381 MOZ_ASSERT(mListener); 382 383 NS_ReleaseOnMainThread( 384 "BaseWebSocketChannel::ListenerAndContextContainer::mListener", 385 mListener.forget()); 386 NS_ReleaseOnMainThread( 387 "BaseWebSocketChannel::ListenerAndContextContainer::mContext", 388 mContext.forget()); 389 } 390 391 } // namespace net 392 } // namespace mozilla