aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraidarsamer <aidarsamer@yandex-team.ru>2022-05-03 22:05:20 +0300
committeraidarsamer <aidarsamer@yandex-team.ru>2022-05-03 22:05:20 +0300
commitf35dbbdae6519b9ab9cabb2250dd4d734c12035b (patch)
tree1dbb5aa57af6a224c14cf771c17a48261f513e52
parentfe43d898800daf570d5e787fd9a54732712990f6 (diff)
downloadydb-f35dbbdae6519b9ab9cabb2250dd4d734c12035b.tar.gz
KIKIMR-13107. Cast pushdown to OLAP.
KIKIMR-13107. Add cast pushdown to OLAP. Fix tuple comparisons conversion. ref:3330b9b0d89d17bf7778b6bd0d566408d08c64b2
-rw-r--r--ydb/core/driver_lib/run/run.cpp2
-rw-r--r--ydb/core/formats/CMakeLists.txt2
-rw-r--r--ydb/core/formats/custom_registry.cpp20
-rw-r--r--ydb/core/formats/custom_registry.h10
-rw-r--r--ydb/core/formats/execs.h7
-rw-r--r--ydb/core/formats/func_cast.cpp150
-rw-r--r--ydb/core/formats/func_cast.h22
-rw-r--r--ydb/core/formats/func_gcd.h17
-rw-r--r--ydb/core/formats/func_math.h3
-rw-r--r--ydb/core/formats/func_round.h16
-rw-r--r--ydb/core/formats/functions.h1
-rw-r--r--ydb/core/formats/program.cpp35
-rw-r--r--ydb/core/formats/program.h20
-rw-r--r--ydb/core/kqp/compile/kqp_olap_compiler.cpp51
-rw-r--r--ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp232
-rw-r--r--ydb/core/kqp/prepare/kqp_type_ann.cpp5
-rw-r--r--ydb/core/kqp/ut/kqp_olap_ut.cpp28
-rw-r--r--ydb/core/protos/ssa.proto14
-rw-r--r--ydb/core/tx/columnshard/columnshard__stats_scan.h3
-rw-r--r--ydb/core/tx/columnshard/columnshard_common.cpp25
-rw-r--r--ydb/core/tx/columnshard/engines/indexed_read_data.cpp3
21 files changed, 555 insertions, 111 deletions
diff --git a/ydb/core/driver_lib/run/run.cpp b/ydb/core/driver_lib/run/run.cpp
index 5e8a9f7ffe..d39be96ca5 100644
--- a/ydb/core/driver_lib/run/run.cpp
+++ b/ydb/core/driver_lib/run/run.cpp
@@ -934,7 +934,7 @@ void TKikimrRunner::InitializeAppData(const TKikimrRunConfig& runConfig)
if (runConfig.AppConfig.GetBootstrapConfig().HasEnableIntrospection())
AppData->EnableIntrospection = runConfig.AppConfig.GetBootstrapConfig().GetEnableIntrospection();
-
+
TAppDataInitializersList appDataInitializers;
// setup domain info
appDataInitializers.AddAppDataInitializer(new TDomainsInitializer(runConfig));
diff --git a/ydb/core/formats/CMakeLists.txt b/ydb/core/formats/CMakeLists.txt
index 0338d2a440..f61a7c54d1 100644
--- a/ydb/core/formats/CMakeLists.txt
+++ b/ydb/core/formats/CMakeLists.txt
@@ -18,6 +18,8 @@ target_sources(ydb-core-formats PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/formats/arrow_batch_builder.cpp
${CMAKE_SOURCE_DIR}/ydb/core/formats/arrow_helpers.cpp
${CMAKE_SOURCE_DIR}/ydb/core/formats/clickhouse_block.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/formats/custom_registry.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/formats/func_cast.cpp
${CMAKE_SOURCE_DIR}/ydb/core/formats/merging_sorted_input_stream.cpp
${CMAKE_SOURCE_DIR}/ydb/core/formats/program.cpp
)
diff --git a/ydb/core/formats/custom_registry.cpp b/ydb/core/formats/custom_registry.cpp
index 404d01d6a9..347e833f9f 100644
--- a/ydb/core/formats/custom_registry.cpp
+++ b/ydb/core/formats/custom_registry.cpp
@@ -1,7 +1,9 @@
+#include "custom_registry.h"
+
#include "functions.h"
#include "func_common.h"
#include <util/system/yassert.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
namespace cp = ::arrow::compute;
@@ -18,7 +20,10 @@ static void RegisterMath(cp::FunctionRegistry* registry) {
Y_VERIFY(registry->AddFunction(MakeMathUnary<TErfc>(TErfc::Name)).ok());
Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp>(TExp::Name)).ok());
Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp2>(TExp2::Name)).ok());
+ // Temporarily disabled because of compilation error on Windows.
+#if 0
Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp10>(TExp10::Name)).ok());
+#endif
Y_VERIFY(registry->AddFunction(MakeMathBinary<THypot>(THypot::Name)).ok());
Y_VERIFY(registry->AddFunction(MakeMathUnary<TLgamma>(TLgamma::Name)).ok());
Y_VERIFY(registry->AddFunction(MakeConstNullary<TPi>(TPi::Name)).ok());
@@ -40,24 +45,31 @@ static void RegisterArithmetic(cp::FunctionRegistry* registry) {
Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModuloOrZero>(TModuloOrZero::Name)).ok());
}
+static void RegisterYdbCast(cp::FunctionRegistry* registry) {
+ cp::internal::RegisterScalarCast(registry);
+ Y_VERIFY(registry->AddFunction(std::make_shared<YdbCastMetaFunction>()).ok());
+}
+
static std::unique_ptr<cp::FunctionRegistry> CreateCustomRegistry() {
auto registry = cp::FunctionRegistry::Make();
RegisterMath(registry.get());
RegisterRound(registry.get());
RegisterArithmetic(registry.get());
- cp::internal::RegisterScalarCast(registry.get());
+ RegisterYdbCast(registry.get());
return registry;
}
+// Creates singleton custom registry
cp::FunctionRegistry* GetCustomFunctionRegistry() {
static auto g_registry = CreateCustomRegistry();
return g_registry.get();
}
+// We want to have ExecContext per thread. All these context use one custom registry.
cp::ExecContext* GetCustomExecContext() {
- static auto context = std::make_unique<cp::ExecContext>(arrow::default_memory_pool(), NULLPTR, GetCustomFunctionRegistry());
- return context.get();
+ static thread_local cp::ExecContext context(arrow::default_memory_pool(), nullptr, GetCustomFunctionRegistry());
+ return &context;
}
}
diff --git a/ydb/core/formats/custom_registry.h b/ydb/core/formats/custom_registry.h
index 8442e44eee..77f419d33d 100644
--- a/ydb/core/formats/custom_registry.h
+++ b/ydb/core/formats/custom_registry.h
@@ -1,9 +1,11 @@
#pragma once
-#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
-namespace cp = ::arrow::compute;
+namespace arrow::compute {
+ class FunctionRegistry;
+ class ExecContext;
+}
namespace NKikimr::NArrow {
- cp::FunctionRegistry* GetCustomFunctionRegistry();
- cp::ExecContext* GetCustomExecContext();
+ arrow::compute::FunctionRegistry* GetCustomFunctionRegistry();
+ arrow::compute::ExecContext* GetCustomExecContext();
}
diff --git a/ydb/core/formats/execs.h b/ydb/core/formats/execs.h
index bf98b70927..253bae4831 100644
--- a/ydb/core/formats/execs.h
+++ b/ydb/core/formats/execs.h
@@ -5,7 +5,12 @@
#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/builder.h>
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-parameter"
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h>
+#pragma clang diagnostic pop
+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
#include <util/datetime/base.h>
@@ -17,8 +22,6 @@
#include "switch_type.h"
namespace cp = arrow::compute;
-using cp::internal::applicator::ScalarBinary;
-using cp::internal::applicator::ScalarUnary;
namespace NKikimr::NArrow {
diff --git a/ydb/core/formats/func_cast.cpp b/ydb/core/formats/func_cast.cpp
new file mode 100644
index 0000000000..7d480bdac0
--- /dev/null
+++ b/ydb/core/formats/func_cast.cpp
@@ -0,0 +1,150 @@
+#include "func_cast.h"
+
+#include <util/system/yassert.h>
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-parameter"
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/util/time.h>
+#pragma clang diagnostic pop
+
+#include <unordered_map>
+
+namespace cp = ::arrow::compute;
+
+namespace NKikimr::NArrow {
+
+namespace {
+
+std::shared_ptr<cp::CastFunction> GetYdbTimestampCast() {
+ auto func = std::make_shared<cp::CastFunction>("ydb_cast_timestamp", ::arrow::Type::TIMESTAMP);
+ cp::internal::AddSimpleCast<arrow::UInt32Type, arrow::TimestampType>(
+ cp::InputType(arrow::Type::UINT32),
+ cp::internal::kOutputTargetType,
+ func.get()
+ );
+ return func;
+}
+
+std::vector<std::shared_ptr<cp::CastFunction>> GetYdbTemporalCasts() {
+ std::vector<std::shared_ptr<cp::CastFunction>> functions;
+ functions.push_back(GetYdbTimestampCast());
+ return functions;
+}
+
+std::unordered_map<int, std::shared_ptr<cp::CastFunction>> ydbCastTable;
+std::once_flag ydbCastTableInitialized;
+
+void AddCastFunctions(const std::vector<std::shared_ptr<cp::CastFunction>>& funcs) {
+ for (const auto& func : funcs) {
+ ydbCastTable[static_cast<int>(func->out_type_id())] = func;
+ }
+}
+
+void InitYdbCastTable() {
+ AddCastFunctions(GetYdbTemporalCasts());
+}
+
+void EnsureInitYdbCastTable() {
+ std::call_once(ydbCastTableInitialized, InitYdbCastTable);
+}
+
+// Private version of GetCastFunction with better error reporting
+// if the input type is known.
+::arrow::Result<std::shared_ptr<cp::CastFunction>> GetYdbCastFunctionInternal(
+ const std::shared_ptr<::arrow::DataType>& to_type, const::arrow::DataType* from_type = nullptr) {
+ EnsureInitYdbCastTable();
+ auto it = ydbCastTable.find(static_cast<int>(to_type->id()));
+ if (it == ydbCastTable.end()) {
+ auto res = cp::GetCastFunction(to_type);
+ if (!res.ok()) {
+ if (from_type != nullptr) {
+ return ::arrow::Status::NotImplemented("Unsupported cast from ", *from_type, " to ",
+ *to_type,
+ " (no available cast function for target type)");
+ } else {
+ return ::arrow::Status::NotImplemented("Unsupported cast to ", *to_type,
+ " (no available cast function for target type)");
+ }
+ }
+ return std::move(res).ValueUnsafe();
+ }
+ return it->second;
+}
+
+} // namespace
+
+static const cp::FunctionDoc ydbCastDoc{"YDB special cast function. Uses Arrow's cast and add casting support for some types."
+ "Cast values to another data type",
+ ("Behavior when values wouldn't fit in the target type\n"
+ "can be controlled through CastOptions."),
+ {"input"},
+ "CastOptions"};
+
+YdbCastMetaFunction::YdbCastMetaFunction()
+ : ::arrow::compute::MetaFunction("ydb.cast", ::arrow::compute::Arity::Unary(), &ydbCastDoc)
+ {}
+
+::arrow::Result<const cp::CastOptions*> YdbCastMetaFunction::ValidateOptions(const cp::FunctionOptions* options) const {
+ auto cast_options = static_cast<const cp::CastOptions*>(options);
+
+ if (cast_options == nullptr || cast_options->to_type == nullptr) {
+ return ::arrow::Status::Invalid(
+ "Cast requires that options be passed with "
+ "the to_type populated");
+ }
+
+ return cast_options;
+}
+
+::arrow::Result<::arrow::Datum> YdbCastMetaFunction::ExecuteImpl(const std::vector<::arrow::Datum>& args,
+ const cp::FunctionOptions* options,
+ cp::ExecContext* ctx) const
+{
+ auto&& optsResult = ValidateOptions(options);
+ if (!optsResult.ok()) {
+ return optsResult.status();
+ }
+ auto cast_options = std::move(optsResult).ValueUnsafe();
+ if (args[0].type()->Equals(*cast_options->to_type)) {
+ return args[0];
+ }
+ auto&& castFuncResult = GetYdbCastFunctionInternal(cast_options->to_type, args[0].type().get());
+ if (!castFuncResult.ok()) {
+ return castFuncResult.status();
+ }
+ std::shared_ptr<cp::CastFunction> castFunc = std::move(castFuncResult).ValueUnsafe();
+ return castFunc->Execute(args, options, ctx);
+}
+
+} // NKikimr::NArrow
+
+namespace arrow::compute::internal {
+
+template <>
+struct CastFunctor<TimestampType, UInt32Type> {
+ static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) {
+ if (batch.num_values() == 0) {
+ return ::arrow::Status::IndexError("Cast from uint32 to timestamp received empty batch.");
+ }
+ Y_VERIFY(batch[0].kind() == Datum::ARRAY, "Cast from uint32 to timestamp expected ARRAY as input.");
+
+ const auto& out_type = checked_cast<const ::arrow::TimestampType&>(*out->type());
+ // get conversion MICROSECONDS -> unit
+ auto conversion = ::arrow::util::GetTimestampConversion(::arrow::TimeUnit::MICRO, out_type.unit());
+ Y_VERIFY(conversion.first == ::arrow::util::MULTIPLY, "Cast from uint32 to timestamp failed because timestamp unit is greater than seconds.");
+
+ auto input = batch[0].array();
+ auto output = out->mutable_array();
+ auto in_data = input->GetValues<uint32_t>(1);
+ auto out_data = output->GetMutableValues<int64_t>(1);
+
+ for (int64_t i = 0; i < input->length; i++) {
+ out_data[i] = static_cast<int64_t>(in_data[i] * conversion.second);
+ }
+ return ::arrow::Status::OK();
+ }
+};
+
+} // namespace arrow::compute::internal \ No newline at end of file
diff --git a/ydb/core/formats/func_cast.h b/ydb/core/formats/func_cast.h
new file mode 100644
index 0000000000..532f3164d7
--- /dev/null
+++ b/ydb/core/formats/func_cast.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h>
+
+#include <type_traits>
+
+namespace NKikimr::NArrow {
+
+// Metafunction for dispatching to appropriate CastFunction. This corresponds
+// to the standard SQL CAST(expr AS target_type)
+class YdbCastMetaFunction : public ::arrow::compute::MetaFunction {
+ public:
+ YdbCastMetaFunction();
+
+ ::arrow::Result<const ::arrow::compute::CastOptions*> ValidateOptions(const ::arrow::compute::FunctionOptions* options) const;
+
+ ::arrow::Result<::arrow::Datum> ExecuteImpl(const std::vector<::arrow::Datum>& args,
+ const ::arrow::compute::FunctionOptions* options,
+ ::arrow::compute::ExecContext* ctx) const override;
+};
+
+} // NKikimr::NArrow
diff --git a/ydb/core/formats/func_gcd.h b/ydb/core/formats/func_gcd.h
index 9d4b72e4b0..0f6fe4bc8d 100644
--- a/ydb/core/formats/func_gcd.h
+++ b/ydb/core/formats/func_gcd.h
@@ -1,17 +1,18 @@
#pragma once
-#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h>
+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h>
-#include "func_common.h"
-#include <cstdlib>
#include <type_traits>
+#ifdef WIN32
+#define INLINE inline
+#else
+#define INLINE __attribute__((always_inline))
+#endif
+
namespace NKikimr::NArrow {
template<typename T>
-__attribute__((always_inline)) void FastIntSwap(T& lhs, T& rhs) {
+INLINE void FastIntSwap(T& lhs, T& rhs) {
lhs ^= rhs;
rhs ^= lhs;
lhs ^= rhs;
@@ -35,3 +36,5 @@ struct TGreatestCommonDivisor {
};
}
+
+#undef INLINE
diff --git a/ydb/core/formats/func_math.h b/ydb/core/formats/func_math.h
index 74cf8694ff..13474e0018 100644
--- a/ydb/core/formats/func_math.h
+++ b/ydb/core/formats/func_math.h
@@ -95,6 +95,8 @@ struct TExp2 {
}
};
+#if 0
+// Temporarily disable function because it doesn't compile on Windows.
struct TExp10 {
static constexpr const char * Name = "exp10";
@@ -104,6 +106,7 @@ struct TExp10 {
return exp10(arg);
}
};
+#endif
struct THypot {
diff --git a/ydb/core/formats/func_round.h b/ydb/core/formats/func_round.h
index a65e893622..3b82591c45 100644
--- a/ydb/core/formats/func_round.h
+++ b/ydb/core/formats/func_round.h
@@ -9,6 +9,15 @@
#include <fenv.h>
#include <type_traits>
+#ifdef WIN32
+#include <intrin.h>
+#define CLZ __lzcnt
+#define CLZLL __lzcnt64
+#else
+#define CLZ __builtin_clz
+#define CLZLL __builtin_clzll
+#endif
+
namespace NKikimr::NArrow {
struct TRound {
@@ -42,7 +51,7 @@ struct TRoundToExp2 {
(sizeof(TRes) <= sizeof(uint32_t)), TRes>
Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) {
static_assert(std::is_same_v<TRes, TArg>, "");
- return arg <= 0 ? 0 : (TRes(1) << (31 - __builtin_clz(arg)));
+ return arg <= 0 ? 0 : (TRes(1) << (31 - CLZ(arg)));
}
template <typename TRes, typename TArg>
@@ -50,7 +59,7 @@ struct TRoundToExp2 {
(sizeof(TRes) == sizeof(uint64_t)), TRes>
Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) {
static_assert(std::is_same_v<TRes, TArg>, "");
- return arg <= 0 ? 0 : (TRes(1) << (63 - __builtin_clzll(arg)));
+ return arg <= 0 ? 0 : (TRes(1) << (63 - CLZLL(arg)));
}
template <typename TRes, typename TArg>
@@ -67,3 +76,6 @@ struct TRoundToExp2 {
};
}
+
+#undef CLZ
+#undef CLZLL
diff --git a/ydb/core/formats/functions.h b/ydb/core/formats/functions.h
index 4c0ecda6be..2f4523a4fe 100644
--- a/ydb/core/formats/functions.h
+++ b/ydb/core/formats/functions.h
@@ -1,5 +1,6 @@
#pragma once
+#include "func_cast.h"
#include "func_gcd.h"
#include "func_lcm.h"
#include "func_modulo.h"
diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp
index 3bc39be6fe..f8d8287ab5 100644
--- a/ydb/core/formats/program.cpp
+++ b/ydb/core/formats/program.cpp
@@ -20,33 +20,21 @@ namespace NKikimr::NArrow {
const char * GetFunctionName(EOperation op) {
switch (op) {
case EOperation::CastBoolean:
- return "cast_boolean";
case EOperation::CastInt8:
- return "cast_int8";
case EOperation::CastInt16:
- return "cast_int16";
case EOperation::CastInt32:
- return "cast_int32";
case EOperation::CastInt64:
- return "cast_int64";
case EOperation::CastUInt8:
- return "cast_uint8";
case EOperation::CastUInt16:
- return "cast_uint16";
case EOperation::CastUInt32:
- return "cast_uint32";
case EOperation::CastUInt64:
- return "cast_uint64";
case EOperation::CastFloat:
- return "cast_float";
case EOperation::CastDouble:
- return "cast_double";
case EOperation::CastBinary:
- return "cast_binary";
case EOperation::CastFixedSizeBinary:
- return "cast_fixed_size_binary";
case EOperation::CastString:
- return "cast_string";
+ case EOperation::CastTimestamp:
+ return "ydb.cast";
case EOperation::IsValid:
return "is_valid";
@@ -263,7 +251,7 @@ std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std::
return result->scalar();
}
-arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args,
+arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, const arrow::compute::FunctionOptions* funcOpts,
std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) {
std::vector<arrow::Datum> arguments;
arguments.reserve(args.size());
@@ -274,17 +262,20 @@ arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>&
arguments.push_back(*column);
}
std::string funcName = GetFunctionName(funcId);
+
arrow::Result<arrow::Datum> result;
if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) {
- result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx);
+ result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts, ctx);
} else {
- result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments);
+ result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts);
}
Y_VERIFY(result.ok());
return result.ValueOrDie();
}
-
+arrow::Datum CallFunctionByAssign(const TAssign& assign, std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) {
+ return CallFunctionById(assign.GetOperation(), assign.GetArguments(), assign.GetFunctionOptions(), batch, ctx);
+}
void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const {
if (Assignes.empty()) {
@@ -298,7 +289,7 @@ void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& bat
if (assign.IsConstant()) {
column = assign.GetConstant();
} else {
- column = CallFunctionById(assign.GetOperation(), assign.GetArguments(), batch, ctx);
+ column = CallFunctionByAssign(assign, batch, ctx);
}
AddColumn(batch, assign.GetName(), column);
}
@@ -313,9 +304,9 @@ void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const {
filters.reserve(Filters.size());
for (auto& colName : Filters) {
auto column = GetColumnByName(batch, colName);
- Y_VERIFY(column.ok());
- Y_VERIFY(column->is_array());
- Y_VERIFY(column->type() == arrow::boolean());
+ Y_VERIFY_S(column.ok(), TStringBuilder() << "Column " << colName << " is not ok.");
+ Y_VERIFY_S(column->is_array(), TStringBuilder() << "Column " << colName << " is not an array.");
+ Y_VERIFY_S(column->type() == arrow::boolean(), TStringBuilder() << "Column " << colName << " type is not bool.");
auto boolColumn = std::static_pointer_cast<arrow::BooleanArray>(column->make_array());
filters.push_back(std::vector<bool>(boolColumn->length()));
auto& bits = filters.back();
diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h
index ff15d30ebf..f4b7d466ab 100644
--- a/ydb/core/formats/program.h
+++ b/ydb/core/formats/program.h
@@ -23,6 +23,7 @@ enum class EOperation {
CastBinary,
CastFixedSizeBinary,
CastString,
+ CastTimestamp,
//
IsValid,
IsNull,
@@ -90,60 +91,77 @@ public:
: Name(name)
, Operation(op)
, Arguments(std::move(args))
+ , FuncOpts(nullptr)
+ {}
+
+ TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args, std::shared_ptr<arrow::compute::FunctionOptions> funcOpts)
+ : Name(name)
+ , Operation(op)
+ , Arguments(std::move(args))
+ , FuncOpts(funcOpts)
{}
explicit TAssign(const std::string& name, bool value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::BooleanScalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, i32 value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::Int32Scalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, ui32 value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::UInt32Scalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, i64 value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::Int64Scalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, ui64 value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::UInt64Scalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, float value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::FloatScalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, double value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::DoubleScalar>(value))
+ , FuncOpts(nullptr)
{}
explicit TAssign(const std::string& name, const std::string& value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(std::make_shared<arrow::StringScalar>(value))
+ , FuncOpts(nullptr)
{}
TAssign(const std::string& name, const std::shared_ptr<arrow::Scalar>& value)
: Name(name)
, Operation(EOperation::Constant)
, Constant(value)
+ , FuncOpts(nullptr)
{}
bool IsConstant() const { return Operation == EOperation::Constant; }
@@ -151,12 +169,14 @@ public:
const std::vector<std::string>& GetArguments() const { return Arguments; }
std::shared_ptr<arrow::Scalar> GetConstant() const { return Constant; }
const std::string& GetName() const { return Name; }
+ const arrow::compute::FunctionOptions* GetFunctionOptions() const { return FuncOpts.get(); }
private:
std::string Name;
EOperation Operation{EOperation::Unspecified};
std::vector<std::string> Arguments;
std::shared_ptr<arrow::Scalar> Constant;
+ std::shared_ptr<arrow::compute::FunctionOptions> FuncOpts;
};
/// Group of commands that finishes with projection. Steps add locality for columns definition.
diff --git a/ydb/core/kqp/compile/kqp_olap_compiler.cpp b/ydb/core/kqp/compile/kqp_olap_compiler.cpp
index 8f16b9345c..f880273706 100644
--- a/ydb/core/kqp/compile/kqp_olap_compiler.cpp
+++ b/ydb/core/kqp/compile/kqp_olap_compiler.cpp
@@ -77,6 +77,7 @@ private:
TProgram::TAssignment* CompileCondition(const TExprBase& condition, TKqpOlapCompileContext& ctx);
+ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx);
ui32 ConvertValueToColumn(const TCoDataCtor& value, TKqpOlapCompileContext& ctx)
{
@@ -145,6 +146,52 @@ ui32 ConvertParameterToColumn(const TCoParameter& parameter, TKqpOlapCompileCont
return ssaValue->GetColumn().GetId();
}
+ui32 ConvertSafeCastToColumn(const TCoSafeCast& cast, TKqpOlapCompileContext& ctx)
+{
+ auto columnId = GetOrCreateColumnId(cast.Value(), ctx);
+
+ TProgram::TAssignment* ssaValue = ctx.CreateAssignCmd();
+
+ auto newCast = ssaValue->MutableFunction();
+
+ auto maybeDataType = cast.Type().Maybe<TCoDataType>();
+ YQL_ENSURE(maybeDataType.IsValid());
+
+ auto dataType = maybeDataType.Cast();
+ ui32 castFunction = TProgram::TAssignment::FUNC_UNSPECIFIED;
+ if (dataType.Type().Value() == "Boolean") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_BOOLEAN;
+ } else if (dataType.Type().Value() == "Int8") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT8;
+ } else if (dataType.Type().Value() == "Int16") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT16;
+ } else if (dataType.Type().Value() == "Int32") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT32;
+ } else if (dataType.Type().Value() == "Int64") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_INT64;
+ } else if (dataType.Type().Value() == "Uint8") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT8;
+ } else if (dataType.Type().Value() == "Uint16") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT16;
+ } else if (dataType.Type().Value() == "Uint32") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT32;
+ } else if (dataType.Type().Value() == "Uint64") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_UINT64;
+ } else if (dataType.Type().Value() == "Float") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_FLOAT;
+ } else if (dataType.Type().Value() == "Double") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_DOUBLE;
+ } else if (dataType.Type().Value() == "Timestamp") {
+ castFunction = TProgram::TAssignment::FUNC_CAST_TO_TIMESTAMP;
+ } else {
+ YQL_ENSURE(false, "Unsupported data type for pushed down safe cast: " << dataType.Type().Value());
+ }
+
+ newCast->SetId(castFunction);
+ newCast->AddArguments()->SetId(columnId);
+ return ssaValue->GetColumn().GetId();
+}
+
ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx) {
if (auto maybeData = node.Maybe<TCoDataCtor>()) {
return ConvertValueToColumn(maybeData.Cast(), ctx);
@@ -158,6 +205,10 @@ ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx) {
return ConvertParameterToColumn(maybeParameter.Cast(), ctx);
}
+ if (auto maybeCast = node.Maybe<TCoSafeCast>()) {
+ return ConvertSafeCastToColumn(maybeCast.Cast(), ctx);
+ }
+
YQL_ENSURE(false, "Unknown node in OLAP comparison compiler: " << node.Ptr()->Content());
}
diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp
index fa0bb9cc80..2dded70c9d 100644
--- a/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp
+++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp
@@ -103,6 +103,19 @@ bool IsSupportedDataType(const TCoDataCtor& node) {
return false;
}
+bool IsSupportedCast(const TCoSafeCast& cast) {
+ auto maybeDataType = cast.Type().Maybe<TCoDataType>();
+ YQL_ENSURE(maybeDataType.IsValid());
+
+ auto dataType = maybeDataType.Cast();
+ if (dataType.Type().Value() == "Int32") {
+ return cast.Value().Maybe<TCoString>().IsValid();
+ } else if (dataType.Type().Value() == "Timestamp") {
+ return cast.Value().Maybe<TCoUint32>().IsValid();
+ }
+ return false;
+}
+
bool IsComparableTypes(const TExprBase& leftNode, const TExprBase& rightNode, bool equality,
const TTypeAnnotationNode* inputType)
{
@@ -228,43 +241,94 @@ bool IsComparableTypes(const TExprBase& leftNode, const TExprBase& rightNode, bo
return true;
}
-
-TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCompare& predicate,
- const TExprNode* rawLambdaArg, const TExprBase& input)
+TVector<TExprBase> ConvertComparisonNode(const TExprBase& nodeIn, const TExprNode* rawLambdaArg)
{
- TVector<std::pair<TExprBase, TExprBase>> out;
+ TVector<TExprBase> out;
auto convertNode = [rawLambdaArg](const TExprBase& node) -> TMaybeNode<TExprBase> {
if (node.Maybe<TCoNull>()) {
return node;
}
+ if (auto maybeSafeCast = node.Maybe<TCoSafeCast>()) {
+ if (!IsSupportedCast(maybeSafeCast.Cast())) {
+ return NullNode;
+ }
+
+ return node;
+ }
+
if (auto maybeParameter = node.Maybe<TCoParameter>()) {
return maybeParameter.Cast();
}
if (auto maybeData = node.Maybe<TCoDataCtor>()) {
- if (IsSupportedDataType(maybeData.Cast())) {
- return node;
+ if (!IsSupportedDataType(maybeData.Cast())) {
+ return NullNode;
}
- return NullNode;
+ return node;
}
if (auto maybeMember = node.Maybe<TCoMember>()) {
- if (maybeMember.Cast().Struct().Raw() == rawLambdaArg) {
- return maybeMember.Cast().Name();
+ if (maybeMember.Cast().Struct().Raw() != rawLambdaArg) {
+ return NullNode;
}
- return NullNode;
+ return maybeMember.Cast().Name();
}
return NullNode;
};
// Columns & values may be single element
- TMaybeNode<TExprBase> left = convertNode(predicate.Left());
- TMaybeNode<TExprBase> right = convertNode(predicate.Right());
+ TMaybeNode<TExprBase> node = convertNode(nodeIn);
+
+ if (node.IsValid()) {
+ out.emplace_back(std::move(node.Cast()));
+ return out;
+ }
+
+ // Or columns and values can be Tuple
+ if (!nodeIn.Maybe<TExprList>()) {
+ // something unusual found, return empty vector
+ return out;
+ }
+
+ auto tuple = nodeIn.Cast<TExprList>();
+
+ out.reserve(tuple.Size());
+
+ for (ui32 i = 0; i < tuple.Size(); ++i) {
+ TMaybeNode<TExprBase> node = convertNode(tuple.Item(i));
+
+ if (!node.IsValid()) {
+ // Return empty vector
+ return TVector<TExprBase>();
+ }
+
+ out.emplace_back(node.Cast());
+ }
+
+ return out;
+}
+
+TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCompare& predicate,
+ const TExprNode* rawLambdaArg, const TExprBase& input)
+{
+ TVector<std::pair<TExprBase, TExprBase>> out;
+
+ auto left = ConvertComparisonNode(predicate.Left(), rawLambdaArg);
+
+ if (left.empty()) {
+ return out;
+ }
+
+ auto right = ConvertComparisonNode(predicate.Right(), rawLambdaArg);
+
+ if (left.size() != right.size()) {
+ return out;
+ }
TMaybeNode<TCoCmpEqual> maybeEqual = predicate.Maybe<TCoCmpEqual>();
TMaybeNode<TCoCmpNotEqual> maybeNotEqual = predicate.Maybe<TCoCmpNotEqual>();
@@ -292,52 +356,19 @@ TVector<std::pair<TExprBase, TExprBase>> ExtractComparisonParameters(const TCoCo
return out;
}
- if (left.IsValid() && right.IsValid()) {
- if (!IsComparableTypes(left.Cast(), right.Cast(), equality, inputType)) {
- return out;
- }
-
- out.emplace_back(std::move(std::make_pair(left.Cast(), right.Cast())));
- return out;
- }
-
- // Or columns and values can be Tuple
- if (!predicate.Left().Maybe<TExprList>() || !predicate.Right().Maybe<TExprList>()) {
- // something unusual found, return empty vector
- return out;
- }
-
- auto tupleLeft = predicate.Left().Cast<TExprList>();
- auto tupleRight = predicate.Right().Cast<TExprList>();
-
- if (tupleLeft.Size() != tupleRight.Size()) {
- return out;
- }
-
- out.reserve(tupleLeft.Size());
-
- for (ui32 i = 0; i < tupleLeft.Size(); ++i) {
- TMaybeNode<TExprBase> left = convertNode(tupleLeft.Item(i));
- TMaybeNode<TExprBase> right = convertNode(tupleRight.Item(i));
-
- if (!left.IsValid() || !right.IsValid()) {
- // Return empty vector
- return TVector<std::pair<TExprBase, TExprBase>>();
- }
-
- if (!IsComparableTypes(left.Cast(), right.Cast(), equality, inputType)) {
+ for (ui32 i = 0; i < left.size(); ++i) {
+ if (!IsComparableTypes(left[i], right[i], equality, inputType)) {
// Return empty vector
return TVector<std::pair<TExprBase, TExprBase>>();
}
-
- out.emplace_back(std::move(std::make_pair(left.Cast(), right.Cast())));
+ out.emplace_back(std::move(std::make_pair(left[i], right[i])));
}
return out;
}
TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& parameter, const TCoCompare& predicate,
- TExprContext& ctx, TPositionHandle pos, const TExprBase& input)
+ TExprContext& ctx, TPositionHandle pos, const TExprBase& input, bool forceStrictComparison)
{
auto isNull = [](const TExprBase& node) {
if (node.Maybe<TCoNull>()) {
@@ -368,7 +399,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param
.Done();
}
- if (predicate.Maybe<TCoCmpLess>()) {
+ if (predicate.Maybe<TCoCmpLess>() || (predicate.Maybe<TCoCmpLessOrEqual>() && forceStrictComparison)) {
return Build<TKqpOlapFilterLess>(ctx, pos)
.Input(input)
.Left(parameter.first)
@@ -376,7 +407,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param
.Done();
}
- if (predicate.Maybe<TCoCmpLessOrEqual>()) {
+ if (predicate.Maybe<TCoCmpLessOrEqual>() && !forceStrictComparison) {
return Build<TKqpOlapFilterLessOrEqual>(ctx, pos)
.Input(input)
.Left(parameter.first)
@@ -384,7 +415,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param
.Done();
}
- if (predicate.Maybe<TCoCmpGreater>()) {
+ if (predicate.Maybe<TCoCmpGreater>() || (predicate.Maybe<TCoCmpGreaterOrEqual>() && forceStrictComparison)) {
return Build<TKqpOlapFilterGreater>(ctx, pos)
.Input(input)
.Left(parameter.first)
@@ -392,7 +423,7 @@ TExprBase BuildOneElementComparison(const std::pair<TExprBase, TExprBase>& param
.Done();
}
- if (predicate.Maybe<TCoCmpGreaterOrEqual>()) {
+ if (predicate.Maybe<TCoCmpGreaterOrEqual>() && !forceStrictComparison) {
return Build<TKqpOlapFilterGreaterOrEqual>(ctx, pos)
.Input(input)
.Left(parameter.first)
@@ -417,7 +448,7 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par
ui32 conditionsCount = parameters.size();
if (conditionsCount == 1) {
- return BuildOneElementComparison(parameters[0], predicate, ctx, pos, input);
+ return BuildOneElementComparison(parameters[0], predicate, ctx, pos, input, false);
}
if (predicate.Maybe<TCoCmpEqual>() || predicate.Maybe<TCoCmpNotEqual>()) {
@@ -425,7 +456,7 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par
conditions.reserve(conditionsCount);
for (ui32 i = 0; i < conditionsCount; ++i) {
- conditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input));
+ conditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input, false));
}
if (predicate.Maybe<TCoCmpEqual>()) {
@@ -447,7 +478,9 @@ TExprBase ComparisonPushdown(const TVector<std::pair<TExprBase, TExprBase>>& par
TVector<TExprBase> andConditions;
andConditions.reserve(conditionsCount);
- andConditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input));
+ // We need strict < and > in beginning columns except the last one
+ // For example: (c1, c2, c3) >= (1, 2, 3) ==> (c1 > 1) OR (c2 > 2 AND c1 = 1) OR (c3 >= 3 AND c2 = 2 AND c1 = 1)
+ andConditions.emplace_back(BuildOneElementComparison(parameters[i], predicate, ctx, pos, input, i < conditionsCount - 1));
for (ui32 j = 0; j < i; ++j) {
andConditions.emplace_back(Build<TKqpOlapFilterEqual>(ctx, pos)
@@ -516,10 +549,46 @@ TMaybeNode<TExprBase> ExistsPushdown(const TCoExists& exists, TExprContext& ctx,
.Done();
}
-TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext& ctx, TPositionHandle pos,
- const TExprNode* lambdaArg, const TExprBase& input)
+TMaybeNode<TExprBase> SafeCastPredicatePushdown(const TCoFlatMap& flatmap,
+ TExprContext& ctx, TPositionHandle pos, const TExprNode* lambdaArg, const TExprBase& input)
{
- auto maybePredicate = coalesce.Predicate().Maybe<TCoCompare>();
+ /*
+ * There are three ways of comparison in following format:
+ *
+ * FlatMap (LeftArgument, FlatMap(RightArgument(), Just(Predicate))
+ *
+ * Examples:
+ * FlatMap (SafeCast(), FlatMap(Member(), Just(Comparison))
+ * FlatMap (Member(), FlatMap(SafeCast(), Just(Comparison))
+ * FlatMap (SafeCast(), FlatMap(SafeCast(), Just(Comparison))
+ */
+ TVector<std::pair<TExprBase, TExprBase>> out;
+
+ auto maybeFlatmap = flatmap.Lambda().Body().Maybe<TCoFlatMap>();
+
+ if (!maybeFlatmap.IsValid()) {
+ return NullNode;
+ }
+
+ auto right = ConvertComparisonNode(maybeFlatmap.Cast().Input(), lambdaArg);
+
+ if (right.empty()) {
+ return NullNode;
+ }
+
+ auto left = ConvertComparisonNode(flatmap.Input(), lambdaArg);
+
+ if (left.empty()) {
+ return NullNode;
+ }
+
+ auto maybeJust = maybeFlatmap.Cast().Lambda().Body().Maybe<TCoJust>();
+
+ if (!maybeJust.IsValid()) {
+ return NullNode;
+ }
+
+ auto maybePredicate = maybeJust.Cast().Input().Maybe<TCoCompare>();
if (!maybePredicate.IsValid()) {
return NullNode;
@@ -531,14 +600,26 @@ TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext
return NullNode;
}
- if (!coalesce.Value().Maybe<TCoBool>()) {
+ TVector<std::pair<TExprBase, TExprBase>> parameters;
+
+ if (left.size() != right.size()) {
return NullNode;
}
- if (coalesce.Value().Cast<TCoBool>().Literal().Value() != "false") {
- return NullNode;
+ for (ui32 i = 0; i < left.size(); ++i) {
+ out.emplace_back(std::move(std::make_pair(left[i], right[i])));
}
+ return ComparisonPushdown(parameters, predicate, ctx, pos, input);
+}
+
+TMaybeNode<TExprBase> SimplePredicatePushdown(const TCoCompare& predicate, TExprContext& ctx, TPositionHandle pos,
+ const TExprNode* lambdaArg, const TExprBase& input)
+{
+ if (!IsSupportedPredicate(predicate)) {
+ return NullNode;
+ }
+
auto parameters = ExtractComparisonParameters(predicate, lambdaArg, input);
if (parameters.empty()) {
@@ -548,6 +629,33 @@ TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext
return ComparisonPushdown(parameters, predicate, ctx, pos, input);
}
+
+TMaybeNode<TExprBase> CoalescePushdown(const TCoCoalesce& coalesce, TExprContext& ctx, TPositionHandle pos,
+ const TExprNode* lambdaArg, const TExprBase& input)
+{
+ if (!coalesce.Value().Maybe<TCoBool>()) {
+ return NullNode;
+ }
+
+ if (coalesce.Value().Cast<TCoBool>().Literal().Value() != "false") {
+ return NullNode;
+ }
+
+ auto maybeFlatmap = coalesce.Predicate().Maybe<TCoFlatMap>();
+
+ if (maybeFlatmap.IsValid()) {
+ return SafeCastPredicatePushdown(maybeFlatmap.Cast(), ctx, pos, lambdaArg, input);
+ }
+
+ auto maybePredicate = coalesce.Predicate().Maybe<TCoCompare>();
+
+ if (maybePredicate.IsValid()) {
+ return SimplePredicatePushdown(maybePredicate.Cast(), ctx, pos, lambdaArg, input);
+ }
+
+ return NullNode;
+}
+
TMaybeNode<TExprBase> PredicatePushdown(const TExprBase& predicate, TExprContext& ctx, TPositionHandle pos,
const TExprNode* lambdaArg, const TExprBase& input)
{
diff --git a/ydb/core/kqp/prepare/kqp_type_ann.cpp b/ydb/core/kqp/prepare/kqp_type_ann.cpp
index ee51c7349c..f111792122 100644
--- a/ydb/core/kqp/prepare/kqp_type_ann.cpp
+++ b/ydb/core/kqp/prepare/kqp_type_ann.cpp
@@ -769,6 +769,11 @@ TStatus AnnotateOlapFilterCompare(const TExprNode::TPtr& node, TExprContext& ctx
return true;
}
+ // SafeCast, the checks about validity should be placed in kqp_opt_phy_olap_filter.cpp
+ if (TCoSafeCast::Match(node)) {
+ return true;
+ }
+
ctx.AddError(TIssue(
ctx.GetPosition(node->Pos()),
TStringBuilder()
diff --git a/ydb/core/kqp/ut/kqp_olap_ut.cpp b/ydb/core/kqp/ut/kqp_olap_ut.cpp
index 0b4b65aac4..1c47fd7550 100644
--- a/ydb/core/kqp/ut/kqp_olap_ut.cpp
+++ b/ydb/core/kqp/ut/kqp_olap_ut.cpp
@@ -1018,9 +1018,11 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
CreateTestOlapTable(kikimr);
WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 5);
+ EnableDebugLogging(kikimr);
auto tableClient = kikimr.GetTableClient();
+ // TODO: Add support for DqPhyPrecompute push-down: Cast((2+2) as Uint64)
std::vector<TString> testData = {
R"(`resource_id` = `uid`)",
R"(`resource_id` = "10001")",
@@ -1037,7 +1039,14 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
R"((`level`, `uid`, `resource_id`) > (Int32("1"), "uid_3000001", "10001"))",
R"((`level`, `uid`, `resource_id`) > (Int32("1"), "uid_3000000", "10001"))",
R"((`level`, `uid`, `resource_id`) < (Int32("1"), "uid_3000002", "10001"))",
- R"((`level`, `uid`, `resource_id`) >= (Int32("2"), "uid_3000000", "10001"))",
+ R"((`level`, `uid`, `resource_id`) >= (Int32("2"), "uid_3000001", "10001"))",
+ R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000002", "10001"))",
+ R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000001", "10002"))",
+ R"((`level`, `uid`, `resource_id`) >= (Int32("1"), "uid_3000001", "10001"))",
+ R"((`level`, `uid`, `resource_id`) <= (Int32("2"), "uid_3000001", "10001"))",
+ R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000002", "10001"))",
+ R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000001", "10002"))",
+ R"((`level`, `uid`, `resource_id`) <= (Int32("1"), "uid_3000001", "10001"))",
R"((`level`, `uid`, `resource_id`) != (Int32("1"), "uid_3000001", "10001"))",
R"((`level`, `uid`, `resource_id`) != (Int32("0"), "uid_3000001", "10011"))",
R"(`level` = 0 OR `level` = 2 OR `level` = 1)",
@@ -1045,8 +1054,7 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
R"(`level` = 0 OR `uid` = "uid_3000003")",
R"(`level` = 0 AND `uid` = "uid_3000003")",
R"(`level` = 0 AND `uid` = "uid_3000000")",
- // Timestamp will be removed by predicate extraction now.
- R"(`timestamp` >= CAST(3000001 AS Timestamp) AND `level` > 3)",
+ R"(`timestamp` >= CAST(3000001u AS Timestamp) AND `level` > 3)",
R"((`level`, `uid`) > (Int32("2"), "uid_3000004") OR (`level`, `uid`) < (Int32("1"), "uid_3000002"))",
R"(Int32("3") > `level`)",
R"((Int32("1"), "uid_3000001", "10001") = (`level`, `uid`, `resource_id`))",
@@ -1057,12 +1065,17 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
R"(`level` IS NOT NULL)",
R"((`level`, `uid`) > (Int32("1"), NULL))",
R"((`level`, `uid`) != (Int32("1"), NULL))",
- //R"((`timestamp`, `level`) >= (CAST(3000001 AS Timestamp), 3))",
+ R"(`level` >= CAST("2" As Int32))",
+ R"(CAST("2" As Int32) >= `level`)",
+ R"(`timestamp` >= CAST(3000001u AS Timestamp))",
+ R"((`timestamp`, `level`) >= (CAST(3000001u AS Timestamp), 3))",
};
std::vector<TString> testDataNoPush = {
R"(`level` != NULL)",
R"(`level` > NULL)",
+ R"(`timestamp` >= CAST(3000001 AS Timestamp))",
+ R"(`level` >= CAST("2" As Uint32))",
};
auto buildQuery = [](const TString& predicate, bool pushEnabled) {
@@ -1073,7 +1086,8 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
if (pushEnabled) {
qBuilder << R"(PRAGMA Kikimr.KqpPushOlapProcess = "true";)" << Endl;
}
-
+
+ qBuilder << R"(PRAGMA Kikimr.OptEnablePredicateExtract = "false";)" << Endl;
qBuilder << "SELECT `timestamp` FROM `/Root/olapStore/olapTable` WHERE ";
qBuilder << predicate;
qBuilder << " ORDER BY `timestamp`";
@@ -1085,10 +1099,14 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
auto normalQuery = buildQuery(predicate, false);
auto pushQuery = buildQuery(predicate, true);
+ Cerr << "--- Run normal query ---\n";
+ Cerr << normalQuery << Endl;
auto it = tableClient.StreamExecuteScanQuery(normalQuery).GetValueSync();
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
auto goodResult = CollectStreamResult(it);
+ Cerr << "--- Run pushed down query ---\n";
+ Cerr << pushQuery << Endl;
it = tableClient.StreamExecuteScanQuery(pushQuery).GetValueSync();
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
auto pushResult = CollectStreamResult(it);
diff --git a/ydb/core/protos/ssa.proto b/ydb/core/protos/ssa.proto
index c035c9f5df..eebf965f9a 100644
--- a/ydb/core/protos/ssa.proto
+++ b/ydb/core/protos/ssa.proto
@@ -59,6 +59,20 @@ message TProgram {
FUNC_MATH_SUBTRACT = 15;
FUNC_MATH_MULTIPLY = 16;
FUNC_MATH_DIVIDE = 17;
+ FUNC_CAST_TO_BOOLEAN = 18;
+ FUNC_CAST_TO_INT8 = 19;
+ FUNC_CAST_TO_INT16 = 20;
+ FUNC_CAST_TO_INT32 = 21;
+ FUNC_CAST_TO_INT64 = 22;
+ FUNC_CAST_TO_UINT8 = 23;
+ FUNC_CAST_TO_UINT16 = 24;
+ FUNC_CAST_TO_UINT32 = 25;
+ FUNC_CAST_TO_UINT64 = 26;
+ FUNC_CAST_TO_FLOAT = 27;
+ FUNC_CAST_TO_DOUBLE = 28;
+ FUNC_CAST_TO_BINARY = 29;
+ FUNC_CAST_TO_FIXED_SIZE_BINARY = 30;
+ FUNC_CAST_TO_TIMESTAMP = 31;
}
message TFunction {
diff --git a/ydb/core/tx/columnshard/columnshard__stats_scan.h b/ydb/core/tx/columnshard/columnshard__stats_scan.h
index 770d683895..28605b4a98 100644
--- a/ydb/core/tx/columnshard/columnshard__stats_scan.h
+++ b/ydb/core/tx/columnshard/columnshard__stats_scan.h
@@ -4,6 +4,7 @@
#include "columnshard_common.h"
#include <ydb/core/tablet_flat/flat_cxx_database.h>
#include <ydb/core/sys_view/common/schema.h>
+#include <ydb/core/formats/custom_registry.h>
namespace NKikimr::NColumnShard {
@@ -53,7 +54,7 @@ public:
ApplyRangePredicates(batch);
if (!ReadMetadata->Program.empty()) {
- ApplyProgram(batch, ReadMetadata->Program);
+ ApplyProgram(batch, ReadMetadata->Program, NArrow::GetCustomExecContext());
}
// Leave only requested columns
diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp
index 1e76247da8..5b66370044 100644
--- a/ydb/core/tx/columnshard/columnshard_common.cpp
+++ b/ydb/core/tx/columnshard/columnshard_common.cpp
@@ -121,6 +121,31 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name,
case TId::FUNC_MATH_DIVIDE:
Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Divide, std::move(arguments));
+ case TId::FUNC_CAST_TO_INT32:
+ {
+ Y_VERIFY(args.size() == 1); // TODO: support CAST with OrDefault/OrNull logic (second argument is default value)
+ auto castOpts = std::make_shared<arrow::compute::CastOptions>(false);
+ castOpts->to_type = std::make_shared<arrow::Int32Type>();
+ return TAssign(name, EOperation::CastInt32, std::move(arguments), castOpts);
+ }
+ case TId::FUNC_CAST_TO_TIMESTAMP:
+ {
+ Y_VERIFY(args.size() == 1);
+ auto castOpts = std::make_shared<arrow::compute::CastOptions>(false);
+ castOpts->to_type = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MICRO);
+ return TAssign(name, EOperation::CastTimestamp, std::move(arguments), castOpts);
+ }
+ case TId::FUNC_CAST_TO_INT8:
+ case TId::FUNC_CAST_TO_INT16:
+ case TId::FUNC_CAST_TO_INT64:
+ case TId::FUNC_CAST_TO_UINT8:
+ case TId::FUNC_CAST_TO_UINT16:
+ case TId::FUNC_CAST_TO_UINT32:
+ case TId::FUNC_CAST_TO_UINT64:
+ case TId::FUNC_CAST_TO_FLOAT:
+ case TId::FUNC_CAST_TO_DOUBLE:
+ case TId::FUNC_CAST_TO_BINARY:
+ case TId::FUNC_CAST_TO_FIXED_SIZE_BINARY:
case TId::FUNC_UNSPECIFIED:
break;
}
diff --git a/ydb/core/tx/columnshard/engines/indexed_read_data.cpp b/ydb/core/tx/columnshard/engines/indexed_read_data.cpp
index d9c3a4f63e..da01c5979b 100644
--- a/ydb/core/tx/columnshard/engines/indexed_read_data.cpp
+++ b/ydb/core/tx/columnshard/engines/indexed_read_data.cpp
@@ -6,6 +6,7 @@
#include <ydb/core/tx/columnshard/columnshard__stats_scan.h>
#include <ydb/core/formats/one_batch_input_stream.h>
#include <ydb/core/formats/merging_sorted_input_stream.h>
+#include <ydb/core/formats/custom_registry.h>
namespace NKikimr::NOlap {
@@ -504,7 +505,7 @@ TIndexedReadData::MakeResult(TVector<std::vector<std::shared_ptr<arrow::RecordBa
if (ReadMetadata->HasProgram()) {
for (auto& batch : out) {
- ApplyProgram(batch.ResultBatch, ReadMetadata->Program);
+ ApplyProgram(batch.ResultBatch, ReadMetadata->Program, NArrow::GetCustomExecContext());
}
}
return out;