summaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp
diff options
context:
space:
mode:
authorvitalyisaev <[email protected]>2023-11-14 09:58:56 +0300
committervitalyisaev <[email protected]>2023-11-14 10:20:20 +0300
commitc2b2dfd9827a400a8495e172a56343462e3ceb82 (patch)
treecd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp
parentd4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff)
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp')
-rw-r--r--contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp141
1 files changed, 141 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp b/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp
new file mode 100644
index 00000000000..7bb19b13ca0
--- /dev/null
+++ b/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp
@@ -0,0 +1,141 @@
+#include <AggregateFunctions/AggregateFunctionFactory.h>
+#include <AggregateFunctions/AggregateFunctionSequenceNextNode.h>
+#include <AggregateFunctions/Helpers.h>
+#include <AggregateFunctions/FactoryHelpers.h>
+#include <Core/Settings.h>
+#include <DataTypes/DataTypeDate.h>
+#include <DataTypes/DataTypeDateTime.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <Interpreters/Context.h>
+#include <Common/CurrentThread.h>
+#include <base/range.h>
+
+
+namespace DB
+{
+
+constexpr size_t max_events_size = 64;
+
+constexpr size_t min_required_args = 3;
+
+namespace ErrorCodes
+{
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+ extern const int BAD_ARGUMENTS;
+ extern const int UNKNOWN_AGGREGATE_FUNCTION;
+}
+
+namespace
+{
+
+template <typename T>
+inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
+ const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base)
+{
+ return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>(
+ data_type, argument_types, parameters, base, direction, min_required_args);
+}
+
+AggregateFunctionPtr
+createAggregateFunctionSequenceNode(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
+{
+ if (settings == nullptr || !settings->allow_experimental_funnel_functions)
+ {
+ throw Exception(ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION, "Aggregate function {} is experimental. "
+ "Set `allow_experimental_funnel_functions` setting to enable it", name);
+ }
+
+ if (parameters.size() < 2)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
+ "Aggregate function '{}' requires 2 parameters (direction, head)", name);
+ auto expected_param_type = Field::Types::Which::String;
+ if (parameters.at(0).getType() != expected_param_type || parameters.at(1).getType() != expected_param_type)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function '{}' requires 'String' parameters", name);
+
+ String param_dir = parameters.at(0).safeGet<String>();
+ std::unordered_map<std::string, SequenceDirection> seq_dir_mapping{
+ {"forward", SequenceDirection::Forward},
+ {"backward", SequenceDirection::Backward},
+ };
+ if (!seq_dir_mapping.contains(param_dir))
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", name, param_dir);
+ SequenceDirection direction = seq_dir_mapping[param_dir];
+
+ String param_base = parameters.at(1).safeGet<String>();
+ std::unordered_map<std::string, SequenceBase> seq_base_mapping{
+ {"head", SequenceBase::Head},
+ {"tail", SequenceBase::Tail},
+ {"first_match", SequenceBase::FirstMatch},
+ {"last_match", SequenceBase::LastMatch},
+ };
+ if (!seq_base_mapping.contains(param_base))
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", name, param_base);
+ SequenceBase base = seq_base_mapping[param_base];
+
+ if ((base == SequenceBase::Head && direction == SequenceDirection::Backward) ||
+ (base == SequenceBase::Tail && direction == SequenceDirection::Forward))
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid argument combination of '{}' with '{}'", param_base, param_dir);
+
+ if (argument_types.size() < min_required_args)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
+ "Aggregate function {} requires at least {} arguments.", name, toString(min_required_args));
+
+ bool is_base_match_type = base == SequenceBase::FirstMatch || base == SequenceBase::LastMatch;
+ if (is_base_match_type && argument_types.size() < min_required_args + 1)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
+ "Aggregate function {} requires at least {} arguments when base is first_match or last_match.",
+ name, toString(min_required_args + 1));
+
+ if (argument_types.size() > max_events_size + min_required_args)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
+ "Aggregate function '{}' requires at most {} (timestamp, value_column, ...{} events) arguments.",
+ name, max_events_size + min_required_args, max_events_size);
+
+ if (const auto * cond_arg = argument_types[2].get(); cond_arg && !isUInt8(cond_arg))
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of third argument of aggregate function {}, "
+ "must be UInt8", cond_arg->getName(), name);
+
+ for (const auto i : collections::range(min_required_args, argument_types.size()))
+ {
+ const auto * cond_arg = argument_types[i].get();
+ if (!isUInt8(cond_arg))
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type '{}' of {} argument of aggregate function '{}', must be UInt8", cond_arg->getName(), i + 1, name);
+ }
+
+ if (WhichDataType(argument_types[1].get()).idx != TypeIndex::String)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} of second argument of aggregate function {}, must be String",
+ argument_types[1].get()->getName(), name);
+
+ DataTypePtr data_type = makeNullable(argument_types[1]);
+
+ WhichDataType timestamp_type(argument_types[0].get());
+ if (timestamp_type.idx == TypeIndex::UInt8)
+ return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, parameters, direction, base);
+ if (timestamp_type.idx == TypeIndex::UInt16)
+ return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, parameters, direction, base);
+ if (timestamp_type.idx == TypeIndex::UInt32)
+ return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, parameters, direction, base);
+ if (timestamp_type.idx == TypeIndex::UInt64)
+ return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, parameters, direction, base);
+ if (timestamp_type.isDate())
+ return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, parameters, direction, base);
+ if (timestamp_type.isDateTime())
+ return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, parameters, direction, base);
+
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} of first argument of aggregate function {}, must "
+ "be Unsigned Number, Date, DateTime", argument_types.front().get()->getName(), name);
+}
+
+}
+
+void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory & factory)
+{
+ AggregateFunctionProperties properties = { .returns_default_when_only_null = true, .is_order_dependent = false };
+ factory.registerFunction("sequenceNextNode", { createAggregateFunctionSequenceNode, properties });
+}
+
+}