diff options
author | babenko <babenko@yandex-team.com> | 2024-08-31 10:34:35 +0300 |
---|---|---|
committer | babenko <babenko@yandex-team.com> | 2024-08-31 10:53:58 +0300 |
commit | 936bbf50940b23dafb36eac2117611136fa94b1e (patch) | |
tree | 3a16087d3b93508ecb0df3b0bf49fce4329099e4 | |
parent | 14ddd77a270f3fd464f24ebc25f142c04d6269a8 (diff) | |
download | ydb-936bbf50940b23dafb36eac2117611136fa94b1e.tar.gz |
YT-22642: Fix unaligned access UB
378099ca41e7698fba0ceda68b8d2b554e61b6ea
-rw-r--r-- | library/cpp/yt/misc/unaligned-inl.h | 31 | ||||
-rw-r--r-- | library/cpp/yt/misc/unaligned.h | 23 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/packet.cpp | 55 |
3 files changed, 94 insertions, 15 deletions
diff --git a/library/cpp/yt/misc/unaligned-inl.h b/library/cpp/yt/misc/unaligned-inl.h new file mode 100644 index 0000000000..68e1c9b499 --- /dev/null +++ b/library/cpp/yt/misc/unaligned-inl.h @@ -0,0 +1,31 @@ +#ifndef UNALIGNED_INL_H_ +#error "Direct inclusion of this file is not allowed, include unaligned.h" +// For the sake of sane code completion. +#include "unaligned.h" +#endif + +#include <cstring> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> + requires std::is_trivial_v<T> +T UnalignedLoad(const T* ptr) +{ + T value; + std::memcpy(&value, ptr, sizeof(T)); + return value; +} + +template <class T> + requires std::is_trivial_v<T> +void UnalignedStore(T* ptr, const T& value) +{ + std::memcpy(ptr, &value, sizeof(T)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/misc/unaligned.h b/library/cpp/yt/misc/unaligned.h new file mode 100644 index 0000000000..68c124183f --- /dev/null +++ b/library/cpp/yt/misc/unaligned.h @@ -0,0 +1,23 @@ +#pragma once + +#include <type_traits> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> + requires std::is_trivial_v<T> +T UnalignedLoad(const T* ptr); + +template <class T> + requires std::is_trivial_v<T> +void UnalignedStore(T* ptr, const T& value); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + +#define UNALIGNED_INL_H_ +#include "unaligned-inl.h" +#undef UNALIGNED_INL_H_ diff --git a/yt/yt/core/bus/tcp/packet.cpp b/yt/yt/core/bus/tcp/packet.cpp index 039c005e8a..0d673b66bd 100644 --- a/yt/yt/core/bus/tcp/packet.cpp +++ b/yt/yt/core/bus/tcp/packet.cpp @@ -6,6 +6,8 @@ #include <library/cpp/yt/string/guid.h> +#include <library/cpp/yt/misc/unaligned.h> + namespace NYT::NBus { //////////////////////////////////////////////////////////////////////////////// @@ -82,8 +84,6 @@ protected: TCompactVector<char, TypicalVariableHeaderSize> VariableHeader_; size_t VariableHeaderSize_; - ui32* PartSizes_; - ui64* PartChecksums_; int PartIndex_ = -1; TSharedRefArray Message_; @@ -148,6 +148,32 @@ protected: { return static_cast<TDerived*>(this); } + + + ui32 GetPartSize(int index) const + { + return UnalignedLoad(PartSizes_ + index); + } + + void SetPartSize(int index, ui32 size) + { + UnalignedStore(PartSizes_ + index, size); + } + + + ui64 GetPartChecksum(int index) const + { + return UnalignedLoad(PartChecksums_ + index); + } + + void SetPartChecksum(int index, ui64 checksum) + { + UnalignedStore(PartChecksums_ + index, checksum); + } + +private: + ui32* PartSizes_; + ui64* PartChecksums_; }; //////////////////////////////////////////////////////////////////////////////// @@ -283,7 +309,7 @@ private: bool EndVariableHeaderPhase() { if (VerifyChecksum_) { - auto expectedChecksum = PartChecksums_[FixedHeader_.PartCount]; + auto expectedChecksum = GetPartChecksum(FixedHeader_.PartCount); if (expectedChecksum != NullChecksum) { auto actualChecksum = GetVariableChecksum(); if (expectedChecksum != actualChecksum) { @@ -295,7 +321,7 @@ private: } for (int index = 0; index < static_cast<int>(FixedHeader_.PartCount); ++index) { - ui32 partSize = PartSizes_[index]; + ui32 partSize = GetPartSize(index); if (partSize != NullPacketPartSize && partSize > MaxMessagePartSize) { YT_LOG_ERROR("Invalid packet part size (PacketId: %v, PartIndex: %v, PartSize: %v)", FixedHeader_.PacketId, @@ -312,7 +338,7 @@ private: bool EndMessagePartPhase() { if (VerifyChecksum_) { - auto expectedChecksum = PartChecksums_[PartIndex_]; + auto expectedChecksum = GetPartChecksum(PartIndex_); if (expectedChecksum != NullChecksum) { auto actualChecksum = GetChecksum(Parts_[PartIndex_]); if (expectedChecksum != actualChecksum) { @@ -337,7 +363,7 @@ private: break; } - ui32 partSize = PartSizes_[PartIndex_]; + ui32 partSize = GetPartSize(PartIndex_); if (partSize == NullPacketPartSize) { Parts_.push_back(TSharedRef()); } else if (partSize == 0) { @@ -411,19 +437,18 @@ public: AllocateVariableHeader(); for (int index = 0; index < static_cast<int>(Message_.Size()); ++index) { - const auto& part = Message_[index]; - if (part) { - PartSizes_[index] = part.Size(); - PartChecksums_[index] = generateChecksums && index < checksummedPartCount - ? GetChecksum(part) - : NullChecksum; + if (const auto& part = Message_[index]) { + SetPartSize(index, part.Size()); + SetPartChecksum( + index, + generateChecksums && index < checksummedPartCount ? GetChecksum(part) : NullChecksum); } else { - PartSizes_[index] = NullPacketPartSize; - PartChecksums_[index] = NullChecksum; + SetPartSize(index, NullPacketPartSize); + SetPartChecksum(index, NullChecksum); } } - PartChecksums_[Message_.Size()] = generateChecksums ? GetVariableChecksum() : NullChecksum; + SetPartChecksum(Message_.Size(), generateChecksums ? GetVariableChecksum() : NullChecksum); } BeginPhase(EPacketPhase::FixedHeader, &FixedHeader_, sizeof (TPacketHeader)); |