#pragma once
#include "helpers.h"
#include <library/cpp/sse/sse.h>
#include <library/cpp/yt/assert/assert.h>
#include <util/generic/array_ref.h>
#include <algorithm>
#include <optional>
namespace NErasure {
template <class TCodecTraits, class TBlobType = typename TCodecTraits::TBlobType>
static inline TBlobType Xor(const std::vector<TBlobType>& refs) {
using TBufferType = typename TCodecTraits::TBufferType;
size_t size = refs.front().Size();
TBufferType result = TCodecTraits::AllocateBuffer(size); // this also fills the buffer with zeros
for (const TBlobType& ref : refs) {
const char* data = reinterpret_cast<const char*>(ref.Begin());
size_t pos = 0;
#ifdef ARCADIA_SSE
for (; pos + sizeof(__m128i) <= size; pos += sizeof(__m128i)) {
__m128i* dst = reinterpret_cast<__m128i*>(result.Begin() + pos);
const __m128i* src = reinterpret_cast<const __m128i*>(data + pos);
_mm_storeu_si128(dst, _mm_xor_si128(_mm_loadu_si128(src), _mm_loadu_si128(dst)));
}
#endif
for (; pos < size; ++pos) {
*(result.Begin() + pos) ^= data[pos];
}
}
return TCodecTraits::FromBufferToBlob(std::move(result));
}
//! Locally Reconstructable Codes
/*!
* See https://www.usenix.org/conference/usenixfederatedconferencesweek/erasure-coding-windows-azure-storage
* for more details.
*/
template <int DataPartCount, int ParityPartCount, int WordSize, class TCodecTraits>
class TLrcCodecBase
: public ICodec<typename TCodecTraits::TBlobType>
{
static_assert(DataPartCount % 2 == 0, "Data part count must be even.");
static_assert(ParityPartCount == 4, "Now we only support n-2-2 scheme for LRC codec");
static_assert(1 + DataPartCount / 2 < (1 << (WordSize / 2)), "Data part count should be enough small to construct proper matrix.");
public:
//! Main blob for storing data.
using TBlobType = typename TCodecTraits::TBlobType;
//! Main mutable blob for decoding data.
using TMutableBlobType = typename TCodecTraits::TMutableBlobType;
static constexpr ui64 RequiredDataAlignment = alignof(ui64);
TLrcCodecBase() {
Groups_[0] = MakeSegment(0, DataPartCount / 2);
// Xor.
Groups_[0].push_back(DataPartCount);
Groups_[1] = MakeSegment(DataPartCount / 2, DataPartCount);
// Xor.
Groups_[1].push_back(DataPartCount + 1);
constexpr int totalPartCount = DataPartCount + ParityPartCount;
if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
CanRepair_.resize(1 << totalPartCount);
for (int mask = 0; mask < (1 << totalPartCount); ++mask) {
TPartIndexList erasedIndices;
for (size_t i = 0; i < totalPartCount; ++i) {
if ((mask & (1 << i)) == 0) {
erasedIndices.push_back(i);
}
}
CanRepair_[mask] = CalculateCanRepair(erasedIndices);
}
}
}
/*! Note that if you want to restore any internal data, blocks offsets must by WordSize * sizeof(long) aligned.
* Though it is possible to restore unaligned data if no more than one index in each Group is failed. See unittests for this case.
*/
std::vector<TBlobType> Decode(
const std::vector<TBlobType>& blocks,
const TPartIndexList& erasedIndices) const override
{
if (erasedIndices.empty()) {
return std::vector<TBlobType>();
}
size_t blockLength = blocks.front().Size();
for (size_t i = 1; i < blocks.size(); ++i) {
YT_VERIFY(blocks[i].Size() == blockLength);
}
TPartIndexList indices = UniqueSortedIndices(erasedIndices);
// We can restore one block by xor.
if (indices.size() == 1) {
int index = erasedIndices.front();
for (size_t i = 0; i < 2; ++i) {
if (Contains(Groups_[i], index)) {
return std::vector<TBlobType>(1, Xor<TCodecTraits>(blocks));
}
}
}
TPartIndexList recoveryIndices = GetRepairIndices(indices).value();
// We can restore two blocks from different groups using xor.
if (indices.size() == 2 &&
indices.back() < DataPartCount + 2 &&
recoveryIndices.back() < DataPartCount + 2)
{
std::vector<TBlobType> result;
for (int index : indices) {
for (size_t groupIndex = 0; groupIndex < 2; ++groupIndex) {
if (!Contains(Groups_[groupIndex], index)) {
continue;
}
std::vector<TBlobType> correspondingBlocks;
for (int pos : Groups_[groupIndex]) {
for (size_t i = 0; i < blocks.size(); ++i) {
if (recoveryIndices[i] != pos) {
continue;
}
correspondingBlocks.push_back(blocks[i]);
}
}
result.push_back(Xor<TCodecTraits>(correspondingBlocks));
}
}
return result;
}
return FallbackToCodecDecode(blocks, std::move(indices));
}
bool CanRepair(const TPartIndexList& erasedIndices) const final {
constexpr int totalPartCount = DataPartCount + ParityPartCount;
if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
int mask = (1 << (totalPartCount)) - 1;
for (int index : erasedIndices) {
mask -= (1 << index);
}
return CanRepair_[mask];
} else {
return CalculateCanRepair(erasedIndices);
}
}
bool CanRepair(const TPartIndexSet& erasedIndicesMask) const final {
constexpr int totalPartCount = DataPartCount + ParityPartCount;
if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
TPartIndexSet mask = erasedIndicesMask;
return CanRepair_[mask.flip().to_ulong()];
} else {
TPartIndexList erasedIndices;
for (size_t i = 0; i < erasedIndicesMask.size(); ++i) {
if (erasedIndicesMask[i]) {
erasedIndices.push_back(i);
}
}
return CalculateCanRepair(erasedIndices);
}
}
std::optional<TPartIndexList> GetRepairIndices(const TPartIndexList& erasedIndices) const final {
if (erasedIndices.empty()) {
return TPartIndexList();
}
TPartIndexList indices = UniqueSortedIndices(erasedIndices);
if (indices.size() > ParityPartCount) {
return std::nullopt;
}
// One erasure from data or xor blocks.
if (indices.size() == 1) {
int index = indices.front();
for (size_t i = 0; i < 2; ++i) {
if (Contains(Groups_[i], index)) {
return Difference(Groups_[i], index);
}
}
}
// Null if we have 4 erasures in one group.
if (indices.size() == ParityPartCount) {
bool intersectsAny = true;
for (size_t i = 0; i < 2; ++i) {
if (Intersection(indices, Groups_[i]).empty()) {
intersectsAny = false;
}
}
if (!intersectsAny) {
return std::nullopt;
}
}
// Calculate coverage of each group.
int groupCoverage[2] = {};
for (int index : indices) {
for (size_t i = 0; i < 2; ++i) {
if (Contains(Groups_[i], index)) {
++groupCoverage[i];
}
}
}
// Two erasures, one in each group.
if (indices.size() == 2 && groupCoverage[0] == 1 && groupCoverage[1] == 1) {
return Difference(Union(Groups_[0], Groups_[1]), indices);
}
// Erasures in only parity blocks.
if (indices.front() >= DataPartCount) {
return MakeSegment(0, DataPartCount);
}
// Remove unnecessary xor parities.
TPartIndexList result = Difference(0, DataPartCount + ParityPartCount, indices);
for (size_t i = 0; i < 2; ++i) {
if (groupCoverage[i] == 0 && indices.size() <= 3) {
result = Difference(result, DataPartCount + i);
}
}
return result;
}
int GetDataPartCount() const override {
return DataPartCount;
}
int GetParityPartCount() const override {
return ParityPartCount;
}
int GetGuaranteedRepairablePartCount() const override {
return ParityPartCount - 1;
}
int GetWordSize() const override {
return WordSize * sizeof(long);
}
virtual ~TLrcCodecBase() = default;
protected:
// Indices of data blocks and corresponding xor (we have two xor parities).
TConstArrayRef<TPartIndexList> GetXorGroups() const {
return Groups_;
}
virtual std::vector<TBlobType> FallbackToCodecDecode(
const std::vector<TBlobType>& /* blocks */,
TPartIndexList /* erasedIndices */) const = 0;
template <typename T>
void InitializeGeneratorMatrix(T* generatorMatrix, const std::function<T(T)>& GFSquare) {
for (int row = 0; row < ParityPartCount; ++row) {
for (int column = 0; column < DataPartCount; ++column) {
int index = row * DataPartCount + column;
bool isFirstHalf = column < DataPartCount / 2;
if (row == 0) generatorMatrix[index] = isFirstHalf ? 1 : 0;
if (row == 1) generatorMatrix[index] = isFirstHalf ? 0 : 1;
// Let alpha_i be coefficient of first half and beta_i of the second half.
// Then matrix is non-singular iff:
// a) alpha_i, beta_j != 0
// b) alpha_i != beta_j
// c) alpha_i + alpha_k != beta_j + beta_l
// for any i, j, k, l.
if (row == 2) {
int shift = isFirstHalf ? 1 : (1 << (WordSize / 2));
int relativeColumn = isFirstHalf ? column : (column - (DataPartCount / 2));
generatorMatrix[index] = shift * (1 + relativeColumn);
}
// The last row is the square of the row before last.
if (row == 3) {
auto prev = generatorMatrix[index - DataPartCount];
generatorMatrix[index] = GFSquare(prev);
}
}
}
}
private:
bool CalculateCanRepair(const TPartIndexList& erasedIndices) const {
TPartIndexList indices = UniqueSortedIndices(erasedIndices);
if (indices.size() > ParityPartCount) {
return false;
}
if (indices.size() == 1) {
int index = indices.front();
for (size_t i = 0; i < 2; ++i) {
if (Contains(Groups_[i], index)) {
return true;
}
}
}
// If 4 indices miss in one block we cannot recover.
if (indices.size() == ParityPartCount) {
for (size_t i = 0; i < 2; ++i) {
if (Intersection(indices, Groups_[i]).empty()) {
return false;
}
}
}
return true;
}
TPartIndexList Groups_[2];
std::vector<bool> CanRepair_;
};
} // namespace NErasure