diff options
author | nsofya <nsofya@yandex-team.com> | 2023-07-04 13:11:46 +0300 |
---|---|---|
committer | nsofya <nsofya@yandex-team.com> | 2023-07-04 13:11:46 +0300 |
commit | 296bc24152bf6bf84380189da1de1f92d441dbac (patch) | |
tree | cc9c57e5649b271efdcb56e259db378859e3f6c0 | |
parent | 6d0d6465b0cc13fa19e019b9f717be4d1bcb6b38 (diff) | |
download | ydb-296bc24152bf6bf84380189da1de1f92d441dbac.tar.gz |
Cast ui8 to bool
Нужно для работы с yql-ядрами, которые возвращают ui8 вместо bool
-rw-r--r-- | ydb/core/formats/arrow/program.cpp | 4 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/engines/ut_program.cpp | 138 | ||||
-rw-r--r-- | ydb/core/tx/program/program.cpp | 5 |
3 files changed, 120 insertions, 27 deletions
diff --git a/ydb/core/formats/arrow/program.cpp b/ydb/core/formats/arrow/program.cpp index 487ba5e03e..e8db218f8f 100644 --- a/ydb/core/formats/arrow/program.cpp +++ b/ydb/core/formats/arrow/program.cpp @@ -57,7 +57,7 @@ public: auto funcNames = GetRegistryFunctionNames(assign.GetOperation()); arrow::Result<arrow::Datum> result = arrow::Status::UnknownError<std::string>("unknown function"); - for (const auto& funcName : funcNames) { + for (const auto& funcName : funcNames) { if (TBase::Ctx && TBase::Ctx->func_registry()->GetFunction(funcName).ok()) { result = arrow::compute::CallFunction(funcName, *arguments, assign.GetOptions(), TBase::Ctx); } else { @@ -131,7 +131,7 @@ template <class TAssignObject> class TKernelFunction : public IStepFunction<TAssignObject> { using TBase = IStepFunction<TAssignObject>; const TFunctionPtr Function; - + public: TKernelFunction(const TFunctionPtr kernelsFunction, arrow::compute::ExecContext* ctx) : TBase(ctx) diff --git a/ydb/core/tx/columnshard/engines/ut_program.cpp b/ydb/core/tx/columnshard/engines/ut_program.cpp index 0897a2f147..230de9d1bb 100644 --- a/ydb/core/tx/columnshard/engines/ut_program.cpp +++ b/ydb/core/tx/columnshard/engines/ut_program.cpp @@ -56,15 +56,16 @@ Y_UNIT_TEST_SUITE(TestProgram) { return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::Add, blockInt32Type, blockInt32Type, blockInt32Type); } case NYql::TKernelRequestBuilder::EBinaryOp::StartsWith: + case NYql::TKernelRequestBuilder::EBinaryOp::EndsWith: { NYql::TExprContext ctx; auto blockStringType = ctx.template MakeType<NYql::TBlockExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::Utf8)); auto blockBoolType = ctx.template MakeType<NYql::TBlockExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::Bool)); if (scalar) { auto scalarStringType = ctx.template MakeType<NYql::TScalarExprType>(ctx.template MakeType<NYql::TDataExprType>(NYql::EDataSlot::String)); - return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, blockStringType, scalarStringType, blockBoolType); + return ReqBuilder->AddBinaryOp(operation, blockStringType, scalarStringType, blockBoolType); } else { - return ReqBuilder->AddBinaryOp(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, blockStringType, blockStringType, blockBoolType); + return ReqBuilder->AddBinaryOp(operation, blockStringType, blockStringType, blockBoolType); } } case NYql::TKernelRequestBuilder::EBinaryOp::StringContains: @@ -76,9 +77,9 @@ Y_UNIT_TEST_SUITE(TestProgram) { } default: Y_FAIL("Not implemented"); - + } - } + } ui32 AddJsonExists(bool isBinaryType = true) { NYql::TExprContext ctx; @@ -108,7 +109,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { TString Serialize() { return ReqBuilder->Serialize(); - } + } }; TString SerializeProgram(const NKikimrSSA::TProgram& programProto) { @@ -122,7 +123,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { Y_PROTOBUF_SUPPRESS_NODISCARD olapProgramProto.SerializeToString(&programSerialized); return programSerialized; } - + Y_UNIT_TEST(YqlKernel) { TIndexInfo indexInfo = BuildTableInfo(testColumns, testKey); TIndexColumnResolver columnResolver(indexInfo); @@ -148,7 +149,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::Add); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); @@ -158,7 +159,8 @@ Y_UNIT_TEST_SUITE(TestProgram) { updates.AddRow().Add<int32_t>(100).Add<int32_t>(0); auto batch = updates.BuildArrow(); - UNIT_ASSERT(program.ApplyProgram(batch).ok()); + auto res = program.ApplyProgram(batch); + UNIT_ASSERT_C(res.ok(), res.ToString()); TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Int32)) })); result.AddRow().Add<int32_t>(2); @@ -200,7 +202,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith, true); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); @@ -211,7 +213,8 @@ Y_UNIT_TEST_SUITE(TestProgram) { auto batch = updates.BuildArrow(); Cerr << batch->ToString() << Endl; - UNIT_ASSERT(program.ApplyProgram(batch).ok()); + auto res = program.ApplyProgram(batch); + UNIT_ASSERT_C(res.ok(), res.ToString()); TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) })); result.AddRow().Add<ui8>(1); @@ -248,7 +251,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); @@ -258,7 +261,8 @@ Y_UNIT_TEST_SUITE(TestProgram) { updates.AddRow().Add<std::string>("Lorem ipsum dolor sit amet.").Add<std::string>("amet."); auto batch = updates.BuildArrow(); - UNIT_ASSERT(program.ApplyProgram(batch).ok()); + auto res = program.ApplyProgram(batch); + UNIT_ASSERT_C(res.ok(), res.ToString()); TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) })); result.AddRow().Add<ui8>(1); @@ -274,12 +278,11 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StringContains); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); - TTableUpdatesBuilder updates(NArrow::MakeArrowSchema({{"string", TTypeInfo(NTypeIds::Bytes) }, {"substring", TTypeInfo(NTypeIds::Bytes) }})); updates.AddRow().Add<std::string>("Lorem ipsum \xC0 dolor\f sit amet.").Add<std::string>("dolor"); updates.AddRow().Add<std::string>("Lorem ipsum dolor sit \amet.").Add<std::string>("amet."); @@ -333,7 +336,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.AddJsonExists(isBinaryType); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); @@ -364,6 +367,92 @@ Y_UNIT_TEST_SUITE(TestProgram) { UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString()); } + Y_UNIT_TEST(StartsAndEnds) { + TIndexInfo indexInfo = BuildTableInfo(testColumns, testKey); + TIndexColumnResolver columnResolver(indexInfo); + + NKikimrSSA::TProgram programProto; + { + auto* command = programProto.AddCommand(); + auto* functionProto = command->MutableAssign()->MutableFunction(); + functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_YQL_KERNEL); + functionProto->SetKernelIdx(0); + functionProto->AddArguments()->SetName("string"); + functionProto->AddArguments()->SetName("substring"); + functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_STR_STARTS_WITH); + command->MutableAssign()->MutableColumn()->SetName("start_with"); + } + { + auto* command = programProto.AddCommand(); + auto* functionProto = command->MutableAssign()->MutableFunction(); + functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_YQL_KERNEL); + functionProto->SetKernelIdx(1); + functionProto->AddArguments()->SetName("string"); + functionProto->AddArguments()->SetName("substring"); + functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_STR_ENDS_WITH); + command->MutableAssign()->MutableColumn()->SetName("ends_with"); + } + { + auto* command = programProto.AddCommand(); + auto* functionProto = command->MutableAssign()->MutableFunction(); + functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW); + functionProto->AddArguments()->SetName("start_with"); + functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_CAST_TO_BOOLEAN); + command->MutableAssign()->MutableColumn()->SetName("start_with_bool"); + } + { + auto* command = programProto.AddCommand(); + auto* functionProto = command->MutableAssign()->MutableFunction(); + functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW); + functionProto->AddArguments()->SetName("ends_with"); + functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_CAST_TO_BOOLEAN); + command->MutableAssign()->MutableColumn()->SetName("ends_with_bool"); + } + { + auto* command = programProto.AddCommand(); + auto* functionProto = command->MutableAssign()->MutableFunction(); + functionProto->SetFunctionType(NKikimrSSA::TProgram::EFunctionType::TProgram_EFunctionType_SIMPLE_ARROW); + functionProto->AddArguments()->SetName("start_with_bool"); + functionProto->AddArguments()->SetName("ends_with_bool"); + functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_BINARY_AND); + command->MutableAssign()->MutableColumn()->SetName("result"); + } + { + auto* command = programProto.AddCommand(); + auto* prjectionProto = command->MutableProjection(); + auto* column = prjectionProto->AddColumns(); + column->SetName("result"); + } + + { + TKernelsWrapper kernels; + kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::StartsWith); + kernels.Add(NYql::TKernelRequestBuilder::EBinaryOp::EndsWith); + programProto.SetKernels(kernels.Serialize()); + const auto programSerialized = SerializeProgram(programProto); + + TProgramContainer program; + TString errors; + UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); + + TTableUpdatesBuilder updates(NArrow::MakeArrowSchema({{"string", TTypeInfo(NTypeIds::Utf8) }, {"substring", TTypeInfo(NTypeIds::Utf8) }})); + updates.AddRow().Add<std::string>("Lorem ipsum dolor sit Lorem").Add<std::string>("Lorem"); + updates.AddRow().Add<std::string>("Lorem ipsum dolor sit amet.").Add<std::string>("amet."); + + auto batch = updates.BuildArrow(); + auto res = program.ApplyProgram(batch); + UNIT_ASSERT_C(res.ok(), res.ToString()); + + TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("result", TTypeInfo(NTypeIds::Bool)) })); + result.AddRow().Add<bool>(true); + result.AddRow().Add<bool>(false); + + auto expected = result.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString()); + } + + } + Y_UNIT_TEST(JsonExists) { JsonExistsImpl(false); } @@ -403,7 +492,7 @@ Y_UNIT_TEST_SUITE(TestProgram) { kernels.AddJsonValue(isBinaryType, resultType); programProto.SetKernels(kernels.Serialize()); const auto programSerialized = SerializeProgram(programProto); - + TProgramContainer program; TString errors; UNIT_ASSERT_C(program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors), errors); @@ -454,43 +543,43 @@ Y_UNIT_TEST_SUITE(TestProgram) { Cerr << "Check output for " << resultType << Endl; if (resultType == NYql::EDataSlot::Utf8) { TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Utf8)) })); - + result.AddRow().Add<std::string>("value"); result.AddRow().Add<std::string>("10"); result.AddRow().Add<std::string>("0.1"); result.AddRow().Add<std::string>("false"); result.AddRow().AddNull(); result.AddRow().AddNull(); - + auto expected = result.BuildArrow(); UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString()); } else if (resultType == NYql::EDataSlot::Bool) { TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint8)) })); - + result.AddRow().AddNull(); result.AddRow().AddNull(); result.AddRow().AddNull(); result.AddRow().Add<ui8>(0); result.AddRow().AddNull(); result.AddRow().AddNull(); - + auto expected = result.BuildArrow(); UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString()); } else if (resultType == NYql::EDataSlot::Int64 || resultType == NYql::EDataSlot::Uint64) { TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Int64)) })); - + result.AddRow().AddNull(); result.AddRow().Add<i64>(10); result.AddRow().AddNull(); result.AddRow().AddNull(); result.AddRow().AddNull(); result.AddRow().AddNull(); - + auto expected = result.BuildArrow(); UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString()); } else if (resultType == NYql::EDataSlot::Double || resultType == NYql::EDataSlot::Float) { TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Double)) })); - + result.AddRow().AddNull(); result.AddRow().Add<double>(10); result.AddRow().Add<double>(0.1); @@ -553,7 +642,8 @@ Y_UNIT_TEST_SUITE(TestProgram) { updates.AddRow().Add(""); auto batch = updates.BuildArrow(); - UNIT_ASSERT(program.ApplyProgram(batch).ok()); + auto res = program.ApplyProgram(batch); + UNIT_ASSERT_C(res.ok(), res.ToString()); TTableUpdatesBuilder result(NArrow::MakeArrowSchema( { std::make_pair("0", TTypeInfo(NTypeIds::Uint64)) })); result.AddRow().Add<uint64_t>(3); diff --git a/ydb/core/tx/program/program.cpp b/ydb/core/tx/program/program.cpp index 99d75de4db..a2f5c6d505 100644 --- a/ydb/core/tx/program/program.cpp +++ b/ydb/core/tx/program/program.cpp @@ -183,6 +183,9 @@ TAssign TProgramBuilder::MakeFunction(const std::string& name, case TId::FUNC_CAST_TO_INT8: return TAssign(name, EOperation::CastInt8, std::move(arguments), mkCastOptions(std::make_shared<arrow::Int8Type>())); + case TId::FUNC_CAST_TO_BOOLEAN: + return TAssign(name, EOperation::CastBoolean, std::move(arguments), + mkCastOptions(std::make_shared<arrow::BooleanType>())); case TId::FUNC_CAST_TO_INT16: return TAssign(name, EOperation::CastInt16, std::move(arguments), mkCastOptions(std::make_shared<arrow::Int16Type>())); @@ -471,7 +474,7 @@ bool TProgramContainer::Init(const IColumnResolver& columnResolver, NKikimrSchem ::google::protobuf::TextFormat::PrintToString(programProto, &out); AFL_DEBUG(NKikimrServices::TX_COLUMNSHARD)("program", out); } - + if (olapProgramProto.HasParameters()) { Y_VERIFY(olapProgramProto.HasParametersSchema(), "Parameters are present, but there is no schema."); |