diff options
| author | vitalyisaev <[email protected]> | 2023-11-14 09:58:56 +0300 |
|---|---|---|
| committer | vitalyisaev <[email protected]> | 2023-11-14 10:20:20 +0300 |
| commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
| tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Functions/FunctionsTextClassification.h | |
| parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Functions/FunctionsTextClassification.h')
| -rw-r--r-- | contrib/clickhouse/src/Functions/FunctionsTextClassification.h | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Functions/FunctionsTextClassification.h b/contrib/clickhouse/src/Functions/FunctionsTextClassification.h new file mode 100644 index 00000000000..8e0f236366d --- /dev/null +++ b/contrib/clickhouse/src/Functions/FunctionsTextClassification.h @@ -0,0 +1,124 @@ +#pragma once + +#include <Columns/ColumnString.h> +#include <Columns/ColumnVector.h> +#include <DataTypes/DataTypesNumber.h> +#include <Functions/FunctionHelpers.h> +#include <Functions/IFunction.h> +#include <Interpreters/Context_fwd.h> +#include <Functions/FunctionFactory.h> +#include <Interpreters/Context.h> + +namespace DB +{ +/// Functions for text classification with different result types + +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int ILLEGAL_COLUMN; +extern const int SUPPORT_IS_DISABLED; +} + +template <typename Impl, typename Name> +class FunctionTextClassificationString : public IFunction +{ +public: + static constexpr auto name = Name::name; + + static FunctionPtr create(ContextPtr context) + { + if (!context->getSettingsRef().allow_experimental_nlp_functions) + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, + "Natural language processing function '{}' is experimental. " + "Set `allow_experimental_nlp_functions` setting to enable it", name); + + return std::make_shared<FunctionTextClassificationString>(); + } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isString(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument of function {}. Must be String.", + arguments[0]->getName(), getName()); + + return arguments[0]; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override + { + const ColumnPtr & column = arguments[0].column; + const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()); + + if (!col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", + arguments[0].column->getName(), getName()); + + auto col_res = ColumnString::create(); + Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets()); + return col_res; + } +}; + +template <typename Impl, typename Name> +class FunctionTextClassificationFloat : public IFunction +{ +public: + static constexpr auto name = Name::name; + + static FunctionPtr create(ContextPtr context) + { + if (!context->getSettingsRef().allow_experimental_nlp_functions) + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, + "Natural language processing function '{}' is experimental. " + "Set `allow_experimental_nlp_functions` setting to enable it", name); + + return std::make_shared<FunctionTextClassificationFloat>(); + } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isString(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument of function {}. Must be String.", + arguments[0]->getName(), getName()); + + return std::make_shared<DataTypeFloat32>(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override + { + const ColumnPtr & column = arguments[0].column; + const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()); + + if (!col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", + arguments[0].column->getName(), getName()); + + auto col_res = ColumnVector<Float32>::create(); + ColumnVector<Float32>::Container & vec_res = col_res->getData(); + vec_res.resize(col->size()); + + Impl::vector(col->getChars(), col->getOffsets(), vec_res); + return col_res; + } +}; + +} |
