diff options
author | ziganshinmr <ziganshinmr@yandex-team.com> | 2024-11-21 21:45:01 +0300 |
---|---|---|
committer | ziganshinmr <ziganshinmr@yandex-team.com> | 2024-11-21 21:57:57 +0300 |
commit | c320ff3884640f83278ad36e5feeed263b523bd4 (patch) | |
tree | e2377204a3b9b060188178a1de02641b0da04aa8 /yql/essentials/udfs/common/vector/vector_udf.cpp | |
parent | 00bc077e8f2272cd0206de2bca64c53300982883 (diff) | |
download | ydb-c320ff3884640f83278ad36e5feeed263b523bd4.tar.gz |
ListSample/ListSampleN/ListShuffle implementation
commit_hash:987b10b398caa89eee8b94b33f9ea1dc74197223
Diffstat (limited to 'yql/essentials/udfs/common/vector/vector_udf.cpp')
-rw-r--r-- | yql/essentials/udfs/common/vector/vector_udf.cpp | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/yql/essentials/udfs/common/vector/vector_udf.cpp b/yql/essentials/udfs/common/vector/vector_udf.cpp new file mode 100644 index 0000000000..e8b01e5a05 --- /dev/null +++ b/yql/essentials/udfs/common/vector/vector_udf.cpp @@ -0,0 +1,192 @@ +#include <yql/essentials/public/udf/udf_type_ops.h> +#include <yql/essentials/public/udf/udf_helpers.h> + +#include <vector> + +using namespace NKikimr; +using namespace NUdf; + +namespace { + +class TVector { +private: + std::vector<TUnboxedValue, TUnboxedValue::TAllocator> Vector; + +public: + TVector() + : Vector() + {} + + TUnboxedValue GetResult(const IValueBuilder* builder) { + TUnboxedValue* values = nullptr; + auto list = builder->NewArray(Vector.size(), values); + std::copy(Vector.begin(), Vector.end(), values); + + return list; + } + + void Emplace(const ui64 index, const TUnboxedValuePod& value) { + if (index < Vector.size()) { + Vector[index] = value; + } else { + Vector.push_back(value); + } + } + + void Swap(const ui64 a, const ui64 b) { + if (a < Vector.size() && b < Vector.size()) { + std::swap(Vector[a], Vector[b]); + } + } + + void Reserve(ui64 expectedSize) { + Vector.reserve(expectedSize); + } +}; + +extern const char VectorResourceName[] = "Vector.VectorResource"; +class TVectorResource: + public TBoxedResource<TVector, VectorResourceName> +{ +public: + template <typename... Args> + inline TVectorResource(Args&&... args) + : TBoxedResource(std::forward<Args>(args)...) + {} +}; + +TVectorResource* GetVectorResource(const TUnboxedValuePod& arg) { + TVectorResource::Validate(arg); + return static_cast<TVectorResource*>(arg.AsBoxed().Get()); +} + +class TVectorCreate: public TBoxedValue { +private: + TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override { + auto resource = new TVectorResource; + resource->Get()->Reserve(args[0].Get<ui64>()); + return TUnboxedValuePod(resource); + } +}; + +class TVectorEmplace: public TBoxedValue { +private: + TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override { + auto resource = GetVectorResource(args[0]); + resource->Get()->Emplace(args[1].Get<ui64>(), args[2]); + return TUnboxedValuePod(resource); + } +}; + +class TVectorSwap: public TBoxedValue { +private: + TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override { + auto resource = GetVectorResource(args[0]); + resource->Get()->Swap(args[1].Get<ui64>(), args[2].Get<ui64>()); + return TUnboxedValuePod(resource); + } +}; + +class TVectorGetResult: public TBoxedValue { +private: + TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override { + return GetVectorResource(args[0])->Get()->GetResult(valueBuilder); + } +}; + +static const auto CreateName = TStringRef::Of("Create"); +static const auto EmplaceName = TStringRef::Of("Emplace"); +static const auto SwapName = TStringRef::Of("Swap"); +static const auto GetResultName = TStringRef::Of("GetResult"); + +class TVectorModule: public IUdfModule { +public: + TStringRef Name() const { + return TStringRef::Of("Vector"); + } + + void CleanupOnTerminate() const final { + } + + void GetAllFunctions(IFunctionsSink& sink) const final { + sink.Add(CreateName)->SetTypeAwareness(); + sink.Add(EmplaceName)->SetTypeAwareness(); + sink.Add(SwapName)->SetTypeAwareness(); + sink.Add(GetResultName)->SetTypeAwareness(); + } + + void BuildFunctionTypeInfo( + const TStringRef& name, + TType* userType, + const TStringRef& typeConfig, + ui32 flags, + IFunctionTypeInfoBuilder& builder) const final + { + Y_UNUSED(typeConfig); + + try { + const bool typesOnly = (flags & TFlags::TypesOnly); + builder.UserType(userType); + + auto typeHelper = builder.TypeInfoHelper(); + + auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType); + if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) { + builder.SetError("User type is not a 3-tuple"); + return; + } + + auto valueType = userTypeInspector.GetElementType(2); + TType* vectorType = builder.Resource(VectorResourceName); + + if (name == CreateName) { + builder.IsStrict(); + + builder.Args()->Add<ui64>().Done().Returns(vectorType); + + if (!typesOnly) { + builder.Implementation(new TVectorCreate); + } + } + + if (name == EmplaceName) { + builder.IsStrict(); + + builder.Args()->Add(vectorType).Add<ui64>().Add(valueType).Done().Returns(vectorType); + + if (!typesOnly) { + builder.Implementation(new TVectorEmplace); + } + } + + if (name == SwapName) { + builder.IsStrict(); + + builder.Args()->Add(vectorType).Add<ui64>().Add<ui64>().Done().Returns(vectorType); + + if (!typesOnly) { + builder.Implementation(new TVectorSwap); + } + } + + if (name == GetResultName) { + auto resultType = builder.List()->Item(valueType).Build(); + + builder.IsStrict(); + + builder.Args()->Add(vectorType).Done().Returns(resultType); + + if (!typesOnly) { + builder.Implementation(new TVectorGetResult); + } + } + + } catch (const std::exception& e) { + builder.SetError(CurrentExceptionMessage()); + } + } +}; + +} // namespace + +REGISTER_MODULES(TVectorModule) |