sctp_unittest.cpp (11476B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 // Original author: ekr@rtfm.com 8 9 #include <iostream> 10 #include <string> 11 12 #include "nsITimer.h" 13 #include "runnable_utils.h" 14 #include "sigslot.h" 15 #include "transportflow.h" 16 #include "transportlayer.h" 17 #include "transportlayerloopback.h" 18 #include "usrsctp.h" 19 20 #define GTEST_HAS_RTTI 0 21 #include "gtest/gtest.h" 22 #include "gtest_utils.h" 23 24 using namespace mozilla; 25 26 static bool sctp_logging = false; 27 static int port_number = 5000; 28 29 namespace { 30 31 class TransportTestPeer; 32 33 class SendPeriodic : public nsITimerCallback, public nsINamed { 34 public: 35 SendPeriodic(TransportTestPeer* peer, int to_send) 36 : peer_(peer), to_send_(to_send) {} 37 38 NS_DECL_THREADSAFE_ISUPPORTS 39 NS_DECL_NSITIMERCALLBACK 40 NS_DECL_NSINAMED 41 42 protected: 43 virtual ~SendPeriodic() = default; 44 45 TransportTestPeer* peer_; 46 int to_send_; 47 }; 48 49 NS_IMPL_ISUPPORTS(SendPeriodic, nsITimerCallback, nsINamed) 50 51 class TransportTestPeer : public sigslot::has_slots<> { 52 public: 53 TransportTestPeer(std::string name, int local_port, int remote_port, 54 MtransportTestUtils* utils) 55 : name_(name), 56 connected_(false), 57 sent_(0), 58 received_(0), 59 flow_(new TransportFlow()), 60 loopback_(new TransportLayerLoopback()), 61 sctp_(usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, receive_cb, 62 nullptr, 0, nullptr)), 63 timer_(NS_NewTimer()), 64 periodic_(nullptr), 65 test_utils_(utils) { 66 std::cerr << "Creating TransportTestPeer; flow=" 67 << static_cast<void*>(flow_.get()) << " local=" << local_port 68 << " remote=" << remote_port << std::endl; 69 70 usrsctp_register_address(static_cast<void*>(this)); 71 int r = usrsctp_set_non_blocking(sctp_, 1); 72 EXPECT_GE(r, 0); 73 74 struct linger l; 75 l.l_onoff = 1; 76 l.l_linger = 0; 77 r = usrsctp_setsockopt(sctp_, SOL_SOCKET, SO_LINGER, &l, 78 (socklen_t)sizeof(l)); 79 EXPECT_GE(r, 0); 80 81 struct sctp_event subscription; 82 memset(&subscription, 0, sizeof(subscription)); 83 subscription.se_assoc_id = SCTP_ALL_ASSOC; 84 subscription.se_on = 1; 85 subscription.se_type = SCTP_ASSOC_CHANGE; 86 r = usrsctp_setsockopt(sctp_, IPPROTO_SCTP, SCTP_EVENT, &subscription, 87 sizeof(subscription)); 88 EXPECT_GE(r, 0); 89 90 memset(&local_addr_, 0, sizeof(local_addr_)); 91 local_addr_.sconn_family = AF_CONN; 92 #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ 93 !defined(__Userspace_os_Android) 94 local_addr_.sconn_len = sizeof(struct sockaddr_conn); 95 #endif 96 local_addr_.sconn_port = htons(local_port); 97 local_addr_.sconn_addr = static_cast<void*>(this); 98 99 memset(&remote_addr_, 0, sizeof(remote_addr_)); 100 remote_addr_.sconn_family = AF_CONN; 101 #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ 102 !defined(__Userspace_os_Android) 103 remote_addr_.sconn_len = sizeof(struct sockaddr_conn); 104 #endif 105 remote_addr_.sconn_port = htons(remote_port); 106 remote_addr_.sconn_addr = static_cast<void*>(this); 107 108 nsresult res; 109 res = loopback_->Init(); 110 EXPECT_EQ((nsresult)NS_OK, res); 111 } 112 113 ~TransportTestPeer() { 114 std::cerr << "Destroying sctp connection flow=" 115 << static_cast<void*>(flow_.get()) << std::endl; 116 usrsctp_close(sctp_); 117 usrsctp_deregister_address(static_cast<void*>(this)); 118 119 test_utils_->SyncDispatchToSTS( 120 WrapRunnable(this, &TransportTestPeer::DeleteFlow_s)); 121 std::cerr << "~TransportTestPeer() completed" << std::endl; 122 } 123 124 void ConnectSocket(TransportTestPeer* peer) { 125 test_utils_->SyncDispatchToSTS( 126 WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer)); 127 } 128 129 void ConnectSocket_s(TransportTestPeer* peer) { 130 loopback_->Connect(peer->loopback_); 131 ASSERT_EQ((nsresult)NS_OK, loopback_->Init()); 132 flow_->PushLayer(loopback_); 133 134 loopback_->SignalPacketReceived.connect(this, 135 &TransportTestPeer::PacketReceived); 136 137 // SCTP here! 138 ASSERT_TRUE(sctp_); 139 std::cerr << "Calling usrsctp_bind()" << std::endl; 140 int r = 141 usrsctp_bind(sctp_, reinterpret_cast<struct sockaddr*>(&local_addr_), 142 sizeof(local_addr_)); 143 ASSERT_GE(0, r); 144 145 std::cerr << "Calling usrsctp_connect()" << std::endl; 146 r = usrsctp_connect(sctp_, 147 reinterpret_cast<struct sockaddr*>(&remote_addr_), 148 sizeof(remote_addr_)); 149 ASSERT_GE(0, r); 150 } 151 152 void DeleteFlow_s() { 153 if (flow_) { 154 flow_ = nullptr; 155 } 156 } 157 158 void Disconnect_s() { 159 loopback_->Disconnect(); 160 disconnect_all(); 161 } 162 163 void Disconnect() { 164 test_utils_->SyncDispatchToSTS( 165 WrapRunnable(this, &TransportTestPeer::Disconnect_s)); 166 } 167 168 void StartTransfer(size_t to_send) { 169 periodic_ = new SendPeriodic(this, to_send); 170 timer_->SetTarget(test_utils_->sts_target()); 171 timer_->InitWithCallback(periodic_, 10, nsITimer::TYPE_REPEATING_SLACK); 172 } 173 174 void SendOne() { 175 unsigned char buf[100]; 176 memset(buf, sent_ & 0xff, sizeof(buf)); 177 178 struct sctp_sndinfo info; 179 info.snd_sid = 1; 180 info.snd_flags = 0; 181 info.snd_ppid = 50; // What the heck is this? 182 info.snd_context = 0; 183 info.snd_assoc_id = 0; 184 185 int r = usrsctp_sendv(sctp_, buf, sizeof(buf), nullptr, 0, 186 static_cast<void*>(&info), sizeof(info), 187 SCTP_SENDV_SNDINFO, 0); 188 ASSERT_TRUE(r >= 0); 189 ASSERT_EQ(sizeof(buf), (size_t)r); 190 191 ++sent_; 192 } 193 194 int sent() const { return sent_; } 195 int received() const { return received_; } 196 bool connected() const { return connected_; } 197 198 static TransportResult SendPacket_s(UniquePtr<MediaPacket> packet, 199 RefPtr<TransportFlow> flow, 200 TransportLayer* layer) { 201 return layer->SendPacket(*packet); 202 } 203 204 TransportResult SendPacket(const unsigned char* data, size_t len) { 205 UniquePtr<MediaPacket> packet(new MediaPacket); 206 packet->Copy(data, len); 207 208 // Uses DISPATCH_NORMAL to avoid possible deadlocks when we're called 209 // from MainThread especially during shutdown (same as DataChannels). 210 // RUN_ON_THREAD short-circuits if already on the STS thread, which is 211 // normal for most transfers outside of connect() and close(). Passes 212 // a refptr to flow_ to avoid any async deletion issues (since we can't 213 // make 'this' into a refptr as it isn't refcounted) 214 RUN_ON_THREAD(test_utils_->sts_target(), 215 WrapRunnableNM(&TransportTestPeer::SendPacket_s, 216 std::move(packet), flow_, loopback_), 217 NS_DISPATCH_NORMAL); 218 219 return 0; 220 } 221 222 void PacketReceived(TransportLayer* layer, MediaPacket& packet) { 223 std::cerr << "Received " << packet.len() << " bytes" << std::endl; 224 225 // Pass the data to SCTP 226 227 usrsctp_conninput(static_cast<void*>(this), packet.data(), packet.len(), 0); 228 } 229 230 // Process SCTP notification 231 void Notification(union sctp_notification* msg, size_t len) { 232 ASSERT_EQ(msg->sn_header.sn_length, len); 233 234 if (msg->sn_header.sn_type == SCTP_ASSOC_CHANGE) { 235 struct sctp_assoc_change* change = &msg->sn_assoc_change; 236 237 if (change->sac_state == SCTP_COMM_UP) { 238 std::cerr << "Connection up" << std::endl; 239 SetConnected(true); 240 } else { 241 std::cerr << "Connection down" << std::endl; 242 SetConnected(false); 243 } 244 } 245 } 246 247 void SetConnected(bool state) { connected_ = state; } 248 249 static int conn_output(void* addr, void* buffer, size_t length, uint8_t tos, 250 uint8_t set_df) { 251 TransportTestPeer* peer = static_cast<TransportTestPeer*>(addr); 252 253 peer->SendPacket(static_cast<unsigned char*>(buffer), length); 254 255 return 0; 256 } 257 258 static int receive_cb(struct socket* sock, union sctp_sockstore addr, 259 void* data, size_t datalen, struct sctp_rcvinfo rcv, 260 int flags, void* ulp_info) { 261 TransportTestPeer* me = 262 static_cast<TransportTestPeer*>(addr.sconn.sconn_addr); 263 MOZ_ASSERT(me); 264 265 if (flags & MSG_NOTIFICATION) { 266 union sctp_notification* notif = 267 static_cast<union sctp_notification*>(data); 268 269 me->Notification(notif, datalen); 270 return 0; 271 } 272 273 me->received_ += datalen; 274 275 std::cerr << "receive_cb: sock " << sock << " data " << data << "(" 276 << datalen << ") total received bytes = " << me->received_ 277 << std::endl; 278 279 return 0; 280 } 281 282 private: 283 std::string name_; 284 std::atomic<bool> connected_; 285 std::atomic<size_t> sent_; 286 std::atomic<size_t> received_; 287 // Owns the TransportLayerLoopback, but basically does nothing else. 288 RefPtr<TransportFlow> flow_; 289 TransportLayerLoopback* loopback_; 290 291 struct sockaddr_conn local_addr_; 292 struct sockaddr_conn remote_addr_; 293 struct socket* sctp_; 294 nsCOMPtr<nsITimer> timer_; 295 RefPtr<SendPeriodic> periodic_; 296 MtransportTestUtils* test_utils_; 297 }; 298 299 // Implemented here because it calls a method of TransportTestPeer 300 NS_IMETHODIMP SendPeriodic::Notify(nsITimer* timer) { 301 peer_->SendOne(); 302 --to_send_; 303 if (!to_send_) { 304 timer->Cancel(); 305 } 306 return NS_OK; 307 } 308 309 NS_IMETHODIMP 310 SendPeriodic::GetName(nsACString& aName) { 311 aName.AssignLiteral("SendPeriodic"); 312 return NS_OK; 313 } 314 315 class SctpTransportTest : public MtransportTest { 316 public: 317 SctpTransportTest() = default; 318 319 ~SctpTransportTest() = default; 320 321 static void debug_printf(const char* format, ...) { 322 va_list ap; 323 324 va_start(ap, format); 325 vprintf(format, ap); 326 va_end(ap); 327 } 328 329 static void SetUpTestCase() { 330 if (sctp_logging) { 331 usrsctp_init(0, &TransportTestPeer::conn_output, debug_printf); 332 usrsctp_sysctl_set_sctp_debug_on(0xffffffff); 333 } else { 334 usrsctp_init(0, &TransportTestPeer::conn_output, nullptr); 335 } 336 } 337 338 void TearDown() override { 339 if (p1_) p1_->Disconnect(); 340 if (p2_) p2_->Disconnect(); 341 delete p1_; 342 delete p2_; 343 344 MtransportTest::TearDown(); 345 } 346 347 void ConnectSocket(int p1port = 0, int p2port = 0) { 348 if (!p1port) p1port = port_number++; 349 if (!p2port) p2port = port_number++; 350 351 p1_ = new TransportTestPeer("P1", p1port, p2port, test_utils_); 352 p2_ = new TransportTestPeer("P2", p2port, p1port, test_utils_); 353 354 p1_->ConnectSocket(p2_); 355 p2_->ConnectSocket(p1_); 356 ASSERT_TRUE_WAIT(p1_->connected(), 2000); 357 ASSERT_TRUE_WAIT(p2_->connected(), 2000); 358 } 359 360 void TestTransfer(int expected = 1) { 361 std::cerr << "Starting trasnsfer test" << std::endl; 362 p1_->StartTransfer(expected); 363 ASSERT_TRUE_WAIT(p1_->sent() == expected, 10000); 364 ASSERT_TRUE_WAIT(p2_->received() == (expected * 100), 10000); 365 std::cerr << "P2 received " << p2_->received() << std::endl; 366 } 367 368 protected: 369 TransportTestPeer* p1_ = nullptr; 370 TransportTestPeer* p2_ = nullptr; 371 }; 372 373 TEST_F(SctpTransportTest, TestConnect) { ConnectSocket(); } 374 375 TEST_F(SctpTransportTest, TestConnectSymmetricalPorts) { 376 ConnectSocket(5002, 5002); 377 } 378 379 TEST_F(SctpTransportTest, TestTransfer) { 380 ConnectSocket(); 381 TestTransfer(50); 382 } 383 384 } // end namespace