aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/public/purecalc/common/transformations/extract_used_columns.cpp
blob: 9ff7a0df63877b0d9c7d0696216d11fb24e8bd52 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include "extract_used_columns.h"

#include <yql/essentials/public/purecalc/common/inspect_input.h>

#include <yql/essentials/core/yql_expr_optimize.h>
#include <yql/essentials/core/expr_nodes/yql_expr_nodes.h>

using namespace NYql;
using namespace NYql::NPureCalc;

namespace {
    class TUsedColumnsExtractor : public TSyncTransformerBase {
    private:
        TVector<THashSet<TString>>* const Destination_;
        const TVector<THashSet<TString>>& AllColumns_;
        TString NodeName_;

        bool CalculatedUsedFields_ = false;

    public:
        TUsedColumnsExtractor(
            TVector<THashSet<TString>>* destination,
            const TVector<THashSet<TString>>& allColumns,
            TString nodeName
        )
            : Destination_(destination)
            , AllColumns_(allColumns)
            , NodeName_(std::move(nodeName))
        {
        }

        TUsedColumnsExtractor(TVector<THashSet<TString>>*, TVector<THashSet<TString>>&&, TString) = delete;

    public:
        TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final {
            output = input;

            if (CalculatedUsedFields_) {
                return IGraphTransformer::TStatus::Ok;
            }

            bool hasError = false;

            *Destination_ = AllColumns_;

            VisitExpr(input, [&](const TExprNode::TPtr& inputExpr) {
                NNodes::TExprBase node(inputExpr);
                if (auto maybeExtract = node.Maybe<NNodes::TCoExtractMembers>()) {
                    auto extract = maybeExtract.Cast();
                    const auto& arg = extract.Input().Ref();
                    if (arg.IsCallable(NodeName_)) {
                        ui32 inputIndex;
                        if (!TryFetchInputIndexFromSelf(arg, ctx, AllColumns_.size(), inputIndex)) {
                            hasError = true;
                            return false;
                        }

                        YQL_ENSURE(inputIndex < AllColumns_.size());

                        auto& destinationColumnsSet = (*Destination_)[inputIndex];
                        const auto& allColumnsSet = AllColumns_[inputIndex];

                        destinationColumnsSet.clear();
                        for (const auto& columnAtom : extract.Members()) {
                            TString name = TString(columnAtom.Value());
                            YQL_ENSURE(allColumnsSet.contains(name), "unexpected column in the input struct");
                            destinationColumnsSet.insert(name);
                        }
                    }
                }

                return true;
            });

            if (hasError) {
                return IGraphTransformer::TStatus::Error;
            }

            CalculatedUsedFields_ = true;

            return IGraphTransformer::TStatus::Ok;
        }

        void Rewind() final {
            CalculatedUsedFields_ = false;
        }
    };
}

TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeUsedColumnsExtractor(
    TVector<THashSet<TString>>* destination,
    const TVector<THashSet<TString>>& allColumns,
    const TString& nodeName
) {
    return new TUsedColumnsExtractor(destination, allColumns, nodeName);
}