tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

cdm-test-decryptor.cpp (14858B)


      1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* This Source Code Form is subject to the terms of the Mozilla Public
      3 * License, v. 2.0. If a copy of the MPL was not distributed with this
      4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      5 
      6 #include "cdm-test-decryptor.h"
      7 
      8 #include <istream>
      9 #include <iterator>
     10 #include <mutex>
     11 #include <set>
     12 #include <sstream>
     13 #include <string>
     14 #include <vector>
     15 
     16 #include "cdm-test-output-protection.h"
     17 #include "cdm-test-storage.h"
     18 #include "mozilla/Assertions.h"
     19 
     20 FakeDecryptor* FakeDecryptor::sInstance = nullptr;
     21 
     22 class TestManager {
     23 public:
     24  TestManager() = default;
     25 
     26  // Register a test with the test manager.
     27  void BeginTest(const std::string& aTestID) {
     28    std::lock_guard<std::mutex> lock(mMutex);
     29    auto found = mTestIDs.find(aTestID);
     30    if (found == mTestIDs.end()) {
     31      mTestIDs.insert(aTestID);
     32    } else {
     33      Error("FAIL BeginTest test already existed: " + aTestID);
     34    }
     35  }
     36 
     37  // Notify the test manager that the test is finished. If all tests are done,
     38  // test manager will send "test-storage complete" to notify the parent that
     39  // all tests are finished and also delete itself.
     40  void EndTest(const std::string& aTestID) {
     41    bool isEmpty = false;
     42    {
     43      std::lock_guard<std::mutex> lock(mMutex);
     44      auto found = mTestIDs.find(aTestID);
     45      if (found != mTestIDs.end()) {
     46        mTestIDs.erase(aTestID);
     47        isEmpty = mTestIDs.empty();
     48      } else {
     49        Error("FAIL EndTest test not existed: " + aTestID);
     50        return;
     51      }
     52    }
     53    if (isEmpty) {
     54      Finish();
     55      delete this;
     56    }
     57  }
     58 
     59 private:
     60  ~TestManager() = default;
     61 
     62  static void Error(const std::string& msg) { FakeDecryptor::Message(msg); }
     63 
     64  static void Finish() { FakeDecryptor::Message("test-storage complete"); }
     65 
     66  std::mutex mMutex;
     67  std::set<std::string> mTestIDs;
     68 };
     69 
     70 FakeDecryptor::FakeDecryptor(cdm::Host_11* aHost) : mHost(aHost) {
     71  MOZ_ASSERT(!sInstance);
     72  sInstance = this;
     73 }
     74 
     75 void FakeDecryptor::Message(const std::string& aMessage) {
     76  MOZ_ASSERT(sInstance);
     77  const static std::string sid("fake-session-id");
     78  sInstance->mHost->OnSessionMessage(sid.c_str(), sid.size(),
     79                                     cdm::MessageType::kLicenseRequest,
     80                                     aMessage.c_str(), aMessage.size());
     81 }
     82 
     83 std::vector<std::string> Tokenize(const std::string& aString) {
     84  std::stringstream strstr(aString);
     85  std::istream_iterator<std::string> it(strstr), end;
     86  return std::vector<std::string>(it, end);
     87 }
     88 
     89 static const char TruncateRecordId[] = "truncate-record-id";
     90 static const char TruncateRecordData[] = "I will soon be truncated";
     91 
     92 template <class Continuation>
     93 class WriteRecordSuccessTask {
     94 public:
     95  WriteRecordSuccessTask(std::string aId, Continuation aThen)
     96      : mId(aId), mThen(std::move(aThen)) {}
     97 
     98  void operator()() { ReadRecord(FakeDecryptor::sInstance->mHost, mId, mThen); }
     99 
    100  std::string mId;
    101  Continuation mThen;
    102 };
    103 
    104 class WriteRecordFailureTask {
    105 public:
    106  explicit WriteRecordFailureTask(const std::string& aMessage,
    107                                  TestManager* aTestManager = nullptr,
    108                                  const std::string& aTestID = "")
    109      : mMessage(aMessage), mTestmanager(aTestManager), mTestID(aTestID) {}
    110 
    111  void operator()() {
    112    FakeDecryptor::Message(mMessage);
    113    if (mTestmanager) {
    114      mTestmanager->EndTest(mTestID);
    115    }
    116  }
    117 
    118 private:
    119  std::string mMessage;
    120  TestManager* const mTestmanager;
    121  const std::string mTestID;
    122 };
    123 
    124 class TestEmptyContinuation : public ReadContinuation {
    125 public:
    126  TestEmptyContinuation(TestManager* aTestManager, const std::string& aTestID)
    127      : mTestmanager(aTestManager), mTestID(aTestID) {}
    128 
    129  virtual void operator()(bool aSuccess, const uint8_t* aData,
    130                          uint32_t aDataSize) override {
    131    if (aDataSize) {
    132      FakeDecryptor::Message(
    133          "FAIL TestEmptyContinuation record was not truncated");
    134    }
    135    mTestmanager->EndTest(mTestID);
    136  }
    137 
    138 private:
    139  TestManager* const mTestmanager;
    140  const std::string mTestID;
    141 };
    142 
    143 class TruncateContinuation : public ReadContinuation {
    144 public:
    145  TruncateContinuation(const std::string& aID, TestManager* aTestManager,
    146                       const std::string& aTestID)
    147      : mID(aID), mTestmanager(aTestManager), mTestID(aTestID) {}
    148 
    149  virtual void operator()(bool aSuccess, const uint8_t* aData,
    150                          uint32_t aDataSize) override {
    151    if (std::string(reinterpret_cast<const char*>(aData), aDataSize) !=
    152        TruncateRecordData) {
    153      FakeDecryptor::Message(
    154          "FAIL TruncateContinuation read data doesn't match written data");
    155    }
    156    auto cont = TestEmptyContinuation(mTestmanager, mTestID);
    157    auto msg = "FAIL in TruncateContinuation write.";
    158    WriteRecord(FakeDecryptor::sInstance->mHost, mID, nullptr, 0,
    159                WriteRecordSuccessTask<TestEmptyContinuation>(mID, cont),
    160                WriteRecordFailureTask(msg, mTestmanager, mTestID));
    161  }
    162 
    163 private:
    164  const std::string mID;
    165  TestManager* const mTestmanager;
    166  const std::string mTestID;
    167 };
    168 
    169 class VerifyAndFinishContinuation : public ReadContinuation {
    170 public:
    171  explicit VerifyAndFinishContinuation(std::string aValue,
    172                                       TestManager* aTestManager,
    173                                       const std::string& aTestID)
    174      : mValue(aValue), mTestmanager(aTestManager), mTestID(aTestID) {}
    175 
    176  virtual void operator()(bool aSuccess, const uint8_t* aData,
    177                          uint32_t aDataSize) override {
    178    if (std::string(reinterpret_cast<const char*>(aData), aDataSize) !=
    179        mValue) {
    180      FakeDecryptor::Message(
    181          "FAIL VerifyAndFinishContinuation read data doesn't match expected "
    182          "data");
    183    }
    184    mTestmanager->EndTest(mTestID);
    185  }
    186 
    187 private:
    188  std::string mValue;
    189  TestManager* const mTestmanager;
    190  const std::string mTestID;
    191 };
    192 
    193 class VerifyAndOverwriteContinuation : public ReadContinuation {
    194 public:
    195  VerifyAndOverwriteContinuation(std::string aId, std::string aValue,
    196                                 std::string aOverwrite,
    197                                 TestManager* aTestManager,
    198                                 const std::string& aTestID)
    199      : mId(aId),
    200        mValue(aValue),
    201        mOverwrite(aOverwrite),
    202        mTestmanager(aTestManager),
    203        mTestID(aTestID) {}
    204 
    205  virtual void operator()(bool aSuccess, const uint8_t* aData,
    206                          uint32_t aDataSize) override {
    207    if (std::string(reinterpret_cast<const char*>(aData), aDataSize) !=
    208        mValue) {
    209      FakeDecryptor::Message(
    210          "FAIL VerifyAndOverwriteContinuation read data doesn't match "
    211          "expected data");
    212    }
    213    auto cont = VerifyAndFinishContinuation(mOverwrite, mTestmanager, mTestID);
    214    auto msg = "FAIL in VerifyAndOverwriteContinuation write.";
    215    WriteRecord(FakeDecryptor::sInstance->mHost, mId, mOverwrite,
    216                WriteRecordSuccessTask<VerifyAndFinishContinuation>(mId, cont),
    217                WriteRecordFailureTask(msg, mTestmanager, mTestID));
    218  }
    219 
    220 private:
    221  std::string mId;
    222  std::string mValue;
    223  std::string mOverwrite;
    224  TestManager* const mTestmanager;
    225  const std::string mTestID;
    226 };
    227 
    228 static const char OpenAgainRecordId[] = "open-again-record-id";
    229 
    230 class OpenedSecondTimeContinuation : public OpenContinuation {
    231 public:
    232  explicit OpenedSecondTimeContinuation(TestManager* aTestManager,
    233                                        const std::string& aTestID)
    234      : mTestmanager(aTestManager), mTestID(aTestID) {}
    235 
    236  void operator()(bool aSuccess) override {
    237    if (!aSuccess) {
    238      FakeDecryptor::Message(
    239          "FAIL OpenSecondTimeContinuation should not be able to re-open "
    240          "record.");
    241    }
    242    // Succeeded, open should have failed.
    243    mTestmanager->EndTest(mTestID);
    244  }
    245 
    246 private:
    247  TestManager* const mTestmanager;
    248  const std::string mTestID;
    249 };
    250 
    251 class OpenedFirstTimeContinuation : public OpenContinuation {
    252 public:
    253  OpenedFirstTimeContinuation(const std::string& aID, TestManager* aTestManager,
    254                              const std::string& aTestID)
    255      : mID(aID), mTestmanager(aTestManager), mTestID(aTestID) {}
    256 
    257  void operator()(bool aSuccess) override {
    258    if (!aSuccess) {
    259      FakeDecryptor::Message(
    260          "FAIL OpenAgainContinuation to open record initially.");
    261      mTestmanager->EndTest(mTestID);
    262      return;
    263    }
    264 
    265    auto cont = OpenedSecondTimeContinuation(mTestmanager, mTestID);
    266    OpenRecord(FakeDecryptor::sInstance->mHost, mID, cont);
    267  }
    268 
    269 private:
    270  const std::string mID;
    271  TestManager* const mTestmanager;
    272  const std::string mTestID;
    273 };
    274 
    275 static void DoTestStorage(const std::string& aPrefix,
    276                          TestManager* aTestManager) {
    277  MOZ_ASSERT(FakeDecryptor::sInstance->mHost,
    278             "FakeDecryptor::sInstance->mHost should not be null");
    279  // Basic I/O tests. We run three cases concurrently. The tests, like
    280  // CDMStorage run asynchronously. When they've all passed, we send
    281  // a message back to the parent process, or a failure message if not.
    282 
    283  // Test 1: Basic I/O test, and test that writing 0 bytes in a record
    284  // deletes record.
    285  //
    286  // Write data to truncate record, then
    287  // read data, verify that we read what we wrote, then
    288  // write 0 bytes to truncate record, then
    289  // read data, verify that 0 bytes was read
    290  const std::string id1 = aPrefix + TruncateRecordId;
    291  const std::string testID1 = aPrefix + "write-test-1";
    292  aTestManager->BeginTest(testID1);
    293  auto cont1 = TruncateContinuation(id1, aTestManager, testID1);
    294  auto msg1 = "FAIL in TestStorage writing TruncateRecord.";
    295  WriteRecord(FakeDecryptor::sInstance->mHost, id1, TruncateRecordData,
    296              WriteRecordSuccessTask<TruncateContinuation>(id1, cont1),
    297              WriteRecordFailureTask(msg1, aTestManager, testID1));
    298 
    299  // Test 2: Test that overwriting a record with a shorter record truncates
    300  // the record to the shorter record.
    301  //
    302  // Write record, then
    303  // read and verify record, then
    304  // write a shorter record to same record.
    305  // read and verify
    306  std::string id2 = aPrefix + "record1";
    307  std::string record1 = "This is the first write to a record.";
    308  std::string overwrite = "A shorter record";
    309  const std::string testID2 = aPrefix + "write-test-2";
    310  aTestManager->BeginTest(testID2);
    311  auto task2 = VerifyAndOverwriteContinuation(id2, record1, overwrite,
    312                                              aTestManager, testID2);
    313  auto msg2 = "FAIL in TestStorage writing record1.";
    314  WriteRecord(
    315      FakeDecryptor::sInstance->mHost, id2, record1,
    316      WriteRecordSuccessTask<VerifyAndOverwriteContinuation>(id2, task2),
    317      WriteRecordFailureTask(msg2, aTestManager, testID2));
    318 
    319  // Test 3: Test that opening a record while it's already open fails.
    320  //
    321  // Open record1, then
    322  // open record1, should fail.
    323  // close record1
    324  const std::string id3 = aPrefix + OpenAgainRecordId;
    325  const std::string testID3 = aPrefix + "open-test-1";
    326  aTestManager->BeginTest(testID3);
    327  auto task3 = OpenedFirstTimeContinuation(id3, aTestManager, testID3);
    328  OpenRecord(FakeDecryptor::sInstance->mHost, id3, task3);
    329 }
    330 
    331 void FakeDecryptor::TestStorage() {
    332  auto* testManager = new TestManager();
    333  // Main thread tests.
    334  DoTestStorage("mt1-", testManager);
    335  DoTestStorage("mt2-", testManager);
    336 
    337  // Note: Once all tests finish, TestManager will dispatch "test-pass" message,
    338  // which ends the test for the parent.
    339 }
    340 
    341 class ReportWritten {
    342 public:
    343  ReportWritten(const std::string& aRecordId, const std::string& aValue)
    344      : mRecordId(aRecordId), mValue(aValue) {}
    345  void operator()() {
    346    FakeDecryptor::Message("stored " + mRecordId + " " + mValue);
    347  }
    348 
    349  const std::string mRecordId;
    350  const std::string mValue;
    351 };
    352 
    353 class ReportReadStatusContinuation : public ReadContinuation {
    354 public:
    355  explicit ReportReadStatusContinuation(const std::string& aRecordId)
    356      : mRecordId(aRecordId) {}
    357  void operator()(bool aSuccess, const uint8_t* aData,
    358                  uint32_t aDataSize) override {
    359    if (!aSuccess) {
    360      FakeDecryptor::Message("retrieve " + mRecordId + " failed");
    361    } else {
    362      std::stringstream ss;
    363      ss << aDataSize;
    364      std::string len;
    365      ss >> len;
    366      FakeDecryptor::Message("retrieve " + mRecordId + " succeeded (length " +
    367                             len + " bytes)");
    368    }
    369  }
    370  std::string mRecordId;
    371 };
    372 
    373 class ReportReadRecordContinuation : public ReadContinuation {
    374 public:
    375  explicit ReportReadRecordContinuation(const std::string& aRecordId)
    376      : mRecordId(aRecordId) {}
    377  void operator()(bool aSuccess, const uint8_t* aData,
    378                  uint32_t aDataSize) override {
    379    if (!aSuccess) {
    380      FakeDecryptor::Message("retrieved " + mRecordId + " failed");
    381    } else {
    382      FakeDecryptor::Message(
    383          "retrieved " + mRecordId + " " +
    384          std::string(reinterpret_cast<const char*>(aData), aDataSize));
    385    }
    386  }
    387  std::string mRecordId;
    388 };
    389 
    390 enum ShutdownMode { ShutdownNormal, ShutdownTimeout, ShutdownStoreToken };
    391 
    392 static ShutdownMode sShutdownMode = ShutdownNormal;
    393 
    394 void FakeDecryptor::UpdateSession(uint32_t aPromiseId, const char* aSessionId,
    395                                  uint32_t aSessionIdLength,
    396                                  const uint8_t* aResponse,
    397                                  uint32_t aResponseSize) {
    398  MOZ_ASSERT(FakeDecryptor::sInstance->mHost,
    399             "FakeDecryptor::sInstance->mHost should not be null");
    400  std::string response((const char*)aResponse,
    401                       (const char*)(aResponse) + aResponseSize);
    402  std::vector<std::string> tokens = Tokenize(response);
    403  const std::string& task = tokens[0];
    404  if (task == "test-storage") {
    405    TestStorage();
    406  } else if (task == "store") {
    407    // send "stored record" message on complete.
    408    const std::string& id = tokens[1];
    409    const std::string& value = tokens[2];
    410    WriteRecord(FakeDecryptor::sInstance->mHost, id, value,
    411                ReportWritten(id, value),
    412                WriteRecordFailureTask("FAIL in writing record."));
    413  } else if (task == "retrieve") {
    414    const std::string& id = tokens[1];
    415    ReadRecord(FakeDecryptor::sInstance->mHost, id,
    416               ReportReadStatusContinuation(id));
    417  } else if (task == "shutdown-mode") {
    418    const std::string& mode = tokens[1];
    419    if (mode == "timeout") {
    420      sShutdownMode = ShutdownTimeout;
    421    } else if (mode == "token") {
    422      sShutdownMode = ShutdownStoreToken;
    423      Message("shutdown-token received " + tokens[2]);
    424    }
    425  } else if (task == "retrieve-shutdown-token") {
    426    ReadRecord(FakeDecryptor::sInstance->mHost, "shutdown-token",
    427               ReportReadRecordContinuation("shutdown-token"));
    428  } else if (task == "test-op-apis") {
    429    mozilla::cdmtest::TestOuputProtectionAPIs();
    430  }
    431 }