tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

test_io.cc (6953B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
      5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #include "test_io.h"
      8 
      9 #include <algorithm>
     10 #include <cassert>
     11 #include <iostream>
     12 #include <memory>
     13 
     14 #include "prerror.h"
     15 #include "prlog.h"
     16 #include "prthread.h"
     17 
     18 extern bool g_ssl_gtest_verbose;
     19 
     20 namespace nss_test {
     21 
     22 #define LOG(a) std::cerr << name_ << ": " << a << std::endl
     23 #define LOGV(a)                      \
     24  do {                               \
     25    if (g_ssl_gtest_verbose) LOG(a); \
     26  } while (false)
     27 
     28 PRDescIdentity DummyPrSocket::LayerId() {
     29  static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket");
     30  return id;
     31 }
     32 
     33 ScopedPRFileDesc DummyPrSocket::CreateFD() {
     34  return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this);
     35 }
     36 
     37 void DummyPrSocket::Reset() {
     38  auto p = peer_.lock();
     39  peer_.reset();
     40  if (p) {
     41    p->peer_.reset();
     42    p->Reset();
     43  }
     44  while (!input_.empty()) {
     45    input_.pop();
     46  }
     47  filter_ = nullptr;
     48  write_error_ = 0;
     49 }
     50 
     51 void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
     52  input_.push(Packet(packet));
     53 }
     54 
     55 int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) {
     56  PR_ASSERT(variant_ == ssl_variant_stream);
     57  if (variant_ != ssl_variant_stream) {
     58    PR_SetError(PR_INVALID_METHOD_ERROR, 0);
     59    return -1;
     60  }
     61 
     62  auto dst = peer_.lock();
     63  if (!dst) {
     64    PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
     65    return -1;
     66  }
     67 
     68  if (input_.empty()) {
     69    LOGV("Read --> wouldblock " << len);
     70    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
     71    return -1;
     72  }
     73 
     74  auto &front = input_.front();
     75  size_t to_read =
     76      std::min(static_cast<size_t>(len), front.len() - front.offset());
     77  memcpy(data, static_cast<const void *>(front.data() + front.offset()),
     78         to_read);
     79  front.Advance(to_read);
     80 
     81  if (!front.remaining()) {
     82    input_.pop();
     83  }
     84 
     85  return static_cast<int32_t>(to_read);
     86 }
     87 
     88 int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen,
     89                            int32_t flags, PRIntervalTime to) {
     90  PR_ASSERT(flags == 0);
     91  if (flags != 0) {
     92    PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
     93    return -1;
     94  }
     95 
     96  if (variant() != ssl_variant_datagram) {
     97    return Read(f, buf, buflen);
     98  }
     99 
    100  auto dst = peer_.lock();
    101  if (!dst) {
    102    PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
    103    return -1;
    104  }
    105 
    106  if (input_.empty()) {
    107    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    108    return -1;
    109  }
    110 
    111  auto &front = input_.front();
    112  if (static_cast<size_t>(buflen) < front.len()) {
    113    PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
    114    return -1;
    115  }
    116 
    117  size_t count = front.len();
    118  memcpy(buf, front.data(), count);
    119 
    120  input_.pop();
    121  return static_cast<int32_t>(count);
    122 }
    123 
    124 int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
    125  if (write_error_) {
    126    PR_SetError(write_error_, 0);
    127    return -1;
    128  }
    129 
    130  auto dst = peer_.lock();
    131  if (!dst) {
    132    PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
    133    return -1;
    134  }
    135 
    136  DataBuffer packet(static_cast<const uint8_t *>(buf),
    137                    static_cast<size_t>(length));
    138  DataBuffer filtered;
    139  PacketFilter::Action action = PacketFilter::KEEP;
    140  if (filter_) {
    141    LOGV("Original packet: " << packet);
    142    action = filter_->Process(packet, &filtered);
    143  }
    144  switch (action) {
    145    case PacketFilter::CHANGE:
    146      LOG("Filtered packet: " << filtered);
    147      dst->PacketReceived(filtered);
    148      break;
    149    case PacketFilter::DROP:
    150      LOG("Drop packet");
    151      break;
    152    case PacketFilter::KEEP:
    153      dst->PacketReceived(packet);
    154      break;
    155  }
    156  // libssl can't handle it if this reports something other than the length
    157  // of what was passed in (or less, but we're not doing partial writes).
    158  return static_cast<int32_t>(packet.len());
    159 }
    160 
    161 Poller *Poller::instance;
    162 
    163 Poller *Poller::Instance() {
    164  if (!instance) instance = new Poller();
    165 
    166  return instance;
    167 }
    168 
    169 void Poller::Shutdown() {
    170  delete instance;
    171  instance = nullptr;
    172 }
    173 
    174 void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter,
    175                  PollTarget *target, PollCallback cb) {
    176  assert(event < TIMER_EVENT);
    177  if (event >= TIMER_EVENT) return;
    178 
    179  std::unique_ptr<Waiter> waiter;
    180  auto it = waiters_.find(adapter);
    181  if (it == waiters_.end()) {
    182    waiter.reset(new Waiter(adapter));
    183  } else {
    184    waiter = std::move(it->second);
    185  }
    186 
    187  waiter->targets_[event] = target;
    188  waiter->callbacks_[event] = cb;
    189  waiters_[adapter] = std::move(waiter);
    190 }
    191 
    192 void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) {
    193  auto it = waiters_.find(adapter);
    194  if (it == waiters_.end()) {
    195    return;
    196  }
    197 
    198  auto &waiter = it->second;
    199  waiter->targets_[event] = nullptr;
    200  waiter->callbacks_[event] = nullptr;
    201 
    202  // Clean up if there are no callbacks.
    203  for (size_t i = 0; i < TIMER_EVENT; ++i) {
    204    if (waiter->callbacks_[i]) return;
    205  }
    206 
    207  waiters_.erase(adapter);
    208 }
    209 
    210 void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
    211                      std::shared_ptr<Timer> *timer) {
    212  auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb);
    213  timers_.push(t);
    214  if (timer) *timer = t;
    215 }
    216 
    217 bool Poller::Poll() {
    218  if (g_ssl_gtest_verbose) {
    219    std::cerr << "Poll() waiters = " << waiters_.size()
    220              << " timers = " << timers_.size() << std::endl;
    221  }
    222  PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT;
    223  PRTime now = PR_Now();
    224  bool fired = false;
    225 
    226  // Figure out the timer for the select.
    227  if (!timers_.empty()) {
    228    auto first_timer = timers_.top();
    229    if (now >= first_timer->deadline_) {
    230      // Timer expired.
    231      timeout = PR_INTERVAL_NO_WAIT;
    232    } else {
    233      timeout =
    234          PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000);
    235    }
    236  }
    237 
    238  for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
    239    auto &waiter = it->second;
    240 
    241    if (waiter->callbacks_[READABLE_EVENT]) {
    242      if (waiter->io_->readable()) {
    243        PollCallback callback = waiter->callbacks_[READABLE_EVENT];
    244        PollTarget *target = waiter->targets_[READABLE_EVENT];
    245        waiter->callbacks_[READABLE_EVENT] = nullptr;
    246        waiter->targets_[READABLE_EVENT] = nullptr;
    247        callback(target, READABLE_EVENT);
    248        fired = true;
    249      }
    250    }
    251  }
    252 
    253  if (fired) timeout = PR_INTERVAL_NO_WAIT;
    254 
    255  // Can't wait forever and also have nothing readable now.
    256  if (timeout == PR_INTERVAL_NO_TIMEOUT) return false;
    257 
    258  // Sleep.
    259  if (timeout != PR_INTERVAL_NO_WAIT) {
    260    PR_Sleep(timeout);
    261  }
    262 
    263  // Now process anything that timed out.
    264  now = PR_Now();
    265  while (!timers_.empty()) {
    266    if (now < timers_.top()->deadline_) break;
    267 
    268    auto timer = timers_.top();
    269    timers_.pop();
    270    if (timer->callback_) {
    271      timer->callback_(timer->target_, TIMER_EVENT);
    272    }
    273  }
    274 
    275  return true;
    276 }
    277 
    278 }  // namespace nss_test