aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Storages/MergeTree/MergeTreeIndexAnnoy.h
blob: 9395388fd94e5757a94b5194485368b7cc819cd2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once

#ifdef ENABLE_ANNOY

#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>

#error #include <annoylib.h>
#error #include <kissrandom.h>

namespace DB
{

template <typename Distance>
class AnnoyIndexWithSerialization : public Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>
{
    using Base = Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>;

public:
    explicit AnnoyIndexWithSerialization(size_t dimensions);
    void serialize(WriteBuffer & ostr) const;
    void deserialize(ReadBuffer & istr);
    size_t getDimensions() const;
};

template <typename Distance>
using AnnoyIndexWithSerializationPtr = std::shared_ptr<AnnoyIndexWithSerialization<Distance>>;


template <typename Distance>
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
{
    MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
    MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_, AnnoyIndexWithSerializationPtr<Distance> index_);

    ~MergeTreeIndexGranuleAnnoy() override = default;

    void serializeBinary(WriteBuffer & ostr) const override;
    void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;

    bool empty() const override { return !index.get(); }

    const String index_name;
    const Block index_sample_block;
    AnnoyIndexWithSerializationPtr<Distance> index;
};


template <typename Distance>
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
{
    MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, UInt64 trees);
    ~MergeTreeIndexAggregatorAnnoy() override = default;

    bool empty() const override { return !index || index->get_n_items() == 0; }
    MergeTreeIndexGranulePtr getGranuleAndReset() override;
    void update(const Block & block, size_t * pos, size_t limit) override;

    const String index_name;
    const Block index_sample_block;
    const UInt64 trees;
    AnnoyIndexWithSerializationPtr<Distance> index;
};


class MergeTreeIndexConditionAnnoy final : public IMergeTreeIndexConditionApproximateNearestNeighbor
{
public:
    MergeTreeIndexConditionAnnoy(
        const IndexDescription & index_description,
        const SelectQueryInfo & query,
        const String & distance_function,
        ContextPtr context);

    ~MergeTreeIndexConditionAnnoy() override = default;

    bool alwaysUnknownOrTrue() const override;
    bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
    std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;

private:
    template <typename Distance>
    std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;

    const ApproximateNearestNeighborCondition ann_condition;
    const String distance_function;
    const Int64 search_k;
};


class MergeTreeIndexAnnoy : public IMergeTreeIndex
{
public:

    MergeTreeIndexAnnoy(const IndexDescription & index_, UInt64 trees_, const String & distance_function_);

    ~MergeTreeIndexAnnoy() override = default;

    MergeTreeIndexGranulePtr createIndexGranule() const override;
    MergeTreeIndexAggregatorPtr createIndexAggregator() const override;
    MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const override;

    bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return false; }

private:
    const UInt64 trees;
    const String distance_function;
};

}

#endif