summaryrefslogtreecommitdiffstats
path: root/library/cpp
diff options
context:
space:
mode:
authoryegorskii <[email protected]>2025-10-22 13:52:37 +0300
committeryegorskii <[email protected]>2025-10-22 14:43:37 +0300
commit46d3b0e28ab2e1fc227fd4f22bb12d3263f9c4bc (patch)
treee8d0f627df4c13815c8655f0747d66a9fedff3eb /library/cpp
parentb099573c0c9ebca4782d99a50f2b302413ae7dfd (diff)
fix data race in lwtrace
commit_hash:a3736570695184ad17b24a2a6da642132020197c
Diffstat (limited to 'library/cpp')
-rw-r--r--library/cpp/lwtrace/shuttle.h16
-rw-r--r--library/cpp/lwtrace/shuttle_race_ut.cpp283
-rw-r--r--library/cpp/lwtrace/ut/ya.make6
3 files changed, 302 insertions, 3 deletions
diff --git a/library/cpp/lwtrace/shuttle.h b/library/cpp/lwtrace/shuttle.h
index c3e31a8223b..686f63ac85f 100644
--- a/library/cpp/lwtrace/shuttle.h
+++ b/library/cpp/lwtrace/shuttle.h
@@ -198,9 +198,19 @@ namespace NLWTrace {
}
// Checks if there is at least one shuttle in orbit
- // NOTE: used by every LWTRACK macro check, so keep it optimized - do not lock
- bool HasShuttles() const {
- return HeadNoLock.Get();
+ // Uses atomic read to synchronize with NotConcurrent() and avoid data
+ // races, still optimized and no lock
+ bool HasShuttles() const
+ {
+ static_assert(sizeof(HeadNoLock) == sizeof(TAtomic));
+ const TAtomic* headPtr =
+ reinterpret_cast<const TAtomic*>(&HeadNoLock);
+ TAtomicBase value = AtomicGet(*headPtr);
+ // Return true for any non-null value (including lock sentinel 0x1)
+ // Lock sentinel means some operation is in progress
+ // This allows harmless false positives but prevents data loss from
+ // false negatives
+ return value != 0;
}
void AddShuttle(const TShuttlePtr& shuttle) {
diff --git a/library/cpp/lwtrace/shuttle_race_ut.cpp b/library/cpp/lwtrace/shuttle_race_ut.cpp
new file mode 100644
index 00000000000..4291d879423
--- /dev/null
+++ b/library/cpp/lwtrace/shuttle_race_ut.cpp
@@ -0,0 +1,283 @@
+#include "shuttle.h"
+
+#include "probe.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/generic/vector.h>
+#include <util/system/thread.h>
+
+#include <atomic>
+
+using namespace NLWTrace;
+
+class TMockShuttle: public IShuttle
+{
+public:
+ TMockShuttle(ui64 traceIdx, ui64 spanId)
+ : IShuttle(traceIdx, spanId)
+ {}
+
+protected:
+ bool DoAddProbe(TProbe*, const TParams&, ui64) override
+ {
+ return true;
+ }
+
+ void DoEndOfTrack() override
+ {}
+
+ void DoDrop() override
+ {}
+
+ void DoSerialize(TShuttleTrace&) override
+ {}
+
+ bool DoFork(TShuttlePtr&) override
+ {
+ return true;
+ }
+
+ bool DoJoin(const TShuttlePtr&) override
+ {
+ return true;
+ }
+};
+
+Y_UNIT_TEST_SUITE(TOrbitMultithreadedUsage)
+{
+ // Test HasShuttles() calls while modifying orbit
+ Y_UNIT_TEST(HasShuttlesAndAddShuttle)
+ {
+ TOrbit orbit;
+ std::atomic<bool> stopFlag{false};
+
+ // user branching and atomic counters to increase race conditions
+ // probability
+ std::atomic<size_t> hasShuttlesCount{0};
+
+ constexpr size_t numShuttles = 100;
+
+ // Reader thread: continuously calls HasShuttles()
+ auto reader = [&]()
+ {
+ while (!stopFlag.load()) {
+ bool result = orbit.HasShuttles();
+ if (result) {
+ hasShuttlesCount.fetch_add(1);
+ }
+ }
+ };
+
+ // Writer thread: continuously adds shuttles
+ auto writer = [&]()
+ {
+ for (size_t i = 0; i < numShuttles; ++i) {
+ orbit.AddShuttle(TShuttlePtr(new TMockShuttle(1, i)));
+ }
+ };
+
+ TVector<THolder<TThread>> threads;
+ threads.emplace_back(MakeHolder<TThread>(reader));
+ threads.emplace_back(MakeHolder<TThread>(reader));
+ threads.emplace_back(MakeHolder<TThread>(writer));
+
+ for (auto& t: threads) {
+ t->Start();
+ }
+
+ // Let writer finish
+ threads[2]->Join();
+
+ // Stop readers
+ stopFlag.store(true);
+ threads[0]->Join();
+ threads[1]->Join();
+
+ UNIT_ASSERT_LT(0, hasShuttlesCount.load());
+ UNIT_ASSERT(orbit.HasShuttles());
+ }
+
+ // Test the race condition from the tsan tests crash: Fork() vs
+ // HasShuttles()
+ Y_UNIT_TEST(ForkAndHasShuttles)
+ {
+ TOrbit orbit;
+
+ constexpr size_t numShuttles = 10;
+ constexpr size_t numForks = 100;
+
+ // Add some shuttles to orbit
+ for (size_t i = 0; i < numShuttles; ++i) {
+ orbit.AddShuttle(TShuttlePtr(new TMockShuttle(1, i)));
+ }
+
+ std::atomic<bool> stopFlag{false};
+
+ // user branching and atomic counters to increase race conditions
+ // probability
+ std::atomic<size_t> forkCount{0};
+ std::atomic<size_t> checkCount{0};
+
+ // Thread 1: Continuously calls Fork()
+ auto forker = [&]()
+ {
+ for (size_t i = 0; i < numForks; ++i) {
+ TOrbit tempOrbit;
+ if (orbit.Fork(tempOrbit)) {
+ forkCount.fetch_add(1);
+ }
+ }
+ };
+
+ // Thread 2: Continuously calls HasShuttles()
+ auto checker = [&]()
+ {
+ while (!stopFlag.load()) {
+ bool result = orbit.HasShuttles();
+ if (result) {
+ checkCount.fetch_add(1);
+ }
+ }
+ };
+
+ TThread t1(forker);
+ TThread t2(checker);
+
+ t1.Start();
+ t2.Start();
+
+ t1.Join();
+ stopFlag.store(true);
+ t2.Join();
+
+ UNIT_ASSERT_EQUAL(numForks, forkCount.load());
+ UNIT_ASSERT_LT(0, checkCount.load());
+ UNIT_ASSERT(orbit.HasShuttles());
+ }
+
+ // Test the Serialize() race condition
+ Y_UNIT_TEST(SerializeAndHasShuttles)
+ {
+ TOrbit orbit;
+
+ constexpr size_t numShuttles = 10;
+ constexpr size_t numIterations = 100;
+ constexpr size_t shuttlesPerIteration = 2;
+
+ // Add shuttles
+ for (size_t i = 0; i < numShuttles; ++i) {
+ auto shuttle = new TMockShuttle(1, i);
+ orbit.AddShuttle(TShuttlePtr(shuttle));
+ }
+
+ std::atomic<bool> stopFlag{false};
+
+ // user branching and atomic counters to increase race conditions
+ // probability
+ std::atomic<size_t> serializeCount{0};
+ std::atomic<size_t> checkCount{0};
+
+ // Thread 1: Serialize (modifies shuttle chain via Drop/Detach/Swap)
+ auto serializer = [&]()
+ {
+ for (size_t i = 0; i < numIterations; ++i) {
+ TShuttleTrace trace;
+ orbit.Serialize(1, trace);
+ serializeCount.fetch_add(1);
+
+ // Re-add shuttles for next iteration
+ for (size_t j = 0; j < shuttlesPerIteration; ++j) {
+ auto shuttle = new TMockShuttle(1, i * numIterations + j);
+ orbit.AddShuttle(TShuttlePtr(shuttle));
+ }
+ }
+ };
+
+ // Thread 2: Check HasShuttles
+ auto checker = [&]()
+ {
+ while (!stopFlag.load()) {
+ bool hasShuttles = orbit.HasShuttles();
+ if (hasShuttles) {
+ checkCount.fetch_add(1);
+ }
+ }
+ };
+
+ TThread t1(serializer);
+ TThread t2(checker);
+
+ t1.Start();
+ t2.Start();
+
+ t1.Join();
+ stopFlag.store(true);
+ t2.Join();
+
+ UNIT_ASSERT(orbit.HasShuttles());
+ }
+
+ // Stress test: Many threads doing many operations
+ Y_UNIT_TEST(StressTestMultipleOperations)
+ {
+ TOrbit orbit;
+
+ // user branching and atomic counters to increase race conditions
+ // probability
+ std::atomic<size_t> totalOperations{0};
+ std::atomic<size_t> addCount{0};
+ std::atomic<size_t> checkCount{0};
+ std::atomic<size_t> serializeCount{0};
+ const size_t numThreads = 8;
+ const size_t operationsPerThread = 1000;
+
+ auto worker = [&](size_t threadId)
+ {
+ for (size_t i = 0; i < operationsPerThread; ++i) {
+ // Mix of operations based on iteration
+ if (i % 5 == 0) {
+ auto shuttle = new TMockShuttle(threadId, i);
+ orbit.AddShuttle(TShuttlePtr(shuttle));
+ addCount.fetch_add(1);
+ } else if (i % 5 == 1) {
+ bool hasShuttles = orbit.HasShuttles();
+ if (hasShuttles) {
+ checkCount.fetch_add(1);
+ }
+ } else if (i % 5 == 2) {
+ TOrbit childOrbit;
+ orbit.Fork(childOrbit);
+ } else if (i % 5 == 3) {
+ TShuttleTrace trace;
+ orbit.Serialize(threadId, trace);
+ serializeCount.fetch_add(1);
+ } else {
+ bool hasShuttle = orbit.HasShuttle(threadId);
+ if (hasShuttle) {
+ checkCount.fetch_add(1);
+ }
+ }
+ totalOperations.fetch_add(1);
+ }
+ };
+
+ TVector<THolder<TThread>> threads;
+ for (size_t i = 0; i < numThreads; ++i) {
+ threads.emplace_back(
+ MakeHolder<TThread>([&worker, i]() { worker(i); }));
+ }
+
+ for (auto& t: threads) {
+ t->Start();
+ }
+
+ for (auto& t: threads) {
+ t->Join();
+ }
+
+ UNIT_ASSERT_EQUAL(
+ numThreads * operationsPerThread,
+ totalOperations.load());
+ }
+}
diff --git a/library/cpp/lwtrace/ut/ya.make b/library/cpp/lwtrace/ut/ya.make
index 8b7c885786b..8028bf9d6da 100644
--- a/library/cpp/lwtrace/ut/ya.make
+++ b/library/cpp/lwtrace/ut/ya.make
@@ -2,6 +2,12 @@ UNITTEST_FOR(library/cpp/lwtrace)
FORK_SUBTESTS()
+IF (SANITIZER_TYPE == "thread")
+ SRCS(
+ shuttle_race_ut.cpp
+ )
+ENDIF()
+
SRCS(
log_ut.cpp
trace_ut.cpp