tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

webrtcproxychannel_unittest.cpp (22654B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include <algorithm>
      8 #include <mutex>
      9 
     10 #include "mozilla/net/WebrtcTCPSocket.h"
     11 #include "mozilla/net/WebrtcTCPSocketCallback.h"
     12 #include "nsISocketTransport.h"
     13 
     14 #define GTEST_HAS_RTTI 0
     15 #include "gtest/gtest.h"
     16 #include "gtest_utils.h"
     17 
     18 static const uint32_t kDefaultTestTimeout = 2000;
     19 static const char kReadData[] = "Hello, World!";
     20 static const size_t kReadDataLength = sizeof(kReadData) - 1;
     21 MOZ_RUNINIT static const std::string kReadDataString =
     22    std::string(kReadData, kReadDataLength);
     23 static int kDataLargeOuterLoopCount = 128;
     24 static int kDataLargeInnerLoopCount = 1024;
     25 
     26 namespace mozilla {
     27 
     28 using namespace net;
     29 using namespace testing;
     30 
     31 class WebrtcTCPSocketTestCallback;
     32 
     33 class FakeSocketTransportProvider : public nsISocketTransport {
     34 public:
     35  NS_DECL_THREADSAFE_ISUPPORTS
     36 
     37  // nsISocketTransport
     38  NS_IMETHOD GetHost(nsACString& aHost) override {
     39    MOZ_ASSERT(false);
     40    return NS_OK;
     41  }
     42  NS_IMETHOD GetPort(int32_t* aPort) override {
     43    MOZ_ASSERT(false);
     44    return NS_OK;
     45  }
     46  NS_IMETHOD GetScriptableOriginAttributes(
     47      JSContext* cx, JS::MutableHandle<JS::Value> aOriginAttributes) override {
     48    MOZ_ASSERT(false);
     49    return NS_OK;
     50  }
     51  NS_IMETHOD SetScriptableOriginAttributes(
     52      JSContext* cx, JS::Handle<JS::Value> aOriginAttributes) override {
     53    MOZ_ASSERT(false);
     54    return NS_OK;
     55  }
     56  virtual nsresult GetOriginAttributes(
     57      mozilla::OriginAttributes* _retval) override {
     58    MOZ_ASSERT(false);
     59    return NS_OK;
     60  }
     61  virtual nsresult SetOriginAttributes(
     62      const mozilla::OriginAttributes& aOriginAttrs) override {
     63    MOZ_ASSERT(false);
     64    return NS_OK;
     65  }
     66  NS_IMETHOD GetPeerAddr(mozilla::net::NetAddr* _retval) override {
     67    MOZ_ASSERT(false);
     68    return NS_OK;
     69  }
     70  NS_IMETHOD GetSelfAddr(mozilla::net::NetAddr* _retval) override {
     71    MOZ_ASSERT(false);
     72    return NS_OK;
     73  }
     74  NS_IMETHOD Bind(mozilla::net::NetAddr* aLocalAddr) override {
     75    MOZ_ASSERT(false);
     76    return NS_OK;
     77  }
     78  NS_IMETHOD GetScriptablePeerAddr(nsINetAddr** _retval) override {
     79    MOZ_ASSERT(false);
     80    return NS_OK;
     81  }
     82  NS_IMETHOD GetScriptableSelfAddr(nsINetAddr** _retval) override {
     83    MOZ_ASSERT(false);
     84    return NS_OK;
     85  }
     86  NS_IMETHOD GetTlsSocketControl(
     87      nsITLSSocketControl** aTLSSocketControl) override {
     88    MOZ_ASSERT(false);
     89    return NS_OK;
     90  }
     91  NS_IMETHOD GetSecurityCallbacks(
     92      nsIInterfaceRequestor** aSecurityCallbacks) override {
     93    MOZ_ASSERT(false);
     94    return NS_OK;
     95  }
     96  NS_IMETHOD SetSecurityCallbacks(
     97      nsIInterfaceRequestor* aSecurityCallbacks) override {
     98    MOZ_ASSERT(false);
     99    return NS_OK;
    100  }
    101  NS_IMETHOD IsAlive(bool* _retval) override {
    102    MOZ_ASSERT(false);
    103    return NS_OK;
    104  }
    105  NS_IMETHOD GetTimeout(uint32_t aType, uint32_t* _retval) override {
    106    MOZ_ASSERT(false);
    107    return NS_OK;
    108  }
    109  NS_IMETHOD SetTimeout(uint32_t aType, uint32_t aValue) override {
    110    MOZ_ASSERT(false);
    111    return NS_OK;
    112  }
    113  NS_IMETHOD SetLinger(bool aPolarity, int16_t aTimeout) override {
    114    MOZ_ASSERT(false);
    115    return NS_OK;
    116  }
    117  NS_IMETHOD SetReuseAddrPort(bool reuseAddrPort) override {
    118    MOZ_ASSERT(false);
    119    return NS_OK;
    120  }
    121  NS_IMETHOD GetConnectionFlags(uint32_t* aConnectionFlags) override {
    122    MOZ_ASSERT(false);
    123    return NS_OK;
    124  }
    125  NS_IMETHOD SetConnectionFlags(uint32_t aConnectionFlags) override {
    126    MOZ_ASSERT(false);
    127    return NS_OK;
    128  }
    129  NS_IMETHOD SetIsPrivate(bool) override {
    130    MOZ_ASSERT(false);
    131    return NS_OK;
    132  }
    133  NS_IMETHOD GetTlsFlags(uint32_t* aTlsFlags) override {
    134    MOZ_ASSERT(false);
    135    return NS_OK;
    136  }
    137  NS_IMETHOD SetTlsFlags(uint32_t aTlsFlags) override {
    138    MOZ_ASSERT(false);
    139    return NS_OK;
    140  }
    141  NS_IMETHOD GetQoSBits(uint8_t* aQoSBits) override {
    142    MOZ_ASSERT(false);
    143    return NS_OK;
    144  }
    145  NS_IMETHOD SetQoSBits(uint8_t aQoSBits) override {
    146    MOZ_ASSERT(false);
    147    return NS_OK;
    148  }
    149  NS_IMETHOD GetRecvBufferSize(uint32_t* aRecvBufferSize) override {
    150    MOZ_ASSERT(false);
    151    return NS_OK;
    152  }
    153  NS_IMETHOD GetSendBufferSize(uint32_t* aSendBufferSize) override {
    154    MOZ_ASSERT(false);
    155    return NS_OK;
    156  }
    157  NS_IMETHOD GetKeepaliveEnabled(bool* aKeepaliveEnabled) override {
    158    MOZ_ASSERT(false);
    159    return NS_OK;
    160  }
    161  NS_IMETHOD SetKeepaliveEnabled(bool aKeepaliveEnabled) override {
    162    MOZ_ASSERT(false);
    163    return NS_OK;
    164  }
    165  NS_IMETHOD SetKeepaliveVals(int32_t keepaliveIdleTime,
    166                              int32_t keepaliveRetryInterval) override {
    167    MOZ_ASSERT(false);
    168    return NS_OK;
    169  }
    170  NS_IMETHOD GetResetIPFamilyPreference(
    171      bool* aResetIPFamilyPreference) override {
    172    MOZ_ASSERT(false);
    173    return NS_OK;
    174  }
    175  NS_IMETHOD GetEchConfigUsed(bool* aEchConfigUsed) override {
    176    MOZ_ASSERT(false);
    177    return NS_OK;
    178  }
    179  NS_IMETHOD SetEchConfig(const nsACString& aEchConfig) override {
    180    MOZ_ASSERT(false);
    181    return NS_OK;
    182  }
    183  NS_IMETHOD ResolvedByTRR(bool* _retval) override {
    184    MOZ_ASSERT(false);
    185    return NS_OK;
    186  }
    187  NS_IMETHOD GetEffectiveTRRMode(
    188      nsIRequest::TRRMode* aEffectiveTRRMode) override {
    189    MOZ_ASSERT(false);
    190    return NS_OK;
    191  }
    192  NS_IMETHOD GetTrrSkipReason(nsITRRSkipReason::value* aSkipReason) override {
    193    MOZ_ASSERT(false);
    194    return NS_OK;
    195  }
    196  NS_IMETHOD GetRetryDnsIfPossible(bool* aRetryDns) override {
    197    MOZ_ASSERT(false);
    198    return NS_OK;
    199  }
    200  NS_IMETHOD GetStatus(nsresult* aStatus) override {
    201    MOZ_ASSERT(false);
    202    return NS_OK;
    203  }
    204 
    205  // nsITransport
    206  NS_IMETHOD OpenInputStream(uint32_t aFlags, uint32_t aSegmentSize,
    207                             uint32_t aSegmentCount,
    208                             nsIInputStream** _retval) override {
    209    MOZ_ASSERT(false);
    210    return NS_OK;
    211  }
    212  NS_IMETHOD OpenOutputStream(uint32_t aFlags, uint32_t aSegmentSize,
    213                              uint32_t aSegmentCount,
    214                              nsIOutputStream** _retval) override {
    215    MOZ_ASSERT(false);
    216    return NS_OK;
    217  }
    218  NS_IMETHOD SetEventSink(nsITransportEventSink* aSink,
    219                          nsIEventTarget* aEventTarget) override {
    220    MOZ_ASSERT(false);
    221    return NS_OK;
    222  }
    223 
    224  // fake except for these methods which are OK to call
    225  // nsISocketTransport
    226  NS_IMETHOD SetRecvBufferSize(uint32_t aRecvBufferSize) override {
    227    return NS_OK;
    228  }
    229  NS_IMETHOD SetSendBufferSize(uint32_t aSendBufferSize) override {
    230    return NS_OK;
    231  }
    232  // nsITransport
    233  NS_IMETHOD Close(nsresult aReason) override { return NS_OK; }
    234 
    235 protected:
    236  virtual ~FakeSocketTransportProvider() = default;
    237 };
    238 
    239 NS_IMPL_ISUPPORTS(FakeSocketTransportProvider, nsISocketTransport, nsITransport)
    240 
    241 // Implements some common elements to WebrtcTCPSocketTestOutputStream and
    242 // WebrtcTCPSocketTestInputStream.
    243 class WebrtcTCPSocketTestStream {
    244 public:
    245  WebrtcTCPSocketTestStream();
    246 
    247  void Fail() { mMustFail = true; }
    248 
    249  size_t DataLength();
    250  template <typename T>
    251  void AppendElements(const T* aBuffer, size_t aLength);
    252 
    253 protected:
    254  virtual ~WebrtcTCPSocketTestStream() = default;
    255 
    256  nsTArray<uint8_t> mData;
    257  std::mutex mDataMutex;
    258 
    259  bool mMustFail;
    260 };
    261 
    262 WebrtcTCPSocketTestStream::WebrtcTCPSocketTestStream() : mMustFail(false) {}
    263 
    264 template <typename T>
    265 void WebrtcTCPSocketTestStream::AppendElements(const T* aBuffer,
    266                                               size_t aLength) {
    267  std::lock_guard<std::mutex> guard(mDataMutex);
    268  mData.AppendElements(aBuffer, aLength);
    269 }
    270 
    271 size_t WebrtcTCPSocketTestStream::DataLength() {
    272  std::lock_guard<std::mutex> guard(mDataMutex);
    273  return mData.Length();
    274 }
    275 
    276 class WebrtcTCPSocketTestInputStream : public nsIAsyncInputStream,
    277                                       public WebrtcTCPSocketTestStream {
    278 public:
    279  NS_DECL_THREADSAFE_ISUPPORTS
    280  NS_DECL_NSIASYNCINPUTSTREAM
    281  NS_DECL_NSIINPUTSTREAM
    282 
    283  WebrtcTCPSocketTestInputStream()
    284      : mMaxReadSize(1024 * 1024),
    285        mMutex("WebrtcTCPSocketTestInputStream::mMutex"),
    286        mAllowCallbacks(false) {}
    287 
    288  void DoCallback();
    289  void CallCallback(const nsCOMPtr<nsIInputStreamCallback>& aCallback);
    290  void AllowCallbacks() { mAllowCallbacks = true; }
    291 
    292  size_t mMaxReadSize;
    293 
    294 protected:
    295  virtual ~WebrtcTCPSocketTestInputStream() = default;
    296 
    297 private:
    298  mutable Mutex mMutex;
    299  nsCOMPtr<nsIInputStreamCallback> mCallback;
    300  nsCOMPtr<nsIEventTarget> mCallbackTarget;
    301 
    302  bool mAllowCallbacks;
    303 };
    304 
    305 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream, nsIAsyncInputStream,
    306                  nsIInputStream)
    307 
    308 nsresult WebrtcTCPSocketTestInputStream::AsyncWait(
    309    nsIInputStreamCallback* aCallback, uint32_t aFlags,
    310    uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
    311  MOZ_ASSERT(!aEventTarget, "no event target should be set");
    312 
    313  {
    314    MutexAutoLock lock(mMutex);
    315    mCallback = aCallback;
    316    mCallbackTarget = NS_GetCurrentThread();
    317  }
    318 
    319  if (mAllowCallbacks && DataLength() > 0) {
    320    DoCallback();
    321  }
    322 
    323  return NS_OK;
    324 }
    325 
    326 nsresult WebrtcTCPSocketTestInputStream::CloseWithStatus(nsresult aStatus) {
    327  return Close();
    328 }
    329 
    330 nsresult WebrtcTCPSocketTestInputStream::Close() { return NS_OK; }
    331 
    332 nsresult WebrtcTCPSocketTestInputStream::Available(uint64_t* aAvailable) {
    333  *aAvailable = DataLength();
    334  return NS_OK;
    335 }
    336 
    337 nsresult WebrtcTCPSocketTestInputStream::StreamStatus() { return NS_OK; }
    338 
    339 nsresult WebrtcTCPSocketTestInputStream::Read(char* aBuffer, uint32_t aCount,
    340                                              uint32_t* aRead) {
    341  std::lock_guard<std::mutex> guard(mDataMutex);
    342  if (mMustFail) {
    343    return NS_ERROR_FAILURE;
    344  }
    345  *aRead = std::min({(size_t)aCount, mData.Length(), mMaxReadSize});
    346  memcpy(aBuffer, mData.Elements(), *aRead);
    347  mData.RemoveElementsAt(0, *aRead);
    348  return *aRead > 0 ? NS_OK : NS_BASE_STREAM_WOULD_BLOCK;
    349 }
    350 
    351 nsresult WebrtcTCPSocketTestInputStream::ReadSegments(nsWriteSegmentFun aWriter,
    352                                                      void* aClosure,
    353                                                      uint32_t aCount,
    354                                                      uint32_t* _retval) {
    355  MOZ_ASSERT(false);
    356  return NS_OK;
    357 }
    358 
    359 nsresult WebrtcTCPSocketTestInputStream::IsNonBlocking(bool* aIsNonBlocking) {
    360  *aIsNonBlocking = true;
    361  return NS_OK;
    362 }
    363 
    364 void WebrtcTCPSocketTestInputStream::CallCallback(
    365    const nsCOMPtr<nsIInputStreamCallback>& aCallback) {
    366  aCallback->OnInputStreamReady(this);
    367 }
    368 
    369 void WebrtcTCPSocketTestInputStream::DoCallback() {
    370  MutexAutoLock lock(mMutex);
    371  if (mCallback) {
    372    mCallbackTarget->Dispatch(
    373        NewRunnableMethod<const nsCOMPtr<nsIInputStreamCallback>&>(
    374            "WebrtcTCPSocketTestInputStream::DoCallback", this,
    375            &WebrtcTCPSocketTestInputStream::CallCallback,
    376            std::move(mCallback)));
    377 
    378    mCallbackTarget = nullptr;
    379  }
    380 }
    381 
    382 class WebrtcTCPSocketTestOutputStream : public nsIAsyncOutputStream,
    383                                        public WebrtcTCPSocketTestStream {
    384 public:
    385  NS_DECL_THREADSAFE_ISUPPORTS
    386  NS_DECL_NSIASYNCOUTPUTSTREAM
    387  NS_DECL_NSIOUTPUTSTREAM
    388 
    389  WebrtcTCPSocketTestOutputStream()
    390      : mMaxWriteSize(1024 * 1024),
    391        mMutex("WebrtcTCPSocketTestOutputStream::mMutex") {}
    392 
    393  void DoCallback();
    394  void CallCallback(const nsCOMPtr<nsIOutputStreamCallback>& aCallback);
    395 
    396  std::string DataString();
    397 
    398  uint32_t mMaxWriteSize;
    399 
    400 protected:
    401  virtual ~WebrtcTCPSocketTestOutputStream() = default;
    402 
    403 private:
    404  mutable Mutex mMutex;
    405  nsCOMPtr<nsIOutputStreamCallback> mCallback;
    406  nsCOMPtr<nsIEventTarget> mCallbackTarget;
    407 };
    408 
    409 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream, nsIAsyncOutputStream,
    410                  nsIOutputStream)
    411 
    412 nsresult WebrtcTCPSocketTestOutputStream::AsyncWait(
    413    nsIOutputStreamCallback* aCallback, uint32_t aFlags,
    414    uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
    415  MOZ_ASSERT(!aEventTarget, "no event target should be set");
    416 
    417  {
    418    MutexAutoLock lock(mMutex);
    419    mCallback = aCallback;
    420    mCallbackTarget = NS_GetCurrentThread();
    421  }
    422 
    423  return NS_OK;
    424 }
    425 
    426 nsresult WebrtcTCPSocketTestOutputStream::CloseWithStatus(nsresult aStatus) {
    427  return Close();
    428 }
    429 
    430 nsresult WebrtcTCPSocketTestOutputStream::Close() { return NS_OK; }
    431 
    432 nsresult WebrtcTCPSocketTestOutputStream::Flush() { return NS_OK; }
    433 
    434 nsresult WebrtcTCPSocketTestOutputStream::StreamStatus() {
    435  return mMustFail ? NS_ERROR_FAILURE : NS_OK;
    436 }
    437 
    438 nsresult WebrtcTCPSocketTestOutputStream::Write(const char* aBuffer,
    439                                                uint32_t aCount,
    440                                                uint32_t* aWrote) {
    441  if (mMustFail) {
    442    return NS_ERROR_FAILURE;
    443  }
    444  *aWrote = std::min(aCount, mMaxWriteSize);
    445  AppendElements(aBuffer, *aWrote);
    446  return NS_OK;
    447 }
    448 
    449 nsresult WebrtcTCPSocketTestOutputStream::WriteSegments(
    450    nsReadSegmentFun aReader, void* aClosure, uint32_t aCount,
    451    uint32_t* _retval) {
    452  MOZ_ASSERT(false);
    453  return NS_OK;
    454 }
    455 
    456 nsresult WebrtcTCPSocketTestOutputStream::WriteFrom(nsIInputStream* aFromStream,
    457                                                    uint32_t aCount,
    458                                                    uint32_t* _retval) {
    459  MOZ_ASSERT(false);
    460  return NS_OK;
    461 }
    462 
    463 nsresult WebrtcTCPSocketTestOutputStream::IsNonBlocking(bool* aIsNonBlocking) {
    464  *aIsNonBlocking = true;
    465  return NS_OK;
    466 }
    467 
    468 void WebrtcTCPSocketTestOutputStream::CallCallback(
    469    const nsCOMPtr<nsIOutputStreamCallback>& aCallback) {
    470  aCallback->OnOutputStreamReady(this);
    471 }
    472 
    473 void WebrtcTCPSocketTestOutputStream::DoCallback() {
    474  MutexAutoLock lock(mMutex);
    475  if (mCallback) {
    476    mCallbackTarget->Dispatch(
    477        NewRunnableMethod<const nsCOMPtr<nsIOutputStreamCallback>&>(
    478            "WebrtcTCPSocketTestOutputStream::CallCallback", this,
    479            &WebrtcTCPSocketTestOutputStream::CallCallback,
    480            std::move(mCallback)));
    481 
    482    mCallbackTarget = nullptr;
    483  }
    484 }
    485 
    486 std::string WebrtcTCPSocketTestOutputStream::DataString() {
    487  std::lock_guard<std::mutex> guard(mDataMutex);
    488  return std::string((char*)mData.Elements(), mData.Length());
    489 }
    490 
    491 // Fake as in not the real WebrtcTCPSocket but real enough
    492 class FakeWebrtcTCPSocket : public WebrtcTCPSocket {
    493 public:
    494  explicit FakeWebrtcTCPSocket(WebrtcTCPSocketCallback* aCallback)
    495      : WebrtcTCPSocket(aCallback) {}
    496 
    497 protected:
    498  virtual ~FakeWebrtcTCPSocket() = default;
    499 
    500  void InvokeOnClose(nsresult aReason) override;
    501  void InvokeOnConnected() override;
    502  void InvokeOnRead(nsTArray<uint8_t>&& aReadData) override;
    503 };
    504 
    505 void FakeWebrtcTCPSocket::InvokeOnClose(nsresult aReason) {
    506  mProxyCallbacks->OnClose(aReason);
    507 }
    508 
    509 void FakeWebrtcTCPSocket::InvokeOnConnected() {
    510  mProxyCallbacks->OnConnected("http"_ns);
    511 }
    512 
    513 void FakeWebrtcTCPSocket::InvokeOnRead(nsTArray<uint8_t>&& aReadData) {
    514  mProxyCallbacks->OnRead(std::move(aReadData));
    515 }
    516 
    517 class WebrtcTCPSocketTest : public MtransportTest {
    518 public:
    519  WebrtcTCPSocketTest()
    520      : mSocketThread(nullptr),
    521        mSocketTransport(nullptr),
    522        mInputStream(nullptr),
    523        mOutputStream(nullptr),
    524        mChannel(nullptr),
    525        mCallback(nullptr),
    526        mOnCloseCalled(false),
    527        mOnConnectedCalled(false) {}
    528 
    529  // WebrtcTCPSocketCallback forwards from mCallback
    530  void OnClose(nsresult aReason);
    531  void OnConnected(const nsACString& aProxyType);
    532  void OnRead(nsTArray<uint8_t>&& aReadData);
    533 
    534  void SetUp() override;
    535  void TearDown() override;
    536  size_t CountUnwrittenBytes() {
    537    size_t result;
    538    test_utils_->SyncDispatchToSTS(WrapRunnableRet(
    539        &result, this, &WebrtcTCPSocketTest::CountUnwrittenBytes_s));
    540    return result;
    541  }
    542  size_t CountUnwrittenBytes_s() { return mChannel->CountUnwrittenBytes(); }
    543 
    544  void DoTransportAvailable();
    545 
    546  std::string ReadDataAsString();
    547  std::string GetDataLarge();
    548 
    549  nsCOMPtr<nsIEventTarget> mSocketThread;
    550 
    551  nsCOMPtr<nsISocketTransport> mSocketTransport;
    552  RefPtr<WebrtcTCPSocketTestInputStream> mInputStream;
    553  RefPtr<WebrtcTCPSocketTestOutputStream> mOutputStream;
    554  RefPtr<FakeWebrtcTCPSocket> mChannel;
    555  RefPtr<WebrtcTCPSocketTestCallback> mCallback;
    556 
    557  std::atomic<bool> mOnCloseCalled;
    558  std::atomic<bool> mOnConnectedCalled;
    559 
    560  size_t ReadDataLength();
    561  template <typename T>
    562  void AppendReadData(const T* aBuffer, size_t aLength);
    563 
    564 private:
    565  nsTArray<uint8_t> mReadData;
    566  std::mutex mReadDataMutex;
    567 };
    568 
    569 class WebrtcTCPSocketTestCallback : public WebrtcTCPSocketCallback {
    570 public:
    571  NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback, override)
    572 
    573  explicit WebrtcTCPSocketTestCallback(WebrtcTCPSocketTest* aTest)
    574      : mTest(aTest) {}
    575 
    576  // WebrtcTCPSocketCallback
    577  void OnClose(nsresult aReason) override;
    578  void OnConnected(const nsACString& aProxyType) override;
    579  void OnRead(nsTArray<uint8_t>&& aReadData) override;
    580 
    581 protected:
    582  virtual ~WebrtcTCPSocketTestCallback() = default;
    583 
    584 private:
    585  WebrtcTCPSocketTest* mTest;
    586 };
    587 
    588 void WebrtcTCPSocketTest::SetUp() {
    589  MtransportTest::SetUp();
    590 
    591  nsresult rv;
    592  // WebrtcTCPSocket's threading model is the same as mtransport
    593  // all socket operations are done on the socket thread
    594  // callbacks are invoked on the main thread
    595  mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
    596  ASSERT_TRUE(NS_SUCCEEDED(rv));
    597 
    598  mSocketTransport = new FakeSocketTransportProvider();
    599  mInputStream = new WebrtcTCPSocketTestInputStream();
    600  mOutputStream = new WebrtcTCPSocketTestOutputStream();
    601  mCallback = new WebrtcTCPSocketTestCallback(this);
    602  mChannel = new FakeWebrtcTCPSocket(mCallback.get());
    603 }
    604 
    605 void WebrtcTCPSocketTest::TearDown() { MtransportTest::TearDown(); }
    606 
    607 // WebrtcTCPSocketCallback
    608 void WebrtcTCPSocketTest::OnRead(nsTArray<uint8_t>&& aReadData) {
    609  AppendReadData(aReadData.Elements(), aReadData.Length());
    610 }
    611 
    612 void WebrtcTCPSocketTest::OnConnected(const nsACString& aProxyType) {
    613  mOnConnectedCalled = true;
    614 }
    615 
    616 void WebrtcTCPSocketTest::OnClose(nsresult aReason) { mOnCloseCalled = true; }
    617 
    618 void WebrtcTCPSocketTest::DoTransportAvailable() {
    619  if (!mSocketThread->IsOnCurrentThread()) {
    620    mSocketThread->Dispatch(
    621        NS_NewRunnableFunction("DoTransportAvailable", [this]() -> void {
    622          nsresult rv;
    623          rv = mChannel->OnTransportAvailable(mSocketTransport, mInputStream,
    624                                              mOutputStream);
    625          ASSERT_EQ(NS_OK, rv);
    626        }));
    627  } else {
    628    // should always be called on the main thread
    629    MOZ_ASSERT(0);
    630  }
    631 }
    632 
    633 std::string WebrtcTCPSocketTest::ReadDataAsString() {
    634  std::lock_guard<std::mutex> guard(mReadDataMutex);
    635  return std::string((char*)mReadData.Elements(), mReadData.Length());
    636 }
    637 
    638 std::string WebrtcTCPSocketTest::GetDataLarge() {
    639  std::string data;
    640  for (int i = 0; i < kDataLargeOuterLoopCount * kDataLargeInnerLoopCount;
    641       ++i) {
    642    data += kReadData;
    643  }
    644  return data;
    645 }
    646 
    647 template <typename T>
    648 void WebrtcTCPSocketTest::AppendReadData(const T* aBuffer, size_t aLength) {
    649  std::lock_guard<std::mutex> guard(mReadDataMutex);
    650  mReadData.AppendElements(aBuffer, aLength);
    651 }
    652 
    653 size_t WebrtcTCPSocketTest::ReadDataLength() {
    654  std::lock_guard<std::mutex> guard(mReadDataMutex);
    655  return mReadData.Length();
    656 }
    657 
    658 void WebrtcTCPSocketTestCallback::OnClose(nsresult aReason) {
    659  mTest->OnClose(aReason);
    660 }
    661 
    662 void WebrtcTCPSocketTestCallback::OnConnected(const nsACString& aProxyType) {
    663  mTest->OnConnected(aProxyType);
    664 }
    665 
    666 void WebrtcTCPSocketTestCallback::OnRead(nsTArray<uint8_t>&& aReadData) {
    667  mTest->OnRead(std::move(aReadData));
    668 }
    669 
    670 }  // namespace mozilla
    671 
    672 typedef mozilla::WebrtcTCPSocketTest WebrtcTCPSocketTest;
    673 
    674 TEST_F(WebrtcTCPSocketTest, SetUp) {}
    675 
    676 TEST_F(WebrtcTCPSocketTest, TransportAvailable) {
    677  DoTransportAvailable();
    678  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    679 }
    680 
    681 TEST_F(WebrtcTCPSocketTest, Read) {
    682  DoTransportAvailable();
    683  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    684 
    685  mInputStream->AppendElements(kReadData, kReadDataLength);
    686  mInputStream->DoCallback();
    687 
    688  ASSERT_TRUE_WAIT(ReadDataAsString() == kReadDataString, kDefaultTestTimeout);
    689 }
    690 
    691 TEST_F(WebrtcTCPSocketTest, Write) {
    692  DoTransportAvailable();
    693  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    694 
    695  nsTArray<uint8_t> data;
    696  data.AppendElements(kReadData, kReadDataLength);
    697  mChannel->Write(std::move(data));
    698 
    699  ASSERT_TRUE_WAIT(CountUnwrittenBytes() == kReadDataLength,
    700                   kDefaultTestTimeout);
    701 
    702  mOutputStream->DoCallback();
    703 
    704  ASSERT_TRUE_WAIT(mOutputStream->DataString() == kReadDataString,
    705                   kDefaultTestTimeout);
    706 }
    707 
    708 TEST_F(WebrtcTCPSocketTest, ReadFail) {
    709  DoTransportAvailable();
    710  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    711 
    712  mInputStream->AppendElements(kReadData, kReadDataLength);
    713  mInputStream->Fail();
    714  mInputStream->DoCallback();
    715 
    716  ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
    717  ASSERT_EQ(0U, ReadDataLength());
    718 }
    719 
    720 TEST_F(WebrtcTCPSocketTest, WriteFail) {
    721  DoTransportAvailable();
    722  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    723 
    724  nsTArray<uint8_t> array;
    725  array.AppendElements(kReadData, kReadDataLength);
    726  mChannel->Write(std::move(array));
    727 
    728  ASSERT_TRUE_WAIT(CountUnwrittenBytes() == kReadDataLength,
    729                   kDefaultTestTimeout);
    730 
    731  mOutputStream->Fail();
    732  mOutputStream->DoCallback();
    733 
    734  ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
    735  ASSERT_EQ(0U, mOutputStream->DataLength());
    736 }
    737 
    738 TEST_F(WebrtcTCPSocketTest, ReadLarge) {
    739  DoTransportAvailable();
    740  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    741 
    742  const std::string data = GetDataLarge();
    743 
    744  mInputStream->AppendElements(data.c_str(), data.length());
    745  // make sure reading loops more than once
    746  mInputStream->mMaxReadSize = 3072;
    747  mInputStream->AllowCallbacks();
    748  mInputStream->DoCallback();
    749 
    750  ASSERT_TRUE_WAIT(ReadDataAsString() == data, kDefaultTestTimeout);
    751 }
    752 
    753 TEST_F(WebrtcTCPSocketTest, WriteLarge) {
    754  DoTransportAvailable();
    755  ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
    756 
    757  const std::string data = GetDataLarge();
    758 
    759  for (int i = 0; i < kDataLargeOuterLoopCount; ++i) {
    760    nsTArray<uint8_t> array;
    761    int chunkSize = kReadDataString.length() * kDataLargeInnerLoopCount;
    762    int offset = i * chunkSize;
    763    array.AppendElements(data.c_str() + offset, chunkSize);
    764    mChannel->Write(std::move(array));
    765  }
    766 
    767  ASSERT_TRUE_WAIT(CountUnwrittenBytes() == data.length(), kDefaultTestTimeout);
    768 
    769  // make sure writing loops more than once per write request
    770  mOutputStream->mMaxWriteSize = 1024;
    771  mOutputStream->DoCallback();
    772 
    773  ASSERT_TRUE_WAIT(mOutputStream->DataString() == data, kDefaultTestTimeout);
    774 }