aboutsummaryrefslogblamecommitdiffstats
path: root/yql/essentials/udfs/common/vector/vector_udf.cpp
blob: e8b01e5a05ccfc5e32ba0c9ba79dadce0184ef4e (plain) (tree)






























































































































































































                                                                                                       
#include <yql/essentials/public/udf/udf_type_ops.h>
#include <yql/essentials/public/udf/udf_helpers.h>

#include <vector>

using namespace NKikimr;
using namespace NUdf;

namespace {

class TVector {
private:
    std::vector<TUnboxedValue, TUnboxedValue::TAllocator> Vector;

public:
    TVector()
        : Vector()
    {}

    TUnboxedValue GetResult(const IValueBuilder* builder) {
        TUnboxedValue* values = nullptr;
        auto list = builder->NewArray(Vector.size(), values);
        std::copy(Vector.begin(), Vector.end(), values);

        return list;
    }

    void Emplace(const ui64 index, const TUnboxedValuePod& value) {
        if (index < Vector.size()) {
            Vector[index] = value;
        } else {
            Vector.push_back(value);
        }
    }

    void Swap(const ui64 a, const ui64 b) {
        if (a < Vector.size() && b < Vector.size()) {
            std::swap(Vector[a], Vector[b]);
        }
    }

    void Reserve(ui64 expectedSize) {
        Vector.reserve(expectedSize);
    }
};

extern const char VectorResourceName[] = "Vector.VectorResource";
class TVectorResource:
    public TBoxedResource<TVector, VectorResourceName>
{
public:
    template <typename... Args>
    inline TVectorResource(Args&&... args)
        : TBoxedResource(std::forward<Args>(args)...)
    {}
};

TVectorResource* GetVectorResource(const TUnboxedValuePod& arg) {
    TVectorResource::Validate(arg);
    return static_cast<TVectorResource*>(arg.AsBoxed().Get());
}

class TVectorCreate: public TBoxedValue {
private:
    TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
        auto resource = new TVectorResource;
        resource->Get()->Reserve(args[0].Get<ui64>());
        return TUnboxedValuePod(resource);
    }
};

class TVectorEmplace: public TBoxedValue {
private:
    TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
        auto resource = GetVectorResource(args[0]);
        resource->Get()->Emplace(args[1].Get<ui64>(), args[2]);
        return TUnboxedValuePod(resource);
    }
};

class TVectorSwap: public TBoxedValue {
private:
    TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
        auto resource = GetVectorResource(args[0]);
        resource->Get()->Swap(args[1].Get<ui64>(), args[2].Get<ui64>());
        return TUnboxedValuePod(resource);
    }
};

class TVectorGetResult: public TBoxedValue {
private:
    TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
        return GetVectorResource(args[0])->Get()->GetResult(valueBuilder);
    }
};

static const auto CreateName = TStringRef::Of("Create");
static const auto EmplaceName = TStringRef::Of("Emplace");
static const auto SwapName = TStringRef::Of("Swap");
static const auto GetResultName = TStringRef::Of("GetResult");

class TVectorModule: public IUdfModule {
public:
    TStringRef Name() const {
        return TStringRef::Of("Vector");
    }

    void CleanupOnTerminate() const final {
    }

    void GetAllFunctions(IFunctionsSink& sink) const final {
        sink.Add(CreateName)->SetTypeAwareness();
        sink.Add(EmplaceName)->SetTypeAwareness();
        sink.Add(SwapName)->SetTypeAwareness();
        sink.Add(GetResultName)->SetTypeAwareness();
    }

    void BuildFunctionTypeInfo(
        const TStringRef& name,
        TType* userType,
        const TStringRef& typeConfig,
        ui32 flags,
        IFunctionTypeInfoBuilder& builder) const final
    {
        Y_UNUSED(typeConfig);

        try {
            const bool typesOnly = (flags & TFlags::TypesOnly);
            builder.UserType(userType);

            auto typeHelper = builder.TypeInfoHelper();

            auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
            if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) {
                builder.SetError("User type is not a 3-tuple");
                return;
            }

            auto valueType = userTypeInspector.GetElementType(2);
            TType* vectorType = builder.Resource(VectorResourceName);

            if (name == CreateName) {
                builder.IsStrict();

                builder.Args()->Add<ui64>().Done().Returns(vectorType);

                if (!typesOnly) {
                    builder.Implementation(new TVectorCreate);
                }
            }

            if (name == EmplaceName) {
                builder.IsStrict();

                builder.Args()->Add(vectorType).Add<ui64>().Add(valueType).Done().Returns(vectorType);

                if (!typesOnly) {
                    builder.Implementation(new TVectorEmplace);
                }
            }

            if (name == SwapName) {
                builder.IsStrict();

                builder.Args()->Add(vectorType).Add<ui64>().Add<ui64>().Done().Returns(vectorType);

                if (!typesOnly) {
                    builder.Implementation(new TVectorSwap);
                }
            }

            if (name == GetResultName) {
                auto resultType = builder.List()->Item(valueType).Build();

                builder.IsStrict();

                builder.Args()->Add(vectorType).Done().Returns(resultType);

                if (!typesOnly) {
                    builder.Implementation(new TVectorGetResult);
                }
            }

        } catch (const std::exception& e) {
            builder.SetError(CurrentExceptionMessage());
        }
    }
};

} // namespace

REGISTER_MODULES(TVectorModule)