tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

system_utils_win32.cpp (6367B)


      1 //
      2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // system_utils_win32.cpp: Implementation of OS-specific functions for Windows.
      7 
      8 #include "common/FastVector.h"
      9 #include "system_utils.h"
     10 
     11 #include <array>
     12 
     13 // Must be included in this order.
     14 // clang-format off
     15 #include <windows.h>
     16 #include <psapi.h>
     17 // clang-format on
     18 
     19 namespace angle
     20 {
     21 bool UnsetEnvironmentVar(const char *variableName)
     22 {
     23    return (SetEnvironmentVariableW(Widen(variableName).c_str(), nullptr) == TRUE);
     24 }
     25 
     26 bool SetEnvironmentVar(const char *variableName, const char *value)
     27 {
     28    return (SetEnvironmentVariableW(Widen(variableName).c_str(), Widen(value).c_str()) == TRUE);
     29 }
     30 
     31 std::string GetEnvironmentVar(const char *variableName)
     32 {
     33    std::wstring variableNameUtf16 = Widen(variableName);
     34    FastVector<wchar_t, MAX_PATH> value;
     35 
     36    DWORD result;
     37 
     38    // First get the length of the variable, including the null terminator
     39    result = GetEnvironmentVariableW(variableNameUtf16.c_str(), nullptr, 0);
     40 
     41    // Zero means the variable was not found, so return now.
     42    if (result == 0)
     43    {
     44        return std::string();
     45    }
     46 
     47    // Now size the vector to fit the data, and read the environment variable.
     48    value.resize(result, 0);
     49    result = GetEnvironmentVariableW(variableNameUtf16.c_str(), value.data(), result);
     50 
     51    return Narrow(value.data());
     52 }
     53 
     54 void *OpenSystemLibraryWithExtensionAndGetError(const char *libraryName,
     55                                                SearchType searchType,
     56                                                std::string *errorOut)
     57 {
     58    char buffer[MAX_PATH];
     59    int ret = snprintf(buffer, MAX_PATH, "%s.%s", libraryName, GetSharedLibraryExtension());
     60    if (ret <= 0 || ret >= MAX_PATH)
     61    {
     62        fprintf(stderr, "Error loading shared library: 0x%x", ret);
     63        return nullptr;
     64    }
     65 
     66    HMODULE libraryModule = nullptr;
     67 
     68    switch (searchType)
     69    {
     70        case SearchType::ModuleDir:
     71        {
     72            std::string moduleRelativePath = ConcatenatePath(GetModuleDirectory(), libraryName);
     73            if (errorOut)
     74            {
     75                *errorOut = moduleRelativePath;
     76            }
     77            libraryModule = LoadLibraryW(Widen(moduleRelativePath).c_str());
     78            break;
     79        }
     80 
     81        case SearchType::SystemDir:
     82        {
     83            if (errorOut)
     84            {
     85                *errorOut = libraryName;
     86            }
     87            libraryModule =
     88                LoadLibraryExW(Widen(libraryName).c_str(), nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32);
     89            break;
     90        }
     91 
     92        case SearchType::AlreadyLoaded:
     93        {
     94            if (errorOut)
     95            {
     96                *errorOut = libraryName;
     97            }
     98            libraryModule = GetModuleHandleW(Widen(libraryName).c_str());
     99            break;
    100        }
    101    }
    102 
    103    return reinterpret_cast<void *>(libraryModule);
    104 }
    105 
    106 namespace
    107 {
    108 class Win32PageFaultHandler : public PageFaultHandler
    109 {
    110  public:
    111    Win32PageFaultHandler(PageFaultCallback callback) : PageFaultHandler(callback) {}
    112    ~Win32PageFaultHandler() override {}
    113 
    114    bool enable() override;
    115    bool disable() override;
    116 
    117    LONG handle(PEXCEPTION_POINTERS pExceptionInfo);
    118 
    119  private:
    120    void *mVectoredExceptionHandler = nullptr;
    121 };
    122 
    123 Win32PageFaultHandler *gWin32PageFaultHandler = nullptr;
    124 static LONG CALLBACK VectoredExceptionHandler(PEXCEPTION_POINTERS info)
    125 {
    126    return gWin32PageFaultHandler->handle(info);
    127 }
    128 
    129 bool SetMemoryProtection(uintptr_t start, size_t size, DWORD protections)
    130 {
    131    DWORD oldProtect;
    132    BOOL res = VirtualProtect(reinterpret_cast<LPVOID>(start), size, protections, &oldProtect);
    133    if (!res)
    134    {
    135        DWORD lastError = GetLastError();
    136        fprintf(stderr, "VirtualProtect failed: 0x%lx\n", lastError);
    137        return false;
    138    }
    139 
    140    return true;
    141 }
    142 
    143 LONG Win32PageFaultHandler::handle(PEXCEPTION_POINTERS info)
    144 {
    145    bool found = false;
    146 
    147    if (info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION &&
    148        info->ExceptionRecord->NumberParameters >= 2 &&
    149        info->ExceptionRecord->ExceptionInformation[0] == 1)
    150    {
    151        found = mCallback(static_cast<uintptr_t>(info->ExceptionRecord->ExceptionInformation[1])) ==
    152                PageFaultHandlerRangeType::InRange;
    153    }
    154 
    155    if (found)
    156    {
    157        return EXCEPTION_CONTINUE_EXECUTION;
    158    }
    159    else
    160    {
    161        return EXCEPTION_CONTINUE_SEARCH;
    162    }
    163 }
    164 
    165 bool Win32PageFaultHandler::disable()
    166 {
    167    if (mVectoredExceptionHandler)
    168    {
    169        ULONG res                 = RemoveVectoredExceptionHandler(mVectoredExceptionHandler);
    170        mVectoredExceptionHandler = nullptr;
    171        if (res == 0)
    172        {
    173            DWORD lastError = GetLastError();
    174            fprintf(stderr, "RemoveVectoredExceptionHandler failed: 0x%lx\n", lastError);
    175            return false;
    176        }
    177    }
    178    return true;
    179 }
    180 
    181 bool Win32PageFaultHandler::enable()
    182 {
    183    if (mVectoredExceptionHandler)
    184    {
    185        return true;
    186    }
    187 
    188    PVECTORED_EXCEPTION_HANDLER handler =
    189        reinterpret_cast<PVECTORED_EXCEPTION_HANDLER>(&VectoredExceptionHandler);
    190 
    191    mVectoredExceptionHandler = AddVectoredExceptionHandler(1, handler);
    192 
    193    if (!mVectoredExceptionHandler)
    194    {
    195        DWORD lastError = GetLastError();
    196        fprintf(stderr, "AddVectoredExceptionHandler failed: 0x%lx\n", lastError);
    197        return false;
    198    }
    199    return true;
    200 }
    201 }  // namespace
    202 
    203 // Set write protection
    204 bool ProtectMemory(uintptr_t start, size_t size)
    205 {
    206    return SetMemoryProtection(start, size, PAGE_READONLY);
    207 }
    208 
    209 // Allow reading and writing
    210 bool UnprotectMemory(uintptr_t start, size_t size)
    211 {
    212    return SetMemoryProtection(start, size, PAGE_READWRITE);
    213 }
    214 
    215 size_t GetPageSize()
    216 {
    217    SYSTEM_INFO info;
    218    GetSystemInfo(&info);
    219    return static_cast<size_t>(info.dwPageSize);
    220 }
    221 
    222 PageFaultHandler *CreatePageFaultHandler(PageFaultCallback callback)
    223 {
    224    gWin32PageFaultHandler = new Win32PageFaultHandler(callback);
    225    return gWin32PageFaultHandler;
    226 }
    227 
    228 uint64_t GetProcessMemoryUsageKB()
    229 {
    230    PROCESS_MEMORY_COUNTERS_EX pmc;
    231    ::GetProcessMemoryInfo(::GetCurrentProcess(), reinterpret_cast<PROCESS_MEMORY_COUNTERS *>(&pmc),
    232                           sizeof(pmc));
    233    return static_cast<uint64_t>(pmc.PrivateUsage) / 1024ull;
    234 }
    235 }  // namespace angle