aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Storages/MergeTree/MergeTreeIndexUSearch.h
blob: 586393aa2485c0c6ce40cec70ece0d6e6bf41f46 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once

#ifdef ENABLE_USEARCH

#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#error #include <usearch/index_dense.hpp>
#pragma clang diagnostic pop

namespace DB
{

template <unum::usearch::metric_kind_t Metric>
class USearchIndexWithSerialization : public unum::usearch::index_dense_t
{
    using Base = unum::usearch::index_dense_t;

public:
    explicit USearchIndexWithSerialization(size_t dimensions);
    void serialize(WriteBuffer & ostr) const;
    void deserialize(ReadBuffer & istr);
    size_t getDimensions() const;
};

template <unum::usearch::metric_kind_t Metric>
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization<Metric>>;


template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
{
    MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_);
    MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, USearchIndexWithSerializationPtr<Metric> index_);

    ~MergeTreeIndexGranuleUSearch() override = default;

    void serializeBinary(WriteBuffer & ostr) const override;
    void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;

    bool empty() const override { return !index.get(); }

    const String index_name;
    const Block index_sample_block;
    USearchIndexWithSerializationPtr<Metric> index;
};


template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
{
    MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block);
    ~MergeTreeIndexAggregatorUSearch() override = default;

    bool empty() const override { return !index || index->size() == 0; }
    MergeTreeIndexGranulePtr getGranuleAndReset() override;
    void update(const Block & block, size_t * pos, size_t limit) override;

    const String index_name;
    const Block index_sample_block;
    USearchIndexWithSerializationPtr<Metric> index;
};


class MergeTreeIndexConditionUSearch final : public IMergeTreeIndexConditionApproximateNearestNeighbor
{
public:
    MergeTreeIndexConditionUSearch(
        const IndexDescription & index_description,
        const SelectQueryInfo & query,
        const String & distance_function,
        ContextPtr context);

    ~MergeTreeIndexConditionUSearch() override = default;

    bool alwaysUnknownOrTrue() const override;
    bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
    std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;

private:
    template <unum::usearch::metric_kind_t Metric>
    std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;

    const ApproximateNearestNeighborCondition ann_condition;
    const String distance_function;
};


class MergeTreeIndexUSearch : public IMergeTreeIndex
{
public:
    MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_);

    ~MergeTreeIndexUSearch() override = default;

    MergeTreeIndexGranulePtr createIndexGranule() const override;
    MergeTreeIndexAggregatorPtr createIndexAggregator() const override;
    MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const override;

    bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return false; }

private:
    const String distance_function;
};

}


#endif