aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authornsofya <nsofya@yandex-team.com>2023-07-04 13:11:46 +0300
committernsofya <nsofya@yandex-team.com>2023-07-04 13:11:46 +0300
commit296bc24152bf6bf84380189da1de1f92d441dbac (patch)
treecc9c57e5649b271efdcb56e259db378859e3f6c0
parent6d0d6465b0cc13fa19e019b9f717be4d1bcb6b38 (diff)
downloadydb-296bc24152bf6bf84380189da1de1f92d441dbac.tar.gz
Cast ui8 to bool
Нужно для работы с yql-ядрами, которые возвращают ui8 вместо bool
-rw-r--r--ydb/core/formats/arrow/program.cpp4
-rw-r--r--ydb/core/tx/columnshard/engines/ut_program.cpp138
-rw-r--r--ydb/core/tx/program/program.cpp5
3 files changed, 120 insertions, 27 deletions
diff --git a/ydb/core/formats/arrow/program.cpp b/ydb/core/formats/arrow/program.cpp
index 487ba5e03e..e8db218f8f 100644
--- a/ydb/core/formats/arrow/program.cpp
+++ b/ydb/core/formats/arrow/program.cpp
@@ -57,7 +57,7 @@ public:
auto funcNames = GetRegistryFunctionNames(assign.GetOperation());
arrow::Result<arrow::Datum> result = arrow::Status::UnknownError<std::string>("unknown function");
- for (const auto& funcName : funcNames) {
+ for (const auto& funcName : funcNames) {
if (TBase::Ctx && TBase::Ctx->func_registry()->GetFunction(funcName).ok()) {
result = arrow::compute::CallFunction(funcName, *arguments, assign.GetOptions(), TBase::Ctx);
} else {
@@ -131,7 +131,7 @@ template <class TAssignObject>
class TKernelFunction : public IStepFunction<TAssignObject> {
using TBase = IStepFunction<TAssignObject>;
const TFunctionPtr Function;
-
+
public:
TKernelFunction(const TFunctionPtr kernelsFunction, arrow::compute::ExecContext* ctx)
: TBase(ctx)
diff --git a/ydb/core/tx/columnshard/engines/ut_program.cpp b/ydb/core/tx/columnshard/engines/ut_program.cpp
index 0897a2f147..230de9d1bb 100644
--- a/ydb/core/tx/columnshard/engines/ut_program.cpp
+++ b/ydb/core/tx/columnshard/engines/ut_program.cpp
@@ -56,15 +56,16 @@ Y_UNIT_TEST_SUITE(TestProgram) {
return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::Add, blockInt32Type, blockInt32Type, blockInt32Type);
}
case NYql::TKernelRequestBuilder::EBinaryOp::StartsWith:
+ case NYql::TKernelRequestBuilder::EBinaryOp::EndsWith:
{
NYql::TExprContext ctx;
auto blockStringType = ctx.template MakeType<NYql::TBlockExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::Utf8));
auto blockBoolType = ctx.template MakeType<NYql::TBlockExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::Bool));
if (scalar) {
auto scalarStringType = ctx.template MakeType<NYql::TScalarExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::String));
- return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, blockStringType, scalarStringType, blockBoolType);
+ return ReqBuilder->AddBinaryOp(operation, blockStringType, scalarStringType, blockBoolType);
} else {
- return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, blockStringType, blockStringType, blockBoolType);
+ return ReqBuilder->AddBinaryOp(operation, blockStringType, blockStringType, blockBoolType);
}
}
case NYql::TKernelRequestBuilder::EBinaryOp::StringContains:
@@ -76,9 +77,9 @@ Y_UNIT_TEST_SUITE(TestProgram) {
}
default:
Y_FAIL("Not implemented");
-
+
}
- }
+ }
ui32 AddJsonExists(bool isBinaryType = true) {
NYql::TExprContext ctx;
@@ -108,7 +109,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
TString Serialize() {
return ReqBuilder->Serialize();
- }
+ }
};
TString SerializeProgram(const NKikimrSSA::TProgram& programProto) {
@@ -122,7 +123,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
Y_PROTOBUF_SUPPRESS_NODISCARD olapProgramProto.SerializeToString(&programSerialized);
return programSerialized;
}
-
+
Y_UNIT_TEST(YqlKernel) {
TIndexInfo indexInfo = BuildTableInfo(testColumns, testKey);
TIndexColumnResolver columnResolver(indexInfo);
@@ -148,7 +149,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::Add);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
@@ -158,7 +159,8 @@ Y_UNIT_TEST_SUITE(TestProgram) {
updates.AddRow().Add<int32_t>(100).Add<int32_t>(0);
auto batch = updates.BuildArrow();
- UNIT_ASSERT(program.ApplyProgram(batch).ok());
+ auto res = program.ApplyProgram(batch);
+ UNIT_ASSERT_C(res.ok(), res.ToString());
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Int32)) }));
result.AddRow().Add<int32_t>(2);
@@ -200,7 +202,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, true);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
@@ -211,7 +213,8 @@ Y_UNIT_TEST_SUITE(TestProgram) {
auto batch = updates.BuildArrow();
Cerr << batch->ToString() << Endl;
- UNIT_ASSERT(program.ApplyProgram(batch).ok());
+ auto res = program.ApplyProgram(batch);
+ UNIT_ASSERT_C(res.ok(), res.ToString());
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) }));
result.AddRow().Add<ui8>(1);
@@ -248,7 +251,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
@@ -258,7 +261,8 @@ Y_UNIT_TEST_SUITE(TestProgram) {
updates.AddRow().Add<std::string>("Lorem ipsum dolor sit amet.").Add<std::string>("amet.");
auto batch = updates.BuildArrow();
- UNIT_ASSERT(program.ApplyProgram(batch).ok());
+ auto res = program.ApplyProgram(batch);
+ UNIT_ASSERT_C(res.ok(), res.ToString());
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) }));
result.AddRow().Add<ui8>(1);
@@ -274,12 +278,11 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StringContains);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
-
TTableUpdatesBuilder updates(NArrow::MakeArrowSchema({{"string", TTypeInfo(NTypeIds::Bytes) }, {"substring", TTypeInfo(NTypeIds::Bytes) }}));
updates.AddRow().Add<std::string>("Lorem ipsum \xC0 dolor\f sit amet.").Add<std::string>("dolor");
updates.AddRow().Add<std::string>("Lorem ipsum dolor sit \amet.").Add<std::string>("amet.");
@@ -333,7 +336,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.AddJsonExists(isBinaryType);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
@@ -364,6 +367,92 @@ Y_UNIT_TEST_SUITE(TestProgram) {
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
}
+ Y_UNIT_TEST(StartsAndEnds) {
+ TIndexInfo indexInfo = BuildTableInfo(testColumns, testKey);
+ TIndexColumnResolver columnResolver(indexInfo);
+
+ NKikimrSSA::TProgram programProto;
+ {
+ auto* command = programProto.AddCommand();
+ auto* functionProto = command->MutableAssign()->MutableFunction();
+ functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_YQL_KERNEL);
+ functionProto->SetKernelIdx(0);
+ functionProto->AddArguments()->SetName("string");
+ functionProto->AddArguments()->SetName("substring");
+ functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_STR_STARTS_WITH);
+ command->MutableAssign()->MutableColumn()->SetName("start_with");
+ }
+ {
+ auto* command = programProto.AddCommand();
+ auto* functionProto = command->MutableAssign()->MutableFunction();
+ functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_YQL_KERNEL);
+ functionProto->SetKernelIdx(1);
+ functionProto->AddArguments()->SetName("string");
+ functionProto->AddArguments()->SetName("substring");
+ functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_STR_ENDS_WITH);
+ command->MutableAssign()->MutableColumn()->SetName("ends_with");
+ }
+ {
+ auto* command = programProto.AddCommand();
+ auto* functionProto = command->MutableAssign()->MutableFunction();
+ functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW);
+ functionProto->AddArguments()->SetName("start_with");
+ functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_CAST_TO_BOOLEAN);
+ command->MutableAssign()->MutableColumn()->SetName("start_with_bool");
+ }
+ {
+ auto* command = programProto.AddCommand();
+ auto* functionProto = command->MutableAssign()->MutableFunction();
+ functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW);
+ functionProto->AddArguments()->SetName("ends_with");
+ functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_CAST_TO_BOOLEAN);
+ command->MutableAssign()->MutableColumn()->SetName("ends_with_bool");
+ }
+ {
+ auto* command = programProto.AddCommand();
+ auto* functionProto = command->MutableAssign()->MutableFunction();
+ functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW);
+ functionProto->AddArguments()->SetName("start_with_bool");
+ functionProto->AddArguments()->SetName("ends_with_bool");
+ functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_BINARY_AND);
+ command->MutableAssign()->MutableColumn()->SetName("result");
+ }
+ {
+ auto* command = programProto.AddCommand();
+ auto* prjectionProto = command->MutableProjection();
+ auto* column = prjectionProto->AddColumns();
+ column->SetName("result");
+ }
+
+ {
+ TKernelsWrapper kernels;
+ kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith);
+ kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::EndsWith);
+ programProto.SetKernels(kernels.Serialize());
+ const auto programSerialized = SerializeProgram(programProto);
+
+ TProgramContainer program;
+ TString errors;
+ UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
+
+ TTableUpdatesBuilder updates(NArrow::MakeArrowSchema({{"string", TTypeInfo(NTypeIds::Utf8) }, {"substring", TTypeInfo(NTypeIds::Utf8) }}));
+ updates.AddRow().Add<std::string>("Lorem ipsum dolor sit Lorem").Add<std::string>("Lorem");
+ updates.AddRow().Add<std::string>("Lorem ipsum dolor sit amet.").Add<std::string>("amet.");
+
+ auto batch = updates.BuildArrow();
+ auto res = program.ApplyProgram(batch);
+ UNIT_ASSERT_C(res.ok(), res.ToString());
+
+ TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("result", TTypeInfo(NTypeIds::Bool)) }));
+ result.AddRow().Add<bool>(true);
+ result.AddRow().Add<bool>(false);
+
+ auto expected = result.BuildArrow();
+ UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
+ }
+
+ }
+
Y_UNIT_TEST(JsonExists) {
JsonExistsImpl(false);
}
@@ -403,7 +492,7 @@ Y_UNIT_TEST_SUITE(TestProgram) {
kernels.AddJsonValue(isBinaryType, resultType);
programProto.SetKernels(kernels.Serialize());
const auto programSerialized = SerializeProgram(programProto);
-
+
TProgramContainer program;
TString errors;
UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors);
@@ -454,43 +543,43 @@ Y_UNIT_TEST_SUITE(TestProgram) {
Cerr << "Check output for " << resultType << Endl;
if (resultType == NYql::EDataSlot::Utf8) {
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Utf8)) }));
-
+
result.AddRow().Add<std::string>("value");
result.AddRow().Add<std::string>("10");
result.AddRow().Add<std::string>("0.1");
result.AddRow().Add<std::string>("false");
result.AddRow().AddNull();
result.AddRow().AddNull();
-
+
auto expected = result.BuildArrow();
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
} else if (resultType == NYql::EDataSlot::Bool) {
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) }));
-
+
result.AddRow().AddNull();
result.AddRow().AddNull();
result.AddRow().AddNull();
result.AddRow().Add<ui8>(0);
result.AddRow().AddNull();
result.AddRow().AddNull();
-
+
auto expected = result.BuildArrow();
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
} else if (resultType == NYql::EDataSlot::Int64 || resultType == NYql::EDataSlot::Uint64) {
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Int64)) }));
-
+
result.AddRow().AddNull();
result.AddRow().Add<i64>(10);
result.AddRow().AddNull();
result.AddRow().AddNull();
result.AddRow().AddNull();
result.AddRow().AddNull();
-
+
auto expected = result.BuildArrow();
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
} else if (resultType == NYql::EDataSlot::Double || resultType == NYql::EDataSlot::Float) {
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Double)) }));
-
+
result.AddRow().AddNull();
result.AddRow().Add<double>(10);
result.AddRow().Add<double>(0.1);
@@ -553,7 +642,8 @@ Y_UNIT_TEST_SUITE(TestProgram) {
updates.AddRow().Add("");
auto batch = updates.BuildArrow();
- UNIT_ASSERT(program.ApplyProgram(batch).ok());
+ auto res = program.ApplyProgram(batch);
+ UNIT_ASSERT_C(res.ok(), res.ToString());
TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint64)) }));
result.AddRow().Add<uint64_t>(3);
diff --git a/ydb/core/tx/program/program.cpp b/ydb/core/tx/program/program.cpp
index 99d75de4db..a2f5c6d505 100644
--- a/ydb/core/tx/program/program.cpp
+++ b/ydb/core/tx/program/program.cpp
@@ -183,6 +183,9 @@ TAssign TProgramBuilder::MakeFunction(const std::string& name,
case TId::FUNC_CAST_TO_INT8:
return TAssign(name, EOperation::CastInt8, std::move(arguments),
mkCastOptions(std::make_shared<arrow::Int8Type>()));
+ case TId::FUNC_CAST_TO_BOOLEAN:
+ return TAssign(name, EOperation::CastBoolean, std::move(arguments),
+ mkCastOptions(std::make_shared<arrow::BooleanType>()));
case TId::FUNC_CAST_TO_INT16:
return TAssign(name, EOperation::CastInt16, std::move(arguments),
mkCastOptions(std::make_shared<arrow::Int16Type>()));
@@ -471,7 +474,7 @@ bool TProgramContainer::Init(const IColumnResolver& columnResolver, NKikimrSchem
::google::protobuf::TextFormat::PrintToString(programProto, &out);
AFL_DEBUG(NKikimrServices::TX_COLUMNSHARD)("program", out);
}
-
+
if (olapProgramProto.HasParameters()) {
Y_VERIFY(olapProgramProto.HasParametersSchema(), "Parameters are present, but there is no schema.");