test_io.cc (6953B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "test_io.h" 8 9 #include <algorithm> 10 #include <cassert> 11 #include <iostream> 12 #include <memory> 13 14 #include "prerror.h" 15 #include "prlog.h" 16 #include "prthread.h" 17 18 extern bool g_ssl_gtest_verbose; 19 20 namespace nss_test { 21 22 #define LOG(a) std::cerr << name_ << ": " << a << std::endl 23 #define LOGV(a) \ 24 do { \ 25 if (g_ssl_gtest_verbose) LOG(a); \ 26 } while (false) 27 28 PRDescIdentity DummyPrSocket::LayerId() { 29 static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket"); 30 return id; 31 } 32 33 ScopedPRFileDesc DummyPrSocket::CreateFD() { 34 return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this); 35 } 36 37 void DummyPrSocket::Reset() { 38 auto p = peer_.lock(); 39 peer_.reset(); 40 if (p) { 41 p->peer_.reset(); 42 p->Reset(); 43 } 44 while (!input_.empty()) { 45 input_.pop(); 46 } 47 filter_ = nullptr; 48 write_error_ = 0; 49 } 50 51 void DummyPrSocket::PacketReceived(const DataBuffer &packet) { 52 input_.push(Packet(packet)); 53 } 54 55 int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { 56 PR_ASSERT(variant_ == ssl_variant_stream); 57 if (variant_ != ssl_variant_stream) { 58 PR_SetError(PR_INVALID_METHOD_ERROR, 0); 59 return -1; 60 } 61 62 auto dst = peer_.lock(); 63 if (!dst) { 64 PR_SetError(PR_NOT_CONNECTED_ERROR, 0); 65 return -1; 66 } 67 68 if (input_.empty()) { 69 LOGV("Read --> wouldblock " << len); 70 PR_SetError(PR_WOULD_BLOCK_ERROR, 0); 71 return -1; 72 } 73 74 auto &front = input_.front(); 75 size_t to_read = 76 std::min(static_cast<size_t>(len), front.len() - front.offset()); 77 memcpy(data, static_cast<const void *>(front.data() + front.offset()), 78 to_read); 79 front.Advance(to_read); 80 81 if (!front.remaining()) { 82 input_.pop(); 83 } 84 85 return static_cast<int32_t>(to_read); 86 } 87 88 int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, 89 int32_t flags, PRIntervalTime to) { 90 PR_ASSERT(flags == 0); 91 if (flags != 0) { 92 PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); 93 return -1; 94 } 95 96 if (variant() != ssl_variant_datagram) { 97 return Read(f, buf, buflen); 98 } 99 100 auto dst = peer_.lock(); 101 if (!dst) { 102 PR_SetError(PR_NOT_CONNECTED_ERROR, 0); 103 return -1; 104 } 105 106 if (input_.empty()) { 107 PR_SetError(PR_WOULD_BLOCK_ERROR, 0); 108 return -1; 109 } 110 111 auto &front = input_.front(); 112 if (static_cast<size_t>(buflen) < front.len()) { 113 PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); 114 return -1; 115 } 116 117 size_t count = front.len(); 118 memcpy(buf, front.data(), count); 119 120 input_.pop(); 121 return static_cast<int32_t>(count); 122 } 123 124 int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { 125 if (write_error_) { 126 PR_SetError(write_error_, 0); 127 return -1; 128 } 129 130 auto dst = peer_.lock(); 131 if (!dst) { 132 PR_SetError(PR_NOT_CONNECTED_ERROR, 0); 133 return -1; 134 } 135 136 DataBuffer packet(static_cast<const uint8_t *>(buf), 137 static_cast<size_t>(length)); 138 DataBuffer filtered; 139 PacketFilter::Action action = PacketFilter::KEEP; 140 if (filter_) { 141 LOGV("Original packet: " << packet); 142 action = filter_->Process(packet, &filtered); 143 } 144 switch (action) { 145 case PacketFilter::CHANGE: 146 LOG("Filtered packet: " << filtered); 147 dst->PacketReceived(filtered); 148 break; 149 case PacketFilter::DROP: 150 LOG("Drop packet"); 151 break; 152 case PacketFilter::KEEP: 153 dst->PacketReceived(packet); 154 break; 155 } 156 // libssl can't handle it if this reports something other than the length 157 // of what was passed in (or less, but we're not doing partial writes). 158 return static_cast<int32_t>(packet.len()); 159 } 160 161 Poller *Poller::instance; 162 163 Poller *Poller::Instance() { 164 if (!instance) instance = new Poller(); 165 166 return instance; 167 } 168 169 void Poller::Shutdown() { 170 delete instance; 171 instance = nullptr; 172 } 173 174 void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter, 175 PollTarget *target, PollCallback cb) { 176 assert(event < TIMER_EVENT); 177 if (event >= TIMER_EVENT) return; 178 179 std::unique_ptr<Waiter> waiter; 180 auto it = waiters_.find(adapter); 181 if (it == waiters_.end()) { 182 waiter.reset(new Waiter(adapter)); 183 } else { 184 waiter = std::move(it->second); 185 } 186 187 waiter->targets_[event] = target; 188 waiter->callbacks_[event] = cb; 189 waiters_[adapter] = std::move(waiter); 190 } 191 192 void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) { 193 auto it = waiters_.find(adapter); 194 if (it == waiters_.end()) { 195 return; 196 } 197 198 auto &waiter = it->second; 199 waiter->targets_[event] = nullptr; 200 waiter->callbacks_[event] = nullptr; 201 202 // Clean up if there are no callbacks. 203 for (size_t i = 0; i < TIMER_EVENT; ++i) { 204 if (waiter->callbacks_[i]) return; 205 } 206 207 waiters_.erase(adapter); 208 } 209 210 void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, 211 std::shared_ptr<Timer> *timer) { 212 auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb); 213 timers_.push(t); 214 if (timer) *timer = t; 215 } 216 217 bool Poller::Poll() { 218 if (g_ssl_gtest_verbose) { 219 std::cerr << "Poll() waiters = " << waiters_.size() 220 << " timers = " << timers_.size() << std::endl; 221 } 222 PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT; 223 PRTime now = PR_Now(); 224 bool fired = false; 225 226 // Figure out the timer for the select. 227 if (!timers_.empty()) { 228 auto first_timer = timers_.top(); 229 if (now >= first_timer->deadline_) { 230 // Timer expired. 231 timeout = PR_INTERVAL_NO_WAIT; 232 } else { 233 timeout = 234 PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000); 235 } 236 } 237 238 for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { 239 auto &waiter = it->second; 240 241 if (waiter->callbacks_[READABLE_EVENT]) { 242 if (waiter->io_->readable()) { 243 PollCallback callback = waiter->callbacks_[READABLE_EVENT]; 244 PollTarget *target = waiter->targets_[READABLE_EVENT]; 245 waiter->callbacks_[READABLE_EVENT] = nullptr; 246 waiter->targets_[READABLE_EVENT] = nullptr; 247 callback(target, READABLE_EVENT); 248 fired = true; 249 } 250 } 251 } 252 253 if (fired) timeout = PR_INTERVAL_NO_WAIT; 254 255 // Can't wait forever and also have nothing readable now. 256 if (timeout == PR_INTERVAL_NO_TIMEOUT) return false; 257 258 // Sleep. 259 if (timeout != PR_INTERVAL_NO_WAIT) { 260 PR_Sleep(timeout); 261 } 262 263 // Now process anything that timed out. 264 now = PR_Now(); 265 while (!timers_.empty()) { 266 if (now < timers_.top()->deadline_) break; 267 268 auto timer = timers_.top(); 269 timers_.pop(); 270 if (timer->callback_) { 271 timer->callback_(timer->target_, TIMER_EVENT); 272 } 273 } 274 275 return true; 276 } 277 278 } // namespace nss_test