agent_improvements.patch (20501B)
1 commit 4ad63eb3aa65ce7baa08190aac2770540dc25f43 2 Author: Greg Stoll <gstoll@mozilla.com> 3 Date: Wed, 27 Mar 2024 12:13:56 -0500 4 5 Mozilla improvements to content_analysis_sdk 6 7 - add ability for demo agent to block/warn/report specific regexes 8 - add ability for demo agent to chose a sequence of delays to apply 9 10 diff --git a/CMakeLists.txt b/CMakeLists.txt 11 index 39477223f031c..5dacc81031117 100644 12 --- a/CMakeLists.txt 13 +++ b/CMakeLists.txt 14 @@ -203,6 +203,7 @@ add_executable(agent 15 ./demo/agent.cc 16 ./demo/handler.h 17 ) 18 +target_compile_features(agent PRIVATE cxx_std_17) 19 target_include_directories(agent PRIVATE ${AGENT_INCLUDES}) 20 target_link_libraries(agent PRIVATE cac_agent) 21 22 diff --git a/agent/src/event_win.h b/agent/src/event_win.h 23 index 9f8b6903566f2..f631f693dcd9c 100644 24 --- a/agent/src/event_win.h 25 +++ b/agent/src/event_win.h 26 @@ -28,6 +28,12 @@ class ContentAnalysisEventWin : public ContentAnalysisEventBase { 27 ResultCode Close() override; 28 ResultCode Send() override; 29 std::string DebugString() const override; 30 + std::string SerializeStringToSendToBrowser() { 31 + return agent_to_chrome()->SerializeAsString(); 32 + } 33 + void SetResponseSent() { response_sent_ = true; } 34 + 35 + HANDLE Pipe() const { return hPipe_; } 36 37 private: 38 void Shutdown(); 39 diff --git a/browser/src/client_win.cc b/browser/src/client_win.cc 40 index 9d3d7e8c52662..039946d131398 100644 41 --- a/browser/src/client_win.cc 42 +++ b/browser/src/client_win.cc 43 @@ -418,7 +418,11 @@ DWORD ClientWin::ConnectToPipe(const std::string& pipename, HANDLE* handle) { 44 45 void ClientWin::Shutdown() { 46 if (hPipe_ != INVALID_HANDLE_VALUE) { 47 - FlushFileBuffers(hPipe_); 48 + // TODO: This trips the LateWriteObserver. We could move this earlier 49 + // (before the LateWriteObserver is created) or just remove it, although 50 + // the later could mean an ACK message is not processed by the agent 51 + // in time. 52 + // FlushFileBuffers(hPipe_); 53 CloseHandle(hPipe_); 54 hPipe_ = INVALID_HANDLE_VALUE; 55 } 56 diff --git a/demo/agent.cc b/demo/agent.cc 57 index ff8b93f647ebd..3e168b0915a0c 100644 58 --- a/demo/agent.cc 59 +++ b/demo/agent.cc 60 @@ -2,12 +2,18 @@ 61 // Use of this source code is governed by a BSD-style license that can be 62 // found in the LICENSE file. 63 64 +#include <algorithm> 65 #include <fstream> 66 #include <iostream> 67 #include <string> 68 +#include <regex> 69 +#include <vector> 70 71 #include "content_analysis/sdk/analysis_agent.h" 72 #include "demo/handler.h" 73 +#include "demo/handler_misbehaving.h" 74 + 75 +using namespace content_analysis::sdk; 76 77 // Different paths are used depending on whether this agent should run as a 78 // use specific agent or not. These values are chosen to match the test 79 @@ -19,19 +25,50 @@ constexpr char kPathSystem[] = "brcm_chrm_cas"; 80 std::string path = kPathSystem; 81 bool use_queue = false; 82 bool user_specific = false; 83 -unsigned long delay = 0; // In seconds. 84 +std::vector<unsigned long> delays = {0}; // In seconds. 85 unsigned long num_threads = 8u; 86 std::string save_print_data_path = ""; 87 +RegexArray toBlock, toWarn, toReport; 88 +static bool useMisbehavingHandler = false; 89 +static std::string modeStr; 90 91 // Command line parameters. 92 -constexpr const char* kArgDelaySpecific = "--delay="; 93 +constexpr const char* kArgDelaySpecific = "--delays="; 94 constexpr const char* kArgPath = "--path="; 95 constexpr const char* kArgQueued = "--queued"; 96 constexpr const char* kArgThreads = "--threads="; 97 constexpr const char* kArgUserSpecific = "--user"; 98 +constexpr const char* kArgToBlock = "--toblock="; 99 +constexpr const char* kArgToWarn = "--towarn="; 100 +constexpr const char* kArgToReport = "--toreport="; 101 +constexpr const char* kArgMisbehave = "--misbehave="; 102 constexpr const char* kArgHelp = "--help"; 103 constexpr const char* kArgSavePrintRequestDataTo = "--save-print-request-data-to="; 104 105 +std::map<std::string, Mode> sStringToMode = { 106 +#define AGENT_MODE(name) {#name, Mode::Mode_##name}, 107 +#include "modes.h" 108 +#undef AGENT_MODE 109 +}; 110 + 111 +std::map<Mode, std::string> sModeToString = { 112 +#define AGENT_MODE(name) {Mode::Mode_##name, #name}, 113 +#include "modes.h" 114 +#undef AGENT_MODE 115 +}; 116 + 117 +std::vector<std::pair<std::string, std::regex>> 118 +ParseRegex(const std::string str) { 119 + std::vector<std::pair<std::string, std::regex>> ret; 120 + for (auto it = str.begin(); it != str.end(); /* nop */) { 121 + auto it2 = std::find(it, str.end(), ','); 122 + ret.push_back(std::make_pair(std::string(it, it2), std::regex(it, it2))); 123 + it = it2 == str.end() ? it2 : it2 + 1; 124 + } 125 + 126 + return ret; 127 +} 128 + 129 bool ParseCommandLine(int argc, char* argv[]) { 130 for (int i = 1; i < argc; ++i) { 131 const std::string arg = argv[i]; 132 @@ -44,16 +81,38 @@ bool ParseCommandLine(int argc, char* argv[]) { 133 path = kPathUser; 134 user_specific = true; 135 } else if (arg.find(kArgDelaySpecific) == 0) { 136 - delay = std::stoul(arg.substr(strlen(kArgDelaySpecific))); 137 + std::string delaysStr = arg.substr(strlen(kArgDelaySpecific)); 138 + delays.clear(); 139 + size_t posStart = 0, posEnd; 140 + unsigned long delay; 141 + while ((posEnd = delaysStr.find(',', posStart)) != std::string::npos) { 142 + delay = std::stoul(delaysStr.substr(posStart, posEnd - posStart)); 143 + if (delay > 30) { 144 + delay = 30; 145 + } 146 + delays.push_back(delay); 147 + posStart = posEnd + 1; 148 + } 149 + delay = std::stoul(delaysStr.substr(posStart)); 150 if (delay > 30) { 151 delay = 30; 152 } 153 + delays.push_back(delay); 154 } else if (arg.find(kArgPath) == 0) { 155 path = arg.substr(strlen(kArgPath)); 156 } else if (arg.find(kArgQueued) == 0) { 157 use_queue = true; 158 } else if (arg.find(kArgThreads) == 0) { 159 num_threads = std::stoul(arg.substr(strlen(kArgThreads))); 160 + } else if (arg.find(kArgToBlock) == 0) { 161 + toBlock = ParseRegex(arg.substr(strlen(kArgToBlock))); 162 + } else if (arg.find(kArgToWarn) == 0) { 163 + toWarn = ParseRegex(arg.substr(strlen(kArgToWarn))); 164 + } else if (arg.find(kArgToReport) == 0) { 165 + toReport = ParseRegex(arg.substr(strlen(kArgToReport))); 166 + } else if (arg.find(kArgMisbehave) == 0) { 167 + modeStr = arg.substr(strlen(kArgMisbehave)); 168 + useMisbehavingHandler = true; 169 } else if (arg.find(kArgHelp) == 0) { 170 return false; 171 } else if (arg.find(kArgSavePrintRequestDataTo) == 0) { 172 @@ -72,13 +131,17 @@ void PrintHelp() { 173 << "A simple agent to process content analysis requests." << std::endl 174 << "Data containing the string 'block' blocks the request data from being used." << std::endl 175 << std::endl << "Options:" << std::endl 176 - << kArgDelaySpecific << "<delay> : Add a delay to request processing in seconds (max 30)." << std::endl 177 + << kArgDelaySpecific << "<delay1,delay2,...> : Add delays to request processing in seconds. Delays are limited to 30 seconds and are applied round-robin to requests. Default is 0." << std::endl 178 << kArgPath << " <path> : Used the specified path instead of default. Must come after --user." << std::endl 179 << kArgQueued << " : Queue requests for processing in a background thread" << std::endl 180 << kArgThreads << " : When queued, number of threads in the request processing thread pool" << std::endl 181 << kArgUserSpecific << " : Make agent OS user specific." << std::endl 182 << kArgHelp << " : prints this help message" << std::endl 183 - << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests"; 184 + << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests" << std::endl 185 + << kArgToBlock << "<regex> : Regular expression matching file and text content to block." << std::endl 186 + << kArgToWarn << "<regex> : Regular expression matching file and text content to warn about." << std::endl 187 + << kArgToReport << "<regex> : Regular expression matching file and text content to report." << std::endl 188 + << kArgMisbehave << "<mode> : Use 'misbehaving' agent in given mode for testing purposes." << std::endl; 189 } 190 191 int main(int argc, char* argv[]) { 192 @@ -87,9 +150,17 @@ int main(int argc, char* argv[]) { 193 return 1; 194 } 195 196 - auto handler = use_queue 197 - ? std::make_unique<QueuingHandler>(num_threads, delay, save_print_data_path) 198 - : std::make_unique<Handler>(delay, save_print_data_path); 199 + auto handler = 200 + useMisbehavingHandler 201 + ? MisbehavingHandler::Create(modeStr, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport)) 202 + : use_queue 203 + ? std::make_unique<QueuingHandler>(num_threads, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport)) 204 + : std::make_unique<Handler>(std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport)); 205 + 206 + if (!handler) { 207 + std::cout << "[Demo] Failed to construct handler." << std::endl; 208 + return 1; 209 + } 210 211 // Each agent uses a unique name to identify itself with Google Chrome. 212 content_analysis::sdk::ResultCode rc; 213 diff --git a/demo/handler.h b/demo/handler.h 214 index 9d1ccfdf9857a..88599963c51b0 100644 215 --- a/demo/handler.h 216 +++ b/demo/handler.h 217 @@ -7,31 +7,51 @@ 218 219 #include <time.h> 220 221 +#include <algorithm> 222 +#include <atomic> 223 #include <chrono> 224 #include <cstdio> 225 #include <fstream> 226 #include <iostream> 227 +#include <optional> 228 #include <thread> 229 #include <utility> 230 +#include <regex> 231 #include <vector> 232 233 #include "content_analysis/sdk/analysis_agent.h" 234 #include "demo/atomic_output.h" 235 #include "demo/request_queue.h" 236 237 +using RegexArray = std::vector<std::pair<std::string, std::regex>>; 238 + 239 // An AgentEventHandler that dumps requests information to stdout and blocks 240 // any requests that have the keyword "block" in their data 241 class Handler : public content_analysis::sdk::AgentEventHandler { 242 public: 243 using Event = content_analysis::sdk::ContentAnalysisEvent; 244 245 - Handler(unsigned long delay, const std::string& print_data_file_path) : 246 - delay_(delay), print_data_file_path_(print_data_file_path) { 247 - } 248 + Handler(std::vector<unsigned long>&& delays, const std::string& print_data_file_path, 249 + RegexArray&& toBlock = RegexArray(), 250 + RegexArray&& toWarn = RegexArray(), 251 + RegexArray&& toReport = RegexArray()) : 252 + toBlock_(std::move(toBlock)), toWarn_(std::move(toWarn)), toReport_(std::move(toReport)), 253 + delays_(std::move(delays)), print_data_file_path_(print_data_file_path) {} 254 255 - unsigned long delay() { return delay_; } 256 + const std::vector<unsigned long> delays() { return delays_; } 257 + size_t nextDelayIndex() const { return nextDelayIndex_; } 258 259 protected: 260 + // subclasses can override this 261 + // returns whether the response has been set 262 + virtual bool SetCustomResponse(AtomicCout& aout, std::unique_ptr<Event>& event) { 263 + return false; 264 + } 265 + // subclasses can override this 266 + // returns whether the response has been sent 267 + virtual bool SendCustomResponse(std::unique_ptr<Event>& event) { 268 + return false; 269 + } 270 // Analyzes one request from Google Chrome and responds back to the browser 271 // with either an allow or block verdict. 272 void AnalyzeContent(AtomicCout& aout, std::unique_ptr<Event> event) { 273 @@ -43,29 +63,25 @@ class Handler : public content_analysis::sdk::AgentEventHandler { 274 275 DumpEvent(aout.stream(), event.get()); 276 277 - bool block = false; 278 bool success = true; 279 - unsigned long delay = delay_; 280 - 281 - if (event->GetRequest().has_text_content()) { 282 - block = ShouldBlockRequest( 283 - event->GetRequest().text_content()); 284 - GetFileSpecificDelay(event->GetRequest().text_content(), &delay); 285 - } else if (event->GetRequest().has_file_path()) { 286 - std::string content; 287 - success = 288 - ReadContentFromFile(event->GetRequest().file_path(), 289 - &content); 290 - if (success) { 291 - block = ShouldBlockRequest(content); 292 - GetFileSpecificDelay(content, &delay); 293 + std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action> caResponse; 294 + bool setResponse = SetCustomResponse(aout, event); 295 + if (!setResponse) { 296 + caResponse = content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK; 297 + if (event->GetRequest().has_text_content()) { 298 + caResponse = DecideCAResponse( 299 + event->GetRequest().text_content(), aout.stream()); 300 + } else if (event->GetRequest().has_file_path()) { 301 + // TODO: Fix downloads to store file *first* so we can check contents. 302 + // Until then, just check the file name: 303 + caResponse = DecideCAResponse( 304 + event->GetRequest().file_path(), aout.stream()); 305 + } else if (event->GetRequest().has_print_data()) { 306 + // In the case of print request, normally the PDF bytes would be parsed 307 + // for sensitive data violations. To keep this class simple, only the 308 + // URL is checked for the word "block". 309 + caResponse = DecideCAResponse(event->GetRequest().request_data().url(), aout.stream()); 310 } 311 - } else if (event->GetRequest().has_print_data()) { 312 - // In the case of print request, normally the PDF bytes would be parsed 313 - // for sensitive data violations. To keep this class simple, only the 314 - // URL is checked for the word "block". 315 - block = ShouldBlockRequest(event->GetRequest().request_data().url()); 316 - GetFileSpecificDelay(event->GetRequest().request_data().url(), &delay); 317 } 318 319 if (!success) { 320 @@ -75,22 +91,44 @@ class Handler : public content_analysis::sdk::AgentEventHandler { 321 content_analysis::sdk::ContentAnalysisResponse::Result::FAILURE); 322 aout.stream() << " Verdict: failed to reach verdict: "; 323 aout.stream() << event->DebugString() << std::endl; 324 - } else if (block) { 325 - auto rc = content_analysis::sdk::SetEventVerdictToBlock(event.get()); 326 - aout.stream() << " Verdict: block"; 327 - if (rc != content_analysis::sdk::ResultCode::OK) { 328 - aout.stream() << " error: " 329 - << content_analysis::sdk::ResultCodeToString(rc) << std::endl; 330 - aout.stream() << " " << event->DebugString() << std::endl; 331 + } else { 332 + aout.stream() << " Verdict: "; 333 + if (caResponse) { 334 + switch (caResponse.value()) { 335 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK: 336 + aout.stream() << "BLOCK"; 337 + break; 338 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN: 339 + aout.stream() << "WARN"; 340 + break; 341 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY: 342 + aout.stream() << "REPORT_ONLY"; 343 + break; 344 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_ACTION_UNSPECIFIED: 345 + aout.stream() << "ACTION_UNSPECIFIED"; 346 + break; 347 + default: 348 + aout.stream() << "<error>"; 349 + break; 350 + } 351 + auto rc = 352 + content_analysis::sdk::SetEventVerdictTo(event.get(), caResponse.value()); 353 + if (rc != content_analysis::sdk::ResultCode::OK) { 354 + aout.stream() << " error: " 355 + << content_analysis::sdk::ResultCodeToString(rc) << std::endl; 356 + aout.stream() << " " << event->DebugString() << std::endl; 357 + } 358 + aout.stream() << std::endl; 359 + } else { 360 + aout.stream() << " Verdict: allow" << std::endl; 361 } 362 aout.stream() << std::endl; 363 - } else { 364 - aout.stream() << " Verdict: allow" << std::endl; 365 } 366 - 367 aout.stream() << std::endl; 368 369 // If a delay is specified, wait that much. 370 + size_t nextDelayIndex = nextDelayIndex_.fetch_add(1); 371 + unsigned long delay = delays_[nextDelayIndex % delays_.size()]; 372 if (delay > 0) { 373 aout.stream() << "Delaying response to " << event->GetRequest().request_token() 374 << " for " << delay << "s" << std::endl<< std::endl; 375 @@ -99,16 +137,19 @@ class Handler : public content_analysis::sdk::AgentEventHandler { 376 } 377 378 // Send the response back to Google Chrome. 379 - auto rc = event->Send(); 380 - if (rc != content_analysis::sdk::ResultCode::OK) { 381 - aout.stream() << "[Demo] Error sending response: " 382 - << content_analysis::sdk::ResultCodeToString(rc) 383 - << std::endl; 384 - aout.stream() << event->DebugString() << std::endl; 385 + bool sentCustomResponse = SendCustomResponse(event); 386 + if (!sentCustomResponse) { 387 + auto rc = event->Send(); 388 + if (rc != content_analysis::sdk::ResultCode::OK) { 389 + aout.stream() << "[Demo] Error sending response: " 390 + << content_analysis::sdk::ResultCodeToString(rc) 391 + << std::endl; 392 + aout.stream() << event->DebugString() << std::endl; 393 + } 394 } 395 } 396 397 - private: 398 + protected: 399 void OnBrowserConnected( 400 const content_analysis::sdk::BrowserInfo& info) override { 401 AtomicCout aout; 402 @@ -362,21 +403,40 @@ class Handler : public content_analysis::sdk::AgentEventHandler { 403 return true; 404 } 405 406 - bool ShouldBlockRequest(const std::string& content) { 407 - // Determines if the request should be blocked. For this simple example 408 - // the content is blocked if the string "block" is found. Otherwise the 409 - // content is allowed. 410 - return content.find("block") != std::string::npos; 411 - } 412 - 413 - void GetFileSpecificDelay(const std::string& content, unsigned long* delay) { 414 - auto pos = content.find("delay="); 415 - if (pos != std::string::npos) { 416 - std::sscanf(content.substr(pos).c_str(), "delay=%lu", delay); 417 + std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action> 418 + DecideCAResponse(const std::string& content, std::stringstream& stream) { 419 + for (auto& r : toBlock_) { 420 + if (std::regex_search(content, r.second)) { 421 + stream << "'" << content << "' matches BLOCK regex '" 422 + << r.first << "'" << std::endl; 423 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK; 424 + } 425 } 426 + for (auto& r : toWarn_) { 427 + if (std::regex_search(content, r.second)) { 428 + stream << "'" << content << "' matches WARN regex '" 429 + << r.first << "'" << std::endl; 430 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN; 431 + } 432 + } 433 + for (auto& r : toReport_) { 434 + if (std::regex_search(content, r.second)) { 435 + stream << "'" << content << "' matches REPORT_ONLY regex '" 436 + << r.first << "'" << std::endl; 437 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY; 438 + } 439 + } 440 + stream << "'" << content << "' was ALLOWed\n"; 441 + return {}; 442 } 443 444 - unsigned long delay_; 445 + // For the demo, block any content that matches these wildcards. 446 + RegexArray toBlock_; 447 + RegexArray toWarn_; 448 + RegexArray toReport_; 449 + 450 + std::vector<unsigned long> delays_; 451 + std::atomic<size_t> nextDelayIndex_; 452 std::string print_data_file_path_; 453 }; 454 455 @@ -384,8 +444,11 @@ class Handler : public content_analysis::sdk::AgentEventHandler { 456 // any requests that have the keyword "block" in their data 457 class QueuingHandler : public Handler { 458 public: 459 - QueuingHandler(unsigned long threads, unsigned long delay, const std::string& print_data_file_path) 460 - : Handler(delay, print_data_file_path) { 461 + QueuingHandler(unsigned long threads, std::vector<unsigned long>&& delays, const std::string& print_data_file_path, 462 + RegexArray&& toBlock = RegexArray(), 463 + RegexArray&& toWarn = RegexArray(), 464 + RegexArray&& toReport = RegexArray()) 465 + : Handler(std::move(delays), print_data_file_path, std::move(toBlock), std::move(toWarn), std::move(toReport)) { 466 StartBackgroundThreads(threads); 467 } 468 469 @@ -421,6 +484,8 @@ class QueuingHandler : public Handler { 470 aout.stream() << std::endl << "----------" << std::endl; 471 aout.stream() << "Thread: " << std::this_thread::get_id() 472 << std::endl; 473 + aout.stream() << "Delaying request processing for " 474 + << handler->delays()[handler->nextDelayIndex() % handler->delays().size()] << "s" << std::endl << std::endl; 475 aout.flush(); 476 477 handler->AnalyzeContent(aout, std::move(event)); 478 -- 479 2.42.0.windows.2