aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/udfs/common/vector/vector_udf.cpp
diff options
context:
space:
mode:
authorziganshinmr <ziganshinmr@yandex-team.com>2024-11-21 21:45:01 +0300
committerziganshinmr <ziganshinmr@yandex-team.com>2024-11-21 21:57:57 +0300
commitc320ff3884640f83278ad36e5feeed263b523bd4 (patch)
treee2377204a3b9b060188178a1de02641b0da04aa8 /yql/essentials/udfs/common/vector/vector_udf.cpp
parent00bc077e8f2272cd0206de2bca64c53300982883 (diff)
downloadydb-c320ff3884640f83278ad36e5feeed263b523bd4.tar.gz
ListSample/ListSampleN/ListShuffle implementation
commit_hash:987b10b398caa89eee8b94b33f9ea1dc74197223
Diffstat (limited to 'yql/essentials/udfs/common/vector/vector_udf.cpp')
-rw-r--r--yql/essentials/udfs/common/vector/vector_udf.cpp192
1 files changed, 192 insertions, 0 deletions
diff --git a/yql/essentials/udfs/common/vector/vector_udf.cpp b/yql/essentials/udfs/common/vector/vector_udf.cpp
new file mode 100644
index 0000000000..e8b01e5a05
--- /dev/null
+++ b/yql/essentials/udfs/common/vector/vector_udf.cpp
@@ -0,0 +1,192 @@
+#include <yql/essentials/public/udf/udf_type_ops.h>
+#include <yql/essentials/public/udf/udf_helpers.h>
+
+#include <vector>
+
+using namespace NKikimr;
+using namespace NUdf;
+
+namespace {
+
+class TVector {
+private:
+ std::vector<TUnboxedValue, TUnboxedValue::TAllocator> Vector;
+
+public:
+ TVector()
+ : Vector()
+ {}
+
+ TUnboxedValue GetResult(const IValueBuilder* builder) {
+ TUnboxedValue* values = nullptr;
+ auto list = builder->NewArray(Vector.size(), values);
+ std::copy(Vector.begin(), Vector.end(), values);
+
+ return list;
+ }
+
+ void Emplace(const ui64 index, const TUnboxedValuePod& value) {
+ if (index < Vector.size()) {
+ Vector[index] = value;
+ } else {
+ Vector.push_back(value);
+ }
+ }
+
+ void Swap(const ui64 a, const ui64 b) {
+ if (a < Vector.size() && b < Vector.size()) {
+ std::swap(Vector[a], Vector[b]);
+ }
+ }
+
+ void Reserve(ui64 expectedSize) {
+ Vector.reserve(expectedSize);
+ }
+};
+
+extern const char VectorResourceName[] = "Vector.VectorResource";
+class TVectorResource:
+ public TBoxedResource<TVector, VectorResourceName>
+{
+public:
+ template <typename... Args>
+ inline TVectorResource(Args&&... args)
+ : TBoxedResource(std::forward<Args>(args)...)
+ {}
+};
+
+TVectorResource* GetVectorResource(const TUnboxedValuePod& arg) {
+ TVectorResource::Validate(arg);
+ return static_cast<TVectorResource*>(arg.AsBoxed().Get());
+}
+
+class TVectorCreate: public TBoxedValue {
+private:
+ TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
+ auto resource = new TVectorResource;
+ resource->Get()->Reserve(args[0].Get<ui64>());
+ return TUnboxedValuePod(resource);
+ }
+};
+
+class TVectorEmplace: public TBoxedValue {
+private:
+ TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
+ auto resource = GetVectorResource(args[0]);
+ resource->Get()->Emplace(args[1].Get<ui64>(), args[2]);
+ return TUnboxedValuePod(resource);
+ }
+};
+
+class TVectorSwap: public TBoxedValue {
+private:
+ TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
+ auto resource = GetVectorResource(args[0]);
+ resource->Get()->Swap(args[1].Get<ui64>(), args[2].Get<ui64>());
+ return TUnboxedValuePod(resource);
+ }
+};
+
+class TVectorGetResult: public TBoxedValue {
+private:
+ TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
+ return GetVectorResource(args[0])->Get()->GetResult(valueBuilder);
+ }
+};
+
+static const auto CreateName = TStringRef::Of("Create");
+static const auto EmplaceName = TStringRef::Of("Emplace");
+static const auto SwapName = TStringRef::Of("Swap");
+static const auto GetResultName = TStringRef::Of("GetResult");
+
+class TVectorModule: public IUdfModule {
+public:
+ TStringRef Name() const {
+ return TStringRef::Of("Vector");
+ }
+
+ void CleanupOnTerminate() const final {
+ }
+
+ void GetAllFunctions(IFunctionsSink& sink) const final {
+ sink.Add(CreateName)->SetTypeAwareness();
+ sink.Add(EmplaceName)->SetTypeAwareness();
+ sink.Add(SwapName)->SetTypeAwareness();
+ sink.Add(GetResultName)->SetTypeAwareness();
+ }
+
+ void BuildFunctionTypeInfo(
+ const TStringRef& name,
+ TType* userType,
+ const TStringRef& typeConfig,
+ ui32 flags,
+ IFunctionTypeInfoBuilder& builder) const final
+ {
+ Y_UNUSED(typeConfig);
+
+ try {
+ const bool typesOnly = (flags & TFlags::TypesOnly);
+ builder.UserType(userType);
+
+ auto typeHelper = builder.TypeInfoHelper();
+
+ auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
+ if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) {
+ builder.SetError("User type is not a 3-tuple");
+ return;
+ }
+
+ auto valueType = userTypeInspector.GetElementType(2);
+ TType* vectorType = builder.Resource(VectorResourceName);
+
+ if (name == CreateName) {
+ builder.IsStrict();
+
+ builder.Args()->Add<ui64>().Done().Returns(vectorType);
+
+ if (!typesOnly) {
+ builder.Implementation(new TVectorCreate);
+ }
+ }
+
+ if (name == EmplaceName) {
+ builder.IsStrict();
+
+ builder.Args()->Add(vectorType).Add<ui64>().Add(valueType).Done().Returns(vectorType);
+
+ if (!typesOnly) {
+ builder.Implementation(new TVectorEmplace);
+ }
+ }
+
+ if (name == SwapName) {
+ builder.IsStrict();
+
+ builder.Args()->Add(vectorType).Add<ui64>().Add<ui64>().Done().Returns(vectorType);
+
+ if (!typesOnly) {
+ builder.Implementation(new TVectorSwap);
+ }
+ }
+
+ if (name == GetResultName) {
+ auto resultType = builder.List()->Item(valueType).Build();
+
+ builder.IsStrict();
+
+ builder.Args()->Add(vectorType).Done().Returns(resultType);
+
+ if (!typesOnly) {
+ builder.Implementation(new TVectorGetResult);
+ }
+ }
+
+ } catch (const std::exception& e) {
+ builder.SetError(CurrentExceptionMessage());
+ }
+ }
+};
+
+} // namespace
+
+REGISTER_MODULES(TVectorModule)