diff options
| author | yegorskii <[email protected]> | 2025-10-22 13:52:37 +0300 |
|---|---|---|
| committer | yegorskii <[email protected]> | 2025-10-22 14:43:37 +0300 |
| commit | 46d3b0e28ab2e1fc227fd4f22bb12d3263f9c4bc (patch) | |
| tree | e8d0f627df4c13815c8655f0747d66a9fedff3eb /library/cpp | |
| parent | b099573c0c9ebca4782d99a50f2b302413ae7dfd (diff) | |
fix data race in lwtrace
commit_hash:a3736570695184ad17b24a2a6da642132020197c
Diffstat (limited to 'library/cpp')
| -rw-r--r-- | library/cpp/lwtrace/shuttle.h | 16 | ||||
| -rw-r--r-- | library/cpp/lwtrace/shuttle_race_ut.cpp | 283 | ||||
| -rw-r--r-- | library/cpp/lwtrace/ut/ya.make | 6 |
3 files changed, 302 insertions, 3 deletions
diff --git a/library/cpp/lwtrace/shuttle.h b/library/cpp/lwtrace/shuttle.h index c3e31a8223b..686f63ac85f 100644 --- a/library/cpp/lwtrace/shuttle.h +++ b/library/cpp/lwtrace/shuttle.h @@ -198,9 +198,19 @@ namespace NLWTrace { } // Checks if there is at least one shuttle in orbit - // NOTE: used by every LWTRACK macro check, so keep it optimized - do not lock - bool HasShuttles() const { - return HeadNoLock.Get(); + // Uses atomic read to synchronize with NotConcurrent() and avoid data + // races, still optimized and no lock + bool HasShuttles() const + { + static_assert(sizeof(HeadNoLock) == sizeof(TAtomic)); + const TAtomic* headPtr = + reinterpret_cast<const TAtomic*>(&HeadNoLock); + TAtomicBase value = AtomicGet(*headPtr); + // Return true for any non-null value (including lock sentinel 0x1) + // Lock sentinel means some operation is in progress + // This allows harmless false positives but prevents data loss from + // false negatives + return value != 0; } void AddShuttle(const TShuttlePtr& shuttle) { diff --git a/library/cpp/lwtrace/shuttle_race_ut.cpp b/library/cpp/lwtrace/shuttle_race_ut.cpp new file mode 100644 index 00000000000..4291d879423 --- /dev/null +++ b/library/cpp/lwtrace/shuttle_race_ut.cpp @@ -0,0 +1,283 @@ +#include "shuttle.h" + +#include "probe.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/vector.h> +#include <util/system/thread.h> + +#include <atomic> + +using namespace NLWTrace; + +class TMockShuttle: public IShuttle +{ +public: + TMockShuttle(ui64 traceIdx, ui64 spanId) + : IShuttle(traceIdx, spanId) + {} + +protected: + bool DoAddProbe(TProbe*, const TParams&, ui64) override + { + return true; + } + + void DoEndOfTrack() override + {} + + void DoDrop() override + {} + + void DoSerialize(TShuttleTrace&) override + {} + + bool DoFork(TShuttlePtr&) override + { + return true; + } + + bool DoJoin(const TShuttlePtr&) override + { + return true; + } +}; + +Y_UNIT_TEST_SUITE(TOrbitMultithreadedUsage) +{ + // Test HasShuttles() calls while modifying orbit + Y_UNIT_TEST(HasShuttlesAndAddShuttle) + { + TOrbit orbit; + std::atomic<bool> stopFlag{false}; + + // user branching and atomic counters to increase race conditions + // probability + std::atomic<size_t> hasShuttlesCount{0}; + + constexpr size_t numShuttles = 100; + + // Reader thread: continuously calls HasShuttles() + auto reader = [&]() + { + while (!stopFlag.load()) { + bool result = orbit.HasShuttles(); + if (result) { + hasShuttlesCount.fetch_add(1); + } + } + }; + + // Writer thread: continuously adds shuttles + auto writer = [&]() + { + for (size_t i = 0; i < numShuttles; ++i) { + orbit.AddShuttle(TShuttlePtr(new TMockShuttle(1, i))); + } + }; + + TVector<THolder<TThread>> threads; + threads.emplace_back(MakeHolder<TThread>(reader)); + threads.emplace_back(MakeHolder<TThread>(reader)); + threads.emplace_back(MakeHolder<TThread>(writer)); + + for (auto& t: threads) { + t->Start(); + } + + // Let writer finish + threads[2]->Join(); + + // Stop readers + stopFlag.store(true); + threads[0]->Join(); + threads[1]->Join(); + + UNIT_ASSERT_LT(0, hasShuttlesCount.load()); + UNIT_ASSERT(orbit.HasShuttles()); + } + + // Test the race condition from the tsan tests crash: Fork() vs + // HasShuttles() + Y_UNIT_TEST(ForkAndHasShuttles) + { + TOrbit orbit; + + constexpr size_t numShuttles = 10; + constexpr size_t numForks = 100; + + // Add some shuttles to orbit + for (size_t i = 0; i < numShuttles; ++i) { + orbit.AddShuttle(TShuttlePtr(new TMockShuttle(1, i))); + } + + std::atomic<bool> stopFlag{false}; + + // user branching and atomic counters to increase race conditions + // probability + std::atomic<size_t> forkCount{0}; + std::atomic<size_t> checkCount{0}; + + // Thread 1: Continuously calls Fork() + auto forker = [&]() + { + for (size_t i = 0; i < numForks; ++i) { + TOrbit tempOrbit; + if (orbit.Fork(tempOrbit)) { + forkCount.fetch_add(1); + } + } + }; + + // Thread 2: Continuously calls HasShuttles() + auto checker = [&]() + { + while (!stopFlag.load()) { + bool result = orbit.HasShuttles(); + if (result) { + checkCount.fetch_add(1); + } + } + }; + + TThread t1(forker); + TThread t2(checker); + + t1.Start(); + t2.Start(); + + t1.Join(); + stopFlag.store(true); + t2.Join(); + + UNIT_ASSERT_EQUAL(numForks, forkCount.load()); + UNIT_ASSERT_LT(0, checkCount.load()); + UNIT_ASSERT(orbit.HasShuttles()); + } + + // Test the Serialize() race condition + Y_UNIT_TEST(SerializeAndHasShuttles) + { + TOrbit orbit; + + constexpr size_t numShuttles = 10; + constexpr size_t numIterations = 100; + constexpr size_t shuttlesPerIteration = 2; + + // Add shuttles + for (size_t i = 0; i < numShuttles; ++i) { + auto shuttle = new TMockShuttle(1, i); + orbit.AddShuttle(TShuttlePtr(shuttle)); + } + + std::atomic<bool> stopFlag{false}; + + // user branching and atomic counters to increase race conditions + // probability + std::atomic<size_t> serializeCount{0}; + std::atomic<size_t> checkCount{0}; + + // Thread 1: Serialize (modifies shuttle chain via Drop/Detach/Swap) + auto serializer = [&]() + { + for (size_t i = 0; i < numIterations; ++i) { + TShuttleTrace trace; + orbit.Serialize(1, trace); + serializeCount.fetch_add(1); + + // Re-add shuttles for next iteration + for (size_t j = 0; j < shuttlesPerIteration; ++j) { + auto shuttle = new TMockShuttle(1, i * numIterations + j); + orbit.AddShuttle(TShuttlePtr(shuttle)); + } + } + }; + + // Thread 2: Check HasShuttles + auto checker = [&]() + { + while (!stopFlag.load()) { + bool hasShuttles = orbit.HasShuttles(); + if (hasShuttles) { + checkCount.fetch_add(1); + } + } + }; + + TThread t1(serializer); + TThread t2(checker); + + t1.Start(); + t2.Start(); + + t1.Join(); + stopFlag.store(true); + t2.Join(); + + UNIT_ASSERT(orbit.HasShuttles()); + } + + // Stress test: Many threads doing many operations + Y_UNIT_TEST(StressTestMultipleOperations) + { + TOrbit orbit; + + // user branching and atomic counters to increase race conditions + // probability + std::atomic<size_t> totalOperations{0}; + std::atomic<size_t> addCount{0}; + std::atomic<size_t> checkCount{0}; + std::atomic<size_t> serializeCount{0}; + const size_t numThreads = 8; + const size_t operationsPerThread = 1000; + + auto worker = [&](size_t threadId) + { + for (size_t i = 0; i < operationsPerThread; ++i) { + // Mix of operations based on iteration + if (i % 5 == 0) { + auto shuttle = new TMockShuttle(threadId, i); + orbit.AddShuttle(TShuttlePtr(shuttle)); + addCount.fetch_add(1); + } else if (i % 5 == 1) { + bool hasShuttles = orbit.HasShuttles(); + if (hasShuttles) { + checkCount.fetch_add(1); + } + } else if (i % 5 == 2) { + TOrbit childOrbit; + orbit.Fork(childOrbit); + } else if (i % 5 == 3) { + TShuttleTrace trace; + orbit.Serialize(threadId, trace); + serializeCount.fetch_add(1); + } else { + bool hasShuttle = orbit.HasShuttle(threadId); + if (hasShuttle) { + checkCount.fetch_add(1); + } + } + totalOperations.fetch_add(1); + } + }; + + TVector<THolder<TThread>> threads; + for (size_t i = 0; i < numThreads; ++i) { + threads.emplace_back( + MakeHolder<TThread>([&worker, i]() { worker(i); })); + } + + for (auto& t: threads) { + t->Start(); + } + + for (auto& t: threads) { + t->Join(); + } + + UNIT_ASSERT_EQUAL( + numThreads * operationsPerThread, + totalOperations.load()); + } +} diff --git a/library/cpp/lwtrace/ut/ya.make b/library/cpp/lwtrace/ut/ya.make index 8b7c885786b..8028bf9d6da 100644 --- a/library/cpp/lwtrace/ut/ya.make +++ b/library/cpp/lwtrace/ut/ya.make @@ -2,6 +2,12 @@ UNITTEST_FOR(library/cpp/lwtrace) FORK_SUBTESTS() +IF (SANITIZER_TYPE == "thread") + SRCS( + shuttle_race_ut.cpp + ) +ENDIF() + SRCS( log_ut.cpp trace_ut.cpp |
