aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Interpreters/RemoveInjectiveFunctionsVisitor.cpp
blob: 8d0303799095ab1271f802599233918512261e93 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <Common/typeid_cast.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Interpreters/RemoveInjectiveFunctionsVisitor.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Functions/FunctionFactory.h>

namespace DB
{

static bool isUniq(const ASTFunction & func)
{
    return func.name == "uniq" || func.name == "uniqExact" || func.name == "uniqHLL12"
        || func.name == "uniqCombined" || func.name == "uniqCombined64"
        || func.name == "uniqTheta";
}

/// Remove injective functions of one argument: replace with a child
static bool removeInjectiveFunction(ASTPtr & ast, ContextPtr context, const FunctionFactory & function_factory)
{
    const ASTFunction * func = ast->as<ASTFunction>();
    if (!func)
        return false;

    if (!func->arguments || func->arguments->children.size() != 1)
        return false;

    if (!function_factory.get(func->name, context)->isInjective({}))
        return false;

    ast = func->arguments->children[0];
    return true;
}

void RemoveInjectiveFunctionsMatcher::visit(ASTPtr & ast, const Data & data)
{
    if (auto * func = ast->as<ASTFunction>())
        visit(*func, ast, data);
}

void RemoveInjectiveFunctionsMatcher::visit(ASTFunction & func, ASTPtr &, const Data & data)
{
    if (isUniq(func))
    {
        const FunctionFactory & function_factory = FunctionFactory::instance();

        for (auto & arg : func.arguments->children)
        {
            while (removeInjectiveFunction(arg, data.getContext(), function_factory))
                ;
        }
    }
}

bool RemoveInjectiveFunctionsMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
{
    if (node->as<ASTSubquery>() ||
        node->as<ASTTableExpression>())
        return false; // NOLINT
    return true;
}

}