tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

TestUDPSocket.cpp (11571B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 #include "TestCommon.h"
      6 #include "gtest/gtest.h"
      7 #include "nsIUDPSocket.h"
      8 #include "nsISocketTransport.h"
      9 #include "nsIOutputStream.h"
     10 #include "nsINetAddr.h"
     11 #include "nsITimer.h"
     12 #include "nsContentUtils.h"
     13 #include "mozilla/gtest/MozAssertions.h"
     14 #include "mozilla/net/DNS.h"
     15 #include "prerror.h"
     16 #include "nsComponentManagerUtils.h"
     17 
     18 #define REQUEST 0x68656c6f
     19 #define RESPONSE 0x6f6c6568
     20 #define MULTICAST_TIMEOUT 2000
     21 
     22 enum TestPhase { TEST_OUTPUT_STREAM, TEST_SEND_API, TEST_MULTICAST, TEST_NONE };
     23 
     24 static TestPhase phase = TEST_NONE;
     25 
     26 static bool CheckMessageContent(nsIUDPMessage* aMessage,
     27                                uint32_t aExpectedContent) {
     28  nsCString data;
     29  aMessage->GetData(data);
     30 
     31  const char* buffer = data.get();
     32  uint32_t len = data.Length();
     33 
     34  FallibleTArray<uint8_t>& rawData = aMessage->GetDataAsTArray();
     35  uint32_t rawLen = rawData.Length();
     36 
     37  if (len != rawLen) {
     38    ADD_FAILURE() << "Raw data length " << rawLen
     39                  << " does not match String data length " << len;
     40    return false;
     41  }
     42 
     43  for (uint32_t i = 0; i < len; i++) {
     44    if (buffer[i] != rawData[i]) {
     45      ADD_FAILURE();
     46      return false;
     47    }
     48  }
     49 
     50  uint32_t input = 0;
     51  for (uint32_t i = 0; i < len; i++) {
     52    input += buffer[i] << (8 * i);
     53  }
     54 
     55  if (len != sizeof(uint32_t)) {
     56    ADD_FAILURE() << "Message length mismatch, expected " << sizeof(uint32_t)
     57                  << " got " << len;
     58    return false;
     59  }
     60  if (input != aExpectedContent) {
     61    ADD_FAILURE() << "Message content mismatch, expected 0x" << std::hex
     62                  << aExpectedContent << " got 0x" << input;
     63    return false;
     64  }
     65 
     66  return true;
     67 }
     68 
     69 /*
     70 * UDPClientListener: listens for incomming UDP packets
     71 */
     72 class UDPClientListener : public nsIUDPSocketListener {
     73 protected:
     74  virtual ~UDPClientListener();
     75 
     76 public:
     77  explicit UDPClientListener(WaitForCondition* waiter) : mWaiter(waiter) {}
     78 
     79  NS_DECL_THREADSAFE_ISUPPORTS
     80  NS_DECL_NSIUDPSOCKETLISTENER
     81  nsresult mResult = NS_ERROR_FAILURE;
     82  RefPtr<WaitForCondition> mWaiter;
     83 };
     84 
     85 NS_IMPL_ISUPPORTS(UDPClientListener, nsIUDPSocketListener)
     86 
     87 UDPClientListener::~UDPClientListener() = default;
     88 
     89 NS_IMETHODIMP
     90 UDPClientListener::OnPacketReceived(nsIUDPSocket* socket,
     91                                    nsIUDPMessage* message) {
     92  mResult = NS_OK;
     93 
     94  uint16_t port;
     95  nsCString ip;
     96  nsCOMPtr<nsINetAddr> fromAddr;
     97  message->GetFromAddr(getter_AddRefs(fromAddr));
     98  fromAddr->GetPort(&port);
     99  fromAddr->GetAddress(ip);
    100 
    101  if (TEST_SEND_API == phase && CheckMessageContent(message, REQUEST)) {
    102    uint32_t count;
    103    nsTArray<uint8_t> data;
    104    const uint32_t dataBuffer = RESPONSE;
    105    data.AppendElements((const uint8_t*)&dataBuffer, sizeof(uint32_t));
    106    mResult = socket->SendWithAddr(fromAddr, data, &count);
    107    if (mResult == NS_OK && count == sizeof(uint32_t)) {
    108      SUCCEED();
    109    } else {
    110      ADD_FAILURE();
    111    }
    112    return NS_OK;
    113  }
    114  if (TEST_OUTPUT_STREAM != phase || !CheckMessageContent(message, RESPONSE)) {
    115    mResult = NS_ERROR_FAILURE;
    116  }
    117 
    118  // Notify thread
    119  mWaiter->Notify();
    120  return NS_OK;
    121 }
    122 
    123 NS_IMETHODIMP
    124 UDPClientListener::OnStopListening(nsIUDPSocket*, nsresult) {
    125  mWaiter->Notify();
    126  return NS_OK;
    127 }
    128 
    129 /*
    130 * UDPServerListener: listens for incomming UDP packets
    131 */
    132 class UDPServerListener : public nsIUDPSocketListener {
    133 protected:
    134  virtual ~UDPServerListener();
    135 
    136 public:
    137  explicit UDPServerListener(WaitForCondition* waiter) : mWaiter(waiter) {}
    138 
    139  NS_DECL_THREADSAFE_ISUPPORTS
    140  NS_DECL_NSIUDPSOCKETLISTENER
    141 
    142  nsresult mResult = NS_ERROR_FAILURE;
    143  RefPtr<WaitForCondition> mWaiter;
    144 };
    145 
    146 NS_IMPL_ISUPPORTS(UDPServerListener, nsIUDPSocketListener)
    147 
    148 UDPServerListener::~UDPServerListener() = default;
    149 
    150 NS_IMETHODIMP
    151 UDPServerListener::OnPacketReceived(nsIUDPSocket* socket,
    152                                    nsIUDPMessage* message) {
    153  mResult = NS_OK;
    154 
    155  uint16_t port;
    156  nsCString ip;
    157  nsCOMPtr<nsINetAddr> fromAddr;
    158  message->GetFromAddr(getter_AddRefs(fromAddr));
    159  fromAddr->GetPort(&port);
    160  fromAddr->GetAddress(ip);
    161  SUCCEED();
    162 
    163  if (TEST_OUTPUT_STREAM == phase && CheckMessageContent(message, REQUEST)) {
    164    nsCOMPtr<nsIOutputStream> outstream;
    165    message->GetOutputStream(getter_AddRefs(outstream));
    166 
    167    uint32_t count;
    168    const uint32_t data = RESPONSE;
    169    mResult = outstream->Write((const char*)&data, sizeof(uint32_t), &count);
    170 
    171    if (mResult == NS_OK && count == sizeof(uint32_t)) {
    172      SUCCEED();
    173    } else {
    174      ADD_FAILURE();
    175    }
    176    return NS_OK;
    177  }
    178  if (TEST_MULTICAST == phase && CheckMessageContent(message, REQUEST)) {
    179    mResult = NS_OK;
    180  } else if (TEST_SEND_API != phase ||
    181             !CheckMessageContent(message, RESPONSE)) {
    182    mResult = NS_ERROR_FAILURE;
    183  }
    184 
    185  // Notify thread
    186  mWaiter->Notify();
    187  return NS_OK;
    188 }
    189 
    190 NS_IMETHODIMP
    191 UDPServerListener::OnStopListening(nsIUDPSocket*, nsresult) {
    192  mWaiter->Notify();
    193  return NS_OK;
    194 }
    195 
    196 /**
    197 * Multicast timer callback: detects delivery failure
    198 */
    199 class MulticastTimerCallback : public nsITimerCallback, public nsINamed {
    200 protected:
    201  virtual ~MulticastTimerCallback();
    202 
    203 public:
    204  explicit MulticastTimerCallback(WaitForCondition* waiter)
    205      : mResult(NS_ERROR_NOT_INITIALIZED), mWaiter(waiter) {}
    206 
    207  NS_DECL_THREADSAFE_ISUPPORTS
    208  NS_DECL_NSITIMERCALLBACK
    209  NS_DECL_NSINAMED
    210 
    211  nsresult mResult;
    212  RefPtr<WaitForCondition> mWaiter;
    213 };
    214 
    215 NS_IMPL_ISUPPORTS(MulticastTimerCallback, nsITimerCallback, nsINamed)
    216 
    217 MulticastTimerCallback::~MulticastTimerCallback() = default;
    218 
    219 NS_IMETHODIMP
    220 MulticastTimerCallback::Notify(nsITimer* timer) {
    221  if (TEST_MULTICAST != phase) {
    222    return NS_OK;
    223  }
    224  // Multicast ping failed
    225  printf("Multicast ping timeout expired\n");
    226  mResult = NS_ERROR_FAILURE;
    227  mWaiter->Notify();
    228  return NS_OK;
    229 }
    230 
    231 NS_IMETHODIMP
    232 MulticastTimerCallback::GetName(nsACString& aName) {
    233  aName.AssignLiteral("MulticastTimerCallback");
    234  return NS_OK;
    235 }
    236 
    237 /**** Main ****/
    238 
    239 TEST(TestUDPSocket, TestUDPSocketMain)
    240 {
    241  nsresult rv;
    242 
    243  // Create UDPSocket
    244  nsCOMPtr<nsIUDPSocket> server, client;
    245  server = do_CreateInstance("@mozilla.org/network/udp-socket;1", &rv);
    246  ASSERT_NS_SUCCEEDED(rv);
    247 
    248  client = do_CreateInstance("@mozilla.org/network/udp-socket;1", &rv);
    249  ASSERT_NS_SUCCEEDED(rv);
    250 
    251  RefPtr<WaitForCondition> waiter = new WaitForCondition();
    252 
    253  // Create UDPServerListener to process UDP packets
    254  RefPtr<UDPServerListener> serverListener = new UDPServerListener(waiter);
    255 
    256  nsCOMPtr<nsIPrincipal> systemPrincipal = nsContentUtils::GetSystemPrincipal();
    257 
    258  // Bind server socket to 0.0.0.0
    259  rv = server->Init(0, false, systemPrincipal, true, 0);
    260  ASSERT_NS_SUCCEEDED(rv);
    261  int32_t serverPort;
    262  server->GetPort(&serverPort);
    263  server->AsyncListen(serverListener);
    264 
    265  // Bind clinet on arbitrary port
    266  RefPtr<UDPClientListener> clientListener = new UDPClientListener(waiter);
    267  client->Init(0, false, systemPrincipal, true, 0);
    268  client->AsyncListen(clientListener);
    269 
    270  // Write data to server
    271  uint32_t count;
    272  nsTArray<uint8_t> data;
    273  const uint32_t dataBuffer = REQUEST;
    274  data.AppendElements((const uint8_t*)&dataBuffer, sizeof(uint32_t));
    275 
    276  phase = TEST_OUTPUT_STREAM;
    277  rv = client->Send("127.0.0.1"_ns, serverPort, data, &count);
    278  ASSERT_NS_SUCCEEDED(rv);
    279  EXPECT_EQ(count, sizeof(uint32_t));
    280 
    281  // Wait for server
    282  waiter->Wait(1);
    283  ASSERT_NS_SUCCEEDED(serverListener->mResult);
    284 
    285  // Read response from server
    286  ASSERT_NS_SUCCEEDED(clientListener->mResult);
    287 
    288  mozilla::net::NetAddr clientAddr;
    289  rv = client->GetAddress(&clientAddr);
    290  ASSERT_NS_SUCCEEDED(rv);
    291  // The client address is 0.0.0.0, but Windows won't receive packets there, so
    292  // use 127.0.0.1 explicitly
    293  clientAddr.inet.ip = PR_htonl(127 << 24 | 1);
    294 
    295  phase = TEST_SEND_API;
    296  rv = server->SendWithAddress(&clientAddr, data.Elements(), data.Length(),
    297                               &count);
    298  ASSERT_NS_SUCCEEDED(rv);
    299  EXPECT_EQ(count, sizeof(uint32_t));
    300 
    301  // Wait for server
    302  waiter->Wait(1);
    303  ASSERT_NS_SUCCEEDED(serverListener->mResult);
    304 
    305  // Read response from server
    306  ASSERT_NS_SUCCEEDED(clientListener->mResult);
    307 
    308  // Setup timer to detect multicast failure
    309  nsCOMPtr<nsITimer> timer = NS_NewTimer();
    310  ASSERT_TRUE(timer);
    311  RefPtr<MulticastTimerCallback> timerCb = new MulticastTimerCallback(waiter);
    312 
    313  // Join multicast group
    314  printf("Joining multicast group\n");
    315  phase = TEST_MULTICAST;
    316  mozilla::net::NetAddr multicastAddr;
    317  multicastAddr.inet.family = AF_INET;
    318  multicastAddr.inet.ip = PR_htonl(224 << 24 | 255);
    319  multicastAddr.inet.port = PR_htons(serverPort);
    320  rv = server->JoinMulticastAddr(multicastAddr, nullptr);
    321  ASSERT_NS_SUCCEEDED(rv);
    322 
    323  // Send multicast ping
    324  timerCb->mResult = NS_OK;
    325  timer->InitWithCallback(timerCb, MULTICAST_TIMEOUT, nsITimer::TYPE_ONE_SHOT);
    326  rv = client->SendWithAddress(&multicastAddr, data.Elements(), data.Length(),
    327                               &count);
    328  ASSERT_NS_SUCCEEDED(rv);
    329  EXPECT_EQ(count, sizeof(uint32_t));
    330 
    331  // Wait for server to receive successfully
    332  waiter->Wait(1);
    333  ASSERT_NS_SUCCEEDED(serverListener->mResult);
    334  ASSERT_NS_SUCCEEDED(timerCb->mResult);
    335  timer->Cancel();
    336 
    337  // Disable multicast loopback
    338  printf("Disable multicast loopback\n");
    339  client->SetMulticastLoopback(false);
    340  server->SetMulticastLoopback(false);
    341 
    342  // Send multicast ping
    343  timerCb->mResult = NS_OK;
    344  timer->InitWithCallback(timerCb, MULTICAST_TIMEOUT, nsITimer::TYPE_ONE_SHOT);
    345  rv = client->SendWithAddress(&multicastAddr, data.Elements(), data.Length(),
    346                               &count);
    347  ASSERT_NS_SUCCEEDED(rv);
    348  EXPECT_EQ(count, sizeof(uint32_t));
    349 
    350  // Wait for server to fail to receive
    351  waiter->Wait(1);
    352  ASSERT_FALSE(NS_SUCCEEDED(timerCb->mResult));
    353  timer->Cancel();
    354 
    355  // Reset state
    356  client->SetMulticastLoopback(true);
    357  server->SetMulticastLoopback(true);
    358 
    359  // Change multicast interface
    360  mozilla::net::NetAddr loopbackAddr;
    361  loopbackAddr.inet.family = AF_INET;
    362  loopbackAddr.inet.ip = PR_htonl(INADDR_LOOPBACK);
    363  client->SetMulticastInterfaceAddr(loopbackAddr);
    364 
    365  // Send multicast ping
    366  timerCb->mResult = NS_OK;
    367  timer->InitWithCallback(timerCb, MULTICAST_TIMEOUT, nsITimer::TYPE_ONE_SHOT);
    368  rv = client->SendWithAddress(&multicastAddr, data.Elements(), data.Length(),
    369                               &count);
    370  ASSERT_NS_SUCCEEDED(rv);
    371  EXPECT_EQ(count, sizeof(uint32_t));
    372 
    373  // Wait for server to fail to receive
    374  waiter->Wait(1);
    375  ASSERT_FALSE(NS_SUCCEEDED(timerCb->mResult));
    376  timer->Cancel();
    377 
    378  // Reset state
    379  mozilla::net::NetAddr anyAddr;
    380  anyAddr.inet.family = AF_INET;
    381  anyAddr.inet.ip = PR_htonl(INADDR_ANY);
    382  client->SetMulticastInterfaceAddr(anyAddr);
    383 
    384  // Leave multicast group
    385  rv = server->LeaveMulticastAddr(multicastAddr, nullptr);
    386  ASSERT_NS_SUCCEEDED(rv);
    387 
    388  // Send multicast ping
    389  timerCb->mResult = NS_OK;
    390  timer->InitWithCallback(timerCb, MULTICAST_TIMEOUT, nsITimer::TYPE_ONE_SHOT);
    391  rv = client->SendWithAddress(&multicastAddr, data.Elements(), data.Length(),
    392                               &count);
    393  ASSERT_NS_SUCCEEDED(rv);
    394  EXPECT_EQ(count, sizeof(uint32_t));
    395 
    396  // Wait for server to fail to receive
    397  waiter->Wait(1);
    398  ASSERT_FALSE(NS_SUCCEEDED(timerCb->mResult));
    399  timer->Cancel();
    400 
    401  goto close;  // suppress warning about unused label
    402 
    403 close:
    404  // Close server
    405  server->Close();
    406  client->Close();
    407 
    408  // Wait for client and server to see closing
    409  waiter->Wait(2);
    410 }