tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

sctp_unittest.cpp (11476B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 // Original author: ekr@rtfm.com
      8 
      9 #include <iostream>
     10 #include <string>
     11 
     12 #include "nsITimer.h"
     13 #include "runnable_utils.h"
     14 #include "sigslot.h"
     15 #include "transportflow.h"
     16 #include "transportlayer.h"
     17 #include "transportlayerloopback.h"
     18 #include "usrsctp.h"
     19 
     20 #define GTEST_HAS_RTTI 0
     21 #include "gtest/gtest.h"
     22 #include "gtest_utils.h"
     23 
     24 using namespace mozilla;
     25 
     26 static bool sctp_logging = false;
     27 static int port_number = 5000;
     28 
     29 namespace {
     30 
     31 class TransportTestPeer;
     32 
     33 class SendPeriodic : public nsITimerCallback, public nsINamed {
     34 public:
     35  SendPeriodic(TransportTestPeer* peer, int to_send)
     36      : peer_(peer), to_send_(to_send) {}
     37 
     38  NS_DECL_THREADSAFE_ISUPPORTS
     39  NS_DECL_NSITIMERCALLBACK
     40  NS_DECL_NSINAMED
     41 
     42 protected:
     43  virtual ~SendPeriodic() = default;
     44 
     45  TransportTestPeer* peer_;
     46  int to_send_;
     47 };
     48 
     49 NS_IMPL_ISUPPORTS(SendPeriodic, nsITimerCallback, nsINamed)
     50 
     51 class TransportTestPeer : public sigslot::has_slots<> {
     52 public:
     53  TransportTestPeer(std::string name, int local_port, int remote_port,
     54                    MtransportTestUtils* utils)
     55      : name_(name),
     56        connected_(false),
     57        sent_(0),
     58        received_(0),
     59        flow_(new TransportFlow()),
     60        loopback_(new TransportLayerLoopback()),
     61        sctp_(usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, receive_cb,
     62                             nullptr, 0, nullptr)),
     63        timer_(NS_NewTimer()),
     64        periodic_(nullptr),
     65        test_utils_(utils) {
     66    std::cerr << "Creating TransportTestPeer; flow="
     67              << static_cast<void*>(flow_.get()) << " local=" << local_port
     68              << " remote=" << remote_port << std::endl;
     69 
     70    usrsctp_register_address(static_cast<void*>(this));
     71    int r = usrsctp_set_non_blocking(sctp_, 1);
     72    EXPECT_GE(r, 0);
     73 
     74    struct linger l;
     75    l.l_onoff = 1;
     76    l.l_linger = 0;
     77    r = usrsctp_setsockopt(sctp_, SOL_SOCKET, SO_LINGER, &l,
     78                           (socklen_t)sizeof(l));
     79    EXPECT_GE(r, 0);
     80 
     81    struct sctp_event subscription;
     82    memset(&subscription, 0, sizeof(subscription));
     83    subscription.se_assoc_id = SCTP_ALL_ASSOC;
     84    subscription.se_on = 1;
     85    subscription.se_type = SCTP_ASSOC_CHANGE;
     86    r = usrsctp_setsockopt(sctp_, IPPROTO_SCTP, SCTP_EVENT, &subscription,
     87                           sizeof(subscription));
     88    EXPECT_GE(r, 0);
     89 
     90    memset(&local_addr_, 0, sizeof(local_addr_));
     91    local_addr_.sconn_family = AF_CONN;
     92 #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \
     93    !defined(__Userspace_os_Android)
     94    local_addr_.sconn_len = sizeof(struct sockaddr_conn);
     95 #endif
     96    local_addr_.sconn_port = htons(local_port);
     97    local_addr_.sconn_addr = static_cast<void*>(this);
     98 
     99    memset(&remote_addr_, 0, sizeof(remote_addr_));
    100    remote_addr_.sconn_family = AF_CONN;
    101 #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \
    102    !defined(__Userspace_os_Android)
    103    remote_addr_.sconn_len = sizeof(struct sockaddr_conn);
    104 #endif
    105    remote_addr_.sconn_port = htons(remote_port);
    106    remote_addr_.sconn_addr = static_cast<void*>(this);
    107 
    108    nsresult res;
    109    res = loopback_->Init();
    110    EXPECT_EQ((nsresult)NS_OK, res);
    111  }
    112 
    113  ~TransportTestPeer() {
    114    std::cerr << "Destroying sctp connection flow="
    115              << static_cast<void*>(flow_.get()) << std::endl;
    116    usrsctp_close(sctp_);
    117    usrsctp_deregister_address(static_cast<void*>(this));
    118 
    119    test_utils_->SyncDispatchToSTS(
    120        WrapRunnable(this, &TransportTestPeer::DeleteFlow_s));
    121    std::cerr << "~TransportTestPeer() completed" << std::endl;
    122  }
    123 
    124  void ConnectSocket(TransportTestPeer* peer) {
    125    test_utils_->SyncDispatchToSTS(
    126        WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer));
    127  }
    128 
    129  void ConnectSocket_s(TransportTestPeer* peer) {
    130    loopback_->Connect(peer->loopback_);
    131    ASSERT_EQ((nsresult)NS_OK, loopback_->Init());
    132    flow_->PushLayer(loopback_);
    133 
    134    loopback_->SignalPacketReceived.connect(this,
    135                                            &TransportTestPeer::PacketReceived);
    136 
    137    // SCTP here!
    138    ASSERT_TRUE(sctp_);
    139    std::cerr << "Calling usrsctp_bind()" << std::endl;
    140    int r =
    141        usrsctp_bind(sctp_, reinterpret_cast<struct sockaddr*>(&local_addr_),
    142                     sizeof(local_addr_));
    143    ASSERT_GE(0, r);
    144 
    145    std::cerr << "Calling usrsctp_connect()" << std::endl;
    146    r = usrsctp_connect(sctp_,
    147                        reinterpret_cast<struct sockaddr*>(&remote_addr_),
    148                        sizeof(remote_addr_));
    149    ASSERT_GE(0, r);
    150  }
    151 
    152  void DeleteFlow_s() {
    153    if (flow_) {
    154      flow_ = nullptr;
    155    }
    156  }
    157 
    158  void Disconnect_s() {
    159    loopback_->Disconnect();
    160    disconnect_all();
    161  }
    162 
    163  void Disconnect() {
    164    test_utils_->SyncDispatchToSTS(
    165        WrapRunnable(this, &TransportTestPeer::Disconnect_s));
    166  }
    167 
    168  void StartTransfer(size_t to_send) {
    169    periodic_ = new SendPeriodic(this, to_send);
    170    timer_->SetTarget(test_utils_->sts_target());
    171    timer_->InitWithCallback(periodic_, 10, nsITimer::TYPE_REPEATING_SLACK);
    172  }
    173 
    174  void SendOne() {
    175    unsigned char buf[100];
    176    memset(buf, sent_ & 0xff, sizeof(buf));
    177 
    178    struct sctp_sndinfo info;
    179    info.snd_sid = 1;
    180    info.snd_flags = 0;
    181    info.snd_ppid = 50;  // What the heck is this?
    182    info.snd_context = 0;
    183    info.snd_assoc_id = 0;
    184 
    185    int r = usrsctp_sendv(sctp_, buf, sizeof(buf), nullptr, 0,
    186                          static_cast<void*>(&info), sizeof(info),
    187                          SCTP_SENDV_SNDINFO, 0);
    188    ASSERT_TRUE(r >= 0);
    189    ASSERT_EQ(sizeof(buf), (size_t)r);
    190 
    191    ++sent_;
    192  }
    193 
    194  int sent() const { return sent_; }
    195  int received() const { return received_; }
    196  bool connected() const { return connected_; }
    197 
    198  static TransportResult SendPacket_s(UniquePtr<MediaPacket> packet,
    199                                      RefPtr<TransportFlow> flow,
    200                                      TransportLayer* layer) {
    201    return layer->SendPacket(*packet);
    202  }
    203 
    204  TransportResult SendPacket(const unsigned char* data, size_t len) {
    205    UniquePtr<MediaPacket> packet(new MediaPacket);
    206    packet->Copy(data, len);
    207 
    208    // Uses DISPATCH_NORMAL to avoid possible deadlocks when we're called
    209    // from MainThread especially during shutdown (same as DataChannels).
    210    // RUN_ON_THREAD short-circuits if already on the STS thread, which is
    211    // normal for most transfers outside of connect() and close().  Passes
    212    // a refptr to flow_ to avoid any async deletion issues (since we can't
    213    // make 'this' into a refptr as it isn't refcounted)
    214    RUN_ON_THREAD(test_utils_->sts_target(),
    215                  WrapRunnableNM(&TransportTestPeer::SendPacket_s,
    216                                 std::move(packet), flow_, loopback_),
    217                  NS_DISPATCH_NORMAL);
    218 
    219    return 0;
    220  }
    221 
    222  void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
    223    std::cerr << "Received " << packet.len() << " bytes" << std::endl;
    224 
    225    // Pass the data to SCTP
    226 
    227    usrsctp_conninput(static_cast<void*>(this), packet.data(), packet.len(), 0);
    228  }
    229 
    230  // Process SCTP notification
    231  void Notification(union sctp_notification* msg, size_t len) {
    232    ASSERT_EQ(msg->sn_header.sn_length, len);
    233 
    234    if (msg->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
    235      struct sctp_assoc_change* change = &msg->sn_assoc_change;
    236 
    237      if (change->sac_state == SCTP_COMM_UP) {
    238        std::cerr << "Connection up" << std::endl;
    239        SetConnected(true);
    240      } else {
    241        std::cerr << "Connection down" << std::endl;
    242        SetConnected(false);
    243      }
    244    }
    245  }
    246 
    247  void SetConnected(bool state) { connected_ = state; }
    248 
    249  static int conn_output(void* addr, void* buffer, size_t length, uint8_t tos,
    250                         uint8_t set_df) {
    251    TransportTestPeer* peer = static_cast<TransportTestPeer*>(addr);
    252 
    253    peer->SendPacket(static_cast<unsigned char*>(buffer), length);
    254 
    255    return 0;
    256  }
    257 
    258  static int receive_cb(struct socket* sock, union sctp_sockstore addr,
    259                        void* data, size_t datalen, struct sctp_rcvinfo rcv,
    260                        int flags, void* ulp_info) {
    261    TransportTestPeer* me =
    262        static_cast<TransportTestPeer*>(addr.sconn.sconn_addr);
    263    MOZ_ASSERT(me);
    264 
    265    if (flags & MSG_NOTIFICATION) {
    266      union sctp_notification* notif =
    267          static_cast<union sctp_notification*>(data);
    268 
    269      me->Notification(notif, datalen);
    270      return 0;
    271    }
    272 
    273    me->received_ += datalen;
    274 
    275    std::cerr << "receive_cb: sock " << sock << " data " << data << "("
    276              << datalen << ") total received bytes = " << me->received_
    277              << std::endl;
    278 
    279    return 0;
    280  }
    281 
    282 private:
    283  std::string name_;
    284  std::atomic<bool> connected_;
    285  std::atomic<size_t> sent_;
    286  std::atomic<size_t> received_;
    287  // Owns the TransportLayerLoopback, but basically does nothing else.
    288  RefPtr<TransportFlow> flow_;
    289  TransportLayerLoopback* loopback_;
    290 
    291  struct sockaddr_conn local_addr_;
    292  struct sockaddr_conn remote_addr_;
    293  struct socket* sctp_;
    294  nsCOMPtr<nsITimer> timer_;
    295  RefPtr<SendPeriodic> periodic_;
    296  MtransportTestUtils* test_utils_;
    297 };
    298 
    299 // Implemented here because it calls a method of TransportTestPeer
    300 NS_IMETHODIMP SendPeriodic::Notify(nsITimer* timer) {
    301  peer_->SendOne();
    302  --to_send_;
    303  if (!to_send_) {
    304    timer->Cancel();
    305  }
    306  return NS_OK;
    307 }
    308 
    309 NS_IMETHODIMP
    310 SendPeriodic::GetName(nsACString& aName) {
    311  aName.AssignLiteral("SendPeriodic");
    312  return NS_OK;
    313 }
    314 
    315 class SctpTransportTest : public MtransportTest {
    316 public:
    317  SctpTransportTest() = default;
    318 
    319  ~SctpTransportTest() = default;
    320 
    321  static void debug_printf(const char* format, ...) {
    322    va_list ap;
    323 
    324    va_start(ap, format);
    325    vprintf(format, ap);
    326    va_end(ap);
    327  }
    328 
    329  static void SetUpTestCase() {
    330    if (sctp_logging) {
    331      usrsctp_init(0, &TransportTestPeer::conn_output, debug_printf);
    332      usrsctp_sysctl_set_sctp_debug_on(0xffffffff);
    333    } else {
    334      usrsctp_init(0, &TransportTestPeer::conn_output, nullptr);
    335    }
    336  }
    337 
    338  void TearDown() override {
    339    if (p1_) p1_->Disconnect();
    340    if (p2_) p2_->Disconnect();
    341    delete p1_;
    342    delete p2_;
    343 
    344    MtransportTest::TearDown();
    345  }
    346 
    347  void ConnectSocket(int p1port = 0, int p2port = 0) {
    348    if (!p1port) p1port = port_number++;
    349    if (!p2port) p2port = port_number++;
    350 
    351    p1_ = new TransportTestPeer("P1", p1port, p2port, test_utils_);
    352    p2_ = new TransportTestPeer("P2", p2port, p1port, test_utils_);
    353 
    354    p1_->ConnectSocket(p2_);
    355    p2_->ConnectSocket(p1_);
    356    ASSERT_TRUE_WAIT(p1_->connected(), 2000);
    357    ASSERT_TRUE_WAIT(p2_->connected(), 2000);
    358  }
    359 
    360  void TestTransfer(int expected = 1) {
    361    std::cerr << "Starting trasnsfer test" << std::endl;
    362    p1_->StartTransfer(expected);
    363    ASSERT_TRUE_WAIT(p1_->sent() == expected, 10000);
    364    ASSERT_TRUE_WAIT(p2_->received() == (expected * 100), 10000);
    365    std::cerr << "P2 received " << p2_->received() << std::endl;
    366  }
    367 
    368 protected:
    369  TransportTestPeer* p1_ = nullptr;
    370  TransportTestPeer* p2_ = nullptr;
    371 };
    372 
    373 TEST_F(SctpTransportTest, TestConnect) { ConnectSocket(); }
    374 
    375 TEST_F(SctpTransportTest, TestConnectSymmetricalPorts) {
    376  ConnectSocket(5002, 5002);
    377 }
    378 
    379 TEST_F(SctpTransportTest, TestTransfer) {
    380  ConnectSocket();
    381  TestTransfer(50);
    382 }
    383 
    384 }  // end namespace