summaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Functions/LeftRight.h
blob: 2ea41b5525234808c87fe74ab14fac851d7b1d04 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#pragma once

#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnConst.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/GatherUtils/GatherUtils.h>
#include <Functions/GatherUtils/Sources.h>
#include <Functions/GatherUtils/Sinks.h>
#include <Functions/GatherUtils/Slices.h>
#include <Functions/GatherUtils/Algorithms.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context_fwd.h>


namespace DB
{

using namespace GatherUtils;

namespace ErrorCodes
{
    extern const int ILLEGAL_COLUMN;
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

enum class SubstringDirection
{
    Left,
    Right
};

template <bool is_utf8, SubstringDirection direction>
class FunctionLeftRight : public IFunction
{
public:
    static constexpr auto name = direction == SubstringDirection::Left
        ? (is_utf8 ? "leftUTF8" : "left")
        : (is_utf8 ? "rightUTF8" : "right");

    static FunctionPtr create(ContextPtr)
    {
        return std::make_shared<FunctionLeftRight>();
    }

    String getName() const override
    {
        return name;
    }

    bool isVariadic() const override { return false; }
    size_t getNumberOfArguments() const override { return 2; }

    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
    bool useDefaultImplementationForConstants() const override { return true; }

    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
    {
        if ((is_utf8 && !isString(arguments[0])) || !isStringOrFixedString(arguments[0]))
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName());

        if (!isNativeNumber(arguments[1]))
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of second argument of function {}",
                    arguments[1]->getName(), getName());

        return std::make_shared<DataTypeString>();
    }

    template <typename Source>
    ColumnPtr executeForSource(const ColumnPtr & column_length,
                          const ColumnConst * column_length_const,
                          Int64 length_value, Source && source,
                          size_t input_rows_count) const
    {
        auto col_res = ColumnString::create();

        if constexpr (direction == SubstringDirection::Left)
        {
            if (column_length_const)
                sliceFromLeftConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), 0, length_value);
            else
                sliceFromLeftDynamicLength(source, StringSink(*col_res, input_rows_count), *column_length);
        }
        else
        {
            if (column_length_const)
                sliceFromRightConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), length_value);
            else
                sliceFromRightDynamicLength(source, StringSink(*col_res, input_rows_count), *column_length);
        }

        return col_res;
    }

    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
    {
        ColumnPtr column_string = arguments[0].column;
        ColumnPtr column_length = arguments[1].column;

        const ColumnConst * column_length_const = checkAndGetColumn<ColumnConst>(column_length.get());

        Int64 length_value = 0;

        if (column_length_const)
            length_value = column_length_const->getInt(0);

        if constexpr (is_utf8)
        {
            if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, UTF8StringSource(*col), input_rows_count);
            else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, ConstSource<UTF8StringSource>(*col_const), input_rows_count);
            else
                throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}",
                    arguments[0].column->getName(), getName());
        }
        else
        {
            if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, StringSource(*col), input_rows_count);
            else if (const ColumnFixedString * col_fixed = checkAndGetColumn<ColumnFixedString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, FixedStringSource(*col_fixed), input_rows_count);
            else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, ConstSource<StringSource>(*col_const), input_rows_count);
            else if (const ColumnConst * col_const_fixed = checkAndGetColumnConst<ColumnFixedString>(column_string.get()))
                return executeForSource(column_length, column_length_const,
                    length_value, ConstSource<FixedStringSource>(*col_const_fixed), input_rows_count);
            else
                throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}",
                    arguments[0].column->getName(), getName());
        }
    }
};

}