test_io.h (5436B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #ifndef test_io_h_ 8 #define test_io_h_ 9 10 #include <string.h> 11 #include <map> 12 #include <memory> 13 #include <ostream> 14 #include <queue> 15 #include <string> 16 17 #include "databuffer.h" 18 #include "dummy_io.h" 19 #include "prio.h" 20 #include "nss_scoped_ptrs.h" 21 #include "sslt.h" 22 23 namespace nss_test { 24 25 class DataBuffer; 26 class DummyPrSocket; // Fwd decl. 27 28 // Allow us to inspect a packet before it is written. 29 class PacketFilter { 30 public: 31 enum Action { 32 KEEP, // keep the original packet unmodified 33 CHANGE, // change the packet to a different value 34 DROP // drop the packet 35 }; 36 explicit PacketFilter(bool on = true) : enabled_(on) {} 37 virtual ~PacketFilter() {} 38 39 bool enabled() const { return enabled_; } 40 41 virtual Action Process(const DataBuffer& input, DataBuffer* output) { 42 if (!enabled_) { 43 return KEEP; 44 } 45 return Filter(input, output); 46 } 47 void Enable() { enabled_ = true; } 48 void Disable() { enabled_ = false; } 49 50 // The packet filter takes input and has the option of mutating it. 51 // 52 // A filter that modifies the data places the modified data in *output and 53 // returns CHANGE. A filter that does not modify data returns LEAVE, in which 54 // case the value in *output is ignored. A Filter can return DROP, in which 55 // case the packet is dropped (and *output is ignored). 56 virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; 57 58 private: 59 bool enabled_; 60 }; 61 62 class DummyPrSocket : public DummyIOLayerMethods { 63 public: 64 DummyPrSocket(const std::string& name, SSLProtocolVariant var) 65 : name_(name), 66 variant_(var), 67 peer_(), 68 input_(), 69 filter_(nullptr), 70 write_error_(0) {} 71 virtual ~DummyPrSocket() {} 72 73 static PRDescIdentity LayerId(); 74 75 // Create a file descriptor that will reference this object. The fd must not 76 // live longer than this adapter; call PR_Close() before. 77 ScopedPRFileDesc CreateFD(); 78 79 std::weak_ptr<DummyPrSocket>& peer() { return peer_; } 80 void SetPeer(const std::shared_ptr<DummyPrSocket>& p) { peer_ = p; } 81 void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) { 82 filter_ = filter; 83 } 84 // Drops peer, packet filter and any outstanding packets. 85 void Reset(); 86 87 void PacketReceived(const DataBuffer& data); 88 int32_t Read(PRFileDesc* f, void* data, int32_t len) override; 89 int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, 90 PRIntervalTime to) override; 91 int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; 92 void SetWriteError(PRErrorCode code) { write_error_ = code; } 93 94 SSLProtocolVariant variant() const { return variant_; } 95 bool readable() const { return !input_.empty(); } 96 97 private: 98 class Packet : public DataBuffer { 99 public: 100 Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} 101 102 void Advance(size_t delta) { 103 PR_ASSERT(offset_ + delta <= len()); 104 offset_ = std::min(len(), offset_ + delta); 105 } 106 107 size_t offset() const { return offset_; } 108 size_t remaining() const { return len() - offset_; } 109 110 private: 111 size_t offset_; 112 }; 113 114 const std::string name_; 115 SSLProtocolVariant variant_; 116 std::weak_ptr<DummyPrSocket> peer_; 117 std::queue<Packet> input_; 118 std::shared_ptr<PacketFilter> filter_; 119 PRErrorCode write_error_; 120 }; 121 122 // Marker interface. 123 class PollTarget {}; 124 125 enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; 126 127 typedef void (*PollCallback)(PollTarget*, Event); 128 129 class Poller { 130 public: 131 static Poller* Instance(); // Get a singleton. 132 static void Shutdown(); // Shut it down. 133 134 class Timer { 135 public: 136 Timer(PRTime deadline, PollTarget* target, PollCallback callback) 137 : deadline_(deadline), target_(target), callback_(callback) {} 138 void Cancel() { callback_ = nullptr; } 139 140 PRTime deadline_; 141 PollTarget* target_; 142 PollCallback callback_; 143 }; 144 145 void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter, 146 PollTarget* target, PollCallback cb); 147 void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter); 148 void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, 149 std::shared_ptr<Timer>* handle); 150 bool Poll(); 151 152 private: 153 Poller() : waiters_(), timers_() {} 154 ~Poller() {} 155 156 class Waiter { 157 public: 158 Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) { 159 memset(&targets_[0], 0, sizeof(targets_)); 160 memset(&callbacks_[0], 0, sizeof(callbacks_)); 161 } 162 163 void WaitFor(Event event, PollCallback callback); 164 165 std::shared_ptr<DummyPrSocket> io_; 166 PollTarget* targets_[TIMER_EVENT]; 167 PollCallback callbacks_[TIMER_EVENT]; 168 }; 169 170 class TimerComparator { 171 public: 172 bool operator()(const std::shared_ptr<Timer> lhs, 173 const std::shared_ptr<Timer> rhs) { 174 return lhs->deadline_ > rhs->deadline_; 175 } 176 }; 177 178 static Poller* instance; 179 std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_; 180 std::priority_queue<std::shared_ptr<Timer>, 181 std::vector<std::shared_ptr<Timer>>, TimerComparator> 182 timers_; 183 }; 184 185 } // namespace nss_test 186 187 #endif