aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/containers/bitseq/bitvector.h
blob: 9d6471ea9a9db9f55f99c6dea24f834d71bbe454 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#pragma once

#include "traits.h"

#include <library/cpp/pop_count/popcount.h>

#include <util/generic/vector.h>
#include <util/ysaveload.h>

template <typename T>
class TReadonlyBitVector;

template <typename T>
class TBitVector {
public:
    using TWord = T;
    using TTraits = TBitSeqTraits<TWord>;

private:
    friend class TReadonlyBitVector<T>;
    ui64 Size_;
    TVector<TWord> Data_;

public:
    TBitVector()
        : Size_(0)
        , Data_(0)
    {
    }

    TBitVector(ui64 size)
        : Size_(size)
        , Data_(static_cast<size_t>((Size_ + TTraits::ModMask) >> TTraits::DivShift), 0)
    {
    }

    virtual ~TBitVector() = default;

    void Clear() {
        Size_ = 0;
        Data_.clear();
    }

    void Resize(ui64 size) {
        Size_ = size;
        Data_.resize((Size_ + TTraits::ModMask) >> TTraits::DivShift);
    }

    void Swap(TBitVector& other) {
        DoSwap(Size_, other.Size_);
        DoSwap(Data_, other.Data_);
    }

    bool Set(ui64 pos) {
        Y_ASSERT(pos < Size_); 
        TWord& val = Data_[pos >> TTraits::DivShift];
        if (val & TTraits::BitMask(pos & TTraits::ModMask))
            return false;
        val |= TTraits::BitMask(pos & TTraits::ModMask);
        return true;
    }

    bool Test(ui64 pos) const {
        return TTraits::Test(Data(), pos, Size_);
    }

    void Reset(ui64 pos) {
        Y_ASSERT(pos < Size_); 
        Data_[pos >> TTraits::DivShift] &= ~TTraits::BitMask(pos & TTraits::ModMask);
    }

    TWord Get(ui64 pos, ui8 width, TWord mask) const {
        return TTraits::Get(Data(), pos, width, mask, Size_);
    }

    TWord Get(ui64 pos, ui8 width) const {
        return Get(pos, width, TTraits::ElemMask(width));
    }

    void Set(ui64 pos, TWord value, ui8 width, TWord mask) {
        if (!width)
            return;
        Y_ASSERT((pos + width) <= Size_); 
        size_t word = pos >> TTraits::DivShift;
        TWord shift1 = pos & TTraits::ModMask;
        TWord shift2 = TTraits::NumBits - shift1;
        Data_[word] &= ~(mask << shift1);
        Data_[word] |= (value & mask) << shift1;
        if (shift2 < width) {
            Data_[word + 1] &= ~(mask >> shift2);
            Data_[word + 1] |= (value & mask) >> shift2;
        }
    }

    void Set(ui64 pos, TWord value, ui8 width) {
        Set(pos, value, width, TTraits::ElemMask(width));
    }

    void Append(TWord value, ui8 width, TWord mask) {
        if (!width)
            return;
        if (Data_.size() * TTraits::NumBits < Size_ + width) {
            Data_.push_back(0);
        }
        Size_ += width;
        Set(Size_ - width, value, width, mask);
    }

    void Append(TWord value, ui8 width) {
        Append(value, width, TTraits::ElemMask(width));
    }

    size_t Count() const {
        size_t count = 0;
        for (size_t i = 0; i < Data_.size(); ++i) {
            count += (size_t)PopCount(Data_[i]);
        }
        return count;
    }

    ui64 Size() const {
        return Size_;
    }

    size_t Words() const {
        return Data_.size();
    }

    const TWord* Data() const {
        return Data_.data();
    }

    void Save(IOutputStream* out) const { 
        ::Save(out, Size_);
        ::Save(out, Data_);
    }

    void Load(IInputStream* inp) { 
        ::Load(inp, Size_);
        ::Load(inp, Data_);
    }

    ui64 Space() const {
        return CHAR_BIT * (sizeof(Size_) +
                           Data_.size() * sizeof(TWord));
    }

    void Print(IOutputStream& out, size_t truncate = 128) { 
        for (size_t i = 0; i < Data_.size() && i < truncate; ++i) {
            for (int j = TTraits::NumBits - 1; j >= 0; --j) {
                size_t pos = TTraits::NumBits * i + j;
                out << (pos < Size_ && Test(pos) ? '1' : '0');
            }
            out << " ";
        }
        out << Endl;
    }
};