cdm-test-decryptor.cpp (14858B)
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* This Source Code Form is subject to the terms of the Mozilla Public 3 * License, v. 2.0. If a copy of the MPL was not distributed with this 4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 5 6 #include "cdm-test-decryptor.h" 7 8 #include <istream> 9 #include <iterator> 10 #include <mutex> 11 #include <set> 12 #include <sstream> 13 #include <string> 14 #include <vector> 15 16 #include "cdm-test-output-protection.h" 17 #include "cdm-test-storage.h" 18 #include "mozilla/Assertions.h" 19 20 FakeDecryptor* FakeDecryptor::sInstance = nullptr; 21 22 class TestManager { 23 public: 24 TestManager() = default; 25 26 // Register a test with the test manager. 27 void BeginTest(const std::string& aTestID) { 28 std::lock_guard<std::mutex> lock(mMutex); 29 auto found = mTestIDs.find(aTestID); 30 if (found == mTestIDs.end()) { 31 mTestIDs.insert(aTestID); 32 } else { 33 Error("FAIL BeginTest test already existed: " + aTestID); 34 } 35 } 36 37 // Notify the test manager that the test is finished. If all tests are done, 38 // test manager will send "test-storage complete" to notify the parent that 39 // all tests are finished and also delete itself. 40 void EndTest(const std::string& aTestID) { 41 bool isEmpty = false; 42 { 43 std::lock_guard<std::mutex> lock(mMutex); 44 auto found = mTestIDs.find(aTestID); 45 if (found != mTestIDs.end()) { 46 mTestIDs.erase(aTestID); 47 isEmpty = mTestIDs.empty(); 48 } else { 49 Error("FAIL EndTest test not existed: " + aTestID); 50 return; 51 } 52 } 53 if (isEmpty) { 54 Finish(); 55 delete this; 56 } 57 } 58 59 private: 60 ~TestManager() = default; 61 62 static void Error(const std::string& msg) { FakeDecryptor::Message(msg); } 63 64 static void Finish() { FakeDecryptor::Message("test-storage complete"); } 65 66 std::mutex mMutex; 67 std::set<std::string> mTestIDs; 68 }; 69 70 FakeDecryptor::FakeDecryptor(cdm::Host_11* aHost) : mHost(aHost) { 71 MOZ_ASSERT(!sInstance); 72 sInstance = this; 73 } 74 75 void FakeDecryptor::Message(const std::string& aMessage) { 76 MOZ_ASSERT(sInstance); 77 const static std::string sid("fake-session-id"); 78 sInstance->mHost->OnSessionMessage(sid.c_str(), sid.size(), 79 cdm::MessageType::kLicenseRequest, 80 aMessage.c_str(), aMessage.size()); 81 } 82 83 std::vector<std::string> Tokenize(const std::string& aString) { 84 std::stringstream strstr(aString); 85 std::istream_iterator<std::string> it(strstr), end; 86 return std::vector<std::string>(it, end); 87 } 88 89 static const char TruncateRecordId[] = "truncate-record-id"; 90 static const char TruncateRecordData[] = "I will soon be truncated"; 91 92 template <class Continuation> 93 class WriteRecordSuccessTask { 94 public: 95 WriteRecordSuccessTask(std::string aId, Continuation aThen) 96 : mId(aId), mThen(std::move(aThen)) {} 97 98 void operator()() { ReadRecord(FakeDecryptor::sInstance->mHost, mId, mThen); } 99 100 std::string mId; 101 Continuation mThen; 102 }; 103 104 class WriteRecordFailureTask { 105 public: 106 explicit WriteRecordFailureTask(const std::string& aMessage, 107 TestManager* aTestManager = nullptr, 108 const std::string& aTestID = "") 109 : mMessage(aMessage), mTestmanager(aTestManager), mTestID(aTestID) {} 110 111 void operator()() { 112 FakeDecryptor::Message(mMessage); 113 if (mTestmanager) { 114 mTestmanager->EndTest(mTestID); 115 } 116 } 117 118 private: 119 std::string mMessage; 120 TestManager* const mTestmanager; 121 const std::string mTestID; 122 }; 123 124 class TestEmptyContinuation : public ReadContinuation { 125 public: 126 TestEmptyContinuation(TestManager* aTestManager, const std::string& aTestID) 127 : mTestmanager(aTestManager), mTestID(aTestID) {} 128 129 virtual void operator()(bool aSuccess, const uint8_t* aData, 130 uint32_t aDataSize) override { 131 if (aDataSize) { 132 FakeDecryptor::Message( 133 "FAIL TestEmptyContinuation record was not truncated"); 134 } 135 mTestmanager->EndTest(mTestID); 136 } 137 138 private: 139 TestManager* const mTestmanager; 140 const std::string mTestID; 141 }; 142 143 class TruncateContinuation : public ReadContinuation { 144 public: 145 TruncateContinuation(const std::string& aID, TestManager* aTestManager, 146 const std::string& aTestID) 147 : mID(aID), mTestmanager(aTestManager), mTestID(aTestID) {} 148 149 virtual void operator()(bool aSuccess, const uint8_t* aData, 150 uint32_t aDataSize) override { 151 if (std::string(reinterpret_cast<const char*>(aData), aDataSize) != 152 TruncateRecordData) { 153 FakeDecryptor::Message( 154 "FAIL TruncateContinuation read data doesn't match written data"); 155 } 156 auto cont = TestEmptyContinuation(mTestmanager, mTestID); 157 auto msg = "FAIL in TruncateContinuation write."; 158 WriteRecord(FakeDecryptor::sInstance->mHost, mID, nullptr, 0, 159 WriteRecordSuccessTask<TestEmptyContinuation>(mID, cont), 160 WriteRecordFailureTask(msg, mTestmanager, mTestID)); 161 } 162 163 private: 164 const std::string mID; 165 TestManager* const mTestmanager; 166 const std::string mTestID; 167 }; 168 169 class VerifyAndFinishContinuation : public ReadContinuation { 170 public: 171 explicit VerifyAndFinishContinuation(std::string aValue, 172 TestManager* aTestManager, 173 const std::string& aTestID) 174 : mValue(aValue), mTestmanager(aTestManager), mTestID(aTestID) {} 175 176 virtual void operator()(bool aSuccess, const uint8_t* aData, 177 uint32_t aDataSize) override { 178 if (std::string(reinterpret_cast<const char*>(aData), aDataSize) != 179 mValue) { 180 FakeDecryptor::Message( 181 "FAIL VerifyAndFinishContinuation read data doesn't match expected " 182 "data"); 183 } 184 mTestmanager->EndTest(mTestID); 185 } 186 187 private: 188 std::string mValue; 189 TestManager* const mTestmanager; 190 const std::string mTestID; 191 }; 192 193 class VerifyAndOverwriteContinuation : public ReadContinuation { 194 public: 195 VerifyAndOverwriteContinuation(std::string aId, std::string aValue, 196 std::string aOverwrite, 197 TestManager* aTestManager, 198 const std::string& aTestID) 199 : mId(aId), 200 mValue(aValue), 201 mOverwrite(aOverwrite), 202 mTestmanager(aTestManager), 203 mTestID(aTestID) {} 204 205 virtual void operator()(bool aSuccess, const uint8_t* aData, 206 uint32_t aDataSize) override { 207 if (std::string(reinterpret_cast<const char*>(aData), aDataSize) != 208 mValue) { 209 FakeDecryptor::Message( 210 "FAIL VerifyAndOverwriteContinuation read data doesn't match " 211 "expected data"); 212 } 213 auto cont = VerifyAndFinishContinuation(mOverwrite, mTestmanager, mTestID); 214 auto msg = "FAIL in VerifyAndOverwriteContinuation write."; 215 WriteRecord(FakeDecryptor::sInstance->mHost, mId, mOverwrite, 216 WriteRecordSuccessTask<VerifyAndFinishContinuation>(mId, cont), 217 WriteRecordFailureTask(msg, mTestmanager, mTestID)); 218 } 219 220 private: 221 std::string mId; 222 std::string mValue; 223 std::string mOverwrite; 224 TestManager* const mTestmanager; 225 const std::string mTestID; 226 }; 227 228 static const char OpenAgainRecordId[] = "open-again-record-id"; 229 230 class OpenedSecondTimeContinuation : public OpenContinuation { 231 public: 232 explicit OpenedSecondTimeContinuation(TestManager* aTestManager, 233 const std::string& aTestID) 234 : mTestmanager(aTestManager), mTestID(aTestID) {} 235 236 void operator()(bool aSuccess) override { 237 if (!aSuccess) { 238 FakeDecryptor::Message( 239 "FAIL OpenSecondTimeContinuation should not be able to re-open " 240 "record."); 241 } 242 // Succeeded, open should have failed. 243 mTestmanager->EndTest(mTestID); 244 } 245 246 private: 247 TestManager* const mTestmanager; 248 const std::string mTestID; 249 }; 250 251 class OpenedFirstTimeContinuation : public OpenContinuation { 252 public: 253 OpenedFirstTimeContinuation(const std::string& aID, TestManager* aTestManager, 254 const std::string& aTestID) 255 : mID(aID), mTestmanager(aTestManager), mTestID(aTestID) {} 256 257 void operator()(bool aSuccess) override { 258 if (!aSuccess) { 259 FakeDecryptor::Message( 260 "FAIL OpenAgainContinuation to open record initially."); 261 mTestmanager->EndTest(mTestID); 262 return; 263 } 264 265 auto cont = OpenedSecondTimeContinuation(mTestmanager, mTestID); 266 OpenRecord(FakeDecryptor::sInstance->mHost, mID, cont); 267 } 268 269 private: 270 const std::string mID; 271 TestManager* const mTestmanager; 272 const std::string mTestID; 273 }; 274 275 static void DoTestStorage(const std::string& aPrefix, 276 TestManager* aTestManager) { 277 MOZ_ASSERT(FakeDecryptor::sInstance->mHost, 278 "FakeDecryptor::sInstance->mHost should not be null"); 279 // Basic I/O tests. We run three cases concurrently. The tests, like 280 // CDMStorage run asynchronously. When they've all passed, we send 281 // a message back to the parent process, or a failure message if not. 282 283 // Test 1: Basic I/O test, and test that writing 0 bytes in a record 284 // deletes record. 285 // 286 // Write data to truncate record, then 287 // read data, verify that we read what we wrote, then 288 // write 0 bytes to truncate record, then 289 // read data, verify that 0 bytes was read 290 const std::string id1 = aPrefix + TruncateRecordId; 291 const std::string testID1 = aPrefix + "write-test-1"; 292 aTestManager->BeginTest(testID1); 293 auto cont1 = TruncateContinuation(id1, aTestManager, testID1); 294 auto msg1 = "FAIL in TestStorage writing TruncateRecord."; 295 WriteRecord(FakeDecryptor::sInstance->mHost, id1, TruncateRecordData, 296 WriteRecordSuccessTask<TruncateContinuation>(id1, cont1), 297 WriteRecordFailureTask(msg1, aTestManager, testID1)); 298 299 // Test 2: Test that overwriting a record with a shorter record truncates 300 // the record to the shorter record. 301 // 302 // Write record, then 303 // read and verify record, then 304 // write a shorter record to same record. 305 // read and verify 306 std::string id2 = aPrefix + "record1"; 307 std::string record1 = "This is the first write to a record."; 308 std::string overwrite = "A shorter record"; 309 const std::string testID2 = aPrefix + "write-test-2"; 310 aTestManager->BeginTest(testID2); 311 auto task2 = VerifyAndOverwriteContinuation(id2, record1, overwrite, 312 aTestManager, testID2); 313 auto msg2 = "FAIL in TestStorage writing record1."; 314 WriteRecord( 315 FakeDecryptor::sInstance->mHost, id2, record1, 316 WriteRecordSuccessTask<VerifyAndOverwriteContinuation>(id2, task2), 317 WriteRecordFailureTask(msg2, aTestManager, testID2)); 318 319 // Test 3: Test that opening a record while it's already open fails. 320 // 321 // Open record1, then 322 // open record1, should fail. 323 // close record1 324 const std::string id3 = aPrefix + OpenAgainRecordId; 325 const std::string testID3 = aPrefix + "open-test-1"; 326 aTestManager->BeginTest(testID3); 327 auto task3 = OpenedFirstTimeContinuation(id3, aTestManager, testID3); 328 OpenRecord(FakeDecryptor::sInstance->mHost, id3, task3); 329 } 330 331 void FakeDecryptor::TestStorage() { 332 auto* testManager = new TestManager(); 333 // Main thread tests. 334 DoTestStorage("mt1-", testManager); 335 DoTestStorage("mt2-", testManager); 336 337 // Note: Once all tests finish, TestManager will dispatch "test-pass" message, 338 // which ends the test for the parent. 339 } 340 341 class ReportWritten { 342 public: 343 ReportWritten(const std::string& aRecordId, const std::string& aValue) 344 : mRecordId(aRecordId), mValue(aValue) {} 345 void operator()() { 346 FakeDecryptor::Message("stored " + mRecordId + " " + mValue); 347 } 348 349 const std::string mRecordId; 350 const std::string mValue; 351 }; 352 353 class ReportReadStatusContinuation : public ReadContinuation { 354 public: 355 explicit ReportReadStatusContinuation(const std::string& aRecordId) 356 : mRecordId(aRecordId) {} 357 void operator()(bool aSuccess, const uint8_t* aData, 358 uint32_t aDataSize) override { 359 if (!aSuccess) { 360 FakeDecryptor::Message("retrieve " + mRecordId + " failed"); 361 } else { 362 std::stringstream ss; 363 ss << aDataSize; 364 std::string len; 365 ss >> len; 366 FakeDecryptor::Message("retrieve " + mRecordId + " succeeded (length " + 367 len + " bytes)"); 368 } 369 } 370 std::string mRecordId; 371 }; 372 373 class ReportReadRecordContinuation : public ReadContinuation { 374 public: 375 explicit ReportReadRecordContinuation(const std::string& aRecordId) 376 : mRecordId(aRecordId) {} 377 void operator()(bool aSuccess, const uint8_t* aData, 378 uint32_t aDataSize) override { 379 if (!aSuccess) { 380 FakeDecryptor::Message("retrieved " + mRecordId + " failed"); 381 } else { 382 FakeDecryptor::Message( 383 "retrieved " + mRecordId + " " + 384 std::string(reinterpret_cast<const char*>(aData), aDataSize)); 385 } 386 } 387 std::string mRecordId; 388 }; 389 390 enum ShutdownMode { ShutdownNormal, ShutdownTimeout, ShutdownStoreToken }; 391 392 static ShutdownMode sShutdownMode = ShutdownNormal; 393 394 void FakeDecryptor::UpdateSession(uint32_t aPromiseId, const char* aSessionId, 395 uint32_t aSessionIdLength, 396 const uint8_t* aResponse, 397 uint32_t aResponseSize) { 398 MOZ_ASSERT(FakeDecryptor::sInstance->mHost, 399 "FakeDecryptor::sInstance->mHost should not be null"); 400 std::string response((const char*)aResponse, 401 (const char*)(aResponse) + aResponseSize); 402 std::vector<std::string> tokens = Tokenize(response); 403 const std::string& task = tokens[0]; 404 if (task == "test-storage") { 405 TestStorage(); 406 } else if (task == "store") { 407 // send "stored record" message on complete. 408 const std::string& id = tokens[1]; 409 const std::string& value = tokens[2]; 410 WriteRecord(FakeDecryptor::sInstance->mHost, id, value, 411 ReportWritten(id, value), 412 WriteRecordFailureTask("FAIL in writing record.")); 413 } else if (task == "retrieve") { 414 const std::string& id = tokens[1]; 415 ReadRecord(FakeDecryptor::sInstance->mHost, id, 416 ReportReadStatusContinuation(id)); 417 } else if (task == "shutdown-mode") { 418 const std::string& mode = tokens[1]; 419 if (mode == "timeout") { 420 sShutdownMode = ShutdownTimeout; 421 } else if (mode == "token") { 422 sShutdownMode = ShutdownStoreToken; 423 Message("shutdown-token received " + tokens[2]); 424 } 425 } else if (task == "retrieve-shutdown-token") { 426 ReadRecord(FakeDecryptor::sInstance->mHost, "shutdown-token", 427 ReportReadRecordContinuation("shutdown-token")); 428 } else if (task == "test-op-apis") { 429 mozilla::cdmtest::TestOuputProtectionAPIs(); 430 } 431 }