From e46bed95ee43ea70afccfa413ea7e9f9e088cc33 Mon Sep 17 00:00:00 2001
From: atarasov5 <atarasov5@yandex-team.com>
Date: Mon, 10 Feb 2025 13:41:59 +0300
Subject: YQL-19535: Provide block implementations for some functions

YQL-19535: Provide block operations

YQL-19535: Specify tests for blocked operations
commit_hash:032aa58fc3f44f0eba3d9b38def021178da949ce
---
 .../common/unicode_base/lib/unicode_base_udf.h     | 269 ++++++++++++++-------
 1 file changed, 177 insertions(+), 92 deletions(-)

(limited to 'yql/essentials/udfs/common/unicode_base/lib/unicode_base_udf.h')

diff --git a/yql/essentials/udfs/common/unicode_base/lib/unicode_base_udf.h b/yql/essentials/udfs/common/unicode_base/lib/unicode_base_udf.h
index 4a852a5a6f..a16582fb4e 100644
--- a/yql/essentials/udfs/common/unicode_base/lib/unicode_base_udf.h
+++ b/yql/essentials/udfs/common/unicode_base/lib/unicode_base_udf.h
@@ -3,6 +3,7 @@
 #include <yql/essentials/public/udf/udf_allocator.h>
 #include <yql/essentials/public/udf/udf_helpers.h>
 #include <yql/essentials/utils/utf8.h>
+#include <yql/essentials/public/udf/arrow/udf_arrow_helpers.h>
 
 #include <library/cpp/string_utils/levenshtein_diff/levenshtein_diff.h>
 #include <library/cpp/unicode/normalization/normalization.h>
