aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/string_utils/base32/base32.cpp
blob: 64730fe7169830f101402b97252c5cee1890e6cd (plain) (tree)








































































































































































                                                                                                                  
#include "base32.h"

#include <util/generic/yexception.h>

#include <algorithm>
#include <array>
#include <limits>

namespace {

    // RFC 4648 Base32 alphabet
    //
    //      A        9    J       18    S       27    3
    // 1    B       10    K       19    T       28    4
    // 2    C       11    L       20    U       29    5
    // 3    D       12    M       21    V       30    6
    // 4    E       13    N       22    W       31    7
    // 5    F       14    O       23    X
    // 6    G       15    P       24    Y
    // 7    H       16    Q       25    Z
    // 8    I       17    R       26    2       pad    =

    constexpr std::string_view BASE32_TABLE = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
                                              "234567";

    constexpr uint8_t BAD = 0xff;

    // clang-format off
    constexpr std::array<uint8_t, 256> BASE32_DECODE_TABLE = {{
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  0x0,  0x1,  0x2,  0x3,  0x4,
        0x5,  0x6,  0x7,  0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,
        0xf,  0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
        0x19, BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
        BAD,  BAD,  BAD,  BAD,  BAD,  BAD,
    }};
    // clang-format on

    char encodeBits(unsigned char sz) {
        static_assert(static_cast<size_t>(std::numeric_limits<decltype(sz)>::max()) < BASE32_DECODE_TABLE.size());
        return BASE32_TABLE[sz];
    }

    uint8_t decodeChar(unsigned char ch, bool isStrict) {
        if (uint8_t val = BASE32_DECODE_TABLE[ch]; val != BAD) {
            return val;
        }

        if (isStrict) {
            ythrow yexception() << "Error during decode symbol from Base32: character is not in Base32 set";
        }
        return 0;
    }

    void shiftBitsFrom(unsigned char& dst, const unsigned char& src, size_t i, size_t n) {
        unsigned char m = ((src << i) & 0xFF) >> (8 - n);
        dst = dst << n;
        dst |= m;
    }

    size_t Base32DecodeImpl(std::string_view src, char* dst, bool isStrict) {
        if (src.empty()) {
            return 0;
        }

        size_t dstSize = 0;
        size_t bitIndex = 0;
        unsigned char byte = 0;

        for (auto c = src.cbegin(); c != src.cend(); ++c) {
            if (*c == '=') {
                Y_ENSURE(
                    !isStrict || std::all_of(c, src.cend(), [](char pad) { return pad == '='; }),
                    "Unexpected character after padding");
                break;
            }

            uint8_t octet = decodeChar(*c, isStrict) << 3;
            byte = byte | (octet >> bitIndex);

            size_t bitsWritten = std::min<size_t>(5, 8 - bitIndex);
            if (bitsWritten < 5 || (bitIndex + bitsWritten) == 8) {
                dst[dstSize++] = byte;
                byte = (octet << bitsWritten);
            }
            bitIndex = (bitIndex + 5) % 8;
        }

        // For example, correct encoding of \x00 is
        // AA====== (\b0000'0000\b00xx'xxxx), not a
        // AAA=     (\b0000'0000\b0000'000x)
        size_t lastOctetBitsCount = (dstSize * 8) % 5;
        size_t expectedBitIndex = (lastOctetBitsCount == 0) ? 0 : 5 - lastOctetBitsCount;
        Y_ENSURE(!isStrict || (byte == 0 && bitIndex == expectedBitIndex), "Invalid Base32 string format");
        return dstSize;
    }

} // namespace

size_t Base32Encode(std::string_view src, char* dst) {
    if (src.size() == 0) {
        return 0;
    }

    size_t dstSize = 0;
    size_t curInd = 0;
    unsigned char c = src[curInd];
    unsigned char bitInd = 0;
    unsigned char n = 0;

    unsigned char ind = 0;
    while (curInd < src.size()) {
        ind = 0;

        unsigned char bitProcessed = 0;
        while (bitProcessed < 5) {
            n = std::min<unsigned char>(8 - bitInd, 5 - bitProcessed);

            shiftBitsFrom(ind, c, bitInd, n);

            bitProcessed += n;

            if (bitInd + n >= 8) {
                ++curInd;
                if (curInd == src.size()) {
                    c = 0;
                } else {
                    c = src[curInd];
                }
            }
            bitInd = (bitInd + n) % 8;
        }
        dst[dstSize++] = encodeBits(ind);
    }

    const size_t paddingSize = ((8 - (dstSize & 7)) & 7);
    for (size_t i = 0; i < paddingSize; ++i) {
        dst[dstSize++] = '=';
    }

    return dstSize;
}

size_t Base32Decode(std::string_view src, char* dst) {
    return Base32DecodeImpl(src, dst, /*isStrict*/ false);
}

size_t Base32StrictDecode(std::string_view src, char* dst) {
    return Base32DecodeImpl(src, dst, /*isStrict*/ true);
}