commit 21c8e2c40719b4b84aa8160a03271f476cc238d8
parent 6def2987b48bccf9876af5eeb63c49e1f9712046
Author: Bob Owen <bobowencode@gmail.com>
Date: Mon, 6 Oct 2025 08:04:25 +0000
Bug 1980886 p2 - Provide a forward_iterator for accessing the ACE_HEADERs in an ACL. r=yjuglaret,win-reviewers,gstoll
Differential Revision: https://phabricator.services.mozilla.com/D266682
Diffstat:
4 files changed, 372 insertions(+), 0 deletions(-)
diff --git a/testing/gtest/mozilla/MozHelpers.h b/testing/gtest/mozilla/MozHelpers.h
@@ -58,8 +58,26 @@ namespace mozilla::gtest {
a; \
}, \
b)
+
+// Wrap EXPECT_DEATH_* macros to also disable the crash reporter.
+# define EXPECT_DEATH_WRAP(a, b) \
+ EXPECT_DEATH_IF_SUPPORTED( \
+ { \
+ mozilla::gtest::DisableCrashReporter(); \
+ a; \
+ }, \
+ b)
+# define EXPECT_DEBUG_DEATH_WRAP(a, b) \
+ EXPECT_DEBUG_DEATH( \
+ { \
+ mozilla::gtest::DisableCrashReporter(); \
+ a; \
+ }, \
+ b)
#else
# define ASSERT_DEATH_WRAP(a, b)
+# define EXPECT_DEATH_WRAP(a, b)
+# define EXPECT_DEBUG_DEATH_WRAP(a, b)
#endif
void DisableCrashReporter();
diff --git a/widget/windows/WinHeaderOnlyUtils.h b/widget/windows/WinHeaderOnlyUtils.h
@@ -21,6 +21,7 @@
#include "mozilla/Attributes.h"
#include "mozilla/DynamicallyLinkedFunctionPtr.h"
#include "mozilla/Maybe.h"
+#include "mozilla/NotNull.h"
#include "mozilla/ResultVariant.h"
#include "mozilla/UniquePtr.h"
#include "nsWindowsHelpers.h"
@@ -815,6 +816,109 @@ int MozPathGetDriveNumber(const T* aPath) {
return ToDriveNumber(aPath);
}
+/**
+ * Class to provide a forward_iterator for accessing the ACE_HEADERs in an ACL.
+ * ACE_HEADERs start after the ACL struct and know the size of their ACE.
+ */
+class AclAceRange {
+ public:
+ explicit AclAceRange(const NotNull<const ACL*> aAcl) : mAcl(aAcl) {}
+
+ class Iterator {
+ public:
+ using iterator_category = std::forward_iterator_tag;
+ using difference_type = WORD;
+ using value_type = const ACE_HEADER;
+ using pointer = value_type*;
+ using reference = value_type&;
+
+ // Constructs an end iterator.
+ Iterator() = default;
+
+ Iterator(const Iterator&) = default;
+ Iterator& operator=(const Iterator& aOther) = default;
+ Iterator(Iterator&&) = default;
+ Iterator& operator=(Iterator&& aOther) = default;
+
+ reference operator*() const {
+ MOZ_RELEASE_ASSERT(mAceCount,
+ "Trying to dereference past end of AclAceRange");
+ return *CurrentAceHeader();
+ }
+ pointer operator->() const {
+ MOZ_RELEASE_ASSERT(mAceCount,
+ "Trying to dereference past end of AclAceRange");
+ return CurrentAceHeader();
+ }
+
+ Iterator& operator++() {
+ MOZ_ASSERT(mAceCount, "Iterating past end of AclAceRange");
+ if (!mAceCount) {
+ return *this;
+ }
+
+ --mAceCount;
+ if (!mAceCount) {
+ return *this;
+ }
+
+ mCharCurrentAceHeader += CurrentAceHeader()->AceSize;
+ SetAtEndIfCurrentAcePastEndOfAcl();
+ return *this;
+ }
+
+ Iterator operator++(int) {
+ auto tmp = *this;
+ ++*this;
+ return tmp;
+ }
+
+ bool operator==(const Iterator& aOther) const {
+ return mAceCount == aOther.mAceCount;
+ }
+ bool operator!=(const Iterator& aOther) const { return !(*this == aOther); }
+
+ private:
+ friend class AclAceRange;
+
+ explicit Iterator(const NotNull<const ACL*> aAcl)
+ : mCharCurrentAceHeader(reinterpret_cast<const char*>(aAcl.get() + 1)),
+ mCharEndAcl(reinterpret_cast<const char*>(aAcl.get()) +
+ aAcl->AclSize),
+ mAceCount(aAcl->AceCount) {
+ if (mAceCount > 0) {
+ SetAtEndIfCurrentAcePastEndOfAcl();
+ } else if (mAceCount < 0) {
+ SetAtEnd();
+ }
+ }
+
+ void SetAtEnd() { mAceCount = 0; }
+
+ void SetAtEndIfCurrentAcePastEndOfAcl() {
+ if (mCharCurrentAceHeader + sizeof(ACE_HEADER) > mCharEndAcl ||
+ mCharCurrentAceHeader + CurrentAceHeader()->AceSize > mCharEndAcl) {
+ SetAtEnd();
+ }
+ }
+
+ pointer CurrentAceHeader() const {
+ return reinterpret_cast<const ACE_HEADER*>(mCharCurrentAceHeader);
+ }
+
+ const char* mCharCurrentAceHeader = nullptr;
+ const char* mCharEndAcl = nullptr;
+ // An mAceCount of 0 means we are at the end.
+ int mAceCount = 0;
+ };
+
+ Iterator begin() { return Iterator(mAcl); }
+ Iterator end() { return Iterator(); }
+
+ private:
+ const NotNull<const ACL*> mAcl;
+};
+
} // namespace mozilla
#endif // mozilla_WinHeaderOnlyUtils_h
diff --git a/widget/windows/tests/gtest/TestAclAceRange.cpp b/widget/windows/tests/gtest/TestAclAceRange.cpp
@@ -0,0 +1,249 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "WinHeaderOnlyUtils.h"
+
+#include <algorithm>
+
+#include "gtest/gtest.h"
+#include "mozilla/gtest/MozHelpers.h"
+
+using namespace mozilla;
+
+struct TestAcl {
+ ACL acl{ACL_REVISION, 0, sizeof(TestAcl), 3, 0};
+ ACCESS_ALLOWED_ACE ace1{
+ {ACCESS_ALLOWED_ACE_TYPE, OBJECT_INHERIT_ACE, sizeof(ACCESS_ALLOWED_ACE)},
+ GENERIC_READ,
+ 0};
+ ACCESS_ALLOWED_OBJECT_ACE ace2{{ACCESS_ALLOWED_OBJECT_ACE_TYPE, INHERITED_ACE,
+ sizeof(ACCESS_ALLOWED_OBJECT_ACE)},
+ GENERIC_READ,
+ 0};
+ ACCESS_DENIED_ACE ace3{
+ {ACCESS_DENIED_ACE_TYPE, INHERITED_ACE, sizeof(ACCESS_DENIED_ACE)},
+ GENERIC_READ,
+ 0};
+ NotNull<ACL*> AsAclPtr() { return WrapNotNull(reinterpret_cast<ACL*>(this)); }
+};
+
+TEST(AclAceRange, SimpleCount)
+{
+ TestAcl testAcl;
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ Unused << aceHeader;
+ ++aceCount;
+ }
+
+ ASSERT_EQ(aceCount, 3);
+}
+
+TEST(AclAceRange, SameAsGetAce)
+{
+ TestAcl testAcl;
+ int aceIdx = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ VOID* pGetAceHeader = nullptr;
+ EXPECT_TRUE(::GetAce(testAcl.AsAclPtr(), aceIdx, &pGetAceHeader));
+ auto* getAceHeader = static_cast<ACE_HEADER*>(pGetAceHeader);
+ EXPECT_EQ(getAceHeader->AceType, aceHeader.AceType);
+ EXPECT_EQ(getAceHeader->AceFlags, aceHeader.AceFlags);
+ EXPECT_EQ(getAceHeader->AceSize, aceHeader.AceSize);
+ ++aceIdx;
+ }
+}
+
+TEST(AclAceRange, WithFlagCount)
+{
+ TestAcl testAcl;
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ if (aceHeader.AceFlags & INHERITED_ACE) {
+ ++aceCount;
+ }
+ }
+
+ ASSERT_EQ(aceCount, 2);
+}
+
+TEST(AclAceRange, AclSizeCheckedAsWellAsCount)
+{
+ TestAcl testAcl;
+ testAcl.acl.AclSize -= sizeof(ACCESS_DENIED_ACE);
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ if (aceHeader.AceFlags & INHERITED_ACE) {
+ ++aceCount;
+ }
+ }
+
+ ASSERT_EQ(aceCount, 1);
+}
+
+TEST(AclAceRange, ChecksAceHeaderSizeInAclSize)
+{
+ TestAcl testAcl;
+ testAcl.acl.AclSize -= 1;
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ if (aceHeader.AceFlags & INHERITED_ACE) {
+ ++aceCount;
+ }
+ }
+
+ ASSERT_EQ(aceCount, 1);
+}
+
+TEST(AclAceRange, AceCountOfZeroResultsInNoIterations)
+{
+ TestAcl testAcl;
+ testAcl.acl.AceCount = 0;
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ Unused << aceHeader;
+ ++aceCount;
+ }
+
+ ASSERT_EQ(aceCount, 0);
+}
+
+TEST(AclAceRange, AclSizeTooSmallForAnyAcesResultsInNoIterations)
+{
+ TestAcl testAcl;
+ testAcl.acl.AclSize = sizeof(ACCESS_ALLOWED_ACE) - 1;
+ int aceCount = 0;
+ for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
+ Unused << aceHeader;
+ ++aceCount;
+ }
+
+ ASSERT_EQ(aceCount, 0);
+}
+
+TEST(AclAceRange, weakly_incrementable)
+{
+ TestAcl testAcl;
+ AclAceRange aclAceRange(testAcl.AsAclPtr());
+ auto iter = aclAceRange.begin();
+
+ EXPECT_TRUE(std::addressof(++iter) == std::addressof(iter))
+ << "addressof pre-increment result should match iterator";
+
+ // pre and post increment advance iterator.
+ EXPECT_EQ(iter->AceType, testAcl.ace2.Header.AceType);
+ EXPECT_EQ(iter->AceFlags, testAcl.ace2.Header.AceFlags);
+ EXPECT_EQ(iter->AceSize, testAcl.ace2.Header.AceSize);
+ iter++;
+ EXPECT_EQ(iter->AceType, testAcl.ace3.Header.AceType);
+ EXPECT_EQ(iter->AceFlags, testAcl.ace3.Header.AceFlags);
+ EXPECT_EQ(iter->AceSize, testAcl.ace3.Header.AceSize);
+
+ // Moveable.
+ auto moveConstructedIter(std::move(iter));
+ EXPECT_EQ(moveConstructedIter->AceType, testAcl.ace3.Header.AceType);
+ EXPECT_EQ(moveConstructedIter->AceFlags, testAcl.ace3.Header.AceFlags);
+ EXPECT_EQ(moveConstructedIter->AceSize, testAcl.ace3.Header.AceSize);
+ auto moveAssignedIter = std::move(iter);
+ EXPECT_EQ(moveAssignedIter->AceType, testAcl.ace3.Header.AceType);
+ EXPECT_EQ(moveAssignedIter->AceFlags, testAcl.ace3.Header.AceFlags);
+ EXPECT_EQ(moveAssignedIter->AceSize, testAcl.ace3.Header.AceSize);
+}
+
+TEST(AclAceRange, incrementable)
+{
+ TestAcl testAcl;
+ AclAceRange aclAceRange1(testAcl.AsAclPtr());
+ AclAceRange aclAceRange2(testAcl.AsAclPtr());
+ auto it1 = aclAceRange1.begin();
+ auto it2 = aclAceRange2.begin();
+
+ // bool(a == b) implies bool(a++ == b)
+ EXPECT_TRUE(it1 == it2) << "begin iterators for same ACL should be equal";
+ EXPECT_TRUE(it1++ == it2);
+ EXPECT_FALSE(it1 == it2);
+ EXPECT_FALSE(it1++ == it2);
+
+ // bool(a == b) implies bool(((void)a++, a) == ++b)
+ it1 = aclAceRange1.begin();
+ EXPECT_TRUE(it1 == it2);
+ EXPECT_TRUE(((void)it1++, it1) == ++it2);
+ it1 = aclAceRange1.begin();
+ EXPECT_FALSE(it1 == it2);
+ EXPECT_FALSE(((void)it1++, it1) == ++it2);
+
+ // Copyable.
+ auto copyConstructedIter(it2);
+ EXPECT_EQ(copyConstructedIter->AceType, testAcl.ace3.Header.AceType);
+ EXPECT_EQ(copyConstructedIter->AceFlags, testAcl.ace3.Header.AceFlags);
+ EXPECT_EQ(copyConstructedIter->AceSize, testAcl.ace3.Header.AceSize);
+ auto copyAssignedIter = it2;
+ EXPECT_EQ(copyAssignedIter->AceType, testAcl.ace3.Header.AceType);
+ EXPECT_EQ(copyAssignedIter->AceFlags, testAcl.ace3.Header.AceFlags);
+ EXPECT_EQ(copyAssignedIter->AceSize, testAcl.ace3.Header.AceSize);
+
+ // Default constructable.
+ AclAceRange::Iterator defaultConstructed;
+ EXPECT_TRUE(defaultConstructed == aclAceRange1.end());
+}
+
+TEST(AclAceRange, AlgorithmCountIf)
+{
+ TestAcl testAcl;
+ AclAceRange aclAceRange(testAcl.AsAclPtr());
+ auto aceCount = std::count_if(
+ aclAceRange.begin(), aclAceRange.end(),
+ [](const auto& hdr) { return hdr.AceFlags & INHERITED_ACE; });
+
+ ASSERT_EQ(aceCount, 2);
+}
+
+TEST(AclAceRange, AlgorithmAnyOf)
+{
+ TestAcl testAcl;
+ AclAceRange aclAceRange(testAcl.AsAclPtr());
+ auto anyInherited =
+ std::any_of(aclAceRange.begin(), aclAceRange.end(),
+ [](const auto& hdr) { return hdr.AceFlags & INHERITED_ACE; });
+
+ ASSERT_TRUE(anyInherited);
+}
+
+TEST(AclAceRange, DereferenceAtEndIsFatal)
+{
+#if DEBUG
+ const auto* msg =
+ "Assertion failure: mAceCount \\(Trying to dereference past end of "
+ "AclAceRange\\)";
+#else
+ const auto* msg = "";
+#endif
+
+ EXPECT_DEATH_WRAP(
+ {
+ TestAcl testAcl;
+ AclAceRange aclAceRange(testAcl.AsAclPtr());
+ auto aceItCurrent = aclAceRange.begin();
+ for (; aceItCurrent != aclAceRange.end(); ++aceItCurrent) {
+ }
+ *aceItCurrent;
+ },
+ msg);
+}
+
+TEST(AclAceRange, DebugAssertForIteratingPastEnd)
+{
+ EXPECT_DEBUG_DEATH_WRAP(
+ {
+ TestAcl testAcl;
+ AclAceRange aclAceRange(testAcl.AsAclPtr());
+ auto aceItCurrent = aclAceRange.begin();
+ for (; aceItCurrent != aclAceRange.end(); ++aceItCurrent) {
+ }
+ ++aceItCurrent;
+ },
+ "Assertion failure: mAceCount \\(Iterating past end of AclAceRange\\)");
+}
diff --git a/widget/windows/tests/gtest/moz.build b/widget/windows/tests/gtest/moz.build
@@ -5,6 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
UNIFIED_SOURCES += [
+ "TestAclAceRange.cpp",
"TestJumpListBuilder.cpp",
"TestWinDND.cpp",
"TestWindowGfx.cpp",