@@ -24,6 +25,9 @@ using namespace NUdf;
 using namespace NUnicode;
 
 namespace {
+    inline constexpr bool IsAscii(wchar32 c) noexcept {
+        return ::IsAscii(c);
+    }
 
     template <class It>
     struct TIsUnicodeSpaceAdapter {
@@ -37,51 +41,144 @@ namespace {
         return {};
     }
 
-#define NORMALIZE_UDF_MAP(XX) \
-    XX(Normalize, NFC)        \
-    XX(NormalizeNFD, NFD)     \
-    XX(NormalizeNFC, NFC)     \
-    XX(NormalizeNFKD, NFKD)   \
-    XX(NormalizeNFKC, NFKC)
-
-#define IS_CATEGORY_UDF_MAP(XX) \
-    XX(IsAscii, IsAscii)   \
-    XX(IsSpace, IsSpace)        \
-    XX(IsUpper, IsUpper)        \
-    XX(IsLower, IsLower)        \
-    XX(IsDigit, IsDigit)        \
-    XX(IsAlpha, IsAlpha)        \
-    XX(IsAlnum, IsAlnum)        \
-    XX(IsHex, IsHexdigit)
-
-#define NORMALIZE_UDF(name, mode)                                                 \
-    SIMPLE_UDF(T##name, TUtf8(TAutoMap<TUtf8>)) {                                 \
-        const auto& inputRef = args[0].AsStringRef();                             \
-        const TUtf16String& input = UTF8ToWide(inputRef.Data(), inputRef.Size()); \
-        const TString& output = WideToUTF8(Normalize<mode>(input));               \
-        return valueBuilder->NewString(output);                                   \
-    }
+    struct TNoChangesTag {};
 
-#define IS_CATEGORY_UDF(udfName, function)                                                \
-    SIMPLE_UDF(T##udfName, bool(TAutoMap<TUtf8>)) {                                       \
-        Y_UNUSED(valueBuilder);                                                           \
-        const TStringBuf input(args[0].AsStringRef());                                    \
-        bool result = true;                                                               \
-        wchar32 rune;                                                                     \
-        const unsigned char* cur = reinterpret_cast<const unsigned char*>(input.begin()); \
-        const unsigned char* last = reinterpret_cast<const unsigned char*>(input.end());  \
-        while (cur != last) {                                                             \
-            ReadUTF8CharAndAdvance(rune, cur, last);                                      \
-            if (!function(rune)) {                                                        \
-                result = false;                                                           \
-                break;                                                                    \
-            }                                                                             \
-        }                                                                                 \
-        return TUnboxedValuePod(result);                                                  \
-    }
+    template <typename TDerived>
+    struct TScalarOperationMixin {
+        static TUnboxedValue DoExecute(const IValueBuilder* builder, const TUnboxedValuePod* args) {
+            Y_DEBUG_ABORT_UNLESS(IsUtf8(args[0].AsStringRef()));
+            auto&& executeResult = TDerived::Execute(args[0].AsStringRef());
+            return ProcessResult(builder, std::move(executeResult), args);
+        }
+
+    private:
+        static TUnboxedValue ProcessResult(const IValueBuilder* builder, TString&& newString, const TUnboxedValuePod*) {
+            return builder->NewString(std::move(newString));
+        }
+
+        template <typename T>
+        static TUnboxedValue ProcessResult(const IValueBuilder* builder, std::variant<TNoChangesTag, T> newValue, const TUnboxedValuePod* initialArg) {
+            if (std::holds_alternative<T>(newValue)) {
+                return ProcessResult(builder, std::move(std::get<T>(newValue)), initialArg);
+            } else {
+                return initialArg[0];
+            }
+        }
+
+        static TUnboxedValue ProcessResult(const IValueBuilder* builder, bool result, const TUnboxedValuePod*) {
+            Y_UNUSED(builder);
+            return TUnboxedValuePod(result);
+        }
+    };
+
+    template <typename TDerived>
+    struct TBlockOperationMixin {
+        template <typename Sync>
+        static void DoExecute(const TBlockItem arg, const Sync& sync) {
+            Y_DEBUG_ABORT_UNLESS(IsUtf8(arg.AsStringRef()));
+            auto&& executeResult = TDerived::Execute(arg.AsStringRef());
+            TBlockItem boxedValue = ProcessResult(std::move(executeResult), arg);
+            sync(boxedValue);
+        }
+
+    private:
+        static TBlockItem ProcessResult(const TString& newString, const TBlockItem arg) {
+            Y_UNUSED(arg);
+            return TBlockItem(std::move(newString));
+        }
+
+        template <typename T>
+        static TBlockItem ProcessResult(const std::variant<TNoChangesTag, T>& newValue, const TBlockItem arg) {
+            if (std::holds_alternative<T>(newValue)) {
+                return ProcessResult(std::get<T>(newValue), arg);
+            } else {
+                return arg;
+            }
+        }
+
+        static TBlockItem ProcessResult(bool result, const TBlockItem arg) {
+            Y_UNUSED(arg);
+            return TBlockItem(result);
+        }
+    };
 
-    NORMALIZE_UDF_MAP(NORMALIZE_UDF)
-    IS_CATEGORY_UDF_MAP(IS_CATEGORY_UDF)
+    template <typename TDerived>
+    struct TOperationMixin: public TBlockOperationMixin<TDerived>, public TScalarOperationMixin<TDerived> {
+        using TBlockOperationMixin<TDerived>::DoExecute;
+        using TScalarOperationMixin<TDerived>::DoExecute;
+    };
+
+    template <auto mode>
+    struct TNormalizeUTF8: public TOperationMixin<TNormalizeUTF8<mode>> {
+        static TString Execute(TStringRef arg) {
+            const TUtf16String& input = UTF8ToWide(arg.Data(), arg.Size());
+            return WideToUTF8(Normalize<mode>(input));
+        }
+    };
+
+    template <bool (*Function)(wchar32)>
+    struct TCheckAllChars: public TOperationMixin<TCheckAllChars<Function>> {
+        static bool Execute(TStringRef arg) {
+            const TStringBuf input(arg);
+            wchar32 rune;
+            const unsigned char* cur = reinterpret_cast<const unsigned char*>(input.begin());
+            const unsigned char* last = reinterpret_cast<const unsigned char*>(input.end());
+            while (cur != last) {
+                ReadUTF8CharAndAdvance(rune, cur, last);
+                if (!static_cast<bool (*)(wchar32)>(Function)(rune)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+    };
+
+    template <bool (*Function)(TUtf16String&, size_t pos, size_t count)>
+    struct TStringToStringMapper: public TOperationMixin<TStringToStringMapper<Function>> {
+        static std::variant<TNoChangesTag, TString> Execute(TStringRef arg) {
+            if (auto wide = UTF8ToWide(arg);
+                static_cast<bool (*)(TUtf16String&, size_t pos, size_t count)>(Function)(wide, 0, TUtf16String::npos)) {
+                return WideToUTF8(std::move(wide));
+            } else {
+                return TNoChangesTag{};
+            }
+        }
+    };
+
+#define DEFINE_UTF8_OPERATION(udfName, Executor, signature)                                          \
+    BEGIN_SIMPLE_STRICT_ARROW_UDF(T##udfName, signature) {                                           \
+        return Executor::DoExecute(valueBuilder, args);                                              \
+    }                                                                                                \
+                                                                                                     \
+    struct T##udfName##KernelExec                                                                    \
+        : public TUnaryKernelExec<T##udfName##KernelExec> {                                          \
+        template <typename TSink>                                                                    \
+        static void Process(const IValueBuilder* valueBuilder, TBlockItem arg1, const TSink& sink) { \
+            Y_UNUSED(valueBuilder);                                                                  \
+            Executor::DoExecute(arg1, sink);                                                         \
+        }                                                                                            \
+    };                                                                                               \
+                                                                                                     \
+    END_SIMPLE_ARROW_UDF(T##udfName, T##udfName##KernelExec::Do)
+
+    DEFINE_UTF8_OPERATION(Normalize, TNormalizeUTF8<NFC>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(NormalizeNFD, TNormalizeUTF8<NFD>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(NormalizeNFC, TNormalizeUTF8<NFC>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(NormalizeNFKD, TNormalizeUTF8<NFKD>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(NormalizeNFKC, TNormalizeUTF8<NFKC>, TUtf8(TAutoMap<TUtf8>));
+
+    DEFINE_UTF8_OPERATION(IsAscii, TCheckAllChars<IsAscii>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsSpace, TCheckAllChars<IsSpace>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsUpper, TCheckAllChars<IsUpper>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsLower, TCheckAllChars<IsLower>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsDigit, TCheckAllChars<IsDigit>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsAlpha, TCheckAllChars<IsAlpha>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsAlnum, TCheckAllChars<IsAlnum>, bool(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(IsHex, TCheckAllChars<IsHexdigit>, bool(TAutoMap<TUtf8>));
+
+    DEFINE_UTF8_OPERATION(ToTitle, TStringToStringMapper<ToTitle>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(ToUpper, TStringToStringMapper<ToUpper>, TUtf8(TAutoMap<TUtf8>));
+    DEFINE_UTF8_OPERATION(ToLower, TStringToStringMapper<ToLower>, TUtf8(TAutoMap<TUtf8>));
 
     SIMPLE_UDF(TIsUtf, bool(TOptional<char*>)) {
         Y_UNUSED(valueBuilder);
@@ -461,27 +558,6 @@ namespace {
         return valueBuilder->NewString(WideToUTF8(wide));
     }
 
-    SIMPLE_UDF(TToLower, TUtf8(TAutoMap<TUtf8>)) {
-        if (auto wide = UTF8ToWide(args->AsStringRef()); ToLower(wide))
-            return valueBuilder->NewString(WideToUTF8(wide));
-        else
-            return *args;
-    }
-
-    SIMPLE_UDF(TToUpper, TUtf8(TAutoMap<TUtf8>)) {
-        if (auto wide = UTF8ToWide(args->AsStringRef()); ToUpper(wide))
-            return valueBuilder->NewString(WideToUTF8(wide));
-        else
-            return *args;
-    }
-
-    SIMPLE_UDF(TToTitle, TUtf8(TAutoMap<TUtf8>)) {
-        if (auto wide = UTF8ToWide(args->AsStringRef()); ToTitle(wide))
-            return valueBuilder->NewString(WideToUTF8(wide));
-        else
-            return *args;
-    }
-
     SIMPLE_UDF(TStrip, TUtf8(TAutoMap<TUtf8>)) {
         const TUtf32String input = UTF8ToUTF32<true>(args[0].AsStringRef());
         const auto& result = StripString(input, IsUnicodeSpaceAdapter(input.begin()));
@@ -512,33 +588,42 @@ namespace {
         return TUnboxedValuePod(result);
     }
 
-#define REGISTER_NORMALIZE_UDF(name, mode) T##name,
-#define REGISTER_IS_CATEGORY_UDF(name, function) T##name,
 #define EXPORTED_UNICODE_BASE_UDF \
-    NORMALIZE_UDF_MAP(REGISTER_NORMALIZE_UDF) \
-    IS_CATEGORY_UDF_MAP(REGISTER_IS_CATEGORY_UDF) \
-    TIsUtf, \
-    TGetLength, \
-    TSubstring, \
-    TFind, \
-    TRFind, \
-    TSplitToList, \
-    TJoinFromList, \
-    TLevensteinDistance, \
-    TReplaceAll, \
-    TReplaceFirst, \
-    TReplaceLast, \
-    TRemoveAll, \
-    TRemoveFirst, \
-    TRemoveLast, \
-    TToCodePointList, \
-    TFromCodePointList, \
-    TReverse, \
-    TToLower, \
-    TToUpper, \
-    TToTitle, \
-    TToUint64, \
-    TTryToUint64, \
-    TStrip, \
-    TIsUnicodeSet
+        TIsUtf,                   \
+        TGetLength,               \
+        TSubstring,               \
+        TFind,                    \
+        TRFind,                   \
+        TSplitToList,             \
+        TJoinFromList,            \
+        TLevensteinDistance,      \
+        TReplaceAll,              \
+        TReplaceFirst,            \
+        TReplaceLast,             \
+        TRemoveAll,               \
+        TRemoveFirst,             \
+        TRemoveLast,              \
+        TToCodePointList,         \
+        TFromCodePointList,       \
+        TReverse,                 \
+        TToLower,                 \
+        TToUpper,                 \
+        TToTitle,                 \
+        TToUint64,                \
+        TTryToUint64,             \
+        TStrip,                   \
+        TIsUnicodeSet,            \
+        TNormalize,               \
+        TNormalizeNFD,            \
+        TNormalizeNFC,            \
+        TNormalizeNFKD,           \
+        TNormalizeNFKC,           \
+        TIsAscii,                 \
+        TIsSpace,                 \
+        TIsUpper,                 \
+        TIsLower,                 \
+        TIsDigit,                 \
+        TIsAlpha,                 \
+        TIsAlnum,                 \
+        TIsHex
 }
-- 
cgit v1.2.3