diff options
author | aidarsamer <aidarsamer@yandex-team.ru> | 2022-05-03 22:05:20 +0300 |
---|---|---|
committer | aidarsamer <aidarsamer@yandex-team.ru> | 2022-05-03 22:05:20 +0300 |
commit | f35dbbdae6519b9ab9cabb2250dd4d734c12035b (patch) | |
tree | 1dbb5aa57af6a224c14cf771c17a48261f513e52 | |
parent | fe43d898800daf570d5e787fd9a54732712990f6 (diff) | |
download | ydb-f35dbbdae6519b9ab9cabb2250dd4d734c12035b.tar.gz |
KIKIMR-13107. Cast pushdown to OLAP.
KIKIMR-13107. Add cast pushdown to OLAP. Fix tuple comparisons conversion.
ref:3330b9b0d89d17bf7778b6bd0d566408d08c64b2
-rw-r--r-- | ydb/core/driver_lib/run/run.cpp | 2 | ||||
-rw-r--r-- | ydb/core/formats/CMakeLists.txt | 2 | ||||
-rw-r--r-- | ydb/core/formats/custom_registry.cpp | 20 | ||||
-rw-r--r-- | ydb/core/formats/custom_registry.h | 10 | ||||
-rw-r--r-- | ydb/core/formats/execs.h | 7 | ||||
-rw-r--r-- | ydb/core/formats/func_cast.cpp | 150 | ||||
-rw-r--r-- | ydb/core/formats/func_cast.h | 22 | ||||
-rw-r--r-- | ydb/core/formats/func_gcd.h | 17 | ||||
-rw-r--r-- | ydb/core/formats/func_math.h | 3 | ||||
-rw-r--r-- | ydb/core/formats/func_round.h | 16 | ||||
-rw-r--r-- | ydb/core/formats/functions.h | 1 | ||||
-rw-r--r-- | ydb/core/formats/program.cpp | 35 | ||||
-rw-r--r-- | ydb/core/formats/program.h | 20 | ||||
-rw-r--r-- | ydb/core/kqp/compile/kqp_olap_compiler.cpp | 51 | ||||
-rw-r--r-- | ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp | 232 | ||||
-rw-r--r-- | ydb/core/kqp/prepare/kqp_type_ann.cpp | 5 | ||||
-rw-r--r-- | ydb/core/kqp/ut/kqp_olap_ut.cpp | 28 | ||||
-rw-r--r-- | ydb/core/protos/ssa.proto | 14 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/columnshard__stats_scan.h | 3 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/columnshard_common.cpp | 25 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/engines/indexed_read_data.cpp | 3 |
21 files changed, 555 insertions, 111 deletions
diff --git a/ydb/core/driver_lib/run/run.cpp b/ydb/core/driver_lib/run/run.cpp index 5e8a9f7ffe..d39be96ca5 100644 --- a/ydb/core/driver_lib/run/run.cpp +++ b/ydb/core/driver_lib/run/run.cpp @@ -934,7 +934,7 @@ void TKikimrRunner::InitializeAppData(const TKikimrRunConfig& runConfig) if (runConfig.AppConfig.GetBootstrapConfig().HasEnableIntrospection()) AppData->EnableIntrospection = runConfig.AppConfig.GetBootstrapConfig().GetEnableIntrospection(); - + TAppDataInitializersList appDataInitializers; // setup domain info appDataInitializers.AddAppDataInitializer(new TDomainsInitializer(runConfig)); diff --git a/ydb/core/formats/CMakeLists.txt b/ydb/core/formats/CMakeLists.txt index 0338d2a440..f61a7c54d1 100644 --- a/ydb/core/formats/CMakeLists.txt +++ b/ydb/core/formats/CMakeLists.txt @@ -18,6 +18,8 @@ target_sources(ydb-core-formats PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/formats/arrow_batch_builder.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/arrow_helpers.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/clickhouse_block.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/formats/custom_registry.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/formats/func_cast.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/merging_sorted_input_stream.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/program.cpp ) diff --git a/ydb/core/formats/custom_registry.cpp b/ydb/core/formats/custom_registry.cpp index 404d01d6a9..347e833f9f 100644 --- a/ydb/core/formats/custom_registry.cpp +++ b/ydb/core/formats/custom_registry.cpp @@ -1,7 +1,9 @@ +#include "custom_registry.h" + #include "functions.h" #include "func_common.h" #include <util/system/yassert.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> namespace cp = ::arrow::compute; @@ -18,7 +20,10 @@ static void RegisterMath(cp::FunctionRegistry* registry) { Y_VERIFY(registry->AddFunction(MakeMathUnary<TErfc>(TErfc::Name)).ok()); Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp>(TExp::Name)).ok()); Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp2>(TExp2::Name)).ok()); + // Temporarily disabled because of compilation error on Windows. +#if 0 Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp10>(TExp10::Name)).ok()); +#endif Y_VERIFY(registry->AddFunction(MakeMathBinary<THypot>(THypot::Name)).ok()); Y_VERIFY(registry->AddFunction(MakeMathUnary<TLgamma>(TLgamma::Name)).ok()); Y_VERIFY(registry->AddFunction(MakeConstNullary<TPi>(TPi::Name)).ok()); @@ -40,24 +45,31 @@ static void RegisterArithmetic(cp::FunctionRegistry* registry) { Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModuloOrZero>(TModuloOrZero::Name)).ok()); } +static void RegisterYdbCast(cp::FunctionRegistry* registry) { + cp::internal::RegisterScalarCast(registry); + Y_VERIFY(registry->AddFunction(std::make_shared<YdbCastMetaFunction>()).ok()); +} + static std::unique_ptr<cp::FunctionRegistry> CreateCustomRegistry() { auto registry = cp::FunctionRegistry::Make(); RegisterMath(registry.get()); RegisterRound(registry.get()); RegisterArithmetic(registry.get()); - cp::internal::RegisterScalarCast(registry.get()); + RegisterYdbCast(registry.get()); return registry; } +// Creates singleton custom registry cp::FunctionRegistry* GetCustomFunctionRegistry() { static auto g_registry = CreateCustomRegistry(); return g_registry.get(); } +// We want to have ExecContext per thread. All these context use one custom registry. cp::ExecContext* GetCustomExecContext() { - static auto context = std::make_unique<cp::ExecContext>(arrow::default_memory_pool(), NULLPTR, GetCustomFunctionRegistry()); - return context.get(); + static thread_local cp::ExecContext context(arrow::default_memory_pool(), nullptr, GetCustomFunctionRegistry()); + return &context; } } diff --git a/ydb/core/formats/custom_registry.h b/ydb/core/formats/custom_registry.h index 8442e44eee..77f419d33d 100644 --- a/ydb/core/formats/custom_registry.h +++ b/ydb/core/formats/custom_registry.h @@ -1,9 +1,11 @@ #pragma once -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> -namespace cp = ::arrow::compute; +namespace arrow::compute { + class FunctionRegistry; + class ExecContext; +} namespace NKikimr::NArrow { - cp::FunctionRegistry* GetCustomFunctionRegistry(); - cp::ExecContext* GetCustomExecContext(); + arrow::compute::FunctionRegistry* GetCustomFunctionRegistry(); + arrow::compute::ExecContext* GetCustomExecContext(); } diff --git a/ydb/core/formats/execs.h b/ydb/core/formats/execs.h index bf98b70927..253bae4831 100644 --- a/ydb/core/formats/execs.h +++ b/ydb/core/formats/execs.h @@ -5,7 +5,12 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/builder.h> + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-parameter" #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h> +#pragma clang diagnostic pop + #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h> #include <util/datetime/base.h> @@ -17,8 +22,6 @@ #include "switch_type.h" namespace cp = arrow::compute; -using cp::internal::applicator::ScalarBinary; -using cp::internal::applicator::ScalarUnary; namespace NKikimr::NArrow { diff --git a/ydb/core/formats/func_cast.cpp b/ydb/core/formats/func_cast.cpp new file mode 100644 index 0000000000..7d480bdac0 --- /dev/null +++ b/ydb/core/formats/func_cast.cpp @@ -0,0 +1,150 @@ +#include "func_cast.h" + +#include <util/system/yassert.h> + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-parameter" +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/util/time.h> +#pragma clang diagnostic pop + +#include <unordered_map> + +namespace cp = ::arrow::compute; + +namespace NKikimr::NArrow { + +namespace { + +std::shared_ptr<cp::CastFunction> GetYdbTimestampCast() { + auto func = std::make_shared<cp::CastFunction>("ydb_cast_timestamp", ::arrow::Type::TIMESTAMP); + cp::internal::AddSimpleCast<arrow::UInt32Type, arrow::TimestampType>( + cp::InputType(arrow::Type::UINT32), + cp::internal::kOutputTargetType, + func.get() + ); + return func; +} + +std::vector<std::shared_ptr<cp::CastFunction>> GetYdbTemporalCasts() { + std::vector<std::shared_ptr<cp::CastFunction>> functions; + functions.push_back(GetYdbTimestampCast()); + return functions; +} + +std::unordered_map<int, std::shared_ptr<cp::CastFunction>> ydbCastTable; +std::once_flag ydbCastTableInitialized; + +void AddCastFunctions(const std::vector<std::shared_ptr<cp::CastFunction>>& funcs) { + for (const auto& func : funcs) { + ydbCastTable[static_cast<int>(func->out_type_id())] = func; + } +} + +void InitYdbCastTable() { + AddCastFunctions(GetYdbTemporalCasts()); +} + +void EnsureInitYdbCastTable() { + std::call_once(ydbCastTableInitialized, InitYdbCastTable); +} + +// Private version of GetCastFunction with better error reporting +// if the input type is known. +::arrow::Result<std::shared_ptr<cp::CastFunction>> GetYdbCastFunctionInternal( + const std::shared_ptr<::arrow::DataType>& to_type, const::arrow::DataType* from_type = nullptr) { + EnsureInitYdbCastTable(); + auto it = ydbCastTable.find(static_cast<int>(to_type->id())); + if (it == ydbCastTable.end()) { + auto res = cp::GetCastFunction(to_type); + if (!res.ok()) { + if (from_type != nullptr) { + return ::arrow::Status::NotImplemented("Unsupported cast from ", *from_type, " to ", + *to_type, + " (no available cast function for target type)"); + } else { + return ::arrow::Status::NotImplemented("Unsupported cast to ", *to_type, + " (no available cast function for target type)"); + } + } + return std::move(res).ValueUnsafe(); + } + return it->second; +} + +} // namespace + +static const cp::FunctionDoc ydbCastDoc{"YDB special cast function. Uses Arrow's cast and add casting support for some types." + "Cast values to another data type", + ("Behavior when values wouldn't fit in the target type\n" + "can be controlled through CastOptions."), + {"input"}, + "CastOptions"}; + +YdbCastMetaFunction::YdbCastMetaFunction() + : ::arrow::compute::MetaFunction("ydb.cast", ::arrow::compute::Arity::Unary(), &ydbCastDoc) + {} + +::arrow::Result<const cp::CastOptions*> YdbCastMetaFunction::ValidateOptions(const cp::FunctionOptions* options) const { + auto cast_options = static_cast<const cp::CastOptions*>(options); + + if (cast_options == nullptr || cast_options->to_type == nullptr) { + return ::arrow::Status::Invalid( + "Cast requires that options be passed with " + "the to_type populated"); + } + + return cast_options; +} + +::arrow::Result<::arrow::Datum> YdbCastMetaFunction::ExecuteImpl(const std::vector<::arrow::Datum>& args, + const cp::FunctionOptions* options, + cp::ExecContext* ctx) const +{ + auto&& optsResult = ValidateOptions(options); + if (!optsResult.ok()) { + return optsResult.status(); + } + auto cast_options = std::move(optsResult).ValueUnsafe(); + if (args[0].type()->Equals(*cast_options->to_type)) { + return args[0]; + } + auto&& castFuncResult = GetYdbCastFunctionInternal(cast_options->to_type, args[0].type().get()); + if (!castFuncResult.ok()) { + return castFuncResult.status(); + } + std::shared_ptr<cp::CastFunction> castFunc = std::move(castFuncResult).ValueUnsafe(); + return castFunc->Execute(args, options, ctx); +} + +} // NKikimr::NArrow + +namespace arrow::compute::internal { + +template <> +struct CastFunctor<TimestampType, UInt32Type> { + static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) { + if (batch.num_values() == 0) { + return ::arrow::Status::IndexError("Cast from uint32 to timestamp received empty batch."); + } + Y_VERIFY(batch[0].kind() == Datum::ARRAY, "Cast from uint32 to timestamp expected ARRAY as input."); + + const auto& out_type = checked_cast<const ::arrow::TimestampType&>(*out->type()); + // get conversion MICROSECONDS -> unit + auto conversion = ::arrow::util::GetTimestampConversion(::arrow::TimeUnit::MICRO, out_type.unit()); + Y_VERIFY(conversion.first == ::arrow::util::MULTIPLY, "Cast from uint32 to timestamp failed because timestamp unit is greater than seconds."); + + auto input = batch[0].array(); + auto output = out->mutable_array(); + auto in_data = input->GetValues<uint32_t>(1); + auto out_data = output->GetMutableValues<int64_t>(1); + + for (int64_t i = 0; i < input->length; i++) { + out_data[i] = static_cast<int64_t>(in_data[i] * conversion.second); + } + return ::arrow::Status::OK(); + } +}; + +} // namespace arrow::compute::internal
\ No newline at end of file diff --git a/ydb/core/formats/func_cast.h b/ydb/core/formats/func_cast.h new file mode 100644 index 0000000000..532f3164d7 --- /dev/null +++ b/ydb/core/formats/func_cast.h @@ -0,0 +1,22 @@ +#pragma once + +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h> + +#include <type_traits> + +namespace NKikimr::NArrow { + +// Metafunction for dispatching to appropriate CastFunction. This corresponds +// to the standard SQL CAST(expr AS target_type) +class YdbCastMetaFunction : public ::arrow::compute::MetaFunction { + public: + YdbCastMetaFunction(); + + ::arrow::Result<const ::arrow::compute::CastOptions*> ValidateOptions(const ::arrow::compute::FunctionOptions* options) const; + + ::arrow::Result<::arrow::Datum> ExecuteImpl(const std::vector<::arrow::Datum>& args, + const ::arrow::compute::FunctionOptions* options, + ::arrow::compute::ExecContext* ctx) const override; +}; + +} // NKikimr::NArrow diff --git a/ydb/core/formats/func_gcd.h b/ydb/core/formats/func_gcd.h index 9d4b72e4b0..0f6fe4bc8d 100644 --- a/ydb/core/formats/func_gcd.h +++ b/ydb/core/formats/func_gcd.h @@ -1,17 +1,18 @@ #pragma once -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> + #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> -#include "func_common.h" -#include <cstdlib> #include <type_traits> +#ifdef WIN32 +#define INLINE inline +#else +#define INLINE __attribute__((always_inline)) +#endif + namespace NKikimr::NArrow { template<typename T> -__attribute__((always_inline)) void FastIntSwap(T& lhs, T& rhs) { +INLINE void FastIntSwap(T& lhs, T& rhs) { lhs ^= rhs; rhs ^= lhs; lhs ^= rhs; @@ -35,3 +36,5 @@ struct TGreatestCommonDivisor { }; } + +#undef INLINE diff --git a/ydb/core/formats/func_math.h b/ydb/core/formats/func_math.h index 74cf8694ff..13474e0018 100644 --- a/ydb/core/formats/func_math.h +++ b/ydb/core/formats/func_math.h @@ -95,6 +95,8 @@ struct TExp2 { } }; +#if 0 +// Temporarily disable function because it doesn't compile on Windows. struct TExp10 { static constexpr const char * Name = "exp10"; @@ -104,6 +106,7 @@ struct TExp10 { return exp10(arg); } }; +#endif struct THypot { diff --git a/ydb/core/formats/func_round.h b/ydb/core/formats/func_round.h index a65e893622..3b82591c45 100644 --- a/ydb/core/formats/func_round.h +++ b/ydb/core/formats/func_round.h @@ -9,6 +9,15 @@ #include <fenv.h> #include <type_traits> +#ifdef WIN32 +#include <intrin.h> +#define CLZ __lzcnt +#define CLZLL __lzcnt64 +#else +#define CLZ __builtin_clz +#define CLZLL __builtin_clzll +#endif + namespace NKikimr::NArrow { struct TRound { @@ -42,7 +51,7 @@ struct TRoundToExp2 { (sizeof(TRes) <= sizeof(uint32_t)), TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { static_assert(std::is_same_v<TRes, TArg>, ""); - return arg <= 0 ? 0 : (TRes(1) << (31 - __builtin_clz(arg))); + return arg <= 0 ? 0 : (TRes(1) << (31 - CLZ(arg))); } template <typename TRes, typename TArg> @@ -50,7 +59,7 @@ struct TRoundToExp2 { (sizeof(TRes) == sizeof(uint64_t)), TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { static_assert(std::is_same_v<TRes, TArg>, ""); - return arg <= 0 ? 0 : (TRes(1) << (63 - __builtin_clzll(arg))); + return arg <= 0 ? 0 : (TRes(1) << (63 - CLZLL(arg))); } template <typename TRes, typename TArg> @@ -67,3 +76,6 @@ struct TRoundToExp2 { }; } + +#undef CLZ +#undef CLZLL diff --git a/ydb/core/formats/functions.h b/ydb/core/formats/functions.h index 4c0ecda6be..2f4523a4fe 100644 --- a/ydb/core/formats/functions.h +++ b/ydb/core/formats/functions.h @@ -1,5 +1,6 @@ #pragma once +#include "func_cast.h" #include "func_gcd.h" #include "func_lcm.h" #include "func_modulo.h" diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp index 3bc39be6fe..f8d8287ab5 100644 --- a/ydb/core/formats/program.cpp +++ b/ydb/core/formats/program.cpp @@ -20,33 +20,21 @@ namespace NKikimr::NArrow { const char * GetFunctionName(EOperation op) { switch (op) { case EOperation::CastBoolean: - return "cast_boolean"; case EOperation::CastInt8: - return "cast_int8"; case EOperation::CastInt16: - return "cast_int16"; case EOperation::CastInt32: - return "cast_int32"; case EOperation::CastInt64: - return "cast_int64"; case EOperation::CastUInt8: - return "cast_uint8"; case EOperation::CastUInt16: - return "cast_uint16"; case EOperation::CastUInt32: - return "cast_uint32"; case EOperation::CastUInt64: - return "cast_uint64"; case EOperation::CastFloat: - return "cast_float"; case EOperation::CastDouble: - return "cast_double"; case EOperation::CastBinary: - return "cast_binary"; case EOperation::CastFixedSizeBinary: - return "cast_fixed_size_binary"; case EOperation::CastString: - return "cast_string"; + case EOperation::CastTimestamp: + return "ydb.cast"; case EOperation::IsValid: return "is_valid"; @@ -263,7 +251,7 @@ std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std:: return result->scalar(); } -arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, +arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, const arrow::compute::FunctionOptions* funcOpts, std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { std::vector<arrow::Datum> arguments; arguments.reserve(args.size()); @@ -274,17 +262,20 @@ arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& arguments.push_back(*column); } std::string funcName = GetFunctionName(funcId); + arrow::Result<arrow::Datum> result; if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { - result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); + result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts, ctx); } else { - result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments); + result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts); } Y_VERIFY(result.ok()); return result.ValueOrDie(); } - +arrow::Datum CallFunctionByAssign(const TAssign& assign, std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { + return CallFunctionById(assign.GetOperation(), assign.GetArguments(), assign.GetFunctionOptions(), batch, ctx); +} void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const { if (Assignes.empty()) { @@ -298,7 +289,7 @@ void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& bat if (assign.IsConstant()) { column = assign.GetConstant(); } else { - column = CallFunctionById(assign.GetOperation(), assign.GetArguments(), batch, ctx); + column = CallFunctionByAssign(assign, batch, ctx); } AddColumn(batch, assign.GetName(), column); } @@ -313,9 +304,9 @@ void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const { filters.reserve(Filters.size()); for (auto& colName : Filters) { auto column = GetColumnByName(batch, colName); - Y_VERIFY(column.ok()); - Y_VERIFY(column->is_array()); - Y_VERIFY(column->type() == arrow::boolean()); + Y_VERIFY_S(column.ok(), TStringBuilder() << "Column " << colName << " is not ok."); + Y_VERIFY_S(column->is_array(), TStringBuilder() << "Column " << colName << " is not an array."); + Y_VERIFY_S(column->type() == arrow::boolean(), TStringBuilder() << "Column " << colName << " type is not bool."); auto boolColumn = std::static_pointer_cast<arrow::BooleanArray>(column->make_array()); filters.push_back(std::vector<bool>(boolColumn->length())); auto& bits = filters.back(); diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h index ff15d30ebf..f4b7d466ab 100644 --- a/ydb/core/formats/program.h +++ b/ydb/core/formats/program.h @@ -23,6 +23,7 @@ enum class EOperation { CastBinary, CastFixedSizeBinary, CastString, + CastTimestamp, // IsValid, IsNull, @@ -90,60 +91,77 @@ public: : Name(name) , Operation(op) , Arguments(std::move(args)) + , FuncOpts(nullptr) + {} + + TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args, std::shared_ptr<arrow::compute::FunctionOptions> funcOpts) + : Name(name) + , Operation(op) + , Arguments(std::move(args)) + , FuncOpts(funcOpts) {} explicit TAssign(const std::string& name, bool value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::BooleanScalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, i32 value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::Int32Scalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, ui32 value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::UInt32Scalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, i64 value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::Int64Scalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, ui64 value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::UInt64Scalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, float value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::FloatScalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, double value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::DoubleScalar>(value)) + , FuncOpts(nullptr) {} explicit TAssign(const std::string& name, const std::string& value) : Name(name) , Operation(EOperation::Constant) , Constant(std::make_shared<arrow::StringScalar>(value)) + , FuncOpts(nullptr) {} TAssign(const std::string& name, const std::shared_ptr<arrow::Scalar>& value) : Name(name) , Operation(EOperation::Constant) , Constant(value) + , FuncOpts(nullptr) {} bool IsConstant() const { return Operation == EOperation::Constant; } @@ -151,12 +169,14 @@ public: const std::vector<std::string>& GetArguments() const { return Arguments; } std::shared_ptr<arrow::Scalar> GetConstant() const { return Constant; } const std::string& GetName() const { return Name; } + const arrow::compute::FunctionOptions* GetFunctionOptions() const { return FuncOpts.get(); } private: std::string Name; EOperation Operation{EOperation::Unspecified}; std::vector<std::string> Arguments; std::shared_ptr<arrow::Scalar> Constant; + std::shared_ptr<arrow::compute::FunctionOptions> FuncOpts; }; /// Group of commands that finishes with projection. Steps add locality for columns definition. diff --git a/ydb/core/kqp/compile/kqp_olap_compiler.cpp b/ydb/core/kqp/compile/kqp_olap_compiler.cpp index 8f16b9345c..f880273706 100644 --- a/ydb/core/kqp/compile/kqp_olap_compiler.cpp +++ b/ydb/core/kqp/compile/kqp_olap_compiler.cpp @@ -77,6 +77,7 @@ private: TProgram::TAssignment* CompileCondition(const TExprBase& condition, TKqpOlapCompileContext& ctx); +ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx); ui32 ConvertValueToColumn(const TCoDataCtor& value, TKqpOlapCompileContext& ctx) { @@ -145,6 +146,52 @@ ui32 ConvertParameterToColumn(const TCoParameter& parameter, TKqpOlapCompileCont return ssaValue->GetColumn().GetId(); } +ui32 ConvertSafeCastToColumn(const TCoSafeCast& cast, TKqpOlapCompileContext& ctx) +{ + auto columnId = GetOrCreateColumnId(cast.Value(), ctx); + + TProgram::TAssignment* ssaValue = ctx.CreateAssignCmd(); + + auto newCast = ssaValue->MutableFunction(); + + auto maybeDataType = cast.Type().Maybe<TCoDataType>(); + YQL_ENSURE(maybeDataType.IsValid()); + + auto dataType = maybeDataType.Cast(); + ui32 castFunction = TProgram::TAssignment::FUNC_UNSPECIFIED; + if (dataType.Type().Value() == "Boolean") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_BOOLEAN; + } else if (dataType.Type().Value() == "Int8") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT8; + } else if (dataType.Type().Value() == "Int16") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT16; + } else if (dataType.Type().Value() == "Int32") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT32; + } else if (dataType.Type().Value() == "Int64") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT64; + } else if (dataType.Type().Value() == "Uint8") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT8; + } else if (dataType.Type().Value() == "Uint16") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT16; + } else if (dataType.Type().Value() == "Uint32") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT32; + } else if (dataType.Type().Value() == "Uint64") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT64; + } else if (dataType.Type().Value() == "Float") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_FLOAT; + } else if (dataType.Type().Value() == "Double") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_DOUBLE; + } else if (dataType.Type().Value() == "Timestamp") { + castFunction = TProgram::TAssignment::FUNC_CAST_TO_TIMESTAMP; + } else { + YQL_ENSURE(false, "Unsupported data type for pushed down safe cast: " << dataType.Type().Value()); + } + + newCast->SetId(castFunction); + newCast->AddArguments()->SetId(columnId); + return ssaValue->GetColumn().GetId(); +} + ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx) { if (auto maybeData = node.Maybe<TCoDataCtor>()) { return ConvertValueToColumn(maybeData.Cast(), ctx); @@ -158,6 +205,10 @@ ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx) { return ConvertParameterToColumn(maybeParameter.Cast(), ctx); } + if (auto maybeCast = node.Maybe<TCoSafeCast>()) { + return ConvertSafeCastToColumn(maybeCast.Cast(), ctx); + } + YQL_ENSURE(false, "Unknown node in OLAP comparison compiler: " << node.Ptr()->Content()); } diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp index fa0bb9cc80..2dded70c9d 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp @@ -103,6 +103,19 @@ bool IsSupportedDataType(const TCoDataCtor& node) { return false; } +bool IsSupportedCast(const TCoSafeCast& cast) { + auto maybeDataType = cast.Type().Maybe<TCoDataType>(); + YQL_ENSURE(maybeDataType.IsValid()); + + auto dataType = maybeDataType.Cast(); + if (dataType.Type().Value() == "Int32") { + return cast.Value().Maybe<TCoString>().IsValid(); + } else if (dataType.Type().Value() == "Timestamp") { + return cast.Value().Maybe<TCoUint32>().IsValid(); + } + return false; +} + bool IsComparableTypes(const TExprBase& leftNode, const TExprBase& rightNode, bool equality, const TTypeAnnotationNode* inputType) { @@ -228,43 +241,94 @@ bool IsComparableTypes(const TExprBase& leftNode, const TExprBase& rightNode, bo return true; } - -TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCompare& predicate, - const TExprNode* rawLambdaArg, const TExprBase& input) +TVector<TExprBase> ConvertComparisonNode(const TExprBase& nodeIn, const TExprNode* rawLambdaArg) { - TVector<std::pair<TExprBase, TExprBase>> out; + TVector<TExprBase> out; auto convertNode = [rawLambdaArg](const TExprBase& node) -> TMaybeNode<TExprBase> { if (node.Maybe<TCoNull>()) { return node; } + if (auto maybeSafeCast = node.Maybe<TCoSafeCast>()) { + if (!IsSupportedCast(maybeSafeCast.Cast())) { + return NullNode; + } + + return node; + } + if (auto maybeParameter = node.Maybe<TCoParameter>()) { return maybeParameter.Cast(); } if (auto maybeData = node.Maybe<TCoDataCtor>()) { - if (IsSupportedDataType(maybeData.Cast())) { - return node; + if (!IsSupportedDataType(maybeData.Cast())) { + return NullNode; } - return NullNode; + return node; } if (auto maybeMember = node.Maybe<TCoMember>()) { - if (maybeMember.Cast().Struct().Raw() == rawLambdaArg) { - return maybeMember.Cast().Name(); + if (maybeMember.Cast().Struct().Raw() != rawLambdaArg) { + return NullNode; } - return NullNode; + return maybeMember.Cast().Name(); } return NullNode; }; // Columns & values may be single element - TMaybeNode<TExprBase> left = convertNode(predicate.Left()); - TMaybeNode<TExprBase> right = convertNode(predicate.Right()); + TMaybeNode<TExprBase> node = convertNode(nodeIn); + + if (node.IsValid()) { + out.emplace_back(std::move(node.Cast())); + return out; + } + + // Or columns and values can be Tuple + if (!nodeIn.Maybe<TExprList>()) { + // something unusual found, return empty vector + return out; + } + + auto tuple = nodeIn.Cast<TExprList>(); + + out.reserve(tuple.Size()); + + for (ui32 i = 0; i < tuple.Size(); ++i) { + TMaybeNode<TExprBase> node = convertNode(tuple.Item(i)); + + if (!node.IsValid()) { + // Return empty vector + return TVector<TExprBase>(); + } + + out.emplace_back(node.Cast()); + } + + return out; +} + +TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCompare& predicate, + const TExprNode* rawLambdaArg, const TExprBase& input) +{ + TVector<std::pair<TExprBase, TExprBase>> out; + + auto left = ConvertComparisonNode(predicate.Left(), rawLambdaArg); + + if (left.empty()) { + return out; + } + + auto right = ConvertComparisonNode(predicate.Right(), rawLambdaArg); + + if (left.size() != right.size()) { + return out; + } TMaybeNode<TCoCmpEqual> maybeEqual = predicate.Maybe<TCoCmpEqual>(); TMaybeNode<TCoCmpNotEqual> maybeNotEqual = predicate.Maybe<TCoCmpNotEqual>(); @@ -292,52 +356,19 @@ TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCo return out; } - if (left.IsValid() && right.IsValid()) { - if (!IsComparableTypes(left.Cast(), right.Cast(), equality, inputType)) { - return out; - } - - out.emplace_back(std::move(std::make_pair(left.Cast(), right.Cast()))); - return out; - } - - // Or columns and values can be Tuple - if (!predicate.Left().Maybe<TExprList>() || !predicate.Right().Maybe<TExprList>()) { - // something unusual found, return empty vector - return out; - } - - auto tupleLeft = predicate.Left().Cast<TExprList>(); - auto tupleRight = predicate.Right().Cast<TExprList>(); - - if (tupleLeft.Size() != tupleRight.Size()) { - return out; - } - - out.reserve(tupleLeft.Size()); - - for (ui32 i = 0; i < tupleLeft.Size(); ++i) { - TMaybeNode<TExprBase> left = convertNode(tupleLeft.Item(i)); - TMaybeNode<TExprBase> right = convertNode(tupleRight.Item(i)); - - if (!left.IsValid() || !right.IsValid()) { - // Return empty vector - return TVector<std::pair<TExprBase, TExprBase>>(); - } - - if (!IsComparableTypes(left.Cast(), right.Cast(), equality, inputType)) { + for (ui32 i = 0; i < left.size(); ++i) { + if (!IsComparableTypes(left[i], right[i], equality, inputType)) { // Return empty vector return TVector<std::pair<TExprBase, TExprBase>>(); } - - out.emplace_back(std::move(std::make_pair(left.Cast(), right.Cast()))); + out.emplace_back(std::move(std::make_pair(left[i], right[i]))); } return out; } TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& parameter, const TCoCompare& predicate, - TExprContext& ctx, TPositionHandle pos, const TExprBase& input) + TExprContext& ctx, TPositionHandle pos, const TExprBase& input, bool forceStrictComparison) { auto isNull = [](const TExprBase& node) { if (node.Maybe<TCoNull>()) { @@ -368,7 +399,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param .Done(); } - if (predicate.Maybe<TCoCmpLess>()) { + if (predicate.Maybe<TCoCmpLess>() || (predicate.Maybe<TCoCmpLessOrEqual>() && forceStrictComparison)) { return Build<TKqpOlapFilterLess>(ctx, pos) .Input(input) .Left(parameter.first) @@ -376,7 +407,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param .Done(); } - if (predicate.Maybe<TCoCmpLessOrEqual>()) { + if (predicate.Maybe<TCoCmpLessOrEqual>() && !forceStrictComparison) { return Build<TKqpOlapFilterLessOrEqual>(ctx, pos) .Input(input) .Left(parameter.first) @@ -384,7 +415,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param .Done(); } - if (predicate.Maybe<TCoCmpGreater>()) { + if (predicate.Maybe<TCoCmpGreater>() || (predicate.Maybe<TCoCmpGreaterOrEqual>() && forceStrictComparison)) { return Build<TKqpOlapFilterGreater>(ctx, pos) .Input(input) .Left(parameter.first) @@ -392,7 +423,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param .Done(); } - if (predicate.Maybe<TCoCmpGreaterOrEqual>()) { + if (predicate.Maybe<TCoCmpGreaterOrEqual>() && !forceStrictComparison) { return Build<TKqpOlapFilterGreaterOrEqual>(ctx, pos) .Input(input) .Left(parameter.first) @@ -417,7 +448,7 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par ui32 conditionsCount = parameters.size(); if (conditionsCount == 1) { - return BuildOneElementComparison(parameters[0], predicate, ctx, pos, input); + return BuildOneElementComparison(parameters[0], predicate, ctx, pos, input, false); } if (predicate.Maybe<TCoCmpEqual>() || predicate.Maybe<TCoCmpNotEqual>()) { @@ -425,7 +456,7 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par conditions.reserve(conditionsCount); for (ui32 i = 0; i < conditionsCount; ++i) { - conditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input)); + conditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input, false)); } if (predicate.Maybe<TCoCmpEqual>()) { @@ -447,7 +478,9 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par TVector<TExprBase> andConditions; andConditions.reserve(conditionsCount); - andConditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input)); + // We need strict < and > in beginning columns except the last one + // For example: (c1, c2, c3) >= (1, 2, 3) ==> (c1 > 1) OR (c2 > 2 AND c1 = 1) OR (c3 >= 3 AND c2 = 2 AND c1 = 1) + andConditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input, i < conditionsCount - 1)); for (ui32 j = 0; j < i; ++j) { andConditions.emplace_back(Build<TKqpOlapFilterEqual>(ctx, pos) @@ -516,10 +549,46 @@ TMaybeNode<TExprBase> ExistsPushdown(const TCoExists& exists, TExprContext& ctx, .Done(); } -TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext& ctx, TPositionHandle pos, - const TExprNode* lambdaArg, const TExprBase& input) +TMaybeNode<TExprBase> SafeCastPredicatePushdown(const TCoFlatMap& flatmap, + TExprContext& ctx, TPositionHandle pos, const TExprNode* lambdaArg, const TExprBase& input) { - auto maybePredicate = coalesce.Predicate().Maybe<TCoCompare>(); + /* + * There are three ways of comparison in following format: + * + * FlatMap (LeftArgument, FlatMap(RightArgument(), Just(Predicate)) + * + * Examples: + * FlatMap (SafeCast(), FlatMap(Member(), Just(Comparison)) + * FlatMap (Member(), FlatMap(SafeCast(), Just(Comparison)) + * FlatMap (SafeCast(), FlatMap(SafeCast(), Just(Comparison)) + */ + TVector<std::pair<TExprBase, TExprBase>> out; + + auto maybeFlatmap = flatmap.Lambda().Body().Maybe<TCoFlatMap>(); + + if (!maybeFlatmap.IsValid()) { + return NullNode; + } + + auto right = ConvertComparisonNode(maybeFlatmap.Cast().Input(), lambdaArg); + + if (right.empty()) { + return NullNode; + } + + auto left = ConvertComparisonNode(flatmap.Input(), lambdaArg); + + if (left.empty()) { + return NullNode; + } + + auto maybeJust = maybeFlatmap.Cast().Lambda().Body().Maybe<TCoJust>(); + + if (!maybeJust.IsValid()) { + return NullNode; + } + + auto maybePredicate = maybeJust.Cast().Input().Maybe<TCoCompare>(); if (!maybePredicate.IsValid()) { return NullNode; @@ -531,14 +600,26 @@ TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext return NullNode; } - if (!coalesce.Value().Maybe<TCoBool>()) { + TVector<std::pair<TExprBase, TExprBase>> parameters; + + if (left.size() != right.size()) { return NullNode; } - if (coalesce.Value().Cast<TCoBool>().Literal().Value() != "false") { - return NullNode; + for (ui32 i = 0; i < left.size(); ++i) { + out.emplace_back(std::move(std::make_pair(left[i], right[i]))); } + return ComparisonPushdown(parameters, predicate, ctx, pos, input); +} + +TMaybeNode<TExprBase> SimplePredicatePushdown(const TCoCompare& predicate, TExprContext& ctx, TPositionHandle pos, + const TExprNode* lambdaArg, const TExprBase& input) +{ + if (!IsSupportedPredicate(predicate)) { + return NullNode; + } + auto parameters = ExtractComparisonParameters(predicate, lambdaArg, input); if (parameters.empty()) { @@ -548,6 +629,33 @@ TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext return ComparisonPushdown(parameters, predicate, ctx, pos, input); } + +TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext& ctx, TPositionHandle pos, + const TExprNode* lambdaArg, const TExprBase& input) +{ + if (!coalesce.Value().Maybe<TCoBool>()) { + return NullNode; + } + + if (coalesce.Value().Cast<TCoBool>().Literal().Value() != "false") { + return NullNode; + } + + auto maybeFlatmap = coalesce.Predicate().Maybe<TCoFlatMap>(); + + if (maybeFlatmap.IsValid()) { + return SafeCastPredicatePushdown(maybeFlatmap.Cast(), ctx, pos, lambdaArg, input); + } + + auto maybePredicate = coalesce.Predicate().Maybe<TCoCompare>(); + + if (maybePredicate.IsValid()) { + return SimplePredicatePushdown(maybePredicate.Cast(), ctx, pos, lambdaArg, input); + } + + return NullNode; +} + TMaybeNode<TExprBase> PredicatePushdown(const TExprBase& predicate, TExprContext& ctx, TPositionHandle pos, const TExprNode* lambdaArg, const TExprBase& input) { diff --git a/ydb/core/kqp/prepare/kqp_type_ann.cpp b/ydb/core/kqp/prepare/kqp_type_ann.cpp index ee51c7349c..f111792122 100644 --- a/ydb/core/kqp/prepare/kqp_type_ann.cpp +++ b/ydb/core/kqp/prepare/kqp_type_ann.cpp @@ -769,6 +769,11 @@ TStatus AnnotateOlapFilterCompare(const TExprNode::TPtr& node, TExprContext& ctx return true; } + // SafeCast, the checks about validity should be placed in kqp_opt_phy_olap_filter.cpp + if (TCoSafeCast::Match(node)) { + return true; + } + ctx.AddError(TIssue( ctx.GetPosition(node->Pos()), TStringBuilder() diff --git a/ydb/core/kqp/ut/kqp_olap_ut.cpp b/ydb/core/kqp/ut/kqp_olap_ut.cpp index 0b4b65aac4..1c47fd7550 100644 --- a/ydb/core/kqp/ut/kqp_olap_ut.cpp +++ b/ydb/core/kqp/ut/kqp_olap_ut.cpp @@ -1018,9 +1018,11 @@ Y_UNIT_TEST_SUITE(KqpOlap) { CreateTestOlapTable(kikimr); WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 5); + EnableDebugLogging(kikimr); auto tableClient = kikimr.GetTableClient(); + // TODO: Add support for DqPhyPrecompute push-down: Cast((2+2) as Uint64) std::vector<TString> testData = { R"(`resource_id` = `uid`)", R"(`resource_id` = "10001")", @@ -1037,7 +1039,14 @@ Y_UNIT_TEST_SUITE(KqpOlap) { R"((`level`, `uid`, `resource_id`) > (Int32("1"), "uid_3000001", "10001"))", R"((`level`, `uid`, `resource_id`) > (Int32("1"), "uid_3000000", "10001"))", R"((`level`, `uid`, `resource_id`) < (Int32("1"), "uid_3000002", "10001"))", - R"((`level`, `uid`, `resource_id`) >= (Int32("2"), "uid_3000000", "10001"))", + R"((`level`, `uid`, `resource_id`) >= (Int32("2"), "uid_3000001", "10001"))", + R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000002", "10001"))", + R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000001", "10002"))", + R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000001", "10001"))", + R"((`level`, `uid`, `resource_id`) <= (Int32("2"), "uid_3000001", "10001"))", + R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000002", "10001"))", + R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000001", "10002"))", + R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000001", "10001"))", R"((`level`, `uid`, `resource_id`) != (Int32("1"), "uid_3000001", "10001"))", R"((`level`, `uid`, `resource_id`) != (Int32("0"), "uid_3000001", "10011"))", R"(`level` = 0 OR `level` = 2 OR `level` = 1)", @@ -1045,8 +1054,7 @@ Y_UNIT_TEST_SUITE(KqpOlap) { R"(`level` = 0 OR `uid` = "uid_3000003")", R"(`level` = 0 AND `uid` = "uid_3000003")", R"(`level` = 0 AND `uid` = "uid_3000000")", - // Timestamp will be removed by predicate extraction now. - R"(`timestamp` >= CAST(3000001 AS Timestamp) AND `level` > 3)", + R"(`timestamp` >= CAST(3000001u AS Timestamp) AND `level` > 3)", R"((`level`, `uid`) > (Int32("2"), "uid_3000004") OR (`level`, `uid`) < (Int32("1"), "uid_3000002"))", R"(Int32("3") > `level`)", R"((Int32("1"), "uid_3000001", "10001") = (`level`, `uid`, `resource_id`))", @@ -1057,12 +1065,17 @@ Y_UNIT_TEST_SUITE(KqpOlap) { R"(`level` IS NOT NULL)", R"((`level`, `uid`) > (Int32("1"), NULL))", R"((`level`, `uid`) != (Int32("1"), NULL))", - //R"((`timestamp`, `level`) >= (CAST(3000001 AS Timestamp), 3))", + R"(`level` >= CAST("2" As Int32))", + R"(CAST("2" As Int32) >= `level`)", + R"(`timestamp` >= CAST(3000001u AS Timestamp))", + R"((`timestamp`, `level`) >= (CAST(3000001u AS Timestamp), 3))", }; std::vector<TString> testDataNoPush = { R"(`level` != NULL)", R"(`level` > NULL)", + R"(`timestamp` >= CAST(3000001 AS Timestamp))", + R"(`level` >= CAST("2" As Uint32))", }; auto buildQuery = [](const TString& predicate, bool pushEnabled) { @@ -1073,7 +1086,8 @@ Y_UNIT_TEST_SUITE(KqpOlap) { if (pushEnabled) { qBuilder << R"(PRAGMA Kikimr.KqpPushOlapProcess = "true";)" << Endl; } - + + qBuilder << R"(PRAGMA Kikimr.OptEnablePredicateExtract = "false";)" << Endl; qBuilder << "SELECT `timestamp` FROM `/Root/olapStore/olapTable` WHERE "; qBuilder << predicate; qBuilder << " ORDER BY `timestamp`"; @@ -1085,10 +1099,14 @@ Y_UNIT_TEST_SUITE(KqpOlap) { auto normalQuery = buildQuery(predicate, false); auto pushQuery = buildQuery(predicate, true); + Cerr << "--- Run normal query ---\n"; + Cerr << normalQuery << Endl; auto it = tableClient.StreamExecuteScanQuery(normalQuery).GetValueSync(); UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); auto goodResult = CollectStreamResult(it); + Cerr << "--- Run pushed down query ---\n"; + Cerr << pushQuery << Endl; it = tableClient.StreamExecuteScanQuery(pushQuery).GetValueSync(); UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); auto pushResult = CollectStreamResult(it); diff --git a/ydb/core/protos/ssa.proto b/ydb/core/protos/ssa.proto index c035c9f5df..eebf965f9a 100644 --- a/ydb/core/protos/ssa.proto +++ b/ydb/core/protos/ssa.proto @@ -59,6 +59,20 @@ message TProgram { FUNC_MATH_SUBTRACT = 15; FUNC_MATH_MULTIPLY = 16; FUNC_MATH_DIVIDE = 17; + FUNC_CAST_TO_BOOLEAN = 18; + FUNC_CAST_TO_INT8 = 19; + FUNC_CAST_TO_INT16 = 20; + FUNC_CAST_TO_INT32 = 21; + FUNC_CAST_TO_INT64 = 22; + FUNC_CAST_TO_UINT8 = 23; + FUNC_CAST_TO_UINT16 = 24; + FUNC_CAST_TO_UINT32 = 25; + FUNC_CAST_TO_UINT64 = 26; + FUNC_CAST_TO_FLOAT = 27; + FUNC_CAST_TO_DOUBLE = 28; + FUNC_CAST_TO_BINARY = 29; + FUNC_CAST_TO_FIXED_SIZE_BINARY = 30; + FUNC_CAST_TO_TIMESTAMP = 31; } message TFunction { diff --git a/ydb/core/tx/columnshard/columnshard__stats_scan.h b/ydb/core/tx/columnshard/columnshard__stats_scan.h index 770d683895..28605b4a98 100644 --- a/ydb/core/tx/columnshard/columnshard__stats_scan.h +++ b/ydb/core/tx/columnshard/columnshard__stats_scan.h @@ -4,6 +4,7 @@ #include "columnshard_common.h" #include <ydb/core/tablet_flat/flat_cxx_database.h> #include <ydb/core/sys_view/common/schema.h> +#include <ydb/core/formats/custom_registry.h> namespace NKikimr::NColumnShard { @@ -53,7 +54,7 @@ public: ApplyRangePredicates(batch); if (!ReadMetadata->Program.empty()) { - ApplyProgram(batch, ReadMetadata->Program); + ApplyProgram(batch, ReadMetadata->Program, NArrow::GetCustomExecContext()); } // Leave only requested columns diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp index 1e76247da8..5b66370044 100644 --- a/ydb/core/tx/columnshard/columnshard_common.cpp +++ b/ydb/core/tx/columnshard/columnshard_common.cpp @@ -121,6 +121,31 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name, case TId::FUNC_MATH_DIVIDE: Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Divide, std::move(arguments)); + case TId::FUNC_CAST_TO_INT32: + { + Y_VERIFY(args.size() == 1); // TODO: support CAST with OrDefault/OrNull logic (second argument is default value) + auto castOpts = std::make_shared<arrow::compute::CastOptions>(false); + castOpts->to_type = std::make_shared<arrow::Int32Type>(); + return TAssign(name, EOperation::CastInt32, std::move(arguments), castOpts); + } + case TId::FUNC_CAST_TO_TIMESTAMP: + { + Y_VERIFY(args.size() == 1); + auto castOpts = std::make_shared<arrow::compute::CastOptions>(false); + castOpts->to_type = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MICRO); + return TAssign(name, EOperation::CastTimestamp, std::move(arguments), castOpts); + } + case TId::FUNC_CAST_TO_INT8: + case TId::FUNC_CAST_TO_INT16: + case TId::FUNC_CAST_TO_INT64: + case TId::FUNC_CAST_TO_UINT8: + case TId::FUNC_CAST_TO_UINT16: + case TId::FUNC_CAST_TO_UINT32: + case TId::FUNC_CAST_TO_UINT64: + case TId::FUNC_CAST_TO_FLOAT: + case TId::FUNC_CAST_TO_DOUBLE: + case TId::FUNC_CAST_TO_BINARY: + case TId::FUNC_CAST_TO_FIXED_SIZE_BINARY: case TId::FUNC_UNSPECIFIED: break; } diff --git a/ydb/core/tx/columnshard/engines/indexed_read_data.cpp b/ydb/core/tx/columnshard/engines/indexed_read_data.cpp index d9c3a4f63e..da01c5979b 100644 --- a/ydb/core/tx/columnshard/engines/indexed_read_data.cpp +++ b/ydb/core/tx/columnshard/engines/indexed_read_data.cpp @@ -6,6 +6,7 @@ #include <ydb/core/tx/columnshard/columnshard__stats_scan.h> #include <ydb/core/formats/one_batch_input_stream.h> #include <ydb/core/formats/merging_sorted_input_stream.h> +#include <ydb/core/formats/custom_registry.h> namespace NKikimr::NOlap { @@ -504,7 +505,7 @@ TIndexedReadData::MakeResult(TVector<std::vector<std::shared_ptr<arrow::RecordBa if (ReadMetadata->HasProgram()) { for (auto& batch : out) { - ApplyProgram(batch.ResultBatch, ReadMetadata->Program); + ApplyProgram(batch.ResultBatch, ReadMetadata->Program, NArrow::GetCustomExecContext()); } } return out; |