diff options
author | mrlolthe1st <mrlolthe1st@yandex-team.com> | 2023-10-05 17:00:50 +0300 |
---|---|---|
committer | mrlolthe1st <mrlolthe1st@yandex-team.com> | 2023-10-05 17:26:16 +0300 |
commit | f288403e6cb7cc62bb16f2c707296f096a62df29 (patch) | |
tree | 9701c077fe0d0d86fa6fb6bd6a2cee1b5451901c | |
parent | 762a22f887e56e471cb1196f503d47588a761490 (diff) | |
download | ydb-f288403e6cb7cc62bb16f2c707296f096a62df29.tar.gz |
YQL-9517: Implement block RPC reader
YQL-9517: Implement RPC reader
32 files changed, 2509 insertions, 181 deletions
diff --git a/ydb/library/yql/dq/integration/yql_dq_integration.h b/ydb/library/yql/dq/integration/yql_dq_integration.h index d4f45d7f168..65500c10680 100644 --- a/ydb/library/yql/dq/integration/yql_dq_integration.h +++ b/ydb/library/yql/dq/integration/yql_dq_integration.h @@ -53,6 +53,7 @@ public: virtual TMaybe<bool> CanWrite(const TExprNode& write, TExprContext& ctx) = 0; virtual TExprNode::TPtr WrapWrite(const TExprNode::TPtr& write, TExprContext& ctx) = 0; + virtual bool CanBlockRead(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) = 0; virtual void RegisterMkqlCompiler(NCommon::TMkqlCallableCompilerBase& compiler) = 0; virtual bool CanFallback() = 0; virtual void FillSourceSettings(const TExprNode& node, ::google::protobuf::Any& settings, TString& sourceType) = 0; diff --git a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp index a66cac44f2c..24ae595d66a 100644 --- a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp +++ b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp @@ -22,6 +22,20 @@ TMaybe<ui64> TDqIntegrationBase::EstimateReadSize(ui64, ui32, const TVector<cons return Nothing(); } +bool TDqIntegrationBase::CanBlockReadTypes(const TStructExprType* node) { + for (const auto& e: node->GetItems()) { + // Check type + auto type = e->GetItemType(); + while (ETypeAnnotationKind::Optional == type->GetKind()) { + type = type->Cast<TOptionalExprType>()->GetItemType(); + } + if (ETypeAnnotationKind::Data != type->GetKind()) { + return false; + } + } + return true; +} + TExprNode::TPtr TDqIntegrationBase::WrapRead(const TDqSettings&, const TExprNode::TPtr& read, TExprContext&) { return read; } @@ -36,6 +50,10 @@ TMaybe<bool> TDqIntegrationBase::CanWrite(const TExprNode&, TExprContext&) { return Nothing(); } +bool TDqIntegrationBase::CanBlockRead(const NNodes::TExprBase&, TExprContext&, TTypeAnnotationContext&) { + return false; +} + TExprNode::TPtr TDqIntegrationBase::WrapWrite(const TExprNode::TPtr& write, TExprContext&) { return write; } diff --git a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h index 7273d5a25ff..fc841ce4c03 100644 --- a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h +++ b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h @@ -15,6 +15,7 @@ public: TMaybe<TOptimizerStatistics> ReadStatistics(const TExprNode::TPtr& readWrap, TExprContext& ctx) override; void RegisterMkqlCompiler(NCommon::TMkqlCallableCompilerBase& compiler) override; TMaybe<bool> CanWrite(const TExprNode& write, TExprContext& ctx) override; + bool CanBlockRead(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) override; TExprNode::TPtr WrapWrite(const TExprNode::TPtr& write, TExprContext& ctx) override; bool CanFallback() override; void FillSourceSettings(const TExprNode& node, ::google::protobuf::Any& settings, TString& sourceType) override; @@ -23,6 +24,8 @@ public: void Annotate(const TExprNode& node, THashMap<TString, TString>& params) override; bool PrepareFullResultTableParams(const TExprNode& root, TExprContext& ctx, THashMap<TString, TString>& params, THashMap<TString, TString>& secureParams) override; void WriteFullResultTableRef(NYson::TYsonWriter& writer, const TVector<TString>& columns, const THashMap<TString, TString>& graphParams) override; +protected: + bool CanBlockReadTypes(const TStructExprType* node); }; } // namespace NYql diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp index 7cc6e857308..9fc67895641 100644 --- a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp +++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp @@ -75,6 +75,7 @@ TDqConfiguration::TDqConfiguration() { REGISTER_SETTING(*this, ExportStats); REGISTER_SETTING(*this, TaskRunnerStats).Parser([](const TString& v) { return FromString<ETaskRunnerStats>(v); }); REGISTER_SETTING(*this, _SkipRevisionCheck); + REGISTER_SETTING(*this, UseBlockReader); } } // namespace NYql diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.h b/ydb/library/yql/providers/dq/common/yql_dq_settings.h index 613c105901f..75cca0876e2 100644 --- a/ydb/library/yql/providers/dq/common/yql_dq_settings.h +++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.h @@ -114,6 +114,7 @@ struct TDqSettings { NCommon::TConfSetting<bool, false> ExportStats; NCommon::TConfSetting<ETaskRunnerStats, false> TaskRunnerStats; NCommon::TConfSetting<bool, false> _SkipRevisionCheck; + NCommon::TConfSetting<bool, false> UseBlockReader; // This options will be passed to executor_actor and worker_actor template <typename TProtoConfig> diff --git a/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json b/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json index 7b528c5bf8d..2d3868424c1 100644 --- a/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json +++ b/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json @@ -27,6 +27,11 @@ "Match": {"Type": "Callable", "Name": "DqReadWideWrap"} }, { + "Name": "TDqReadBlockWideWrap", + "Base": "TDqReadWrapBase", + "Match": {"Type": "Callable", "Name": "DqReadBlockWideWrap"} + }, + { "Name": "TDqWrite", "Base": "TCallable", "Match": {"Type": "Callable", "Name": "DqWrite"}, diff --git a/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp b/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp index c4f2cd6cbd4..724082b8c78 100644 --- a/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp +++ b/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp @@ -10,7 +10,7 @@ using namespace NKikimr::NMiniKQL; using namespace NNodes; void RegisterDqsMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, const TTypeAnnotationContext& ctx) { - compiler.AddCallable({TDqSourceWideWrap::CallableName(), TDqSourceWideBlockWrap::CallableName(), TDqReadWideWrap::CallableName()}, + compiler.AddCallable({TDqSourceWideWrap::CallableName(), TDqSourceWideBlockWrap::CallableName(), TDqReadWideWrap::CallableName(), TDqReadBlockWideWrap::CallableName()}, [](const TExprNode& node, NCommon::TMkqlBuildContext&) { YQL_ENSURE(false, "Unsupported reader: " << node.Head().Content()); return TRuntimeNode(); diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp index 8cba4d0601a..d8ae81155db 100644 --- a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp +++ b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp @@ -3,6 +3,7 @@ #include <ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.h> #include <ydb/library/yql/providers/common/mkql/yql_type_mkql.h> #include <ydb/library/yql/providers/common/codec/yql_codec.h> +#include <ydb/library/yql/providers/common/provider/yql_provider_names.h> #include <ydb/library/yql/core/yql_expr_optimize.h> #include <ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.h> @@ -16,6 +17,7 @@ #include <ydb/library/yql/dq/opt/dq_opt_build.h> #include <ydb/library/yql/dq/opt/dq_opt_peephole.h> #include <ydb/library/yql/dq/type_ann/dq_type_ann.h> +#include <ydb/library/yql/dq/integration/yql_dq_integration.h> #include <ydb/library/yql/utils/log/log.h> @@ -64,6 +66,44 @@ namespace NYql::NDqs { }); } + THolder<IGraphTransformer> CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(TTypeAnnotationContext& typesCtx) { + return CreateFunctorTransformer([&typesCtx](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + TOptimizeExprSettings optSettings{nullptr}; + optSettings.VisitLambdas = true; + optSettings.VisitTuples = true; + return OptimizeExpr(input, output, + [&typesCtx](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { + if (!TDqReadWideWrap::Match(node.Get())) { + return node; + } + + auto readWideWrap = TDqReadWideWrap(node); + auto dataSource = readWideWrap.Raw()->Child(0)->Child(1); + auto dataSourceName = dataSource->Child(0)->Content(); + if (dataSourceName == DqProviderName || dataSource->IsCallable(ConfigureName)) { + return node; + } + + auto datasource = typesCtx.DataSourceMap.FindPtr(dataSourceName); + YQL_ENSURE(datasource); + auto dqIntegration = (*datasource)->GetDqIntegration(); + if (!dqIntegration || !dqIntegration->CanBlockRead(readWideWrap, ctx, typesCtx)) { + return node; + } + + YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration"; + + return Build<TCoWideFromBlocks>(ctx, node->Pos()) + .Input(Build<TDqReadBlockWideWrap>(ctx, node->Pos()) + .Input(readWideWrap.Input()) + .Flags(readWideWrap.Flags()) + .Token(readWideWrap.Token()) + .Done()) + .Done().Ptr(); + }, ctx, optSettings); + }); + } + THolder<IGraphTransformer> CreateDqsReplacePrecomputesTransformer(TTypeAnnotationContext& typesCtx, const NKikimr::NMiniKQL::IFunctionRegistry* funcRegistry) { return CreateFunctorTransformer([&typesCtx, funcRegistry](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) -> TStatus { TOptimizeExprSettings settings(&typesCtx); diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.h b/ydb/library/yql/providers/dq/opt/dqs_opt.h index 9beec287f07..7f94e976eed 100644 --- a/ydb/library/yql/providers/dq/opt/dqs_opt.h +++ b/ydb/library/yql/providers/dq/opt/dqs_opt.h @@ -15,6 +15,7 @@ namespace NYql::NDqs { THolder<IGraphTransformer> CreateDqsFinalizingOptTransformer(); THolder<IGraphTransformer> CreateDqsRewritePhyCallablesTransformer(TTypeAnnotationContext& typesCtx); + THolder<IGraphTransformer> CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(TTypeAnnotationContext& typesCtx); THolder<IGraphTransformer> CreateDqsReplacePrecomputesTransformer(TTypeAnnotationContext& typesCtx, const NKikimr::NMiniKQL::IFunctionRegistry* funcRegistry); } // namespace NYql::NDqs diff --git a/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp b/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp index fdd6619ae2d..3c7904f4dc9 100644 --- a/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp +++ b/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp @@ -233,6 +233,9 @@ private: void AfterTypeAnnotation(TTransformationPipeline* pipeline) const final { pipeline->Add(NDqs::CreateDqsReplacePrecomputesTransformer(*pipeline->GetTypeAnnotationContext(), State_->FunctionRegistry), "ReplacePrecomputes"); + if (State_->Settings->UseBlockReader.Get().GetOrElse(false)) { + pipeline->Add(NDqs::CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(*pipeline->GetTypeAnnotationContext()), "ReplaceWideReadsWithBlock"); + } bool useWideChannels = State_->Settings->UseWideChannels.Get().GetOrElse(false); bool useChannelBlocks = State_->Settings->UseWideBlockChannels.Get().GetOrElse(false); NDq::EChannelMode mode; diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp index 52399454561..a212e46824f 100644 --- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp +++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp @@ -20,6 +20,7 @@ public: TCoConfigure::CallableName(), TDqReadWrap::CallableName(), TDqReadWideWrap::CallableName(), + TDqReadBlockWideWrap::CallableName(), TDqSource::CallableName(), TDqSourceWrap::CallableName(), TDqSourceWideWrap::CallableName(), diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp index f385ba37f92..b7e18d36a3b 100644 --- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp +++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp @@ -23,7 +23,8 @@ public: AddHandler({TDqSourceWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleSourceWrap<true, false>)); AddHandler({TDqSourceWideBlockWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleSourceWrap<true, true>)); AddHandler({TDqReadWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleReadWrap)); - AddHandler({TDqReadWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap)); + AddHandler({TDqReadWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap<false>)); + AddHandler({TDqReadBlockWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap<true>)); AddHandler({TDqSource::CallableName()}, Hndl(&NDq::AnnotateDqSource)); AddHandler({TDqPhyLength::CallableName()}, Hndl(&NDq::AnnotateDqPhyLength)); @@ -113,6 +114,7 @@ private: return TStatus::Ok; } + template<bool IsBlock> TStatus HandleWideReadWrap(const TExprNode::TPtr& input, TExprContext& ctx) { if (!EnsureMinMaxArgsCount(*input, 1, 3, ctx)) { return TStatus::Error; @@ -145,8 +147,15 @@ private: const auto structType = readerType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>(); TTypeAnnotationNode::TListType types; const auto& items = structType->GetItems(); - types.reserve(items.size()); + types.reserve(items.size() + IsBlock); std::transform(items.cbegin(), items.cend(), std::back_inserter(types), std::bind(&TItemExprType::GetItemType, std::placeholders::_1)); + if constexpr (IsBlock) { + for (auto& type : types) { + type = ctx.MakeType<TBlockExprType>(type); + } + + types.push_back(ctx.MakeType<TScalarExprType>(ctx.MakeType<TDataExprType>(EDataSlot::Uint64))); + } input->SetTypeAnn(ctx.MakeType<TFlowExprType>(ctx.MakeType<TMultiExprType>(types))); return TStatus::Ok; diff --git a/ydb/library/yql/providers/yt/common/yql_configuration.h b/ydb/library/yql/providers/yt/common/yql_configuration.h index ca01b761ebd..0b671b7341f 100644 --- a/ydb/library/yql/providers/yt/common/yql_configuration.h +++ b/ydb/library/yql/providers/yt/common/yql_configuration.h @@ -56,7 +56,6 @@ constexpr bool DEFAULT_USE_RPC_READER_IN_DQ = false; constexpr size_t DEFAULT_RPC_READER_INFLIGHT = 1; constexpr TDuration DEFAULT_RPC_READER_TIMEOUT = TDuration::Seconds(120); - constexpr auto DEFAULT_SWITCH_MEMORY_LIMIT = 128_MB; constexpr ui32 DEFAULT_MAX_INPUT_TABLES = 3000; diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt index b1d1973fea3..9082a770b58 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt @@ -11,6 +11,9 @@ add_library(yt-comp_nodes-dq) target_compile_options(yt-comp_nodes-dq PRIVATE -DUSE_CURRENT_UDF_ABI_VERSION ) +target_include_directories(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include +) target_link_libraries(yt-comp_nodes-dq PUBLIC contrib-libs-cxxsupp yutil @@ -18,10 +21,14 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC providers-yt-comp_nodes providers-yt-codec providers-common-codec + core-formats-arrow cpp-mapreduce-interface cpp-mapreduce-common cpp-yson-node yt-yt-core + public-udf-arrow + libs-apache-arrow + contrib-libs-flatbuffers ) target_sources(yt-comp_nodes-dq PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt index 5b2c1550db9..06ae4cb2da3 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt @@ -9,8 +9,12 @@ add_library(yt-comp_nodes-dq) target_compile_options(yt-comp_nodes-dq PRIVATE + -Wno-unused-parameter -DUSE_CURRENT_UDF_ABI_VERSION ) +target_include_directories(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include +) target_link_libraries(yt-comp_nodes-dq PUBLIC contrib-libs-linux-headers contrib-libs-cxxsupp @@ -19,16 +23,23 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC providers-yt-comp_nodes providers-yt-codec providers-common-codec + core-formats-arrow cpp-mapreduce-interface cpp-mapreduce-common cpp-yson-node yt-yt-core + public-udf-arrow + libs-apache-arrow + contrib-libs-flatbuffers yt-yt-client yt-client-arrow yt-lib-yt_rpc_helpers ) target_sources(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_writer.cpp diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt index 5b2c1550db9..06ae4cb2da3 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt @@ -9,8 +9,12 @@ add_library(yt-comp_nodes-dq) target_compile_options(yt-comp_nodes-dq PRIVATE + -Wno-unused-parameter -DUSE_CURRENT_UDF_ABI_VERSION ) +target_include_directories(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include +) target_link_libraries(yt-comp_nodes-dq PUBLIC contrib-libs-linux-headers contrib-libs-cxxsupp @@ -19,16 +23,23 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC providers-yt-comp_nodes providers-yt-codec providers-common-codec + core-formats-arrow cpp-mapreduce-interface cpp-mapreduce-common cpp-yson-node yt-yt-core + public-udf-arrow + libs-apache-arrow + contrib-libs-flatbuffers yt-yt-client yt-client-arrow yt-lib-yt_rpc_helpers ) target_sources(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_writer.cpp diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt index b1d1973fea3..9082a770b58 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt @@ -11,6 +11,9 @@ add_library(yt-comp_nodes-dq) target_compile_options(yt-comp_nodes-dq PRIVATE -DUSE_CURRENT_UDF_ABI_VERSION ) +target_include_directories(yt-comp_nodes-dq PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include +) target_link_libraries(yt-comp_nodes-dq PUBLIC contrib-libs-cxxsupp yutil @@ -18,10 +21,14 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC providers-yt-comp_nodes providers-yt-codec providers-common-codec + core-formats-arrow cpp-mapreduce-interface cpp-mapreduce-common cpp-yson-node yt-yt-core + public-udf-arrow + libs-apache-arrow + contrib-libs-flatbuffers ) target_sources(yt-comp_nodes-dq PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp new file mode 100644 index 00000000000..923deacc08c --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp @@ -0,0 +1,584 @@ +#include "dq_yt_block_reader.h" +#include "stream_decoder.h" +#include "dq_yt_rpc_helpers.h" + +#include <ydb/library/yql/public/udf/arrow/block_builder.h> + +#include <ydb/library/yql/providers/yt/codec/yt_codec.h> +#include <ydb/library/yql/providers/common/codec/yql_codec.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> +#include <ydb/library/yql/minikql/mkql_stats_registry.h> +#include <ydb/library/yql/minikql/mkql_node.h> +#include <ydb/library/yql/dq/runtime/dq_arrow_helpers.h> +#include <ydb/core/formats/arrow/dictionary/conversion.h> +#include <ydb/core/formats/arrow/arrow_helpers.h> + +#include <yt/yt/core/concurrency/thread_pool.h> +#include <yt/cpp/mapreduce/interface/common.h> +#include <yt/cpp/mapreduce/interface/errors.h> +#include <yt/cpp/mapreduce/interface/client.h> +#include <yt/cpp/mapreduce/interface/serialize.h> +#include <yt/cpp/mapreduce/interface/config.h> +#include <yt/cpp/mapreduce/common/wait_proxy.h> + +#include <arrow/compute/cast.h> +#include <arrow/compute/api_vector.h> +#include <arrow/compute/api.h> +#include <arrow/io/interfaces.h> +#include <arrow/io/memory.h> +#include <arrow/ipc/reader.h> +#include <arrow/array.h> +#include <arrow/record_batch.h> +#include <arrow/type.h> +#include <arrow/result.h> +#include <arrow/buffer.h> + +#include <library/cpp/yson/node/node.h> +#include <library/cpp/yson/node/node_io.h> + +#include <util/generic/size_literals.h> +#include <util/stream/output.h> + +namespace NYql::NDqs { + +using namespace NKikimr::NMiniKQL; + +namespace { +struct TResultBatch { + using TPtr = std::shared_ptr<TResultBatch>; + size_t RowsCnt; + std::vector<arrow::Datum> Columns; + TResultBatch(int64_t cnt, decltype(Columns)&& columns) : RowsCnt(cnt), Columns(std::move(columns)) {} + TResultBatch(std::shared_ptr<arrow::RecordBatch> batch) : RowsCnt(batch->num_rows()), Columns(batch->columns().begin(), batch->columns().end()) {} +}; + +template<typename T> +class TBlockingQueueWithLimit { + struct Poison { + TString Error; + }; + using TPoisonOr = std::variant<T, Poison>; +public: + TBlockingQueueWithLimit(size_t limit) : Limit_(limit) {} + void Push(T&& val) { + PushInternal(std::move(val)); + } + + void PushPoison(const TString& err) { + PushInternal(Poison{err}); + } + + T Get() { + auto res = GetInternal(); + if (std::holds_alternative<Poison>(res)) { + throw std::runtime_error(std::get<Poison>(res).Error); + } + return std::move(std::get<T>(res)); + } +private: + template<typename X> + void PushInternal(X&& val) { + NYT::TPromise<void> promise; + { + std::lock_guard _(Mtx_); + if (!Awaiting_.empty()) { + Awaiting_.front().Set(std::move(val)); + Awaiting_.pop(); + return; + } + if (Ready_.size() >= Limit_) { + promise = NYT::NewPromise<void>(); + BlockedPushes_.push(promise); + } else { + Ready_.emplace(std::move(val)); + return; + } + } + YQL_ENSURE(NYT::NConcurrency::WaitFor(promise.ToFuture()).IsOK()); + std::lock_guard _(Mtx_); + Ready_.emplace(std::move(val)); + } + + TPoisonOr GetInternal() { + NYT::TPromise<TPoisonOr> awaiter; + { + std::lock_guard _(Mtx_); + if (!BlockedPushes_.empty()) { + BlockedPushes_.front().Set(); + BlockedPushes_.pop(); + } + if (!Ready_.empty()) { + auto res = std::move(Ready_.front()); + Ready_.pop(); + return res; + } + awaiter = NYT::NewPromise<TPoisonOr>(); + Awaiting_.push(awaiter); + } + auto awaitResult = NYT::NConcurrency::WaitFor(awaiter.ToFuture()); + if (!awaitResult.IsOK()) { + throw std::runtime_error(awaitResult.GetMessage()); + } + return std::move(awaitResult.Value()); + } + std::mutex Mtx_; + std::queue<TPoisonOr> Ready_; + std::queue<NYT::TPromise<void>> BlockedPushes_; + std::queue<NYT::TPromise<TPoisonOr>> Awaiting_; + size_t Limit_; +}; + +class TListener { + using TBatchPtr = std::shared_ptr<arrow::RecordBatch>; +public: + using TPromise = NYT::TPromise<TResultBatch>; + using TPtr = std::shared_ptr<TListener>; + TListener(size_t initLatch, size_t inflight) : Latch_(initLatch), Queue_(inflight) {} + + void OnEOF() { + bool excepted = 0; + if (GotEOF_.compare_exchange_strong(excepted, 1)) { + // block poining to nullptr is marker of EOF + HandleResult(nullptr); + } else { + // can't get EOF more than one time + HandleError("EOS already got"); + } + } + + // Handles result + void HandleResult(TResultBatch::TPtr&& res) { + Queue_.Push(std::move(res)); + } + + void OnRecordBatchDecoded(TBatchPtr record_batch) { + YQL_ENSURE(record_batch); + // decode dictionary + record_batch = NKikimr::NArrow::DictionaryToArray(record_batch); + // and handle result + HandleResult(std::make_shared<TResultBatch>(record_batch)); + } + + TResultBatch::TPtr Get() { + return Queue_.Get(); + } + + void HandleError(const TString& msg) { + Queue_.PushPoison(msg); + } + + void HandleFallback(TResultBatch::TPtr&& block) { + HandleResult(std::move(block)); + } + + void InputDone() { + // EOF comes when all inputs got EOS + if (!--Latch_) { + OnEOF(); + } + } +private: + std::atomic<size_t> Latch_; + std::atomic<bool> GotEOF_; + TBlockingQueueWithLimit<TResultBatch::TPtr> Queue_; +}; + +class TBlockBuilder { +public: + void Init(std::shared_ptr<std::vector<TType*>> columnTypes, arrow::MemoryPool& pool, const NUdf::IPgBuilder* pgBuilder) { + ColumnTypes_ = columnTypes; + ColumnBuilders_.reserve(ColumnTypes_->size()); + size_t maxBlockItemSize = 0; + for (auto& type: *ColumnTypes_) { + maxBlockItemSize = std::max(maxBlockItemSize, CalcMaxBlockItemSize(type)); + } + size_t maxBlockLen = CalcBlockLen(maxBlockItemSize); + for (size_t i = 0; i < ColumnTypes_->size(); ++i) { + ColumnBuilders_.push_back( + std::move(NUdf::MakeArrayBuilder( + TTypeInfoHelper(), ColumnTypes_->at(i), + pool, + maxBlockLen, + pgBuilder + )) + ); + } + } + + void Add(const NUdf::TUnboxedValue& val) { + for (ui32 i = 0; i < ColumnBuilders_.size(); ++i) { + auto v = val.GetElement(i); + ColumnBuilders_[i]->Add(v); + } + ++RowsCnt_; + } + + std::vector<TResultBatch::TPtr> Build() { + std::vector<arrow::Datum> columns; + columns.reserve(ColumnBuilders_.size()); + for (size_t i = 0; i < ColumnBuilders_.size(); ++i) { + columns.emplace_back(std::move(ColumnBuilders_[i]->Build(false))); + } + std::vector<std::shared_ptr<TResultBatch>> blocks; + int64_t offset = 0; + std::vector<int64_t> currentChunk(columns.size()), inChunkOffset(columns.size()); + while (RowsCnt_) { + int64_t max_curr_len = RowsCnt_; + for (size_t i = 0; i < columns.size(); ++i) { + if (arrow::Datum::Kind::CHUNKED_ARRAY == columns[i].kind()) { + auto& c_arr = columns[i].chunked_array(); + while (currentChunk[i] < c_arr->num_chunks() && !c_arr->chunk(currentChunk[i])) { + ++currentChunk[i]; + } + YQL_ENSURE(currentChunk[i] < c_arr->num_chunks()); + max_curr_len = std::min(max_curr_len, c_arr->chunk(currentChunk[i])->length() - inChunkOffset[i]); + } + } + RowsCnt_ -= max_curr_len; + decltype(columns) result_columns; + result_columns.reserve(columns.size()); + offset += max_curr_len; + for (size_t i = 0; i < columns.size(); ++i) { + auto& e = columns[i]; + if (arrow::Datum::Kind::CHUNKED_ARRAY == e.kind()) { + result_columns.emplace_back(e.chunked_array()->chunk(currentChunk[i])->Slice(inChunkOffset[i], max_curr_len)); + if (max_curr_len + inChunkOffset[i] == e.chunked_array()->chunk(currentChunk[i])->length()) { + ++currentChunk[i]; + inChunkOffset[i] = 0; + } else { + inChunkOffset[i] += max_curr_len; + } + } else { + result_columns.emplace_back(e.array()->Slice(offset - max_curr_len, max_curr_len)); + } + } + blocks.emplace_back(std::make_shared<TResultBatch>(max_curr_len, std::move(result_columns))); + } + return blocks; + } + +private: + int64_t RowsCnt_ = 0; + std::vector<std::unique_ptr<NUdf::IArrayBuilder>> ColumnBuilders_; + std::shared_ptr<std::vector<TType*>> ColumnTypes_; +}; + +class TLocalListener : public arrow::ipc::Listener { +public: + TLocalListener(std::shared_ptr<TListener> consumer) + : Consumer_(consumer) {} + + void Init(std::shared_ptr<TLocalListener> self) { + Self_ = self; + Decoder_ = std::make_shared<arrow::ipc::NDqs::StreamDecoder2>(self, arrow::ipc::IpcReadOptions{.use_threads=false}); + } + + arrow::Status OnEOS() override { + Decoder_->Reset(); + return arrow::Status::OK(); + } + + arrow::Status OnRecordBatchDecoded(std::shared_ptr<arrow::RecordBatch> batch) override { + Consumer_->OnRecordBatchDecoded(batch); + return arrow::Status::OK(); + } + + void Consume(std::shared_ptr<arrow::Buffer> buff) { + ARROW_OK(Decoder_->Consume(buff)); + } + + void Finish() { + Self_ = nullptr; + } +private: + std::shared_ptr<TLocalListener> Self_; + std::shared_ptr<TListener> Consumer_; + std::shared_ptr<arrow::ipc::NDqs::StreamDecoder2> Decoder_; +}; + +class TSource : public TNonCopyable { +public: + using TPtr = std::shared_ptr<TSource>; + TSource(std::unique_ptr<TSettingsHolder>&& settings, + size_t inflight, TType* type, const THolderFactory& holderFactory) + : Settings_(std::move(settings)) + , Inputs_(std::move(Settings_->RawInputs)) + , Listener_(std::make_shared<TListener>(Inputs_.size(), inflight)) + , HolderFactory_(holderFactory) + { + auto structType = AS_TYPE(TStructType, type); + std::vector<TType*> columnTypes_(structType->GetMembersCount()); + for (ui32 i = 0; i < structType->GetMembersCount(); ++i) { + columnTypes_[i] = structType->GetMemberType(i); + } + auto ptr = std::make_shared<decltype(columnTypes_)>(std::move(columnTypes_)); + Inflight_ = std::min(inflight, Inputs_.size()); + + LocalListeners_.reserve(Inputs_.size()); + for (size_t i = 0; i < Inputs_.size(); ++i) { + InputsQueue_.emplace(i); + LocalListeners_.emplace_back(std::make_shared<TLocalListener>(Listener_)); + LocalListeners_.back()->Init(LocalListeners_.back()); + } + BlockBuilder_.Init(ptr, *Settings_->Pool, Settings_->PgBuilder); + FallbackReader_.SetSpecs(*Settings_->Specs, HolderFactory_); + } + + void RunRead() { + size_t inputIdx; + { + std::lock_guard _(Mtx_); + if (InputsQueue_.empty()) { + return; + } + inputIdx = InputsQueue_.front(); + InputsQueue_.pop(); + } + Inputs_[inputIdx]->Read().SubscribeUnique(BIND([inputIdx = inputIdx, self = Self_](NYT::TErrorOr<NYT::TSharedRef>&& res) { + self->Pool_->GetInvoker()->Invoke(BIND([inputIdx, self, res = std::move(res)] () mutable { + try { + self->Accept(inputIdx, std::move(res)); + self->RunRead(); + } catch (std::exception& e) { + self->Listener_->HandleError(e.what()); + } + })); + })); + } + + void Accept(size_t inputIdx, NYT::TErrorOr<NYT::TSharedRef>&& res) { + if (res.IsOK() && !res.Value()) { + // End Of Stream + Listener_->InputDone(); + return; + } + + if (!res.IsOK()) { + // Propagate error + Listener_->HandleError(res.GetMessage()); + return; + } + + NYT::NApi::NRpcProxy::NProto::TRowsetDescriptor descriptor; + NYT::NApi::NRpcProxy::NProto::TRowsetStatistics statistics; + NYT::TSharedRef currentPayload = NYT::NApi::NRpcProxy::DeserializeRowStreamBlockEnvelope(res.Value(), &descriptor, &statistics); + if (descriptor.rowset_format() != NYT::NApi::NRpcProxy::NProto::RF_ARROW) { + auto promise = NYT::NewPromise<std::vector<TResultBatch::TPtr>>(); + MainInvoker_->Invoke(BIND([inputIdx, currentPayload, self = Self_, promise] { + try { + promise.Set(self->FallbackHandler(inputIdx, currentPayload)); + } catch (std::exception& e) { + promise.Set(NYT::TError(e.what())); + } + })); + auto result = NYT::NConcurrency::WaitFor(promise.ToFuture()); + if (!result.IsOK()) { + Listener_->HandleError(result.GetMessage()); + return; + } + for (auto& e: result.Value()) { + Listener_->HandleFallback(std::move(e)); + } + InputDone(inputIdx); + return; + } + + if (!currentPayload.Size()) { + // EOS + Listener_->InputDone(); + return; + } + // TODO(): support row and range indexes + auto payload = TMemoryInput(currentPayload.Begin(), currentPayload.Size()); + arrow::BufferBuilder bb; + ARROW_OK(bb.Reserve(currentPayload.Size())); + ARROW_OK(bb.Append((const uint8_t*)payload.Buf(), currentPayload.Size())); + LocalListeners_[inputIdx]->Consume(*bb.Finish()); + InputDone(inputIdx); + } + + // Return input back to queue + void InputDone(auto input) { + std::lock_guard _(Mtx_); + InputsQueue_.emplace(input); + } + + TResultBatch::TPtr Next() { + auto result = Listener_->Get(); + return result; + } + + std::vector<TResultBatch::TPtr> FallbackHandler(size_t idx, NYT::TSharedRef payload) { + if (!payload.Size()) { + return {}; + } + // We're have only one mkql reader, protect it if 2 fallbacks happen at the same time + std::lock_guard _(FallbackMtx_); + auto currentReader_ = std::make_shared<TPayloadRPCReader>(std::move(payload)); + + // TODO(): save and recover row indexes + FallbackReader_.SetReader(*currentReader_, 1, 4_MB, ui32(Settings_->OriginalIndexes[idx]), true); + // If we don't save the reader, after exiting FallbackHandler it will be destroyed, + // but FallbackReader points on it yet. + Reader_ = currentReader_; + FallbackReader_.Next(); + while (FallbackReader_.IsValid()) { + auto currentRow = std::move(FallbackReader_.GetRow()); + if (!Settings_->Specs->InputGroups.empty()) { + currentRow = std::move(HolderFactory_.CreateVariantHolder(currentRow.Release(), Settings_->Specs->InputGroups.at(Settings_->OriginalIndexes[idx]))); + } + BlockBuilder_.Add(currentRow); + FallbackReader_.Next(); + } + return BlockBuilder_.Build(); + } + + void Finish() { + FallbackReader_.Finish(); + Pool_->Shutdown(); + for (auto& e: LocalListeners_) { + e->Finish(); + } + Self_ = nullptr; + } + + void SetSelfAndRun(TPtr self) { + MainInvoker_ = NYT::GetCurrentInvoker(); + Self_ = self; + Pool_ = NYT::NConcurrency::CreateThreadPool(Inflight_, "rpc_reader_inflight"); + // Run Inflight_ reads at the same time + for (size_t i = 0; i < Inflight_; ++i) { + RunRead(); + } + } + +private: + NYT::IInvoker* MainInvoker_; + NYT::NConcurrency::IThreadPoolPtr Pool_; + std::mutex Mtx_; + std::mutex FallbackMtx_; + std::unique_ptr<TSettingsHolder> Settings_; + std::vector<std::shared_ptr<TLocalListener>> LocalListeners_; + std::vector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> Inputs_; + TMkqlReaderImpl FallbackReader_; + TBlockBuilder BlockBuilder_; + std::shared_ptr<TPayloadRPCReader> Reader_; + std::queue<size_t> InputsQueue_; + TListener::TPtr Listener_; + TPtr Self_; + size_t Inflight_; + const THolderFactory& HolderFactory_; +}; + +class TState: public TComputationValue<TState> { + using TBase = TComputationValue<TState>; +public: + TState(TMemoryUsageInfo* memInfo, TSource::TPtr source, size_t width, TType* type) + : TBase(memInfo) + , Source_(std::move(source)) + , Width_(width) + , Type_(type) + , Types_(width) + { + for (size_t i = 0; i < Width_; ++i) { + Types_[i] = NArrow::GetArrowType(AS_TYPE(TStructType, Type_)->GetMemberType(i)); + } + } + + EFetchResult FetchValues(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) { + auto batch = Source_->Next(); + if (!batch) { + Source_->Finish(); + return EFetchResult::Finish; + } + for (size_t i = 0; i < Width_; ++i) { + if (!output[i]) { + continue; + } + if(!batch->Columns[i].type()->Equals(Types_[i])) { + *(output[i]) = ctx.HolderFactory.CreateArrowBlock(ARROW_RESULT(arrow::compute::Cast(batch->Columns[i], Types_[i]))); + continue; + } + *(output[i]) = ctx.HolderFactory.CreateArrowBlock(std::move(batch->Columns[i])); + } + if (output[Width_]) { + *(output[Width_]) = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(ui64(batch->RowsCnt))); + } + return EFetchResult::One; + } + +private: + TSource::TPtr Source_; + const size_t Width_; + TType* Type_; + std::vector<std::shared_ptr<arrow::DataType>> Types_; +}; +}; + +class TDqYtReadBlockWrapper : public TStatefulWideFlowComputationNode<TDqYtReadBlockWrapper> { +using TBaseComputation = TStatefulWideFlowComputationNode<TDqYtReadBlockWrapper>; +public: + + TDqYtReadBlockWrapper(const TComputationNodeFactoryContext& ctx, const TString& clusterName, + const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec, + const TVector<ui32>& inputGroups, + TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight, + size_t timeout) : TBaseComputation(ctx.Mutables, this, EValueRepresentation::Boxed) + , Width(AS_TYPE(TStructType, itemType)->GetMembersCount()) + , CodecCtx(ctx.Env, ctx.FunctionRegistry, &ctx.HolderFactory) + , ClusterName(clusterName) + , Token(token) + , SamplingSpec(samplingSpec) + , Tables(std::move(tables)) + , Inflight(inflight) + , Timeout(timeout) + , Type(itemType) + { + // TODO() Enable range indexes + Specs.SetUseSkiff("", TMkqlIOSpecs::ESystemField::RowIndex); + Specs.Init(CodecCtx, inputSpec, inputGroups, tableNames, itemType, {}, {}, jobStats); + } + + void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { + auto settings = CreateInputStreams(true, Token, ClusterName, Timeout, Inflight > 1, Tables, SamplingSpec); + settings->Specs = &Specs; + settings->Pool = arrow::default_memory_pool(); + settings->PgBuilder = &ctx.Builder->GetPgBuilder(); + auto source = std::make_shared<TSource>(std::move(settings), Inflight, Type, ctx.HolderFactory); + source->SetSelfAndRun(source); + state = ctx.HolderFactory.Create<TState>(source, Width, Type); + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (!state.HasValue()) { + MakeState(ctx, state); + } + return static_cast<TState&>(*state.AsBoxed()).FetchValues(ctx, output); + } + + void RegisterDependencies() const final {} + + const ui32 Width; + NCommon::TCodecContext CodecCtx; + TMkqlIOSpecs Specs; + + TString ClusterName; + TString Token; + NYT::TNode SamplingSpec; + TVector<std::pair<NYT::TRichYPath, NYT::TFormat>> Tables; + size_t Inflight; + size_t Timeout; + TType* Type; +}; + +IComputationNode* CreateDqYtReadBlockWrapper(const TComputationNodeFactoryContext& ctx, const TString& clusterName, + const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec, + const TVector<ui32>& inputGroups, + TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight, + size_t timeout) +{ + return new TDqYtReadBlockWrapper(ctx, clusterName, token, inputSpec, samplingSpec, inputGroups, itemType, tableNames, std::move(tables), jobStats, inflight, timeout); +} +}
\ No newline at end of file diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h new file mode 100644 index 00000000000..0bcce64b114 --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h @@ -0,0 +1,17 @@ +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> +#include <ydb/library/yql/minikql/mkql_stats_registry.h> +#include <ydb/library/yql/minikql/mkql_node.h> + +#include <ydb/library/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.h> +#include <ydb/library/yql/providers/yt/codec/yt_codec.h> +#include <ydb/library/yql/providers/common/codec/yql_codec.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> + +namespace NYql::NDqs { +NKikimr::NMiniKQL::IComputationNode* CreateDqYtReadBlockWrapper(const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx, const TString& clusterName, + const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec, + const TVector<ui32>& inputGroups, + NKikimr::NMiniKQL::TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight, + size_t timeout); +} diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp index c02a693f699..4d88adfc6eb 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp @@ -9,8 +9,8 @@ using namespace NKikimr::NMiniKQL; TComputationNodeFactory GetDqYtFactory(NKikimr::NMiniKQL::IStatsRegistry* jobStats) { return [=] (TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* { TStringBuf name = callable.GetType()->GetName(); - if (name == "DqYtRead") { - return NDqs::WrapDqYtRead(callable, jobStats, ctx); + if (name == "DqYtRead" || name == "DqYtBlockRead") { + return NDqs::WrapDqYtRead(callable, jobStats, ctx, name == "DqYtBlockRead"); } if (name == "YtDqRowsWideWrite") { diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp index 8461b8e495b..578c5115079 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp @@ -1,7 +1,7 @@ #include "dq_yt_reader.h" #include "dq_yt_reader_impl.h" - +#include "dq_yt_block_reader.h" #include "dq_yt_rpc_reader.h" namespace NYql::NDqs { @@ -63,7 +63,7 @@ using TInputType = NYT::TRawTableReaderPtr; } }; -IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const TComputationNodeFactoryContext& ctx) { +IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const TComputationNodeFactoryContext& ctx, bool useBlocks) { MKQL_ENSURE(callable.GetInputsCount() == 8 || callable.GetInputsCount() == 9, "Expected 8 or 9 arguments."); TString clusterName(AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef()); @@ -102,15 +102,23 @@ IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsReg #ifdef __linux__ size_t inflight(AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<size_t>()); if (inflight) { - return new TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>(ctx, clusterName, token, - NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(), - inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout); + if (useBlocks) { + return CreateDqYtReadBlockWrapper(ctx, clusterName, token, + NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(), + inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout); + } else { + return new TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>(ctx, clusterName, token, + NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(), + inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout); + } } else { + YQL_ENSURE(!useBlocks); return new TDqYtReadWrapperBase<TDqYtReadWrapperHttp, TFileInputState>(ctx, clusterName, token, NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(), inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout); } #else + YQL_ENSURE(!useBlocks); return new TDqYtReadWrapperBase<TDqYtReadWrapperHttp, TFileInputState>(ctx, clusterName, token, NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(), inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, 0, timeout); diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h index 2dd5c2deefd..5f1a471a690 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h @@ -6,6 +6,6 @@ namespace NYql::NDqs { -NKikimr::NMiniKQL::IComputationNode* WrapDqYtRead(NKikimr::NMiniKQL::TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx); +NKikimr::NMiniKQL::IComputationNode* WrapDqYtRead(NKikimr::NMiniKQL::TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx, bool useBlocks); } // NYql diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp new file mode 100644 index 00000000000..735bb9ab9b0 --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp @@ -0,0 +1,123 @@ +#include "dq_yt_rpc_helpers.h" + +#include <ydb/library/yql/providers/yt/lib/yt_rpc_helpers/yt_convert_helpers.h> + +namespace NYql::NDqs { + +NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath) { + NYT::NYPath::TRichYPath tableYPath(richYPath.Path_); + const auto& rngs = richYPath.GetRanges(); + if (rngs) { + TVector<NYT::NChunkClient::TReadRange> ranges; + for (const auto& rng: *rngs) { + auto& range = ranges.emplace_back(); + if (rng.LowerLimit_.Offset_) { + range.LowerLimit().SetOffset(*rng.LowerLimit_.Offset_); + } + + if (rng.LowerLimit_.TabletIndex_) { + range.LowerLimit().SetTabletIndex(*rng.LowerLimit_.TabletIndex_); + } + + if (rng.LowerLimit_.RowIndex_) { + range.LowerLimit().SetRowIndex(*rng.LowerLimit_.RowIndex_); + } + + if (rng.UpperLimit_.Offset_) { + range.UpperLimit().SetOffset(*rng.UpperLimit_.Offset_); + } + + if (rng.UpperLimit_.TabletIndex_) { + range.UpperLimit().SetTabletIndex(*rng.UpperLimit_.TabletIndex_); + } + + if (rng.UpperLimit_.RowIndex_) { + range.UpperLimit().SetRowIndex(*rng.UpperLimit_.RowIndex_); + } + } + tableYPath.SetRanges(std::move(ranges)); + } + + if (richYPath.Columns_) { + tableYPath.SetColumns(richYPath.Columns_->Parts_); + } + + return tableYPath; +} + + +std::unique_ptr<TSettingsHolder> CreateInputStreams(bool isArrow, const TString& token, const TString& clusterName, const ui64 timeout, bool unordered, const TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>& tables, NYT::TNode samplingSpec) { + auto connectionConfig = NYT::New<NYT::NApi::NRpcProxy::TConnectionConfig>(); + connectionConfig->ClusterUrl = clusterName; + connectionConfig->DefaultTotalStreamingTimeout = TDuration::MilliSeconds(timeout); + auto connection = CreateConnection(connectionConfig); + auto clientOptions = NYT::NApi::TClientOptions(); + + if (token) { + clientOptions.Token = token; + } + + auto client = DynamicPointerCast<NYT::NApi::NRpcProxy::TClient>(connection->CreateClient(clientOptions)); + Y_VERIFY(client); + auto apiServiceProxy = client->CreateApiServiceProxy(); + + TVector<NYT::TFuture<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>> waitFor; + + size_t inputIdx = 0; + TVector<size_t> originalIndexes; + + for (auto [richYPath, format]: tables) { + if (richYPath.GetRanges() && richYPath.GetRanges()->empty()) { + ++inputIdx; + continue; + } + originalIndexes.emplace_back(inputIdx); + + auto request = apiServiceProxy.ReadTable(); + client->InitStreamingRequest(*request); + request->ClientAttachmentsStreamingParameters().ReadTimeout = TDuration::MilliSeconds(timeout); + + TString ppath; + auto tableYPath = ConvertYPathFromOld(richYPath); + + NYT::NYPath::ToProto(&ppath, tableYPath); + request->set_path(ppath); + request->set_desired_rowset_format(isArrow ? NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_ARROW : NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT); + if (isArrow) { + request->set_arrow_fallback_rowset_format(NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT); + } + + request->set_enable_row_index(true); + request->set_enable_table_index(true); + // TODO() Enable range indexes + request->set_enable_range_index(!isArrow); + + request->set_unordered(unordered); + + // https://a.yandex-team.ru/arcadia/yt/yt_proto/yt/client/api/rpc_proxy/proto/api_service.proto?rev=r11519304#L2338 + if (!samplingSpec.IsUndefined()) { + TStringStream ss; + samplingSpec.Save(&ss); + request->set_config(ss.Str()); + } + + ConfigureTransaction(request, richYPath); + + // Get skiff format yson string + TStringStream fmt; + format.Config.Save(&fmt); + request->set_format(fmt.Str()); + + waitFor.emplace_back(std::move(CreateRpcClientInputStream(std::move(request)).ApplyUnique(BIND([](NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr&& stream) { + // first packet contains meta, skip it + return stream->Read().ApplyUnique(BIND([stream = std::move(stream)](NYT::TSharedRef&&) { + return std::move(stream); + })); + })))); + } + TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> rawInputs; + NYT::NConcurrency::WaitFor(NYT::AllSucceeded(waitFor)).ValueOrThrow().swap(rawInputs); + return std::make_unique<TSettingsHolder>(std::move(connection), std::move(client), std::move(rawInputs), std::move(originalIndexes)); +} + +} diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h new file mode 100644 index 00000000000..7bab55334a0 --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h @@ -0,0 +1,67 @@ +#pragma once + +#include <ydb/library/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.h> +#include <yt/cpp/mapreduce/common/helpers.h> +#include <yt/cpp/mapreduce/interface/client.h> +#include <yt/cpp/mapreduce/interface/serialize.h> +#include <yt/cpp/mapreduce/interface/config.h> +#include <yt/cpp/mapreduce/interface/common.h> + +#include <yt/yt/library/auth/auth.h> +#include <yt/yt/client/api/client.h> +#include <yt/yt/client/api/rpc_proxy/client_impl.h> +#include <yt/yt/client/api/rpc_proxy/config.h> +#include <yt/yt/client/api/rpc_proxy/connection.h> +#include <yt/yt/client/api/rpc_proxy/row_stream.h> + +namespace NYql::NDqs { +NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath); +class TPayloadRPCReader : public NYT::TRawTableReader { +public: + TPayloadRPCReader(NYT::TSharedRef&& payload) : Payload_(std::move(payload)), PayloadStream_(Payload_.Begin(), Payload_.Size()) {} + + bool Retry(const TMaybe<ui32>&, const TMaybe<ui64>&) override { + return false; + } + + void ResetRetries() override { + + } + + bool HasRangeIndices() const override { + return true; + }; + + size_t DoRead(void* buf, size_t len) override { + if (!PayloadStream_.Exhausted()) { + return PayloadStream_.Read(buf, len); + } + return 0; + }; + + virtual ~TPayloadRPCReader() override { + } +private: + NYT::TSharedRef Payload_; + TMemoryInput PayloadStream_; +}; + +struct TSettingsHolder : public TNonCopyable { + TSettingsHolder(NYT::NApi::IConnectionPtr&& connection, NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient>&& client, + TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& inputs, TVector<size_t>&& originalIndexes) + : Connection(std::move(connection)) + , Client(std::move(client)) + , RawInputs(std::move(inputs)) + , OriginalIndexes(std::move(originalIndexes)) {}; + NYT::NApi::IConnectionPtr Connection; + NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient> Client; + const TMkqlIOSpecs* Specs = nullptr; + arrow::MemoryPool* Pool = nullptr; + const NUdf::IPgBuilder* PgBuilder = nullptr; + TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> RawInputs; + TVector<size_t> OriginalIndexes; +}; + +std::unique_ptr<TSettingsHolder> CreateInputStreams(bool isArrow, const TString& token, const TString& clusterName, const ui64 timeout, bool unordered, const TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>& tables, NYT::TNode samplingSpec); + +}; diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp index d9cd70fa492..2bdd1e41a6c 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp @@ -1,4 +1,5 @@ #include "dq_yt_rpc_reader.h" +#include "dq_yt_rpc_helpers.h" #include "yt/cpp/mapreduce/common/helpers.h" @@ -11,86 +12,12 @@ #include <yt/yt/client/api/rpc_proxy/connection.h> #include <yt/yt/client/api/rpc_proxy/row_stream.h> -#include <ydb/library/yql/providers/yt/lib/yt_rpc_helpers/yt_convert_helpers.h> - namespace NYql::NDqs { using namespace NKikimr::NMiniKQL; namespace { TStatKey RPCReaderAwaitingStallTime("Job_RPCReaderAwaitingStallTime", true); -NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath) { - NYT::NYPath::TRichYPath tableYPath(richYPath.Path_); - const auto& rngs = richYPath.GetRanges(); - if (rngs) { - TVector<NYT::NChunkClient::TReadRange> ranges; - for (const auto& rng: *rngs) { - auto& range = ranges.emplace_back(); - if (rng.LowerLimit_.Offset_) { - range.LowerLimit().SetOffset(*rng.LowerLimit_.Offset_); - } - - if (rng.LowerLimit_.TabletIndex_) { - range.LowerLimit().SetTabletIndex(*rng.LowerLimit_.TabletIndex_); - } - - if (rng.LowerLimit_.RowIndex_) { - range.LowerLimit().SetRowIndex(*rng.LowerLimit_.RowIndex_); - } - - if (rng.UpperLimit_.Offset_) { - range.UpperLimit().SetOffset(*rng.UpperLimit_.Offset_); - } - - if (rng.UpperLimit_.TabletIndex_) { - range.UpperLimit().SetTabletIndex(*rng.UpperLimit_.TabletIndex_); - } - - if (rng.UpperLimit_.RowIndex_) { - range.UpperLimit().SetRowIndex(*rng.UpperLimit_.RowIndex_); - } - } - tableYPath.SetRanges(std::move(ranges)); - } - - if (richYPath.Columns_) { - tableYPath.SetColumns(richYPath.Columns_->Parts_); - } - - return tableYPath; -} - -class TFakeRPCReader : public NYT::TRawTableReader { -public: - TFakeRPCReader(NYT::TSharedRef&& payload) : Payload_(std::move(payload)), PayloadStream_(Payload_.Begin(), Payload_.Size()) {} - - bool Retry(const TMaybe<ui32>& rangeIndex, const TMaybe<ui64>& rowIndex) override { - Y_UNUSED(rangeIndex); - Y_UNUSED(rowIndex); - return false; - } - - void ResetRetries() override { - - } - - bool HasRangeIndices() const override { - return true; - }; - - size_t DoRead(void* buf, size_t len) override { - if (!PayloadStream_.Exhausted()) { - return PayloadStream_.Read(buf, len); - } - return 0; - }; - - virtual ~TFakeRPCReader() override { - } -private: - NYT::TSharedRef Payload_; - TMemoryInput PayloadStream_; -}; } #ifdef RPC_PRINT_TIME int cnt = 0; @@ -102,28 +29,21 @@ void print_add(int x) { } #endif -struct TSettingsHolder { - NYT::NApi::IConnectionPtr Connection; - NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient> Client; -}; - TParallelFileInputState::TParallelFileInputState(const TMkqlIOSpecs& spec, const NKikimr::NMiniKQL::THolderFactory& holderFactory, - TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& rawInputs, size_t blockSize, size_t inflight, - std::unique_ptr<TSettingsHolder>&& settings, - TVector<size_t>&& originalIndexes) - : InnerState_(new TInnerState (rawInputs.size())) - , StateByReader_(rawInputs.size()) + std::unique_ptr<TSettingsHolder>&& settings) + : InnerState_(new TInnerState (settings->RawInputs.size())) + , StateByReader_(settings->RawInputs.size()) , Spec_(&spec) , HolderFactory_(holderFactory) - , RawInputs_(std::move(rawInputs)) + , RawInputs_(std::move(settings->RawInputs)) , BlockSize_(blockSize) , Inflight_(inflight) - , Settings_(std::move(settings)) , TimerAwaiting_(RPCReaderAwaitingStallTime, 100) - , OriginalIndexes_(std::move(originalIndexes)) + , OriginalIndexes_(std::move(settings->OriginalIndexes)) + , Settings_(std::move(settings)) { #ifdef RPC_PRINT_TIME print_add(1); @@ -277,79 +197,14 @@ bool TParallelFileInputState::NextValue() { InnerState_->WaitPromise = NYT::NewPromise<void>(); } CurrentInput_ = result.Input_; - CurrentReader_ = MakeIntrusive<TFakeRPCReader>(std::move(result.Value_)); + CurrentReader_ = MakeIntrusive<TPayloadRPCReader>(std::move(result.Value_)); MkqlReader_.SetReader(*CurrentReader_, 1, BlockSize_, ui32(OriginalIndexes_[CurrentInput_]), true, StateByReader_[CurrentInput_].CurrentRow, StateByReader_[CurrentInput_].CurrentRange); MkqlReader_.Next(); } } void TDqYtReadWrapperRPC::MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { - auto connectionConfig = NYT::New<NYT::NApi::NRpcProxy::TConnectionConfig>(); - connectionConfig->ClusterUrl = ClusterName; - connectionConfig->DefaultTotalStreamingTimeout = TDuration::MilliSeconds(Timeout); - auto connection = CreateConnection(connectionConfig); - auto clientOptions = NYT::NApi::TClientOptions(); - - if (Token) { - clientOptions.Token = Token; - } - - auto client = DynamicPointerCast<NYT::NApi::NRpcProxy::TClient>(connection->CreateClient(clientOptions)); - Y_VERIFY(client); - auto apiServiceProxy = client->CreateApiServiceProxy(); - - TVector<NYT::TFuture<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>> waitFor; - TVector<size_t> originalIndexes; - size_t inputIdx = 0; - for (auto [richYPath, format]: Tables) { - if (richYPath.GetRanges() && richYPath.GetRanges()->empty()) { - ++inputIdx; - continue; - } - - originalIndexes.push_back(inputIdx++); - auto request = apiServiceProxy.ReadTable(); - client->InitStreamingRequest(*request); - request->ClientAttachmentsStreamingParameters().ReadTimeout = TDuration::MilliSeconds(Timeout); - - TString ppath; - auto tableYPath = ConvertYPathFromOld(richYPath); - - NYT::NYPath::ToProto(&ppath, tableYPath); - request->set_path(ppath); - request->set_desired_rowset_format(NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT); - - request->set_enable_table_index(true); - request->set_enable_range_index(true); - request->set_enable_row_index(true); - request->set_unordered(Inflight > 1); - - // https://a.yandex-team.ru/arcadia/yt/yt_proto/yt/client/api/rpc_proxy/proto/api_service.proto?rev=r11519304#L2338 - if (!SamplingSpec.IsUndefined()) { - TStringStream ss; - SamplingSpec.Save(&ss); - request->set_config(ss.Str()); - } - ConfigureTransaction(request, richYPath); - - // Get skiff format yson string - TStringStream fmt; - format.Config.Save(&fmt); - request->set_format(fmt.Str()); - - waitFor.emplace_back(std::move(CreateRpcClientInputStream(std::move(request)).ApplyUnique(BIND([](NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr&& stream) { - // first packet contains meta, skip it - return stream->Read().ApplyUnique(BIND([stream = std::move(stream)](NYT::TSharedRef&&) { - return std::move(stream); - })); - })))); - } - - TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> rawInputs; - NYT::NConcurrency::WaitFor(NYT::AllSucceeded(waitFor)).ValueOrThrow().swap(rawInputs); - state = ctx.HolderFactory.Create<TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>::TState>( - Specs, ctx.HolderFactory, std::move(rawInputs), 4_MB, Inflight, - std::make_unique<TSettingsHolder>(TSettingsHolder{std::move(connection), std::move(client)}), std::move(originalIndexes) - ); + auto settings = CreateInputStreams(false, Token, ClusterName, Timeout, Inflight > 1, Tables, SamplingSpec); + state = ctx.HolderFactory.Create<TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>::TState>(Specs, ctx.HolderFactory, 4_MB, Inflight, std::move(settings)); } } diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h index 84d7a31b79a..0219ac56450 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h @@ -22,11 +22,9 @@ struct TReaderState { public: TParallelFileInputState(const TMkqlIOSpecs& spec, const NKikimr::NMiniKQL::THolderFactory& holderFactory, - TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& rawInputs, size_t blockSize, size_t inflight, - std::unique_ptr<TSettingsHolder>&& client, - TVector<size_t>&& originalIndexes); + std::unique_ptr<TSettingsHolder>&& settings); size_t GetTableIndex() const; @@ -71,12 +69,11 @@ private: size_t CurrentRecord_ = 1; size_t Inflight_ = 1; bool Valid_ = true; - std::unique_ptr<TSettingsHolder> Settings_; NUdf::TUnboxedValue CurrentValue_; std::function<void()> OnNextBlockCallback_; NKikimr::NMiniKQL::TSamplingStatTimer TimerAwaiting_; TVector<size_t> OriginalIndexes_; - + std::unique_ptr<TSettingsHolder> Settings_; }; diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp new file mode 100644 index 00000000000..920575a5b84 --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp @@ -0,0 +1,1404 @@ +#include "stream_decoder.h" + +#include <algorithm> +#include <climits> +#include <cstdint> +#include <cstring> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +#include <flatbuffers/flatbuffers.h> // IWYU pragma: export + +#include <arrow/array.h> +#include <arrow/buffer.h> +#include <arrow/extension_type.h> +#include <arrow/io/caching.h> +#include <arrow/io/interfaces.h> +#include <arrow/io/memory.h> +#include <arrow/ipc/message.h> +#include <arrow/ipc/metadata_internal.h> +#include <arrow/ipc/util.h> +#include <arrow/ipc/writer.h> +#include <arrow/record_batch.h> +#include <arrow/sparse_tensor.h> +#include <arrow/status.h> +#include <arrow/type.h> +#include <arrow/type_traits.h> +#include <arrow/util/bit_util.h> +#include <arrow/util/bitmap_ops.h> +#include <arrow/util/checked_cast.h> +#include <arrow/util/compression.h> +#include <arrow/util/endian.h> +#include <arrow/util/key_value_metadata.h> +#include <arrow/util/parallel.h> +#include <arrow/util/string.h> +#include <arrow/util/thread_pool.h> +#include <arrow/util/ubsan.h> +#include <arrow/visitor_inline.h> + +#include <generated/File.fbs.h> // IWYU pragma: export +#include <generated/Message.fbs.h> +#include <generated/Schema.fbs.h> +#include <generated/SparseTensor.fbs.h> + +namespace arrow { + +namespace flatbuf = org::apache::arrow::flatbuf; + +using internal::checked_cast; +using internal::checked_pointer_cast; +using internal::GetByteWidth; + +namespace ipc { + +using internal::FileBlock; +using internal::kArrowMagicBytes; + +namespace NDqs { +Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) { + if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) { + ARROW_ASSIGN_OR_RAISE(*metadata, (*metadata)->CopySlice(0, (*metadata)->size())); + } + return Status::OK(); +} + +Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) { + const flatbuf::Message* fb_message = nullptr; + RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message)); + *body_length = fb_message->bodyLength(); + if (*body_length < 0) { + return Status::IOError("Invalid IPC message: negative bodyLength"); + } + return Status::OK(); +} + +std::string FormatMessageType(MessageType type) { + switch (type) { + case MessageType::SCHEMA: + return "schema"; + case MessageType::RECORD_BATCH: + return "record batch"; + case MessageType::DICTIONARY_BATCH: + return "dictionary"; + case MessageType::TENSOR: + return "tensor"; + case MessageType::SPARSE_TENSOR: + return "sparse tensor"; + default: + break; + } + return "unknown"; +} + +Status WriteMessage(const Buffer& message, const IpcWriteOptions& options, + io::OutputStream* file, int32_t* message_length) { + const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8; + const int32_t flatbuffer_size = static_cast<int32_t>(message.size()); + + int32_t padded_message_length = static_cast<int32_t>( + PaddedLength(flatbuffer_size + prefix_size, options.alignment)); + + int32_t padding = padded_message_length - flatbuffer_size - prefix_size; + + *message_length = padded_message_length; + + if (!options.write_legacy_ipc_format) { + RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t))); + } + + int32_t padded_flatbuffer_size = + BitUtil::ToLittleEndian(padded_message_length - prefix_size); + RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t))); + + RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size)); + if (padding > 0) { + RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); + } + + return Status::OK(); +} + +static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t); +static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t); + +class MessageDecoder2::MessageDecoderImpl { + public: + explicit MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener, + State initial_state, int64_t initial_next_required_size, + MemoryPool* pool) + : listener_(std::move(listener)), + pool_(pool), + state_(initial_state), + next_required_size_(initial_next_required_size), + save_initial_size_(initial_next_required_size), + chunks_(), + buffered_size_(0), + metadata_(nullptr) {} + + Status ConsumeData(const uint8_t* data, int64_t size) { + if (buffered_size_ == 0) { + while (size > 0 && size >= next_required_size_) { + auto used_size = next_required_size_; + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_)); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_)); + break; + case State::METADATA: { + auto buffer = std::make_shared<Buffer>(data, next_required_size_); + RETURN_NOT_OK(ConsumeMetadataBuffer(buffer)); + } break; + case State::BODY: { + auto buffer = std::make_shared<Buffer>(data, next_required_size_); + RETURN_NOT_OK(ConsumeBodyBuffer(buffer)); + } break; + case State::EOS: + return Status::OK(); + } + data += used_size; + size -= used_size; + } + } + + if (size == 0) { + return Status::OK(); + } + + chunks_.push_back(std::make_shared<Buffer>(data, size)); + buffered_size_ += size; + return ConsumeChunks(); + } + + Status ConsumeBuffer(std::shared_ptr<Buffer> buffer) { + if (buffered_size_ == 0) { + while (buffer->size() >= next_required_size_) { + auto used_size = next_required_size_; + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialBuffer(buffer)); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer)); + break; + case State::METADATA: + if (buffer->size() == next_required_size_) { + return ConsumeMetadataBuffer(buffer); + } else { + auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_); + RETURN_NOT_OK(ConsumeMetadataBuffer(sliced_buffer)); + } + break; + case State::BODY: + if (buffer->size() == next_required_size_) { + return ConsumeBodyBuffer(buffer); + } else { + auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_); + RETURN_NOT_OK(ConsumeBodyBuffer(sliced_buffer)); + } + break; + case State::EOS: + return Status::OK(); + } + if (buffer->size() == used_size) { + return Status::OK(); + } + buffer = SliceBuffer(buffer, used_size); + } + } + + if (buffer->size() == 0) { + return Status::OK(); + } + + buffered_size_ += buffer->size(); + chunks_.push_back(std::move(buffer)); + return ConsumeChunks(); + } + + int64_t next_required_size() const { return next_required_size_ - buffered_size_; } + + MessageDecoder2::State state() const { return state_; } + void Reset() { + state_ = State::INITIAL; + next_required_size_ = save_initial_size_; + chunks_.clear(); + buffered_size_ = 0; + metadata_ = nullptr; + } + + private: + Status ConsumeChunks() { + while (state_ != State::EOS) { + if (buffered_size_ < next_required_size_) { + return Status::OK(); + } + + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialChunks()); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthChunks()); + break; + case State::METADATA: + RETURN_NOT_OK(ConsumeMetadataChunks()); + break; + case State::BODY: + RETURN_NOT_OK(ConsumeBodyChunks()); + break; + case State::EOS: + return Status::OK(); + } + } + + return Status::OK(); + } + + Status ConsumeInitialData(const uint8_t* data, int64_t) { + return ConsumeInitial(BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data))); + } + + Status ConsumeInitialBuffer(const std::shared_ptr<Buffer>& buffer) { + ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer)); + return ConsumeInitial(BitUtil::FromLittleEndian(continuation)); + } + + Status ConsumeInitialChunks() { + int32_t continuation = 0; + RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation)); + return ConsumeInitial(BitUtil::FromLittleEndian(continuation)); + } + + Status ConsumeInitial(int32_t continuation) { + if (continuation == internal::kIpcContinuationToken) { + state_ = State::METADATA_LENGTH; + next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength; + RETURN_NOT_OK(listener_->OnMetadataLength()); + // Valid IPC message, read the message length now + return Status::OK(); + } else if (continuation == 0) { + state_ = State::EOS; + next_required_size_ = 0; + RETURN_NOT_OK(listener_->OnEOS()); + return Status::OK(); + } else if (continuation > 0) { + state_ = State::METADATA; + next_required_size_ = continuation; + RETURN_NOT_OK(listener_->OnMetadata()); + return Status::OK(); + } else { + return Status::IOError("Invalid IPC stream: negative continuation token"); + } + } + + Status ConsumeMetadataLengthData(const uint8_t* data, int64_t) { + return ConsumeMetadataLength( + BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data))); + } + + Status ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer>& buffer) { + ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer)); + return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length)); + } + + Status ConsumeMetadataLengthChunks() { + int32_t metadata_length = 0; + RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length)); + return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length)); + } + + Status ConsumeMetadataLength(int32_t metadata_length) { + if (metadata_length == 0) { + state_ = State::EOS; + next_required_size_ = 0; + RETURN_NOT_OK(listener_->OnEOS()); + return Status::OK(); + } else if (metadata_length > 0) { + state_ = State::METADATA; + next_required_size_ = metadata_length; + RETURN_NOT_OK(listener_->OnMetadata()); + return Status::OK(); + } else { + return Status::IOError("Invalid IPC message: negative metadata length"); + } + } + + Status ConsumeMetadataBuffer(const std::shared_ptr<Buffer>& buffer) { + if (buffer->is_cpu()) { + metadata_ = buffer; + } else { + ARROW_ASSIGN_OR_RAISE(metadata_, + Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_))); + } + return ConsumeMetadata(); + } + + Status ConsumeMetadataChunks() { + if (chunks_[0]->size() >= next_required_size_) { + if (chunks_[0]->size() == next_required_size_) { + if (chunks_[0]->is_cpu()) { + metadata_ = std::move(chunks_[0]); + } else { + ARROW_ASSIGN_OR_RAISE( + metadata_, + Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_))); + } + chunks_.erase(chunks_.begin()); + } else { + metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_); + if (!chunks_[0]->is_cpu()) { + ARROW_ASSIGN_OR_RAISE( + metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_))); + } + chunks_[0] = SliceBuffer(chunks_[0], next_required_size_); + } + buffered_size_ -= next_required_size_; + } else { + ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_)); + metadata_ = std::shared_ptr<Buffer>(metadata.release()); + RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data())); + } + return ConsumeMetadata(); + } + + Status ConsumeMetadata() { + RETURN_NOT_OK(MaybeAlignMetadata(&metadata_)); + int64_t body_length = -1; + RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length)); + + state_ = State::BODY; + next_required_size_ = body_length; + RETURN_NOT_OK(listener_->OnBody()); + if (next_required_size_ == 0) { + ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); + std::shared_ptr<Buffer> shared_body(body.release()); + return ConsumeBody(&shared_body); + } else { + return Status::OK(); + } + } + + Status ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer) { + return ConsumeBody(&buffer); + } + + Status ConsumeBodyChunks() { + if (chunks_[0]->size() >= next_required_size_) { + auto used_size = next_required_size_; + if (chunks_[0]->size() == next_required_size_) { + RETURN_NOT_OK(ConsumeBody(&chunks_[0])); + chunks_.erase(chunks_.begin()); + } else { + auto body = SliceBuffer(chunks_[0], 0, next_required_size_); + RETURN_NOT_OK(ConsumeBody(&body)); + chunks_[0] = SliceBuffer(chunks_[0], used_size); + } + buffered_size_ -= used_size; + return Status::OK(); + } else { + ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_)); + RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data())); + std::shared_ptr<Buffer> shared_body(body.release()); + return ConsumeBody(&shared_body); + } + } + + Status ConsumeBody(std::shared_ptr<Buffer>* buffer) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, + Message::Open(metadata_, *buffer)); + + RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message))); + state_ = State::INITIAL; + next_required_size_ = kMessageDecoderNextRequiredSizeInitial; + RETURN_NOT_OK(listener_->OnInitial()); + return Status::OK(); + } + + Result<int32_t> ConsumeDataBufferInt32(const std::shared_ptr<Buffer>& buffer) { + if (buffer->is_cpu()) { + return util::SafeLoadAs<int32_t>(buffer->data()); + } else { + ARROW_ASSIGN_OR_RAISE(auto cpu_buffer, + Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_))); + return util::SafeLoadAs<int32_t>(cpu_buffer->data()); + } + } + + Status ConsumeDataChunks(int64_t nbytes, void* out) { + size_t offset = 0; + size_t n_used_chunks = 0; + auto required_size = nbytes; + std::shared_ptr<Buffer> last_chunk; + for (auto& chunk : chunks_) { + if (!chunk->is_cpu()) { + ARROW_ASSIGN_OR_RAISE( + chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_))); + } + auto data = chunk->data(); + auto data_size = chunk->size(); + auto copy_size = std::min(required_size, data_size); + memcpy(static_cast<uint8_t*>(out) + offset, data, copy_size); + n_used_chunks++; + offset += copy_size; + required_size -= copy_size; + if (required_size == 0) { + if (data_size != copy_size) { + last_chunk = SliceBuffer(chunk, copy_size); + } + break; + } + } + chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks); + if (last_chunk.get() != nullptr) { + chunks_.insert(chunks_.begin(), std::move(last_chunk)); + } + buffered_size_ -= offset; + return Status::OK(); + } + + std::shared_ptr<MessageDecoderListener> listener_; + MemoryPool* pool_; + State state_; + int64_t next_required_size_, save_initial_size_; + std::vector<std::shared_ptr<Buffer>> chunks_; + int64_t buffered_size_; + std::shared_ptr<Buffer> metadata_; // Must be CPU buffer +}; + +MessageDecoder2::MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener, + MemoryPool* pool) { + impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL, + kMessageDecoderNextRequiredSizeInitial, pool)); +} + +MessageDecoder2::MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener, + State initial_state, int64_t initial_next_required_size, + MemoryPool* pool) { + impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state, + initial_next_required_size, pool)); +} + +MessageDecoder2::~MessageDecoder2() {} + +Status MessageDecoder2::Consume(const uint8_t* data, int64_t size) { + return impl_->ConsumeData(data, size); +} + +void MessageDecoder2::Reset() { + impl_->Reset(); +} + +Status MessageDecoder2::Consume(std::shared_ptr<Buffer> buffer) { + return impl_->ConsumeBuffer(buffer); +} + +int64_t MessageDecoder2::next_required_size() const { return impl_->next_required_size(); } + +MessageDecoder2::State MessageDecoder2::state() const { return impl_->state(); } + +enum class DictionaryKind { New, Delta, Replacement }; + +Status InvalidMessageType(MessageType expected, MessageType actual) { + return Status::IOError("Expected IPC message of type ", ::arrow::ipc::FormatMessageType(expected), + " but got ", ::arrow::ipc::FormatMessageType(actual)); +} + +#define CHECK_MESSAGE_TYPE(expected, actual) \ + do { \ + if ((actual) != (expected)) { \ + return InvalidMessageType((expected), (actual)); \ + } \ + } while (0) + +#define CHECK_HAS_BODY(message) \ + do { \ + if ((message).body() == nullptr) { \ + return Status::IOError("Expected body in IPC message of type ", \ + ::arrow::ipc::FormatMessageType((message).type())); \ + } \ + } while (0) + +#define CHECK_HAS_NO_BODY(message) \ + do { \ + if ((message).body_length() != 0) { \ + return Status::IOError("Unexpected body in IPC message of type ", \ + ::arrow::ipc::FormatMessageType((message).type())); \ + } \ + } while (0) +struct IpcReadContext { + IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap, + MetadataVersion version = MetadataVersion::V5, + Compression::type kind = Compression::UNCOMPRESSED) + : dictionary_memo(memo), + options(option), + metadata_version(version), + compression(kind), + swap_endian(swap) {} + + DictionaryMemo* dictionary_memo; + + const IpcReadOptions& options; + + MetadataVersion metadata_version; + + Compression::type compression; + + const bool swap_endian; +}; + + + +Result<std::shared_ptr<Buffer>> DecompressBuffer(const std::shared_ptr<Buffer>& buf, + const IpcReadOptions& options, + util::Codec* codec) { + if (buf == nullptr || buf->size() == 0) { + return buf; + } + + if (buf->size() < 8) { + return Status::Invalid( + "Likely corrupted message, compressed buffers " + "are larger than 8 bytes by construction"); + } + + const uint8_t* data = buf->data(); + int64_t compressed_size = buf->size() - sizeof(int64_t); + int64_t uncompressed_size = BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data)); + + ARROW_ASSIGN_OR_RAISE(auto uncompressed, + AllocateBuffer(uncompressed_size, options.memory_pool)); + + ARROW_ASSIGN_OR_RAISE( + int64_t actual_decompressed, + codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size, + uncompressed->mutable_data())); + if (actual_decompressed != uncompressed_size) { + return Status::Invalid("Failed to fully decompress buffer, expected ", + uncompressed_size, " bytes but decompressed ", + actual_decompressed); + } + + return std::move(uncompressed); +} + +Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options, + ArrayDataVector* fields) { + struct BufferAccumulator { + using BufferPtrVector = std::vector<std::shared_ptr<Buffer>*>; + + void AppendFrom(const ArrayDataVector& fields) { + for (const auto& field : fields) { + for (auto& buffer : field->buffers) { + buffers_.push_back(&buffer); + } + AppendFrom(field->child_data); + } + } + + BufferPtrVector Get(const ArrayDataVector& fields) && { + AppendFrom(fields); + return std::move(buffers_); + } + + BufferPtrVector buffers_; + }; + + auto buffers = BufferAccumulator{}.Get(*fields); + + std::unique_ptr<util::Codec> codec; + ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression)); + + return ::arrow::internal::OptionalParallelFor( + options.use_threads, static_cast<int>(buffers.size()), [&](int i) { + ARROW_ASSIGN_OR_RAISE(*buffers[i], + DecompressBuffer(*buffers[i], options, codec.get())); + return Status::OK(); + }); +} +class ArrayLoader { + public: + explicit ArrayLoader(const flatbuf::RecordBatch* metadata, + MetadataVersion metadata_version, const IpcReadOptions& options, + io::RandomAccessFile* file) + : metadata_(metadata), + metadata_version_(metadata_version), + file_(file), + max_recursion_depth_(options.max_recursion_depth) {} + + Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) { + if (skip_io_) { + return Status::OK(); + } + if (offset < 0) { + return Status::Invalid("Negative offset for reading buffer ", buffer_index_); + } + if (length < 0) { + return Status::Invalid("Negative length for reading buffer ", buffer_index_); + } + if (!BitUtil::IsMultipleOf8(offset)) { + return Status::Invalid("Buffer ", buffer_index_, + " did not start on 8-byte aligned offset: ", offset); + } + return file_->ReadAt(offset, length).Value(out); + } + + Status LoadType(const DataType& type) { return VisitTypeInline(type, this); } + + Status Load(const Field* field, ArrayData* out) { + if (max_recursion_depth_ <= 0) { + return Status::Invalid("Max recursion depth reached"); + } + + field_ = field; + out_ = out; + out_->type = field_->type(); + return LoadType(*field_->type()); + } + + Status SkipField(const Field* field) { + ArrayData dummy; + skip_io_ = true; + Status status = Load(field, &dummy); + skip_io_ = false; + return status; + } + + Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) { + auto buffers = metadata_->buffers(); + CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers"); + if (buffer_index >= static_cast<int>(buffers->size())) { + return Status::IOError("buffer_index out of range."); + } + const flatbuf::Buffer* buffer = buffers->Get(buffer_index); + if (buffer->length() == 0) { + // Should never return a null buffer here. + // (zero-sized buffer allocations are cheap) + return AllocateBuffer(0).Value(out); + } else { + return ReadBuffer(buffer->offset(), buffer->length(), out); + } + } + + Status GetFieldMetadata(int field_index, ArrayData* out) { + auto nodes = metadata_->nodes(); + CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes"); + if (field_index >= static_cast<int>(nodes->size())) { + return Status::Invalid("Ran out of field metadata, likely malformed"); + } + const flatbuf::FieldNode* node = nodes->Get(field_index); + + out->length = node->length(); + out->null_count = node->null_count(); + out->offset = 0; + return Status::OK(); + } + + Status LoadCommon(Type::type type_id) { + RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_)); + + if (internal::HasValidityBitmap(type_id, metadata_version_)) { + // Extract null_bitmap which is common to all arrays except for unions + // and nulls. + if (out_->null_count != 0) { + RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0])); + } + buffer_index_++; + } + return Status::OK(); + } + + template <typename TYPE> + Status LoadPrimitive(Type::type type_id) { + out_->buffers.resize(2); + + RETURN_NOT_OK(LoadCommon(type_id)); + if (out_->length > 0) { + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); + } else { + buffer_index_++; + out_->buffers[1].reset(new Buffer(nullptr, 0)); + } + return Status::OK(); + } + + template <typename TYPE> + Status LoadBinary(Type::type type_id) { + out_->buffers.resize(3); + + RETURN_NOT_OK(LoadCommon(type_id)); + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); + return GetBuffer(buffer_index_++, &out_->buffers[2]); + } + + template <typename TYPE> + Status LoadList(const TYPE& type) { + out_->buffers.resize(2); + + RETURN_NOT_OK(LoadCommon(type.id())); + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); + + const int num_children = type.num_fields(); + if (num_children != 1) { + return Status::Invalid("Wrong number of children: ", num_children); + } + + return LoadChildren(type.fields()); + } + + Status LoadChildren(const std::vector<std::shared_ptr<Field>>& child_fields) { + ArrayData* parent = out_; + + parent->child_data.resize(child_fields.size()); + for (int i = 0; i < static_cast<int>(child_fields.size()); ++i) { + parent->child_data[i] = std::make_shared<ArrayData>(); + --max_recursion_depth_; + RETURN_NOT_OK(Load(child_fields[i].get(), parent->child_data[i].get())); + ++max_recursion_depth_; + } + out_ = parent; + return Status::OK(); + } + + Status Visit(const NullType&) { + out_->buffers.resize(1); + + return GetFieldMetadata(field_index_++, out_); + } + + template <typename T> + enable_if_t<std::is_base_of<FixedWidthType, T>::value && + !std::is_base_of<FixedSizeBinaryType, T>::value && + !std::is_base_of<DictionaryType, T>::value, + Status> + Visit(const T& type) { + return LoadPrimitive<T>(type.id()); + } + + template <typename T> + enable_if_base_binary<T, Status> Visit(const T& type) { + return LoadBinary<T>(type.id()); + } + + Status Visit(const FixedSizeBinaryType& type) { + out_->buffers.resize(2); + RETURN_NOT_OK(LoadCommon(type.id())); + return GetBuffer(buffer_index_++, &out_->buffers[1]); + } + + template <typename T> + enable_if_var_size_list<T, Status> Visit(const T& type) { + return LoadList(type); + } + + Status Visit(const MapType& type) { + RETURN_NOT_OK(LoadList(type)); + return MapArray::ValidateChildData(out_->child_data); + } + + Status Visit(const FixedSizeListType& type) { + out_->buffers.resize(1); + + RETURN_NOT_OK(LoadCommon(type.id())); + + const int num_children = type.num_fields(); + if (num_children != 1) { + return Status::Invalid("Wrong number of children: ", num_children); + } + + return LoadChildren(type.fields()); + } + + Status Visit(const StructType& type) { + out_->buffers.resize(1); + RETURN_NOT_OK(LoadCommon(type.id())); + return LoadChildren(type.fields()); + } + + Status Visit(const UnionType& type) { + int n_buffers = type.mode() == UnionMode::SPARSE ? 2 : 3; + out_->buffers.resize(n_buffers); + + RETURN_NOT_OK(LoadCommon(type.id())); + + if (out_->null_count != 0 && out_->buffers[0] != nullptr) { + return Status::Invalid( + "Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + } + out_->buffers[0] = nullptr; + out_->null_count = 0; + + if (out_->length > 0) { + RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1])); + if (type.mode() == UnionMode::DENSE) { + RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2])); + } + } + buffer_index_ += n_buffers - 1; + return LoadChildren(type.fields()); + } + + Status Visit(const DictionaryType& type) { + // out_->dictionary will be filled later in ResolveDictionaries() + return LoadType(*type.index_type()); + } + + Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); } + + private: + const flatbuf::RecordBatch* metadata_; + const MetadataVersion metadata_version_; + io::RandomAccessFile* file_; + int max_recursion_depth_; + int buffer_index_ = 0; + int field_index_ = 0; + bool skip_io_ = false; + + const Field* field_; + ArrayData* out_; +}; + +Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset( + const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema, + const std::vector<bool>* inclusion_mask, const IpcReadContext& context, + io::RandomAccessFile* file) { + ArrayLoader loader(metadata, context.metadata_version, context.options, file); + + ArrayDataVector columns(schema->num_fields()); + ArrayDataVector filtered_columns; + FieldVector filtered_fields; + std::shared_ptr<Schema> filtered_schema; + + for (int i = 0; i < schema->num_fields(); ++i) { + const Field& field = *schema->field(i); + if (!inclusion_mask || (*inclusion_mask)[i]) { + // Read field + auto column = std::make_shared<ArrayData>(); + RETURN_NOT_OK(loader.Load(&field, column.get())); + if (metadata->length() != column->length) { + return Status::IOError("Array length did not match record batch length"); + } + columns[i] = std::move(column); + if (inclusion_mask) { + filtered_columns.push_back(columns[i]); + filtered_fields.push_back(schema->field(i)); + } + } else { + RETURN_NOT_OK(loader.SkipField(&field)); + } + } + + RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo, + context.options.memory_pool)); + + if (inclusion_mask) { + filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata()); + columns.clear(); + } else { + filtered_schema = schema; + filtered_columns = std::move(columns); + } + if (context.compression != Compression::UNCOMPRESSED) { + RETURN_NOT_OK( + DecompressBuffers(context.compression, context.options, &filtered_columns)); + } + + if (context.swap_endian) { + for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) { + ARROW_ASSIGN_OR_RAISE(filtered_columns[i], + arrow::internal::SwapEndianArrayData(filtered_columns[i])); + } + } + return RecordBatch::Make(std::move(filtered_schema), metadata->length(), + std::move(filtered_columns)); +} + +Result<std::shared_ptr<RecordBatch>> LoadRecordBatch( + const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema, + const std::vector<bool>& inclusion_mask, const IpcReadContext& context, + io::RandomAccessFile* file) { + if (inclusion_mask.size() > 0) { + return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file); + } else { + return LoadRecordBatchSubset(metadata, schema, /*param_name=*/nullptr, context, file); + } +} + +Status GetCompression(const flatbuf::RecordBatch* batch, Compression::type* out) { + *out = Compression::UNCOMPRESSED; + const flatbuf::BodyCompression* compression = batch->compression(); + if (compression != nullptr) { + if (compression->method() != flatbuf::BodyCompressionMethod::BUFFER) { + return Status::Invalid("This library only supports BUFFER compression method"); + } + + if (compression->codec() == flatbuf::CompressionType::LZ4_FRAME) { + *out = Compression::LZ4_FRAME; + } else if (compression->codec() == flatbuf::CompressionType::ZSTD) { + *out = Compression::ZSTD; + } else { + return Status::Invalid("Unsupported codec in RecordBatch::compression metadata"); + } + return Status::OK(); + } + return Status::OK(); +} + +Status GetCompressionExperimental(const flatbuf::Message* message, + Compression::type* out) { + *out = Compression::UNCOMPRESSED; + if (message->custom_metadata() != nullptr) { + std::shared_ptr<KeyValueMetadata> metadata; + RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata)); + int index = metadata->FindKey("ARROW:experimental_compression"); + if (index != -1) { + auto name = arrow::internal::AsciiToLower(metadata->value(index)); + ARROW_ASSIGN_OR_RAISE(*out, util::Codec::GetCompressionType(name)); + } + return internal::CheckCompressionSupported(*out); + } + return Status::OK(); +} + +static Status ReadContiguousPayload(io::InputStream* file, + std::unique_ptr<Message>* message) { + ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file)); + if (*message == nullptr) { + return Status::Invalid("Unable to read metadata at offset"); + } + return Status::OK(); +} + +Result<std::shared_ptr<RecordBatch>> ReadRecordBatch( + const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, io::InputStream* file) { + std::unique_ptr<Message> message; + RETURN_NOT_OK(ReadContiguousPayload(file, &message)); + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options, + reader.get()); +} + +Result<std::shared_ptr<RecordBatch>> ReadRecordBatch( + const Message& message, const std::shared_ptr<Schema>& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) { + CHECK_MESSAGE_TYPE(MessageType::RECORD_BATCH, message.type()); + CHECK_HAS_BODY(message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options, + reader.get()); +} + +Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal( + const Buffer& metadata, const std::shared_ptr<Schema>& schema, + const std::vector<bool>& inclusion_mask, IpcReadContext& context, + io::RandomAccessFile* file) { + const flatbuf::Message* message = nullptr; + RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); + auto batch = message->header_as_RecordBatch(); + if (batch == nullptr) { + return Status::IOError( + "Header-type of flatbuffer-encoded Message is not RecordBatch."); + } + + Compression::type compression; + RETURN_NOT_OK(GetCompression(batch, &compression)); + if (context.compression == Compression::UNCOMPRESSED && + message->version() == flatbuf::MetadataVersion::V4) { + RETURN_NOT_OK(GetCompressionExperimental(message, &compression)); + } + context.compression = compression; + context.metadata_version = internal::GetMetadataVersion(message->version()); + return LoadRecordBatch(batch, schema, inclusion_mask, context, file); +} + +Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema, + const std::vector<int>& included_indices, + std::vector<bool>* inclusion_mask, + std::shared_ptr<Schema>* out_schema) { + inclusion_mask->clear(); + if (included_indices.empty()) { + *out_schema = full_schema; + return Status::OK(); + } + + inclusion_mask->resize(full_schema->num_fields(), false); + + auto included_indices_sorted = included_indices; + std::sort(included_indices_sorted.begin(), included_indices_sorted.end()); + + FieldVector included_fields; + for (int i : included_indices_sorted) { + if (i < 0 || i >= full_schema->num_fields()) { + return Status::Invalid("Out of bounds field index: ", i); + } + + if (inclusion_mask->at(i)) continue; + + inclusion_mask->at(i) = true; + included_fields.push_back(full_schema->field(i)); + } + + *out_schema = schema(std::move(included_fields), full_schema->endianness(), + full_schema->metadata()); + return Status::OK(); +} + +Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options, + DictionaryMemo* dictionary_memo, + std::shared_ptr<Schema>* schema, + std::shared_ptr<Schema>* out_schema, + std::vector<bool>* field_inclusion_mask, bool* swap_endian) { + RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema)); + + RETURN_NOT_OK(GetInclusionMaskAndOutSchema(*schema, options.included_fields, + field_inclusion_mask, out_schema)); + *swap_endian = options.ensure_native_endian && !out_schema->get()->is_native_endian(); + if (*swap_endian) { + *schema = schema->get()->WithEndianness(Endianness::Native); + *out_schema = out_schema->get()->WithEndianness(Endianness::Native); + } + return Status::OK(); +} + +Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options, + DictionaryMemo* dictionary_memo, + std::shared_ptr<Schema>* schema, + std::shared_ptr<Schema>* out_schema, + std::vector<bool>* field_inclusion_mask, bool* swap_endian) { + CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message.type()); + CHECK_HAS_NO_BODY(message); + + return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema, + out_schema, field_inclusion_mask, swap_endian); +} + +Result<std::shared_ptr<RecordBatch>> ReadRecordBatch( + const Buffer& metadata, const std::shared_ptr<Schema>& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, + io::RandomAccessFile* file) { + std::shared_ptr<Schema> out_schema; + // Empty means do not use + std::vector<bool> inclusion_mask; + IpcReadContext context(const_cast<DictionaryMemo*>(dictionary_memo), options, false); + RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields, + &inclusion_mask, &out_schema)); + return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file); +} + +Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context, + DictionaryKind* kind, io::RandomAccessFile* file) { + const flatbuf::Message* message = nullptr; + RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); + const auto dictionary_batch = message->header_as_DictionaryBatch(); + if (dictionary_batch == nullptr) { + return Status::IOError( + "Header-type of flatbuffer-encoded Message is not DictionaryBatch."); + } + + // The dictionary is embedded in a record batch with a single column + const auto batch_meta = dictionary_batch->data(); + + CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data"); + + Compression::type compression; + RETURN_NOT_OK(GetCompression(batch_meta, &compression)); + if (compression == Compression::UNCOMPRESSED && + message->version() == flatbuf::MetadataVersion::V4) { + RETURN_NOT_OK(GetCompressionExperimental(message, &compression)); + } + + const int64_t id = dictionary_batch->id(); + + ARROW_ASSIGN_OR_RAISE(auto value_type, context.dictionary_memo->GetDictionaryType(id)); + + ArrayLoader loader(batch_meta, internal::GetMetadataVersion(message->version()), + context.options, file); + auto dict_data = std::make_shared<ArrayData>(); + const Field dummy_field("", value_type); + RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get())); + + if (compression != Compression::UNCOMPRESSED) { + ArrayDataVector dict_fields{dict_data}; + RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields)); + } + + if (context.swap_endian) { + ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(dict_data)); + } + + if (dictionary_batch->isDelta()) { + if (kind != nullptr) { + *kind = DictionaryKind::Delta; + } + return context.dictionary_memo->AddDictionaryDelta(id, dict_data); + } + ARROW_ASSIGN_OR_RAISE(bool inserted, + context.dictionary_memo->AddOrReplaceDictionary(id, dict_data)); + if (kind != nullptr) { + *kind = inserted ? DictionaryKind::New : DictionaryKind::Replacement; + } + return Status::OK(); +} + +Status ReadDictionary(const Message& message, const IpcReadContext& context, + DictionaryKind* kind) { + DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH); + CHECK_HAS_BODY(message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + return ReadDictionary(*message.metadata(), context, kind, reader.get()); +} + + +class StreamDecoder2::StreamDecoder2Impl : public MessageDecoderListener { + private: + enum State { + SCHEMA, + INITIAL_DICTIONARIES, + RECORD_BATCHES, + EOS, + }; + + public: + explicit StreamDecoder2Impl(std::shared_ptr<Listener> listener, IpcReadOptions options) + : listener_(std::move(listener)), + options_(std::move(options)), + state_(State::SCHEMA), + message_decoder_(std::shared_ptr<StreamDecoder2Impl>(this, [](void*) {}), + options_.memory_pool), + n_required_dictionaries_(0), + dictionary_memo_(std::make_unique<DictionaryMemo>()) {} + + void Reset() { + state_ = State::SCHEMA; + field_inclusion_mask_.clear(); + n_required_dictionaries_ = 0; + dictionary_memo_ = std::make_unique<DictionaryMemo>(); + schema_ = out_schema_ = nullptr; + message_decoder_.Reset(); + } + + Status OnMessageDecoded(std::unique_ptr<Message> message) override { + ++stats_.num_messages; + switch (state_) { + case State::SCHEMA: + ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message))); + break; + case State::INITIAL_DICTIONARIES: + ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message))); + break; + case State::RECORD_BATCHES: + ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message))); + break; + case State::EOS: + break; + } + return Status::OK(); + } + + Status OnEOS() override { + state_ = State::EOS; + return listener_->OnEOS(); + } + Status Consume(const uint8_t* data, int64_t size) { + return message_decoder_.Consume(data, size); + } + + Status Consume(std::shared_ptr<Buffer> buffer) { + return message_decoder_.Consume(std::move(buffer)); + } + + std::shared_ptr<Schema> schema() const { return out_schema_; } + + int64_t next_required_size() const { return message_decoder_.next_required_size(); } + + ReadStats stats() const { return stats_; } + + private: + Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) { + RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, dictionary_memo_.get(), &schema_, + &out_schema_, &field_inclusion_mask_, + &swap_endian_)); + + n_required_dictionaries_ = dictionary_memo_->fields().num_fields(); + if (n_required_dictionaries_ == 0) { + state_ = State::RECORD_BATCHES; + RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } else { + state_ = State::INITIAL_DICTIONARIES; + } + return Status::OK(); + } + + Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) { + if (message->type() != MessageType::DICTIONARY_BATCH) { + return Status::Invalid("IPC stream did not have the expected number (", + dictionary_memo_->fields().num_fields(), + ") of dictionaries at the start of the stream"); + } + RETURN_NOT_OK(ReadDictionary(*message)); + n_required_dictionaries_--; + if (n_required_dictionaries_ == 0) { + state_ = State::RECORD_BATCHES; + ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } + return Status::OK(); + } + + Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) { + IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_); + if (message->type() == MessageType::DICTIONARY_BATCH) { + return ReadDictionary(*message); + } else { + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_); + ARROW_ASSIGN_OR_RAISE( + auto batch, + ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + context, reader.get())); + ++stats_.num_record_batches; + return listener_->OnRecordBatchDecoded(std::move(batch)); + } + } + + Status ReadDictionary(const Message& message) { + DictionaryKind kind; + IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_); + RETURN_NOT_OK(::arrow::ipc::NDqs::ReadDictionary(message, context, &kind)); + ++stats_.num_dictionary_batches; + switch (kind) { + case DictionaryKind::New: + break; + case DictionaryKind::Delta: + ++stats_.num_dictionary_deltas; + break; + case DictionaryKind::Replacement: + ++stats_.num_replaced_dictionaries; + break; + } + return Status::OK(); + } + + std::shared_ptr<Listener> listener_; + const IpcReadOptions options_; + State state_; + MessageDecoder2 message_decoder_; + std::vector<bool> field_inclusion_mask_; + int n_required_dictionaries_; + std::unique_ptr<DictionaryMemo> dictionary_memo_; + std::shared_ptr<Schema> schema_, out_schema_; + ReadStats stats_; + bool swap_endian_; +}; + +Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream, + DictionaryMemo* dictionary_memo) { + std::unique_ptr<MessageReader> reader = MessageReader::Open(stream); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage()); + if (!message) { + return Status::Invalid("Tried reading schema message, was null or length 0"); + } + CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message->type()); + return ReadSchema(*message, dictionary_memo); +} + +Result<std::shared_ptr<Schema>> ReadSchema(const Message& message, + DictionaryMemo* dictionary_memo) { + std::shared_ptr<Schema> result; + RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result)); + return result; +} + +Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) { + std::unique_ptr<Message> message; + RETURN_NOT_OK(ReadContiguousPayload(file, &message)); + return ReadTensor(*message); +} + +Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) { + std::shared_ptr<DataType> type; + std::vector<int64_t> shape; + std::vector<int64_t> strides; + std::vector<std::string> dim_names; + CHECK_HAS_BODY(message); + RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides, + &dim_names)); + return Tensor::Make(type, message.body(), shape, strides, dim_names); +} + + +StreamDecoder2::StreamDecoder2(std::shared_ptr<Listener> listener, IpcReadOptions options) { + impl_.reset(new StreamDecoder2::StreamDecoder2Impl(std::move(listener), options)); +} + +StreamDecoder2::~StreamDecoder2() {} + +Status StreamDecoder2::Consume(const uint8_t* data, int64_t size) { + return impl_->Consume(data, size); +} + +void StreamDecoder2::Reset() { + impl_->Reset(); +} + +Status StreamDecoder2::Consume(std::shared_ptr<Buffer> buffer) { + return impl_->Consume(std::move(buffer)); +} + +std::shared_ptr<Schema> StreamDecoder2::schema() const { return impl_->schema(); } + +int64_t StreamDecoder2::next_required_size() const { return impl_->next_required_size(); } + +ReadStats StreamDecoder2::stats() const { return impl_->stats(); } + +class InputStreamMessageReader : public MessageReader, public MessageDecoderListener { + public: + explicit InputStreamMessageReader(io::InputStream* stream) + : stream_(stream), + owned_stream_(), + message_(), + decoder_(std::shared_ptr<InputStreamMessageReader>(this, [](void*) {})) {} + + explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream) + : InputStreamMessageReader(owned_stream.get()) { + owned_stream_ = owned_stream; + } + + ~InputStreamMessageReader() {} + + Status OnMessageDecoded(std::unique_ptr<Message> message) override { + message_ = std::move(message); + return Status::OK(); + } + + Result<std::unique_ptr<Message>> ReadNextMessage() override { + ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_)); + return std::move(message_); + } + + private: + io::InputStream* stream_; + std::shared_ptr<io::InputStream> owned_stream_; + std::unique_ptr<Message> message_; + MessageDecoder decoder_; +}; + + +std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) { + return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream)); +} + +std::unique_ptr<MessageReader> MessageReader::Open( + const std::shared_ptr<io::InputStream>& owned_stream) { + return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream)); +} + +} + +} // namespace ipc::NDqs +} // namespace arrow diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h new file mode 100644 index 00000000000..f5ba1be7fda --- /dev/null +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h @@ -0,0 +1,85 @@ +// almost copy of reader.cc + message.cc without comments +// TODO(): Remove when .Reset() will be added in contrib version +#pragma once + +#include <arrow/ipc/reader.h> +#include <arrow/ipc/message.h> +#include <cstddef> +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include <arrow/io/caching.h> +#include <arrow/io/type_fwd.h> +#include <arrow/ipc/message.h> +#include <arrow/ipc/options.h> +#include <arrow/record_batch.h> +#include <arrow/result.h> +#include <arrow/type_fwd.h> +#include <arrow/util/async_generator.h> +#include <arrow/util/macros.h> +#include <arrow/util/visibility.h> + +namespace arrow { +namespace ipc::NDqs { + +class ARROW_EXPORT MessageDecoder2 { + public: + enum State { + INITIAL, + METADATA_LENGTH, + METADATA, + BODY, + EOS, + }; + + explicit MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener, + MemoryPool* pool = default_memory_pool()); + + MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener, State initial_state, + int64_t initial_next_required_size, + MemoryPool* pool = default_memory_pool()); + + virtual ~MessageDecoder2(); + Status Consume(const uint8_t* data, int64_t size); + Status Consume(std::shared_ptr<Buffer> buffer); + int64_t next_required_size() const; + State state() const; + void Reset(); + + private: + class MessageDecoderImpl; + std::unique_ptr<MessageDecoderImpl> impl_; + + ARROW_DISALLOW_COPY_AND_ASSIGN(MessageDecoder2); +}; + +class ARROW_EXPORT MessageReader { + public: + virtual ~MessageReader() = default; + static std::unique_ptr<MessageReader> Open(io::InputStream* stream); + static std::unique_ptr<MessageReader> Open( + const std::shared_ptr<io::InputStream>& owned_stream); + virtual Result<std::unique_ptr<Message>> ReadNextMessage() = 0; +}; + +class ARROW_EXPORT StreamDecoder2 { + public: + StreamDecoder2(std::shared_ptr<Listener> listener, + IpcReadOptions options = IpcReadOptions::Defaults()); + + virtual ~StreamDecoder2(); + Status Consume(const uint8_t* data, int64_t size); + Status Consume(std::shared_ptr<Buffer> buffer); + std::shared_ptr<Schema> schema() const; + int64_t next_required_size() const; + ReadStats stats() const; + void Reset(); + + private: + class StreamDecoder2Impl; + std::unique_ptr<StreamDecoder2Impl> impl_; +}; +} // namespace ipc::NDqs +} // namespace arrow diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make b/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make index 6f96e18fe33..48ae5d6ea10 100644 --- a/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make +++ b/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make @@ -5,10 +5,18 @@ PEERDIR( ydb/library/yql/providers/yt/comp_nodes ydb/library/yql/providers/yt/codec ydb/library/yql/providers/common/codec + ydb/core/formats/arrow yt/cpp/mapreduce/interface yt/cpp/mapreduce/common library/cpp/yson/node yt/yt/core + ydb/library/yql/public/udf/arrow + contrib/libs/apache/arrow + contrib/libs/flatbuffers +) + +ADDINCL( + contrib/libs/flatbuffers/include ) IF(LINUX) @@ -19,7 +27,13 @@ IF(LINUX) ) SRCS( + stream_decoder.cpp dq_yt_rpc_reader.cpp + dq_yt_rpc_helpers.cpp + dq_yt_block_reader.cpp + ) + CFLAGS( + -Wno-unused-parameter ) ENDIF() @@ -29,7 +43,6 @@ SRCS( dq_yt_writer.cpp ) - YQL_LAST_ABI_VERSION() diff --git a/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp b/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp index 41d652354cf..a29a5878630 100644 --- a/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp +++ b/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp @@ -31,7 +31,7 @@ public: } NMiniKQL::TCallableVisitFunc operator()(NMiniKQL::TInternName name) { - if (TaskParams.contains("yt") && name == "DqYtRead") { + if (TaskParams.contains("yt") && (name == "DqYtRead" || name == "DqYtBlockRead")) { return [this](NMiniKQL::TCallable& callable, const NMiniKQL::TTypeEnvironment& env) { using namespace NMiniKQL; diff --git a/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp b/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp index 8f54076f94b..c83a855c270 100644 --- a/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp +++ b/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp @@ -344,6 +344,33 @@ public: return false; } + bool CanBlockRead(const NNodes::TExprBase& node, TExprContext&, TTypeAnnotationContext&) override { + auto wrap = node.Cast<TDqReadWideWrap>(); + auto maybeRead = wrap.Input().Maybe<TYtReadTable>(); + if (!maybeRead) { + return false; + } + + + if (!State_->Configuration->UseRPCReaderInDQ.Get(maybeRead.Cast().DataSource().Cluster().StringValue()).GetOrElse(DEFAULT_USE_RPC_READER_IN_DQ)) { + return false; + } + + const auto structType = GetSeqItemType(maybeRead.Raw()->GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back())->Cast<TStructExprType>(); + if (!CanBlockReadTypes(structType)) { + return false; + } + + const TYtSectionList& sectionList = wrap.Input().Cast<TYtReadTable>().Input(); + for (size_t i = 0; i < sectionList.Size(); ++i) { + auto section = sectionList.Item(i); + if (!NYql::GetSettingAsColumnList(section.Settings().Ref(), EYtSettingType::SysColumns).empty()) { + return false; + } + } + return true; + } + TMaybe<TOptimizerStatistics> ReadStatistics(const TExprNode::TPtr& read, TExprContext& ctx) override { Y_UNUSED(ctx); TOptimizerStatistics stat(0, 0); diff --git a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp index 3a5c0733650..e4d97ea43be 100644 --- a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp +++ b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp @@ -288,7 +288,8 @@ TRuntimeNode BuildDqYtInputCall( const TYtState::TPtr& state, NCommon::TMkqlBuildContext& ctx, size_t inflight, - size_t timeout) + size_t timeout, + bool enableBlockReader) { NYT::TNode specNode = NYT::TNode::CreateMap(); NYT::TNode& tablesNode = specNode[YqlIOSpecTables]; @@ -363,12 +364,12 @@ TRuntimeNode BuildDqYtInputCall( auto res = uniqSpecs.emplace(NYT::NodeToCanonicalYsonString(specNode), refName); if (res.second) { registryNode[refName] = specNode; - } - else { + } else { refName = res.first->second; } tablesNode.Add(refName); - auto skiffNode = SingleTableSpecToInputSkiff(specNode, structColumns, true, true, false); + // TODO() Enable range indexes + auto skiffNode = SingleTableSpecToInputSkiff(specNode, structColumns, true, !enableBlockReader, false); const auto tmpFolder = GetTablesTmpFolder(*state->Configuration); auto tableName = pathInfo.Table->Name; if (pathInfo.Table->IsAnonymous && !TYtTableInfo::HasSubstAnonymousLabel(pathInfo.Table->FromNode.Cast())) { @@ -402,7 +403,7 @@ TRuntimeNode BuildDqYtInputCall( auto server = state->Gateway->GetClusterServer(clusterName); YQL_ENSURE(server, "Invalid YT cluster: " << clusterName); - TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), "DqYtRead", outputType); + TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), enableBlockReader ? "DqYtBlockRead" : "DqYtRead", outputType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(server)); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(tokenName)); @@ -475,6 +476,35 @@ void RegisterYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) { void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, const TYtState::TPtr& state) { + compiler.ChainCallable(TDqReadBlockWideWrap::CallableName(), + [state](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { + if (const auto& wrapper = TDqReadBlockWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) { + const auto ytRead = wrapper.Input().Cast<TYtReadTable>(); + const auto readType = ytRead.Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back(); + const auto inputItemType = NCommon::BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType), ctx.ProgramBuilder); + const auto cluster = ytRead.DataSource().Cluster().StringValue(); + size_t inflight = state->Configuration->UseRPCReaderInDQ.Get(cluster).GetOrElse(DEFAULT_USE_RPC_READER_IN_DQ) ? state->Configuration->DQRPCReaderInflight.Get(cluster).GetOrElse(DEFAULT_RPC_READER_INFLIGHT) : 0; + size_t timeout = state->Configuration->DQRPCReaderTimeout.Get(cluster).GetOrElse(DEFAULT_RPC_READER_TIMEOUT).MilliSeconds(); + const auto outputType = NCommon::BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn(), ctx.ProgramBuilder); + TString tokenName; + if (auto secureParams = wrapper.Token()) { + tokenName = secureParams.Cast().Name().StringValue(); + } + + bool solid = false; + for (const auto& flag : wrapper.Flags()) + if (solid = flag.Value() == "Solid") + break; + + if (solid) + return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight); + else + return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight); + } + + return TRuntimeNode(); + }); + compiler.ChainCallable(TDqReadWideWrap::CallableName(), [state](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { if (const auto& wrapper = TDqReadWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) { @@ -496,9 +526,9 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con break; if (solid) - return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout); + return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout, false); else - return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout); + return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout, false); } return TRuntimeNode(); |