ssltunnel.cpp (57047B)
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* This Source Code Form is subject to the terms of the Mozilla Public 3 * License, v. 2.0. If a copy of the MPL was not distributed with this 4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 5 6 /* 7 * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to 8 * be plagued with the usual problems endemic to C (buffer overflows 9 * and the like). We don't especially care here (but would accept 10 * patches!) because this is only intended for use in our test 11 * harnesses in controlled situations where input is guaranteed not to 12 * be malicious. 13 */ 14 15 #include "ScopedNSSTypes.h" 16 #include <assert.h> 17 #include <stdio.h> 18 #include <string> 19 #include <vector> 20 #include <stdarg.h> 21 #include "prinit.h" 22 #include "prerror.h" 23 #include "prenv.h" 24 #include "prnetdb.h" 25 #include "prtpool.h" 26 #include "nss.h" 27 #include "keyhi.h" 28 #include "ssl.h" 29 #include "sslproto.h" 30 #include "plhash.h" 31 #include "mozilla/Sprintf.h" 32 33 using namespace mozilla; 34 using namespace mozilla::psm; 35 using std::string; 36 using std::vector; 37 38 #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7))) 39 #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7))) 40 #define DELIM_TABLE_SIZE 32 41 42 // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n 43 // is 0 through 3. The default is 1, INFO level logging. 44 enum LogLevel { 45 LEVEL_DEBUG = 0, 46 LEVEL_INFO = 1, 47 LEVEL_ERROR = 2, 48 LEVEL_SILENT = 3 49 } gLogLevel, 50 gLastLogLevel; 51 52 #define _LOG_OUTPUT(level, func, params) \ 53 PR_BEGIN_MACRO \ 54 if (level >= gLogLevel) { \ 55 gLastLogLevel = level; \ 56 func params; \ 57 } \ 58 PR_END_MACRO 59 60 // The most verbose output 61 #define LOG_DEBUG(params) _LOG_OUTPUT(LEVEL_DEBUG, printf, params) 62 63 // Top level informative messages 64 #define LOG_INFO(params) _LOG_OUTPUT(LEVEL_INFO, printf, params) 65 66 // Serious errors that must be logged always until completely gag 67 #define LOG_ERROR(params) _LOG_OUTPUT(LEVEL_ERROR, eprintf, params) 68 69 // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message 70 // will be put to the stdout instead of stderr to keep continuity with other 71 // LOG_DEBUG message output 72 #define LOG_ERRORD(params) \ 73 PR_BEGIN_MACRO \ 74 if (gLogLevel == LEVEL_DEBUG) \ 75 _LOG_OUTPUT(LEVEL_ERROR, printf, params); \ 76 else \ 77 _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \ 78 PR_END_MACRO 79 80 // If there is any output written between LOG_BEGIN_BLOCK() and 81 // LOG_END_BLOCK() then a new line will be put to the proper output (out/err) 82 #define LOG_BEGIN_BLOCK() gLastLogLevel = LEVEL_SILENT; 83 84 #define LOG_END_BLOCK() \ 85 PR_BEGIN_MACRO \ 86 if (gLastLogLevel == LEVEL_ERROR) LOG_ERROR(("\n")); \ 87 if (gLastLogLevel < LEVEL_ERROR) _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \ 88 PR_END_MACRO 89 90 int eprintf(const char* str, ...) { 91 va_list ap; 92 va_start(ap, str); 93 int result = vfprintf(stderr, str, ap); 94 va_end(ap); 95 return result; 96 } 97 98 // Copied from nsCRT 99 char* strtok2(char* string, const char* delims, char** newStr) { 100 PR_ASSERT(string); 101 102 char delimTable[DELIM_TABLE_SIZE]; 103 uint32_t i; 104 char* result; 105 char* str = string; 106 107 for (i = 0; i < DELIM_TABLE_SIZE; i++) delimTable[i] = '\0'; 108 109 for (i = 0; delims[i]; i++) { 110 SET_DELIM(delimTable, static_cast<uint8_t>(delims[i])); 111 } 112 113 // skip to beginning 114 while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { 115 str++; 116 } 117 result = str; 118 119 // fix up the end of the token 120 while (*str) { 121 if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { 122 *str++ = '\0'; 123 break; 124 } 125 str++; 126 } 127 *newStr = str; 128 129 return str == result ? nullptr : result; 130 } 131 132 enum client_auth_option { caNone = 0, caRequire = 1, caRequest = 2 }; 133 134 // Structs for passing data into jobs on the thread pool 135 struct server_info_t { 136 int32_t listen_port; 137 string cert_nickname; 138 PLHashTable* host_cert_table; 139 PLHashTable* host_clientauth_table; 140 PLHashTable* host_redir_table; 141 PLHashTable* host_ssl3_table; 142 PLHashTable* host_tls1_table; 143 PLHashTable* host_tls11_table; 144 PLHashTable* host_tls12_table; 145 PLHashTable* host_tls13_table; 146 PLHashTable* host_3des_table; 147 PLHashTable* host_failhandshake_table; 148 }; 149 150 struct connection_info_t { 151 PRFileDesc* client_sock; 152 PRNetAddr client_addr; 153 server_info_t* server_info; 154 // the original host in the Host: header for this connection is 155 // stored here, for proxied connections 156 string original_host; 157 // true if no SSL should be used for this connection 158 bool http_proxy_only; 159 // true if this connection is for a WebSocket 160 bool iswebsocket; 161 }; 162 163 struct server_match_t { 164 string fullHost; 165 bool matched; 166 }; 167 168 const int32_t BUF_SIZE = 16384; 169 const int32_t BUF_MARGIN = 1024; 170 const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN; 171 172 struct relayBuffer { 173 char *buffer, *bufferhead, *buffertail, *bufferend; 174 175 relayBuffer() { 176 // Leave 1024 bytes more for request line manipulations 177 bufferhead = buffertail = buffer = new char[BUF_TOTAL]; 178 bufferend = buffer + BUF_SIZE; 179 } 180 181 ~relayBuffer() { delete[] buffer; } 182 183 void compact() { 184 if (buffertail == bufferhead) buffertail = bufferhead = buffer; 185 } 186 187 bool empty() { return bufferhead == buffertail; } 188 size_t areafree() { return bufferend - buffertail; } 189 size_t margin() { return areafree() + BUF_MARGIN; } 190 size_t present() { return buffertail - bufferhead; } 191 }; 192 193 // These numbers are multiplied by the number of listening ports (actual 194 // servers running). According the thread pool implementation there is no 195 // need to limit the number of threads initially, threads are allocated 196 // dynamically and stored in a linked list. Initial number of 2 is chosen 197 // to allocate a thread for socket accept and preallocate one for the first 198 // connection that is with high probability expected to come. 199 const uint32_t INITIAL_THREADS = 2; 200 const uint32_t MAX_THREADS = 100; 201 const uint32_t DEFAULT_STACKSIZE = (512 * 1024); 202 203 // global data 204 MOZ_RUNINIT string nssconfigdir; 205 MOZ_RUNINIT vector<server_info_t> servers; 206 PRNetAddr remote_addr; 207 PRNetAddr websocket_server; 208 PRThreadPool* threads = nullptr; 209 PRLock* shutdown_lock = nullptr; 210 PRCondVar* shutdown_condvar = nullptr; 211 // Not really used, unless something fails to start 212 bool shutdown_server = false; 213 bool do_http_proxy = false; 214 bool any_host_spec_config = false; 215 bool listen_public = false; 216 217 int ClientAuthValueComparator(const void* v1, const void* v2) { 218 int a = *static_cast<const client_auth_option*>(v1) - 219 *static_cast<const client_auth_option*>(v2); 220 if (a == 0) return 0; 221 if (a > 0) return 1; 222 // (a < 0) 223 return -1; 224 } 225 226 static int match_hostname(PLHashEntry* he, int index, void* arg) { 227 server_match_t* match = (server_match_t*)arg; 228 if (match->fullHost.find((char*)he->key) != string::npos) 229 match->matched = true; 230 return HT_ENUMERATE_NEXT; 231 } 232 233 /* 234 * Signal the main thread that the application should shut down. 235 */ 236 void SignalShutdown() { 237 PR_Lock(shutdown_lock); 238 PR_NotifyCondVar(shutdown_condvar); 239 PR_Unlock(shutdown_lock); 240 } 241 242 // available flags 243 enum { 244 USE_SSL3 = 1 << 0, 245 USE_3DES = 1 << 1, 246 FAIL_HANDSHAKE = 1 << 2, 247 USE_TLS1 = 1 << 3, 248 USE_TLS1_1 = 1 << 4, 249 USE_TLS1_2 = 1 << 5, 250 USE_TLS1_3 = 1 << 6 251 }; 252 253 bool ReadConnectRequest(server_info_t* server_info, relayBuffer& buffer, 254 int32_t* result, string& certificate, 255 client_auth_option* clientauth, string& host, 256 string& location, int32_t* flags) { 257 if (buffer.present() < 4) { 258 LOG_DEBUG( 259 (" !! only %d bytes present in the buffer", (int)buffer.present())); 260 return false; 261 } 262 if (strncmp(buffer.buffertail - 4, "\r\n\r\n", 4)) { 263 LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x", 264 *(buffer.buffertail - 4), *(buffer.buffertail - 3), 265 *(buffer.buffertail - 2), *(buffer.buffertail - 1))); 266 return false; 267 } 268 269 LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", 270 (int)buffer.present(), buffer.bufferhead)); 271 272 *result = 400; 273 274 char* token; 275 char* _caret; 276 token = strtok2(buffer.bufferhead, " ", &_caret); 277 if (!token) { 278 LOG_ERRORD((" no space found")); 279 return true; 280 } 281 if (strcmp(token, "CONNECT")) { 282 LOG_ERRORD((" not CONNECT request but %s", token)); 283 return true; 284 } 285 286 token = strtok2(_caret, " ", &_caret); 287 void* c = PL_HashTableLookup(server_info->host_cert_table, token); 288 if (c) certificate = static_cast<char*>(c); 289 290 host = "https://"; 291 host += token; 292 293 c = PL_HashTableLookup(server_info->host_clientauth_table, token); 294 if (c) 295 *clientauth = *static_cast<client_auth_option*>(c); 296 else 297 *clientauth = caNone; 298 299 void* redir = PL_HashTableLookup(server_info->host_redir_table, token); 300 if (redir) location = static_cast<char*>(redir); 301 302 if (PL_HashTableLookup(server_info->host_ssl3_table, token)) { 303 *flags |= USE_SSL3; 304 } 305 306 if (PL_HashTableLookup(server_info->host_3des_table, token)) { 307 *flags |= USE_3DES; 308 } 309 310 if (PL_HashTableLookup(server_info->host_tls1_table, token)) { 311 *flags |= USE_TLS1; 312 } 313 314 if (PL_HashTableLookup(server_info->host_tls11_table, token)) { 315 *flags |= USE_TLS1_1; 316 } 317 318 if (PL_HashTableLookup(server_info->host_tls12_table, token)) { 319 *flags |= USE_TLS1_2; 320 } 321 322 if (PL_HashTableLookup(server_info->host_tls13_table, token)) { 323 *flags |= USE_TLS1_3; 324 } 325 326 if (PL_HashTableLookup(server_info->host_failhandshake_table, token)) { 327 *flags |= FAIL_HANDSHAKE; 328 } 329 330 token = strtok2(_caret, "/", &_caret); 331 if (strcmp(token, "HTTP")) { 332 LOG_ERRORD((" not tailed with HTTP but with %s", token)); 333 return true; 334 } 335 336 *result = (redir) ? 302 : 200; 337 return true; 338 } 339 340 bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, 341 const string& certificate, 342 const client_auth_option clientAuth, 343 int32_t flags) { 344 const char* certnick = 345 certificate.empty() ? si->cert_nickname.c_str() : certificate.c_str(); 346 347 UniqueCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr)); 348 if (!cert) { 349 LOG_ERROR(("Failed to find cert %s\n", certnick)); 350 return false; 351 } 352 353 UniqueSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert.get(), nullptr)); 354 if (!privKey) { 355 LOG_ERROR(("Failed to find private key\n")); 356 return false; 357 } 358 359 PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket); 360 if (!ssl_socket) { 361 LOG_ERROR(("Error importing SSL socket\n")); 362 return false; 363 } 364 365 if (flags & FAIL_HANDSHAKE) { 366 // deliberately cause handshake to fail by sending the client a client hello 367 SSL_ResetHandshake(ssl_socket, false); 368 return true; 369 } 370 371 SSLKEAType certKEA = NSS_FindCertKEAType(cert.get()); 372 if (SSL_ConfigSecureServer(ssl_socket, cert.get(), privKey.get(), certKEA) != 373 SECSuccess) { 374 LOG_ERROR(("Error configuring SSL server socket\n")); 375 return false; 376 } 377 378 SSL_OptionSet(ssl_socket, SSL_SECURITY, true); 379 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false); 380 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true); 381 SSL_OptionSet(ssl_socket, SSL_ENABLE_SESSION_TICKETS, true); 382 383 if (clientAuth != caNone) { 384 // If we're requesting or requiring a client certificate, we should 385 // configure NSS to include the "certificate_authorities" field in the 386 // certificate request message. That way we can test that gecko properly 387 // takes note of it. 388 UniqueCERTCertificate issuer( 389 CERT_FindCertIssuer(cert.get(), PR_Now(), certUsageAnyCA)); 390 if (!issuer) { 391 LOG_DEBUG(("Failed to find issuer for %s\n", certnick)); 392 return false; 393 } 394 UniqueCERTCertList issuerList(CERT_NewCertList()); 395 if (!issuerList) { 396 LOG_ERROR(("Failed to allocate new CERTCertList\n")); 397 return false; 398 } 399 if (CERT_AddCertToListTail(issuerList.get(), issuer.get()) != SECSuccess) { 400 LOG_ERROR(("Failed to add issuer to issuerList\n")); 401 return false; 402 } 403 (void)issuer.release(); // Ownership transferred to issuerList. 404 if (SSL_SetTrustAnchors(ssl_socket, issuerList.get()) != SECSuccess) { 405 LOG_ERROR( 406 ("Failed to set certificate_authorities list for client " 407 "authentication\n")); 408 return false; 409 } 410 SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true); 411 SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire); 412 } 413 414 SSLVersionRange range = {SSL_LIBRARY_VERSION_TLS_1_3, 415 SSL_LIBRARY_VERSION_3_0}; 416 if (flags & USE_SSL3) { 417 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_3_0); 418 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_3_0); 419 } 420 if (flags & USE_TLS1) { 421 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_0); 422 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_0); 423 } 424 if (flags & USE_TLS1_1) { 425 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_1); 426 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_1); 427 } 428 if (flags & USE_TLS1_2) { 429 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_2); 430 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_2); 431 } 432 if (flags & USE_TLS1_3) { 433 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_3); 434 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_3); 435 } 436 // Set the valid range, if any were specified (if not, skip 437 // when the default range is invalid, i.e. max > min) 438 if (range.min <= range.max && 439 SSL_VersionRangeSet(ssl_socket, &range) != SECSuccess) { 440 LOG_ERROR(("Error configuring SSL socket version range\n")); 441 return false; 442 } 443 444 if (flags & USE_3DES) { 445 for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) { 446 uint16_t cipher_id = SSL_ImplementedCiphers[i]; 447 if (cipher_id == TLS_RSA_WITH_3DES_EDE_CBC_SHA) { 448 SSL_CipherPrefSet(ssl_socket, cipher_id, true); 449 } else { 450 SSL_CipherPrefSet(ssl_socket, cipher_id, false); 451 } 452 } 453 } 454 455 SSL_ResetHandshake(ssl_socket, true); 456 457 return true; 458 } 459 460 /** 461 * This function examines the buffer for a Sec-WebSocket-Location: field, 462 * and if it's present, it replaces the hostname in that field with the 463 * value in the server's original_host field. This function works 464 * in the reverse direction as AdjustWebSocketHost(), replacing the real 465 * hostname of a response with the potentially fake hostname that is expected 466 * by the browser (e.g., mochi.test). 467 * 468 * @return true if the header was adjusted successfully, or not found, false 469 * if the header is present but the url is not, which should indicate 470 * that more data needs to be read from the socket 471 */ 472 bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t* ci) { 473 assert(buffer.margin()); 474 buffer.buffertail[1] = '\0'; 475 476 char* wsloc_header = strstr(buffer.bufferhead, "Sec-WebSocket-Location:"); 477 if (!wsloc_header) { 478 return true; 479 } 480 // advance pointer to the start of the hostname 481 char* wsloc = strstr(wsloc_header, "ws://"); 482 if (!wsloc) { 483 wsloc = strstr(wsloc_header, "wss://"); 484 } 485 if (!wsloc) return false; 486 wsloc += 5; 487 // find the end of the hostname 488 char* wslocend = strchr(wsloc + 1, '/'); 489 if (!wslocend) return false; 490 char* crlf = strstr(wsloc, "\r\n"); 491 if (!crlf) return false; 492 if (ci->original_host.empty()) return true; 493 494 int diff = ci->original_host.length() - (wslocend - wsloc); 495 if (diff > 0) assert(size_t(diff) <= buffer.margin()); 496 memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff); 497 buffer.buffertail += diff; 498 499 memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length()); 500 return true; 501 } 502 503 /** 504 * This function examines the buffer for a Host: field, and if it's present, 505 * it replaces the hostname in that field with the hostname in the server's 506 * remote_addr field. This is needed because proxy requests may be coming 507 * from mochitest with fake hosts, like mochi.test, and these need to be 508 * replaced with the host that the destination server is actually running 509 * on. 510 */ 511 bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t* ci) { 512 const char HEADER_UPGRADE[] = "Upgrade:"; 513 const char HEADER_HOST[] = "Host:"; 514 515 PRNetAddr inet_addr = 516 (websocket_server.inet.port ? websocket_server : remote_addr); 517 518 assert(buffer.margin()); 519 520 // Cannot use strnchr so add a null char at the end. There is always some 521 // space left because we preserve a margin. 522 buffer.buffertail[1] = '\0'; 523 524 // Verify this is a WebSocket header. 525 char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE); 526 if (!h1) return false; 527 h1 += strlen(HEADER_UPGRADE); 528 h1 += strspn(h1, " \t"); 529 char* h2 = strstr(h1, "WebSocket\r\n"); 530 if (!h2) h2 = strstr(h1, "websocket\r\n"); 531 if (!h2) h2 = strstr(h1, "Websocket\r\n"); 532 if (!h2) return false; 533 534 char* host = strstr(buffer.bufferhead, HEADER_HOST); 535 if (!host) return false; 536 // advance pointer to beginning of hostname 537 host += strlen(HEADER_HOST); 538 host += strspn(host, " \t"); 539 540 char* endhost = strstr(host, "\r\n"); 541 if (!endhost) return false; 542 543 // Save the original host, so we can use it later on responses from the 544 // server. 545 ci->original_host.assign(host, endhost - host); 546 547 char newhost[40]; 548 PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost)); 549 assert(strlen(newhost) < sizeof(newhost) - 7); 550 SprintfLiteral(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port)); 551 552 int diff = strlen(newhost) - (endhost - host); 553 if (diff > 0) assert(size_t(diff) <= buffer.margin()); 554 memmove(endhost + diff, endhost, buffer.buffertail - host - diff); 555 buffer.buffertail += diff; 556 557 memcpy(host, newhost, strlen(newhost)); 558 return true; 559 } 560 561 /** 562 * This function prefixes Request-URI path with a full scheme-host-port 563 * string. 564 */ 565 bool AdjustRequestURI(relayBuffer& buffer, string* host) { 566 assert(buffer.margin()); 567 568 // Cannot use strnchr so add a null char at the end. There is always some 569 // space left because we preserve a margin. 570 buffer.buffertail[1] = '\0'; 571 LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead)); 572 573 char *token, *path; 574 path = strchr(buffer.bufferhead, ' ') + 1; 575 if (!path) return false; 576 577 // If the path doesn't start with a slash don't change it, it is probably '*' 578 // or a full path already. Return true, we are done with this request 579 // adjustment. 580 if (*path != '/') return true; 581 582 token = strchr(path, ' ') + 1; 583 if (!token) return false; 584 585 if (strncmp(token, "HTTP/", 5)) return false; 586 587 size_t hostlength = host->length(); 588 assert(hostlength <= buffer.margin()); 589 590 memmove(path + hostlength, path, buffer.buffertail - path); 591 memcpy(path, host->c_str(), hostlength); 592 buffer.buffertail += hostlength; 593 594 return true; 595 } 596 597 bool ConnectSocket(UniquePRFileDesc& fd, const PRNetAddr* addr, 598 PRIntervalTime timeout) { 599 PRStatus stat = PR_Connect(fd.get(), addr, timeout); 600 if (stat != PR_SUCCESS) return false; 601 602 PRSocketOptionData option; 603 option.option = PR_SockOpt_Nonblocking; 604 option.value.non_blocking = true; 605 PR_SetSocketOption(fd.get(), &option); 606 607 return true; 608 } 609 610 /* 611 * Handle an incoming client connection. The server thread has already 612 * accepted the connection, so we just need to connect to the remote 613 * port and then proxy data back and forth. 614 * The data parameter is a connection_info_t*, and must be deleted 615 * by this function. 616 */ 617 void HandleConnection(void* data) { 618 connection_info_t* ci = static_cast<connection_info_t*>(data); 619 PRIntervalTime connect_timeout = PR_SecondsToInterval(30); 620 621 UniquePRFileDesc other_sock(PR_NewTCPSocket()); 622 bool client_done = false; 623 bool client_error = false; 624 bool connect_accepted = !do_http_proxy; 625 bool ssl_updated = !do_http_proxy; 626 bool expect_request_start = do_http_proxy; 627 string certificateToUse; 628 string locationHeader; 629 client_auth_option clientAuth; 630 string fullHost; 631 int32_t flags = 0; 632 633 LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n", 634 static_cast<void*>(data), static_cast<void*>(ci->client_sock), 635 static_cast<void*>(other_sock.get()))); 636 if (other_sock) { 637 int32_t numberOfSockets = 1; 638 639 relayBuffer buffers[2]; 640 641 if (!do_http_proxy) { 642 if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, 643 certificateToUse, caNone, flags)) 644 client_error = true; 645 else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout)) 646 client_error = true; 647 else 648 numberOfSockets = 2; 649 } 650 651 PRPollDesc sockets[2] = {{ci->client_sock, PR_POLL_READ, 0}, 652 {other_sock.get(), PR_POLL_READ, 0}}; 653 bool socketErrorState[2] = {false, false}; 654 655 while (!((client_error || client_done) && buffers[0].empty() && 656 buffers[1].empty())) { 657 sockets[0].in_flags |= PR_POLL_EXCEPT; 658 sockets[1].in_flags |= PR_POLL_EXCEPT; 659 LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n", 660 static_cast<void*>(data), 661 sockets[0].in_flags & PR_POLL_READ ? 'R' : '-', 662 sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-', 663 sockets[1].in_flags & PR_POLL_READ ? 'R' : '-', 664 sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-')); 665 int32_t pollStatus = 666 PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000)); 667 if (pollStatus < 0) { 668 LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n", 669 static_cast<void*>(data), pollStatus)); 670 client_error = true; 671 break; 672 } 673 674 if (pollStatus == 0) { 675 // timeout 676 LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n", 677 static_cast<void*>(data))); 678 continue; 679 } 680 681 for (int32_t s = 0; s < numberOfSockets; ++s) { 682 int32_t s2 = s == 1 ? 0 : 1; 683 int16_t out_flags = sockets[s].out_flags; 684 int16_t& in_flags = sockets[s].in_flags; 685 int16_t& in_flags2 = sockets[s2].in_flags; 686 sockets[s].out_flags = 0; 687 688 LOG_BEGIN_BLOCK(); 689 LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d", 690 static_cast<void*>(data), s == 0 ? 'c' : 's', s, 691 static_cast<void*>(sockets[s].fd), out_flags)); 692 if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) { 693 LOG_DEBUG((" :exception\n")); 694 client_error = true; 695 socketErrorState[s] = true; 696 // We got a fatal error state on the socket. Clear the output buffer 697 // for this socket to break the main loop, we will never more be able 698 // to send those data anyway. 699 buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; 700 continue; 701 } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling 702 703 if (out_flags & PR_POLL_READ && !buffers[s].areafree()) { 704 LOG_DEBUG( 705 (" no place in read buffer but got read flag, dropping it now!")); 706 in_flags &= ~PR_POLL_READ; 707 } 708 709 if (out_flags & PR_POLL_READ && buffers[s].areafree()) { 710 LOG_DEBUG((" :reading")); 711 int32_t bytesRead = 712 PR_Recv(sockets[s].fd, buffers[s].buffertail, 713 buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT); 714 715 if (bytesRead == 0) { 716 LOG_DEBUG((" socket gracefully closed")); 717 client_done = true; 718 in_flags &= ~PR_POLL_READ; 719 } else if (bytesRead < 0) { 720 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { 721 LOG_DEBUG((" error=%d", PR_GetError())); 722 // We are in error state, indicate that the connection was 723 // not closed gracefully 724 client_error = true; 725 socketErrorState[s] = true; 726 // Wipe out our send buffer, we cannot send it anyway. 727 buffers[s2].bufferhead = buffers[s2].buffertail = 728 buffers[s2].buffer; 729 } else 730 LOG_DEBUG((" would block")); 731 } else { 732 // If the other socket is in error state (unable to send/receive) 733 // throw this data away and continue loop 734 if (socketErrorState[s2]) { 735 LOG_DEBUG((" have read but other socket is in error state\n")); 736 continue; 737 } 738 739 buffers[s].buffertail += bytesRead; 740 LOG_DEBUG((", read %d bytes", bytesRead)); 741 742 // We have to accept and handle the initial CONNECT request here 743 int32_t response; 744 if (!connect_accepted && 745 ReadConnectRequest(ci->server_info, buffers[s], &response, 746 certificateToUse, &clientAuth, fullHost, 747 locationHeader, &flags)) { 748 // Mark this as a proxy-only connection (no SSL) if the CONNECT 749 // request didn't come for port 443 or from any of the server's 750 // cert or clientauth hostnames. 751 if (fullHost.find(":443") == string::npos) { 752 server_match_t match; 753 match.fullHost = fullHost; 754 match.matched = false; 755 PL_HashTableEnumerateEntries(ci->server_info->host_cert_table, 756 match_hostname, &match); 757 PL_HashTableEnumerateEntries( 758 ci->server_info->host_clientauth_table, match_hostname, 759 &match); 760 PL_HashTableEnumerateEntries(ci->server_info->host_ssl3_table, 761 match_hostname, &match); 762 PL_HashTableEnumerateEntries(ci->server_info->host_tls1_table, 763 match_hostname, &match); 764 PL_HashTableEnumerateEntries(ci->server_info->host_tls11_table, 765 match_hostname, &match); 766 PL_HashTableEnumerateEntries(ci->server_info->host_tls12_table, 767 match_hostname, &match); 768 PL_HashTableEnumerateEntries(ci->server_info->host_tls13_table, 769 match_hostname, &match); 770 PL_HashTableEnumerateEntries(ci->server_info->host_3des_table, 771 match_hostname, &match); 772 PL_HashTableEnumerateEntries( 773 ci->server_info->host_failhandshake_table, match_hostname, 774 &match); 775 ci->http_proxy_only = !match.matched; 776 } else { 777 ci->http_proxy_only = false; 778 } 779 780 // Clean the request as it would be read 781 buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer; 782 in_flags |= PR_POLL_WRITE; 783 connect_accepted = true; 784 785 // Store response to the oposite buffer 786 if (response == 200) { 787 LOG_DEBUG( 788 (" accepted CONNECT request, connected to the server, " 789 "sending OK to the client\n")); 790 strcpy( 791 buffers[s2].buffer, 792 "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n"); 793 } else if (response == 302) { 794 LOG_DEBUG( 795 (" accepted CONNECT request with redirection, " 796 "sending location and 302 to the client\n")); 797 client_done = true; 798 snprintf(buffers[s2].buffer, 799 buffers[s2].bufferend - buffers[s2].buffer, 800 "HTTP/1.1 302 Moved\r\n" 801 "Location: https://%s/\r\n" 802 "Connection: close\r\n\r\n", 803 locationHeader.c_str()); 804 } else { 805 LOG_ERRORD( 806 (" could not read the connect request, closing connection " 807 "with %d", 808 response)); 809 client_done = true; 810 snprintf(buffers[s2].buffer, 811 buffers[s2].bufferend - buffers[s2].buffer, 812 "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", 813 response); 814 815 break; 816 } 817 818 buffers[s2].buffertail = 819 buffers[s2].buffer + strlen(buffers[s2].buffer); 820 821 // Send the response to the client socket 822 break; 823 } // end of CONNECT handling 824 825 if (!buffers[s].areafree()) { 826 // Do not poll for read when the buffer is full 827 LOG_DEBUG((" no place in our read buffer, stop reading")); 828 in_flags &= ~PR_POLL_READ; 829 } 830 831 if (ssl_updated) { 832 if (s == 0 && expect_request_start) { 833 if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) { 834 // We haven't received the complete header yet, so wait. 835 continue; 836 } 837 ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci); 838 expect_request_start = !( 839 ci->iswebsocket || AdjustRequestURI(buffers[s], &fullHost)); 840 PRNetAddr* addr = &remote_addr; 841 if (ci->iswebsocket && websocket_server.inet.port) 842 addr = &websocket_server; 843 if (!ConnectSocket(other_sock, addr, connect_timeout)) { 844 LOG_ERRORD( 845 (" could not open connection to the real server\n")); 846 client_error = true; 847 break; 848 } 849 LOG_DEBUG(("\n connected to remote server\n")); 850 numberOfSockets = 2; 851 } else if (s == 1 && ci->iswebsocket) { 852 if (!AdjustWebSocketLocation(buffers[s], ci)) continue; 853 } 854 855 in_flags2 |= PR_POLL_WRITE; 856 LOG_DEBUG((" telling the other socket to write")); 857 } else 858 LOG_DEBUG( 859 (" we have something for the other socket to write, but ssl " 860 "has not been administered on it")); 861 } 862 } // PR_POLL_READ handling 863 864 if (out_flags & PR_POLL_WRITE) { 865 LOG_DEBUG((" :writing")); 866 int32_t bytesWrite = 867 PR_Send(sockets[s].fd, buffers[s2].bufferhead, 868 buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT); 869 870 if (bytesWrite < 0) { 871 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { 872 LOG_DEBUG((" error=%d", PR_GetError())); 873 client_error = true; 874 socketErrorState[s] = true; 875 // We got a fatal error while writting the buffer. Clear it to 876 // break the main loop, we will never more be able to send it. 877 buffers[s2].bufferhead = buffers[s2].buffertail = 878 buffers[s2].buffer; 879 } else 880 LOG_DEBUG((" would block")); 881 } else { 882 LOG_DEBUG((", written %d bytes", bytesWrite)); 883 buffers[s2].buffertail[1] = '\0'; 884 LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead)); 885 886 buffers[s2].bufferhead += bytesWrite; 887 if (buffers[s2].present()) { 888 LOG_DEBUG((" still have to write %d bytes", 889 (int)buffers[s2].present())); 890 in_flags |= PR_POLL_WRITE; 891 } else { 892 if (!ssl_updated) { 893 LOG_DEBUG((" proxy response sent to the client")); 894 // Proxy response has just been writen, update to ssl 895 ssl_updated = true; 896 if (ci->http_proxy_only) { 897 LOG_DEBUG( 898 (" not updating to SSL based on http_proxy_only for this " 899 "socket")); 900 } else if (!ConfigureSSLServerSocket( 901 ci->client_sock, ci->server_info, 902 certificateToUse, clientAuth, flags)) { 903 LOG_ERRORD((" failed to config server socket\n")); 904 client_error = true; 905 break; 906 } else { 907 LOG_DEBUG((" client socket updated to SSL")); 908 } 909 } // sslUpdate 910 911 LOG_DEBUG( 912 (" dropping our write flag and setting other socket read " 913 "flag")); 914 in_flags &= ~PR_POLL_WRITE; 915 in_flags2 |= PR_POLL_READ; 916 buffers[s2].compact(); 917 } 918 } 919 } // PR_POLL_WRITE handling 920 LOG_END_BLOCK(); // end the log 921 } // for... 922 } // while, poll 923 } else 924 client_error = true; 925 926 LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n", 927 static_cast<void*>(data), static_cast<void*>(ci->client_sock), 928 static_cast<void*>(other_sock.get()))); 929 if (!client_error) PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND); 930 PR_Close(ci->client_sock); 931 932 delete ci; 933 } 934 935 /* 936 * Start listening for SSL connections on a specified port, handing 937 * them off to client threads after accepting the connection. 938 * The data parameter is a server_info_t*, owned by the calling 939 * function. 940 */ 941 void StartServer(void* data) { 942 server_info_t* si = static_cast<server_info_t*>(data); 943 944 // TODO: select ciphers? 945 UniquePRFileDesc listen_socket(PR_NewTCPSocket()); 946 if (!listen_socket) { 947 LOG_ERROR(("failed to create socket\n")); 948 SignalShutdown(); 949 return; 950 } 951 952 // In case the socket is still open in the TIME_WAIT state from a previous 953 // instance of ssltunnel we ask to reuse the port. 954 PRSocketOptionData socket_option; 955 socket_option.option = PR_SockOpt_Reuseaddr; 956 socket_option.value.reuse_addr = true; 957 PR_SetSocketOption(listen_socket.get(), &socket_option); 958 959 PRNetAddr server_addr; 960 PRNetAddrValue listen_addr; 961 if (listen_public) { 962 listen_addr = PR_IpAddrAny; 963 } else { 964 listen_addr = PR_IpAddrLoopback; 965 } 966 PR_InitializeNetAddr(listen_addr, si->listen_port, &server_addr); 967 968 if (PR_Bind(listen_socket.get(), &server_addr) != PR_SUCCESS) { 969 LOG_ERROR(("failed to bind socket on port %d: error %d\n", si->listen_port, 970 PR_GetError())); 971 SignalShutdown(); 972 return; 973 } 974 975 if (PR_Listen(listen_socket.get(), 1) != PR_SUCCESS) { 976 LOG_ERROR(("failed to listen on socket\n")); 977 SignalShutdown(); 978 return; 979 } 980 981 LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port, 982 si->cert_nickname.c_str())); 983 984 while (!shutdown_server) { 985 connection_info_t* ci = new connection_info_t(); 986 ci->server_info = si; 987 ci->http_proxy_only = do_http_proxy; 988 // block waiting for connections 989 ci->client_sock = PR_Accept(listen_socket.get(), &ci->client_addr, 990 PR_INTERVAL_NO_TIMEOUT); 991 992 PRSocketOptionData option; 993 option.option = PR_SockOpt_Nonblocking; 994 option.value.non_blocking = true; 995 PR_SetSocketOption(ci->client_sock, &option); 996 997 if (ci->client_sock) 998 // Not actually using this PRJob*... 999 // PRJob* job = 1000 PR_QueueJob(threads, HandleConnection, ci, true); 1001 else 1002 delete ci; 1003 } 1004 } 1005 1006 // bogus password func, just don't use passwords. :-P 1007 char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) { 1008 if (retry) return nullptr; 1009 1010 return PL_strdup(""); 1011 } 1012 1013 server_info_t* findServerInfo(int portnumber) { 1014 for (auto& server : servers) { 1015 if (server.listen_port == portnumber) return &server; 1016 } 1017 1018 return nullptr; 1019 } 1020 1021 PLHashTable* get_ssl3_table(server_info_t* server) { 1022 return server->host_ssl3_table; 1023 } 1024 1025 PLHashTable* get_tls1_table(server_info_t* server) { 1026 return server->host_tls1_table; 1027 } 1028 1029 PLHashTable* get_tls11_table(server_info_t* server) { 1030 return server->host_tls11_table; 1031 } 1032 1033 PLHashTable* get_tls12_table(server_info_t* server) { 1034 return server->host_tls12_table; 1035 } 1036 1037 PLHashTable* get_tls13_table(server_info_t* server) { 1038 return server->host_tls13_table; 1039 } 1040 1041 PLHashTable* get_3des_table(server_info_t* server) { 1042 return server->host_3des_table; 1043 } 1044 1045 PLHashTable* get_failhandshake_table(server_info_t* server) { 1046 return server->host_failhandshake_table; 1047 } 1048 1049 int parseWeakCryptoConfig(char* const& keyword, char*& _caret, 1050 PLHashTable* (*get_table)(server_info_t*)) { 1051 char* hostname = strtok2(_caret, ":", &_caret); 1052 char* hostportstring = strtok2(_caret, ":", &_caret); 1053 char* serverportstring = strtok2(_caret, "\n", &_caret); 1054 1055 int port = atoi(serverportstring); 1056 if (port <= 0) { 1057 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1058 return 1; 1059 } 1060 1061 if (server_info_t* existingServer = findServerInfo(port)) { 1062 any_host_spec_config = true; 1063 1064 char* hostname_copy = 1065 new char[strlen(hostname) + strlen(hostportstring) + 2]; 1066 if (!hostname_copy) { 1067 LOG_ERROR(("Out of memory")); 1068 return 1; 1069 } 1070 1071 strcpy(hostname_copy, hostname); 1072 strcat(hostname_copy, ":"); 1073 strcat(hostname_copy, hostportstring); 1074 1075 PLHashEntry* entry = 1076 PL_HashTableAdd(get_table(existingServer), hostname_copy, keyword); 1077 if (!entry) { 1078 LOG_ERROR(("Out of memory")); 1079 return 1; 1080 } 1081 } else { 1082 LOG_ERROR( 1083 ("Server on port %d for redirhost option is not defined, use 'listen' " 1084 "option first", 1085 port)); 1086 return 1; 1087 } 1088 1089 return 0; 1090 } 1091 1092 int processConfigLine(char* configLine) { 1093 if (*configLine == 0 || *configLine == '#') return 0; 1094 1095 char* _caret; 1096 char* keyword = strtok2(configLine, ":", &_caret); 1097 // Configure usage of http/ssl tunneling proxy behavior 1098 if (!strcmp(keyword, "httpproxy")) { 1099 char* value = strtok2(_caret, ":", &_caret); 1100 if (!strcmp(value, "1")) do_http_proxy = true; 1101 1102 return 0; 1103 } 1104 1105 if (!strcmp(keyword, "websocketserver")) { 1106 char* ipstring = strtok2(_caret, ":", &_caret); 1107 if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) { 1108 LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring)); 1109 return 1; 1110 } 1111 char* remoteport = strtok2(_caret, ":", &_caret); 1112 int port = atoi(remoteport); 1113 if (port <= 0) { 1114 LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport)); 1115 return 1; 1116 } 1117 websocket_server.inet.port = PR_htons(port); 1118 return 0; 1119 } 1120 1121 // Configure the forward address of the target server 1122 if (!strcmp(keyword, "forward")) { 1123 char* ipstring = strtok2(_caret, ":", &_caret); 1124 if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) { 1125 LOG_ERROR(("Invalid remote IP address: %s\n", ipstring)); 1126 return 1; 1127 } 1128 char* serverportstring = strtok2(_caret, ":", &_caret); 1129 int port = atoi(serverportstring); 1130 if (port <= 0) { 1131 LOG_ERROR(("Invalid remote port: %s\n", serverportstring)); 1132 return 1; 1133 } 1134 remote_addr.inet.port = PR_htons(port); 1135 1136 return 0; 1137 } 1138 1139 // Configure all listen sockets and port+certificate bindings. 1140 // Listen on the public address if "*" was specified as the listen 1141 // address or listen on the loopback address if "127.0.0.1" was 1142 // specified. Using loopback will prevent users getting errors from 1143 // their firewalls about ssltunnel needing permission. A public 1144 // address is required when proxying ssl traffic from a physical or 1145 // emulated Android device since it has a different ip address from 1146 // the host. 1147 if (!strcmp(keyword, "listen")) { 1148 char* hostname = strtok2(_caret, ":", &_caret); 1149 char* hostportstring = nullptr; 1150 if (!strcmp(hostname, "*")) { 1151 listen_public = true; 1152 } else if (strcmp(hostname, "127.0.0.1")) { 1153 any_host_spec_config = true; 1154 hostportstring = strtok2(_caret, ":", &_caret); 1155 } 1156 1157 char* serverportstring = strtok2(_caret, ":", &_caret); 1158 char* certnick = strtok2(_caret, ":", &_caret); 1159 1160 int port = atoi(serverportstring); 1161 if (port <= 0) { 1162 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1163 return 1; 1164 } 1165 1166 if (server_info_t* existingServer = findServerInfo(port)) { 1167 if (!hostportstring) { 1168 LOG_ERROR( 1169 ("Null hostportstring specified for hostname %s\n", hostname)); 1170 return 1; 1171 } 1172 char* certnick_copy = new char[strlen(certnick) + 1]; 1173 char* hostname_copy = 1174 new char[strlen(hostname) + strlen(hostportstring) + 2]; 1175 1176 strcpy(hostname_copy, hostname); 1177 strcat(hostname_copy, ":"); 1178 strcat(hostname_copy, hostportstring); 1179 strcpy(certnick_copy, certnick); 1180 1181 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, 1182 hostname_copy, certnick_copy); 1183 if (!entry) { 1184 LOG_ERROR(("Out of memory")); 1185 return 1; 1186 } 1187 } else { 1188 server_info_t server; 1189 server.cert_nickname = certnick; 1190 server.listen_port = port; 1191 server.host_cert_table = 1192 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1193 PL_CompareStrings, nullptr, nullptr); 1194 if (!server.host_cert_table) { 1195 LOG_ERROR(("Internal, could not create hash table\n")); 1196 return 1; 1197 } 1198 server.host_clientauth_table = 1199 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1200 ClientAuthValueComparator, nullptr, nullptr); 1201 if (!server.host_clientauth_table) { 1202 LOG_ERROR(("Internal, could not create hash table\n")); 1203 return 1; 1204 } 1205 server.host_redir_table = 1206 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1207 PL_CompareStrings, nullptr, nullptr); 1208 if (!server.host_redir_table) { 1209 LOG_ERROR(("Internal, could not create hash table\n")); 1210 return 1; 1211 } 1212 1213 server.host_ssl3_table = 1214 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1215 PL_CompareStrings, nullptr, nullptr); 1216 1217 if (!server.host_ssl3_table) { 1218 LOG_ERROR(("Internal, could not create hash table\n")); 1219 return 1; 1220 } 1221 1222 server.host_tls1_table = 1223 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1224 PL_CompareStrings, nullptr, nullptr); 1225 1226 if (!server.host_tls1_table) { 1227 LOG_ERROR(("Internal, could not create hash table\n")); 1228 return 1; 1229 } 1230 1231 server.host_tls11_table = 1232 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1233 PL_CompareStrings, nullptr, nullptr); 1234 1235 if (!server.host_tls11_table) { 1236 LOG_ERROR(("Internal, could not create hash table\n")); 1237 return 1; 1238 } 1239 1240 server.host_tls12_table = 1241 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1242 PL_CompareStrings, nullptr, nullptr); 1243 1244 if (!server.host_tls12_table) { 1245 LOG_ERROR(("Internal, could not create hash table\n")); 1246 return 1; 1247 } 1248 1249 server.host_tls13_table = 1250 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1251 PL_CompareStrings, nullptr, nullptr); 1252 1253 if (!server.host_tls13_table) { 1254 LOG_ERROR(("Internal, could not create hash table\n")); 1255 return 1; 1256 } 1257 1258 server.host_3des_table = 1259 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1260 PL_CompareStrings, nullptr, nullptr); 1261 ; 1262 if (!server.host_3des_table) { 1263 LOG_ERROR(("Internal, could not create hash table\n")); 1264 return 1; 1265 } 1266 1267 server.host_failhandshake_table = 1268 PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1269 PL_CompareStrings, nullptr, nullptr); 1270 ; 1271 if (!server.host_failhandshake_table) { 1272 LOG_ERROR(("Internal, could not create hash table\n")); 1273 return 1; 1274 } 1275 1276 servers.push_back(server); 1277 } 1278 1279 return 0; 1280 } 1281 1282 if (!strcmp(keyword, "clientauth")) { 1283 char* hostname = strtok2(_caret, ":", &_caret); 1284 char* hostportstring = strtok2(_caret, ":", &_caret); 1285 char* serverportstring = strtok2(_caret, ":", &_caret); 1286 1287 int port = atoi(serverportstring); 1288 if (port <= 0) { 1289 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1290 return 1; 1291 } 1292 1293 if (server_info_t* existingServer = findServerInfo(port)) { 1294 char* authoptionstring = strtok2(_caret, ":", &_caret); 1295 client_auth_option* authoption = new client_auth_option; 1296 if (!authoption) { 1297 LOG_ERROR(("Out of memory")); 1298 return 1; 1299 } 1300 1301 if (!strcmp(authoptionstring, "require")) 1302 *authoption = caRequire; 1303 else if (!strcmp(authoptionstring, "request")) 1304 *authoption = caRequest; 1305 else if (!strcmp(authoptionstring, "none")) 1306 *authoption = caNone; 1307 else { 1308 LOG_ERROR( 1309 ("Incorrect client auth option modifier for host '%s'", hostname)); 1310 delete authoption; 1311 return 1; 1312 } 1313 1314 any_host_spec_config = true; 1315 1316 char* hostname_copy = 1317 new char[strlen(hostname) + strlen(hostportstring) + 2]; 1318 if (!hostname_copy) { 1319 LOG_ERROR(("Out of memory")); 1320 delete authoption; 1321 return 1; 1322 } 1323 1324 strcpy(hostname_copy, hostname); 1325 strcat(hostname_copy, ":"); 1326 strcat(hostname_copy, hostportstring); 1327 1328 PLHashEntry* entry = PL_HashTableAdd( 1329 existingServer->host_clientauth_table, hostname_copy, authoption); 1330 if (!entry) { 1331 LOG_ERROR(("Out of memory")); 1332 delete authoption; 1333 return 1; 1334 } 1335 } else { 1336 LOG_ERROR( 1337 ("Server on port %d for client authentication option is not defined, " 1338 "use 'listen' option first", 1339 port)); 1340 return 1; 1341 } 1342 1343 return 0; 1344 } 1345 1346 if (!strcmp(keyword, "redirhost")) { 1347 char* hostname = strtok2(_caret, ":", &_caret); 1348 char* hostportstring = strtok2(_caret, ":", &_caret); 1349 char* serverportstring = strtok2(_caret, ":", &_caret); 1350 1351 int port = atoi(serverportstring); 1352 if (port <= 0) { 1353 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1354 return 1; 1355 } 1356 1357 if (server_info_t* existingServer = findServerInfo(port)) { 1358 char* redirhoststring = strtok2(_caret, ":", &_caret); 1359 1360 any_host_spec_config = true; 1361 1362 char* hostname_copy = 1363 new char[strlen(hostname) + strlen(hostportstring) + 2]; 1364 if (!hostname_copy) { 1365 LOG_ERROR(("Out of memory")); 1366 return 1; 1367 } 1368 1369 strcpy(hostname_copy, hostname); 1370 strcat(hostname_copy, ":"); 1371 strcat(hostname_copy, hostportstring); 1372 1373 char* redir_copy = new char[strlen(redirhoststring) + 1]; 1374 strcpy(redir_copy, redirhoststring); 1375 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, 1376 hostname_copy, redir_copy); 1377 if (!entry) { 1378 LOG_ERROR(("Out of memory")); 1379 delete[] hostname_copy; 1380 delete[] redir_copy; 1381 return 1; 1382 } 1383 } else { 1384 LOG_ERROR( 1385 ("Server on port %d for redirhost option is not defined, use " 1386 "'listen' option first", 1387 port)); 1388 return 1; 1389 } 1390 1391 return 0; 1392 } 1393 1394 if (!strcmp(keyword, "ssl3")) { 1395 return parseWeakCryptoConfig(keyword, _caret, get_ssl3_table); 1396 } 1397 if (!strcmp(keyword, "tls1")) { 1398 return parseWeakCryptoConfig(keyword, _caret, get_tls1_table); 1399 } 1400 if (!strcmp(keyword, "tls1_1")) { 1401 return parseWeakCryptoConfig(keyword, _caret, get_tls11_table); 1402 } 1403 if (!strcmp(keyword, "tls1_2")) { 1404 return parseWeakCryptoConfig(keyword, _caret, get_tls12_table); 1405 } 1406 if (!strcmp(keyword, "tls1_3")) { 1407 return parseWeakCryptoConfig(keyword, _caret, get_tls13_table); 1408 } 1409 1410 if (!strcmp(keyword, "3des")) { 1411 return parseWeakCryptoConfig(keyword, _caret, get_3des_table); 1412 } 1413 1414 if (!strcmp(keyword, "failHandshake")) { 1415 return parseWeakCryptoConfig(keyword, _caret, get_failhandshake_table); 1416 } 1417 1418 // Configure the NSS certificate database directory 1419 if (!strcmp(keyword, "certdbdir")) { 1420 nssconfigdir = strtok2(_caret, "\n", &_caret); 1421 return 0; 1422 } 1423 1424 LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword)); 1425 return 1; 1426 } 1427 1428 int parseConfigFile(const char* filePath) { 1429 FILE* f = fopen(filePath, "r"); 1430 if (!f) return 1; 1431 1432 char buffer[1024], *b = buffer; 1433 while (!feof(f)) { 1434 char c; 1435 1436 if (fscanf(f, "%c", &c) != 1) { 1437 break; 1438 } 1439 1440 switch (c) { 1441 case '\n': 1442 *b++ = 0; 1443 if (processConfigLine(buffer)) { 1444 fclose(f); 1445 return 1; 1446 } 1447 b = buffer; 1448 continue; 1449 1450 case '\r': 1451 continue; 1452 1453 default: 1454 *b++ = c; 1455 } 1456 } 1457 1458 fclose(f); 1459 1460 // Check mandatory items 1461 if (nssconfigdir.empty()) { 1462 LOG_ERROR( 1463 ("Error: missing path to NSS certification database\n,use " 1464 "certdbdir:<path> in the config file\n")); 1465 return 1; 1466 } 1467 1468 if (any_host_spec_config && !do_http_proxy) { 1469 LOG_ERROR( 1470 ("Warning: any host-specific configurations are ignored, add " 1471 "httpproxy:1 to allow them\n")); 1472 } 1473 1474 return 0; 1475 } 1476 1477 int freeHostCertHashItems(PLHashEntry* he, int i, void* arg) { 1478 delete[] (char*)he->key; 1479 delete[] (char*)he->value; 1480 return HT_ENUMERATE_REMOVE; 1481 } 1482 1483 int freeHostRedirHashItems(PLHashEntry* he, int i, void* arg) { 1484 delete[] (char*)he->key; 1485 delete[] (char*)he->value; 1486 return HT_ENUMERATE_REMOVE; 1487 } 1488 1489 int freeClientAuthHashItems(PLHashEntry* he, int i, void* arg) { 1490 delete[] (char*)he->key; 1491 delete (client_auth_option*)he->value; 1492 return HT_ENUMERATE_REMOVE; 1493 } 1494 1495 int freeSSL3HashItems(PLHashEntry* he, int i, void* arg) { 1496 delete[] (char*)he->key; 1497 return HT_ENUMERATE_REMOVE; 1498 } 1499 1500 int freeTLSHashItems(PLHashEntry* he, int i, void* arg) { 1501 delete[] (char*)he->key; 1502 return HT_ENUMERATE_REMOVE; 1503 } 1504 1505 int free3DESHashItems(PLHashEntry* he, int i, void* arg) { 1506 delete[] (char*)he->key; 1507 return HT_ENUMERATE_REMOVE; 1508 } 1509 1510 int main(int argc, char** argv) { 1511 const char* configFilePath; 1512 1513 const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL"); 1514 gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO; 1515 1516 if (argc == 1) 1517 configFilePath = "ssltunnel.cfg"; 1518 else 1519 configFilePath = argv[1]; 1520 1521 memset(&websocket_server, 0, sizeof(PRNetAddr)); 1522 1523 if (parseConfigFile(configFilePath)) { 1524 LOG_ERROR(( 1525 "Error: config file \"%s\" missing or formating incorrect\n" 1526 "Specify path to the config file as parameter to ssltunnel or \n" 1527 "create ssltunnel.cfg in the working directory.\n\n" 1528 "Example format of the config file:\n\n" 1529 " # Enable http/ssl tunneling proxy-like behavior.\n" 1530 " # If not specified ssltunnel simply does direct forward.\n" 1531 " httpproxy:1\n\n" 1532 " # Specify path to the certification database used.\n" 1533 " certdbdir:/path/to/certdb\n\n" 1534 " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n" 1535 " forward:127.0.0.1:8888\n\n" 1536 " # Accept connections on port 4443 or 5678 resp. and " 1537 "authenticate\n" 1538 " # to any host ('*') using the 'server cert' or 'server cert 2' " 1539 "resp.\n" 1540 " listen:*:4443:server cert\n" 1541 " listen:*:5678:server cert 2\n\n" 1542 " # Accept connections on port 4443 and authenticate using\n" 1543 " # 'a different cert' when target host is 'my.host.name:443'.\n" 1544 " # This only works in httpproxy mode and has higher priority\n" 1545 " # than the previous option.\n" 1546 " listen:my.host.name:443:4443:a different cert\n\n" 1547 " # To make a specific host require or just request a client " 1548 "certificate\n" 1549 " # to authenticate use the following options. This can only be " 1550 "used\n" 1551 " # in httpproxy mode and only after the 'listen' option has " 1552 "been\n" 1553 " # specified. You also have to specify the tunnel listen port.\n" 1554 " clientauth:requesting-client-cert.host.com:443:4443:request\n" 1555 " clientauth:requiring-client-cert.host.com:443:4443:require\n" 1556 " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n" 1557 " # instead of the server specified in the 'forward' option.\n" 1558 " websocketserver:127.0.0.1:9999\n", 1559 configFilePath)); 1560 return 1; 1561 } 1562 1563 // create a thread pool to handle connections 1564 threads = 1565 PR_CreateThreadPool(INITIAL_THREADS * servers.size(), 1566 MAX_THREADS * servers.size(), DEFAULT_STACKSIZE); 1567 if (!threads) { 1568 LOG_ERROR(("Failed to create thread pool\n")); 1569 return 1; 1570 } 1571 1572 shutdown_lock = PR_NewLock(); 1573 if (!shutdown_lock) { 1574 LOG_ERROR(("Failed to create lock\n")); 1575 PR_ShutdownThreadPool(threads); 1576 return 1; 1577 } 1578 shutdown_condvar = PR_NewCondVar(shutdown_lock); 1579 if (!shutdown_condvar) { 1580 LOG_ERROR(("Failed to create condvar\n")); 1581 PR_ShutdownThreadPool(threads); 1582 PR_DestroyLock(shutdown_lock); 1583 return 1; 1584 } 1585 1586 PK11_SetPasswordFunc(password_func); 1587 1588 // Initialize NSS 1589 if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) { 1590 int32_t errorlen = PR_GetErrorTextLength(); 1591 if (errorlen) { 1592 auto err = mozilla::MakeUnique<char[]>(errorlen + 1); 1593 PR_GetErrorText(err.get()); 1594 LOG_ERROR(("Failed to init NSS: %s", err.get())); 1595 } else { 1596 LOG_ERROR(("Failed to init NSS: Cannot get error from NSPR.")); 1597 } 1598 PR_ShutdownThreadPool(threads); 1599 PR_DestroyCondVar(shutdown_condvar); 1600 PR_DestroyLock(shutdown_lock); 1601 return 1; 1602 } 1603 1604 if (NSS_SetDomesticPolicy() != SECSuccess) { 1605 LOG_ERROR(("NSS_SetDomesticPolicy failed\n")); 1606 PR_ShutdownThreadPool(threads); 1607 PR_DestroyCondVar(shutdown_condvar); 1608 PR_DestroyLock(shutdown_lock); 1609 NSS_Shutdown(); 1610 return 1; 1611 } 1612 1613 // these values should make NSS use the defaults 1614 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { 1615 LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n")); 1616 PR_ShutdownThreadPool(threads); 1617 PR_DestroyCondVar(shutdown_condvar); 1618 PR_DestroyLock(shutdown_lock); 1619 NSS_Shutdown(); 1620 return 1; 1621 } 1622 1623 for (auto& server : servers) { 1624 // Not actually using this PRJob*... 1625 // PRJob* server_job = 1626 PR_QueueJob(threads, StartServer, &server, true); 1627 } 1628 // now wait for someone to tell us to quit 1629 PR_Lock(shutdown_lock); 1630 PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT); 1631 PR_Unlock(shutdown_lock); 1632 shutdown_server = true; 1633 LOG_INFO(("Shutting down...\n")); 1634 // cleanup 1635 PR_ShutdownThreadPool(threads); 1636 PR_JoinThreadPool(threads); 1637 PR_DestroyCondVar(shutdown_condvar); 1638 PR_DestroyLock(shutdown_lock); 1639 if (NSS_Shutdown() == SECFailure) { 1640 LOG_DEBUG(("Leaked NSS objects!\n")); 1641 } 1642 1643 for (auto& server : servers) { 1644 PL_HashTableEnumerateEntries(server.host_cert_table, freeHostCertHashItems, 1645 nullptr); 1646 PL_HashTableEnumerateEntries(server.host_clientauth_table, 1647 freeClientAuthHashItems, nullptr); 1648 PL_HashTableEnumerateEntries(server.host_redir_table, 1649 freeHostRedirHashItems, nullptr); 1650 PL_HashTableEnumerateEntries(server.host_ssl3_table, freeSSL3HashItems, 1651 nullptr); 1652 PL_HashTableEnumerateEntries(server.host_tls1_table, freeTLSHashItems, 1653 nullptr); 1654 PL_HashTableEnumerateEntries(server.host_tls11_table, freeTLSHashItems, 1655 nullptr); 1656 PL_HashTableEnumerateEntries(server.host_tls12_table, freeTLSHashItems, 1657 nullptr); 1658 PL_HashTableEnumerateEntries(server.host_tls13_table, freeTLSHashItems, 1659 nullptr); 1660 PL_HashTableEnumerateEntries(server.host_3des_table, free3DESHashItems, 1661 nullptr); 1662 PL_HashTableEnumerateEntries(server.host_failhandshake_table, 1663 free3DESHashItems, nullptr); 1664 PL_HashTableDestroy(server.host_cert_table); 1665 PL_HashTableDestroy(server.host_clientauth_table); 1666 PL_HashTableDestroy(server.host_redir_table); 1667 PL_HashTableDestroy(server.host_ssl3_table); 1668 PL_HashTableDestroy(server.host_tls1_table); 1669 PL_HashTableDestroy(server.host_tls11_table); 1670 PL_HashTableDestroy(server.host_tls12_table); 1671 PL_HashTableDestroy(server.host_tls13_table); 1672 PL_HashTableDestroy(server.host_3des_table); 1673 PL_HashTableDestroy(server.host_failhandshake_table); 1674 } 1675 1676 PR_Cleanup(); 1677 return 0; 1678 }