aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbabenko <babenko@yandex-team.com>2024-08-31 10:34:35 +0300
committerbabenko <babenko@yandex-team.com>2024-08-31 10:53:58 +0300
commit936bbf50940b23dafb36eac2117611136fa94b1e (patch)
tree3a16087d3b93508ecb0df3b0bf49fce4329099e4
parent14ddd77a270f3fd464f24ebc25f142c04d6269a8 (diff)
downloadydb-936bbf50940b23dafb36eac2117611136fa94b1e.tar.gz
YT-22642: Fix unaligned access UB
378099ca41e7698fba0ceda68b8d2b554e61b6ea
-rw-r--r--library/cpp/yt/misc/unaligned-inl.h31
-rw-r--r--library/cpp/yt/misc/unaligned.h23
-rw-r--r--yt/yt/core/bus/tcp/packet.cpp55
3 files changed, 94 insertions, 15 deletions
diff --git a/library/cpp/yt/misc/unaligned-inl.h b/library/cpp/yt/misc/unaligned-inl.h
new file mode 100644
index 0000000000..68e1c9b499
--- /dev/null
+++ b/library/cpp/yt/misc/unaligned-inl.h
@@ -0,0 +1,31 @@
+#ifndef UNALIGNED_INL_H_
+#error "Direct inclusion of this file is not allowed, include unaligned.h"
+// For the sake of sane code completion.
+#include "unaligned.h"
+#endif
+
+#include <cstring>
+
+namespace NYT {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <class T>
+ requires std::is_trivial_v<T>
+T UnalignedLoad(const T* ptr)
+{
+ T value;
+ std::memcpy(&value, ptr, sizeof(T));
+ return value;
+}
+
+template <class T>
+ requires std::is_trivial_v<T>
+void UnalignedStore(T* ptr, const T& value)
+{
+ std::memcpy(ptr, &value, sizeof(T));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT
diff --git a/library/cpp/yt/misc/unaligned.h b/library/cpp/yt/misc/unaligned.h
new file mode 100644
index 0000000000..68c124183f
--- /dev/null
+++ b/library/cpp/yt/misc/unaligned.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <type_traits>
+
+namespace NYT {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <class T>
+ requires std::is_trivial_v<T>
+T UnalignedLoad(const T* ptr);
+
+template <class T>
+ requires std::is_trivial_v<T>
+void UnalignedStore(T* ptr, const T& value);
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT
+
+#define UNALIGNED_INL_H_
+#include "unaligned-inl.h"
+#undef UNALIGNED_INL_H_
diff --git a/yt/yt/core/bus/tcp/packet.cpp b/yt/yt/core/bus/tcp/packet.cpp
index 039c005e8a..0d673b66bd 100644
--- a/yt/yt/core/bus/tcp/packet.cpp
+++ b/yt/yt/core/bus/tcp/packet.cpp
@@ -6,6 +6,8 @@
#include <library/cpp/yt/string/guid.h>
+#include <library/cpp/yt/misc/unaligned.h>
+
namespace NYT::NBus {
////////////////////////////////////////////////////////////////////////////////
@@ -82,8 +84,6 @@ protected:
TCompactVector<char, TypicalVariableHeaderSize> VariableHeader_;
size_t VariableHeaderSize_;
- ui32* PartSizes_;
- ui64* PartChecksums_;
int PartIndex_ = -1;
TSharedRefArray Message_;
@@ -148,6 +148,32 @@ protected:
{
return static_cast<TDerived*>(this);
}
+
+
+ ui32 GetPartSize(int index) const
+ {
+ return UnalignedLoad(PartSizes_ + index);
+ }
+
+ void SetPartSize(int index, ui32 size)
+ {
+ UnalignedStore(PartSizes_ + index, size);
+ }
+
+
+ ui64 GetPartChecksum(int index) const
+ {
+ return UnalignedLoad(PartChecksums_ + index);
+ }
+
+ void SetPartChecksum(int index, ui64 checksum)
+ {
+ UnalignedStore(PartChecksums_ + index, checksum);
+ }
+
+private:
+ ui32* PartSizes_;
+ ui64* PartChecksums_;
};
////////////////////////////////////////////////////////////////////////////////
@@ -283,7 +309,7 @@ private:
bool EndVariableHeaderPhase()
{
if (VerifyChecksum_) {
- auto expectedChecksum = PartChecksums_[FixedHeader_.PartCount];
+ auto expectedChecksum = GetPartChecksum(FixedHeader_.PartCount);
if (expectedChecksum != NullChecksum) {
auto actualChecksum = GetVariableChecksum();
if (expectedChecksum != actualChecksum) {
@@ -295,7 +321,7 @@ private:
}
for (int index = 0; index < static_cast<int>(FixedHeader_.PartCount); ++index) {
- ui32 partSize = PartSizes_[index];
+ ui32 partSize = GetPartSize(index);
if (partSize != NullPacketPartSize && partSize > MaxMessagePartSize) {
YT_LOG_ERROR("Invalid packet part size (PacketId: %v, PartIndex: %v, PartSize: %v)",
FixedHeader_.PacketId,
@@ -312,7 +338,7 @@ private:
bool EndMessagePartPhase()
{
if (VerifyChecksum_) {
- auto expectedChecksum = PartChecksums_[PartIndex_];
+ auto expectedChecksum = GetPartChecksum(PartIndex_);
if (expectedChecksum != NullChecksum) {
auto actualChecksum = GetChecksum(Parts_[PartIndex_]);
if (expectedChecksum != actualChecksum) {
@@ -337,7 +363,7 @@ private:
break;
}
- ui32 partSize = PartSizes_[PartIndex_];
+ ui32 partSize = GetPartSize(PartIndex_);
if (partSize == NullPacketPartSize) {
Parts_.push_back(TSharedRef());
} else if (partSize == 0) {
@@ -411,19 +437,18 @@ public:
AllocateVariableHeader();
for (int index = 0; index < static_cast<int>(Message_.Size()); ++index) {
- const auto& part = Message_[index];
- if (part) {
- PartSizes_[index] = part.Size();
- PartChecksums_[index] = generateChecksums && index < checksummedPartCount
- ? GetChecksum(part)
- : NullChecksum;
+ if (const auto& part = Message_[index]) {
+ SetPartSize(index, part.Size());
+ SetPartChecksum(
+ index,
+ generateChecksums && index < checksummedPartCount ? GetChecksum(part) : NullChecksum);
} else {
- PartSizes_[index] = NullPacketPartSize;
- PartChecksums_[index] = NullChecksum;
+ SetPartSize(index, NullPacketPartSize);
+ SetPartChecksum(index, NullChecksum);
}
}
- PartChecksums_[Message_.Size()] = generateChecksums ? GetVariableChecksum() : NullChecksum;
+ SetPartChecksum(Message_.Size(), generateChecksums ? GetVariableChecksum() : NullChecksum);
}
BeginPhase(EPacketPhase::FixedHeader, &FixedHeader_, sizeof (TPacketHeader));