aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormrlolthe1st <mrlolthe1st@yandex-team.com>2023-10-05 17:00:50 +0300
committermrlolthe1st <mrlolthe1st@yandex-team.com>2023-10-05 17:26:16 +0300
commitf288403e6cb7cc62bb16f2c707296f096a62df29 (patch)
tree9701c077fe0d0d86fa6fb6bd6a2cee1b5451901c
parent762a22f887e56e471cb1196f503d47588a761490 (diff)
downloadydb-f288403e6cb7cc62bb16f2c707296f096a62df29.tar.gz
YQL-9517: Implement block RPC reader
YQL-9517: Implement RPC reader
-rw-r--r--ydb/library/yql/dq/integration/yql_dq_integration.h1
-rw-r--r--ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp18
-rw-r--r--ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h3
-rw-r--r--ydb/library/yql/providers/dq/common/yql_dq_settings.cpp1
-rw-r--r--ydb/library/yql/providers/dq/common/yql_dq_settings.h1
-rw-r--r--ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json5
-rw-r--r--ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp2
-rw-r--r--ydb/library/yql/providers/dq/opt/dqs_opt.cpp40
-rw-r--r--ydb/library/yql/providers/dq/opt/dqs_opt.h1
-rw-r--r--ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp3
-rw-r--r--ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp1
-rw-r--r--ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp13
-rw-r--r--ydb/library/yql/providers/yt/common/yql_configuration.h1
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt7
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt11
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt11
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt7
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp584
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h17
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp4
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp18
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h2
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp123
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h67
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp165
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h7
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp1404
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h85
-rw-r--r--ydb/library/yql/providers/yt/comp_nodes/dq/ya.make15
-rw-r--r--ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp2
-rw-r--r--ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp27
-rw-r--r--ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp44
32 files changed, 2509 insertions, 181 deletions
diff --git a/ydb/library/yql/dq/integration/yql_dq_integration.h b/ydb/library/yql/dq/integration/yql_dq_integration.h
index d4f45d7f168..65500c10680 100644
--- a/ydb/library/yql/dq/integration/yql_dq_integration.h
+++ b/ydb/library/yql/dq/integration/yql_dq_integration.h
@@ -53,6 +53,7 @@ public:
virtual TMaybe<bool> CanWrite(const TExprNode& write, TExprContext& ctx) = 0;
virtual TExprNode::TPtr WrapWrite(const TExprNode::TPtr& write, TExprContext& ctx) = 0;
+ virtual bool CanBlockRead(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) = 0;
virtual void RegisterMkqlCompiler(NCommon::TMkqlCallableCompilerBase& compiler) = 0;
virtual bool CanFallback() = 0;
virtual void FillSourceSettings(const TExprNode& node, ::google::protobuf::Any& settings, TString& sourceType) = 0;
diff --git a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp
index a66cac44f2c..24ae595d66a 100644
--- a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp
+++ b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.cpp
@@ -22,6 +22,20 @@ TMaybe<ui64> TDqIntegrationBase::EstimateReadSize(ui64, ui32, const TVector<cons
return Nothing();
}
+bool TDqIntegrationBase::CanBlockReadTypes(const TStructExprType* node) {
+ for (const auto& e: node->GetItems()) {
+ // Check type
+ auto type = e->GetItemType();
+ while (ETypeAnnotationKind::Optional == type->GetKind()) {
+ type = type->Cast<TOptionalExprType>()->GetItemType();
+ }
+ if (ETypeAnnotationKind::Data != type->GetKind()) {
+ return false;
+ }
+ }
+ return true;
+}
+
TExprNode::TPtr TDqIntegrationBase::WrapRead(const TDqSettings&, const TExprNode::TPtr& read, TExprContext&) {
return read;
}
@@ -36,6 +50,10 @@ TMaybe<bool> TDqIntegrationBase::CanWrite(const TExprNode&, TExprContext&) {
return Nothing();
}
+bool TDqIntegrationBase::CanBlockRead(const NNodes::TExprBase&, TExprContext&, TTypeAnnotationContext&) {
+ return false;
+}
+
TExprNode::TPtr TDqIntegrationBase::WrapWrite(const TExprNode::TPtr& write, TExprContext&) {
return write;
}
diff --git a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h
index 7273d5a25ff..fc841ce4c03 100644
--- a/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h
+++ b/ydb/library/yql/providers/common/dq/yql_dq_integration_impl.h
@@ -15,6 +15,7 @@ public:
TMaybe<TOptimizerStatistics> ReadStatistics(const TExprNode::TPtr& readWrap, TExprContext& ctx) override;
void RegisterMkqlCompiler(NCommon::TMkqlCallableCompilerBase& compiler) override;
TMaybe<bool> CanWrite(const TExprNode& write, TExprContext& ctx) override;
+ bool CanBlockRead(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) override;
TExprNode::TPtr WrapWrite(const TExprNode::TPtr& write, TExprContext& ctx) override;
bool CanFallback() override;
void FillSourceSettings(const TExprNode& node, ::google::protobuf::Any& settings, TString& sourceType) override;
@@ -23,6 +24,8 @@ public:
void Annotate(const TExprNode& node, THashMap<TString, TString>& params) override;
bool PrepareFullResultTableParams(const TExprNode& root, TExprContext& ctx, THashMap<TString, TString>& params, THashMap<TString, TString>& secureParams) override;
void WriteFullResultTableRef(NYson::TYsonWriter& writer, const TVector<TString>& columns, const THashMap<TString, TString>& graphParams) override;
+protected:
+ bool CanBlockReadTypes(const TStructExprType* node);
};
} // namespace NYql
diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp
index 7cc6e857308..9fc67895641 100644
--- a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp
+++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp
@@ -75,6 +75,7 @@ TDqConfiguration::TDqConfiguration() {
REGISTER_SETTING(*this, ExportStats);
REGISTER_SETTING(*this, TaskRunnerStats).Parser([](const TString& v) { return FromString<ETaskRunnerStats>(v); });
REGISTER_SETTING(*this, _SkipRevisionCheck);
+ REGISTER_SETTING(*this, UseBlockReader);
}
} // namespace NYql
diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.h b/ydb/library/yql/providers/dq/common/yql_dq_settings.h
index 613c105901f..75cca0876e2 100644
--- a/ydb/library/yql/providers/dq/common/yql_dq_settings.h
+++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.h
@@ -114,6 +114,7 @@ struct TDqSettings {
NCommon::TConfSetting<bool, false> ExportStats;
NCommon::TConfSetting<ETaskRunnerStats, false> TaskRunnerStats;
NCommon::TConfSetting<bool, false> _SkipRevisionCheck;
+ NCommon::TConfSetting<bool, false> UseBlockReader;
// This options will be passed to executor_actor and worker_actor
template <typename TProtoConfig>
diff --git a/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json b/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json
index 7b528c5bf8d..2d3868424c1 100644
--- a/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json
+++ b/ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.json
@@ -27,6 +27,11 @@
"Match": {"Type": "Callable", "Name": "DqReadWideWrap"}
},
{
+ "Name": "TDqReadBlockWideWrap",
+ "Base": "TDqReadWrapBase",
+ "Match": {"Type": "Callable", "Name": "DqReadBlockWideWrap"}
+ },
+ {
"Name": "TDqWrite",
"Base": "TCallable",
"Match": {"Type": "Callable", "Name": "DqWrite"},
diff --git a/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp b/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp
index c4f2cd6cbd4..724082b8c78 100644
--- a/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp
+++ b/ydb/library/yql/providers/dq/mkql/dqs_mkql_compiler.cpp
@@ -10,7 +10,7 @@ using namespace NKikimr::NMiniKQL;
using namespace NNodes;
void RegisterDqsMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, const TTypeAnnotationContext& ctx) {
- compiler.AddCallable({TDqSourceWideWrap::CallableName(), TDqSourceWideBlockWrap::CallableName(), TDqReadWideWrap::CallableName()},
+ compiler.AddCallable({TDqSourceWideWrap::CallableName(), TDqSourceWideBlockWrap::CallableName(), TDqReadWideWrap::CallableName(), TDqReadBlockWideWrap::CallableName()},
[](const TExprNode& node, NCommon::TMkqlBuildContext&) {
YQL_ENSURE(false, "Unsupported reader: " << node.Head().Content());
return TRuntimeNode();
diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp
index 8cba4d0601a..d8ae81155db 100644
--- a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp
+++ b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp
@@ -3,6 +3,7 @@
#include <ydb/library/yql/providers/dq/expr_nodes/dqs_expr_nodes.h>
#include <ydb/library/yql/providers/common/mkql/yql_type_mkql.h>
#include <ydb/library/yql/providers/common/codec/yql_codec.h>
+#include <ydb/library/yql/providers/common/provider/yql_provider_names.h>
#include <ydb/library/yql/core/yql_expr_optimize.h>
#include <ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.h>
@@ -16,6 +17,7 @@
#include <ydb/library/yql/dq/opt/dq_opt_build.h>
#include <ydb/library/yql/dq/opt/dq_opt_peephole.h>
#include <ydb/library/yql/dq/type_ann/dq_type_ann.h>
+#include <ydb/library/yql/dq/integration/yql_dq_integration.h>
#include <ydb/library/yql/utils/log/log.h>
@@ -64,6 +66,44 @@ namespace NYql::NDqs {
});
}
+ THolder<IGraphTransformer> CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(TTypeAnnotationContext& typesCtx) {
+ return CreateFunctorTransformer([&typesCtx](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) {
+ TOptimizeExprSettings optSettings{nullptr};
+ optSettings.VisitLambdas = true;
+ optSettings.VisitTuples = true;
+ return OptimizeExpr(input, output,
+ [&typesCtx](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr {
+ if (!TDqReadWideWrap::Match(node.Get())) {
+ return node;
+ }
+
+ auto readWideWrap = TDqReadWideWrap(node);
+ auto dataSource = readWideWrap.Raw()->Child(0)->Child(1);
+ auto dataSourceName = dataSource->Child(0)->Content();
+ if (dataSourceName == DqProviderName || dataSource->IsCallable(ConfigureName)) {
+ return node;
+ }
+
+ auto datasource = typesCtx.DataSourceMap.FindPtr(dataSourceName);
+ YQL_ENSURE(datasource);
+ auto dqIntegration = (*datasource)->GetDqIntegration();
+ if (!dqIntegration || !dqIntegration->CanBlockRead(readWideWrap, ctx, typesCtx)) {
+ return node;
+ }
+
+ YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration";
+
+ return Build<TCoWideFromBlocks>(ctx, node->Pos())
+ .Input(Build<TDqReadBlockWideWrap>(ctx, node->Pos())
+ .Input(readWideWrap.Input())
+ .Flags(readWideWrap.Flags())
+ .Token(readWideWrap.Token())
+ .Done())
+ .Done().Ptr();
+ }, ctx, optSettings);
+ });
+ }
+
THolder<IGraphTransformer> CreateDqsReplacePrecomputesTransformer(TTypeAnnotationContext& typesCtx, const NKikimr::NMiniKQL::IFunctionRegistry* funcRegistry) {
return CreateFunctorTransformer([&typesCtx, funcRegistry](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) -> TStatus {
TOptimizeExprSettings settings(&typesCtx);
diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.h b/ydb/library/yql/providers/dq/opt/dqs_opt.h
index 9beec287f07..7f94e976eed 100644
--- a/ydb/library/yql/providers/dq/opt/dqs_opt.h
+++ b/ydb/library/yql/providers/dq/opt/dqs_opt.h
@@ -15,6 +15,7 @@ namespace NYql::NDqs {
THolder<IGraphTransformer> CreateDqsFinalizingOptTransformer();
THolder<IGraphTransformer> CreateDqsRewritePhyCallablesTransformer(TTypeAnnotationContext& typesCtx);
+ THolder<IGraphTransformer> CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(TTypeAnnotationContext& typesCtx);
THolder<IGraphTransformer> CreateDqsReplacePrecomputesTransformer(TTypeAnnotationContext& typesCtx, const NKikimr::NMiniKQL::IFunctionRegistry* funcRegistry);
} // namespace NYql::NDqs
diff --git a/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp b/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp
index fdd6619ae2d..3c7904f4dc9 100644
--- a/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp
+++ b/ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp
@@ -233,6 +233,9 @@ private:
void AfterTypeAnnotation(TTransformationPipeline* pipeline) const final {
pipeline->Add(NDqs::CreateDqsReplacePrecomputesTransformer(*pipeline->GetTypeAnnotationContext(), State_->FunctionRegistry), "ReplacePrecomputes");
+ if (State_->Settings->UseBlockReader.Get().GetOrElse(false)) {
+ pipeline->Add(NDqs::CreateDqsRewritePhyBlockReadOnDqIntegrationTransformer(*pipeline->GetTypeAnnotationContext()), "ReplaceWideReadsWithBlock");
+ }
bool useWideChannels = State_->Settings->UseWideChannels.Get().GetOrElse(false);
bool useChannelBlocks = State_->Settings->UseWideBlockChannels.Get().GetOrElse(false);
NDq::EChannelMode mode;
diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp
index 52399454561..a212e46824f 100644
--- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp
+++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_constraints.cpp
@@ -20,6 +20,7 @@ public:
TCoConfigure::CallableName(),
TDqReadWrap::CallableName(),
TDqReadWideWrap::CallableName(),
+ TDqReadBlockWideWrap::CallableName(),
TDqSource::CallableName(),
TDqSourceWrap::CallableName(),
TDqSourceWideWrap::CallableName(),
diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp
index f385ba37f92..b7e18d36a3b 100644
--- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp
+++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource_type_ann.cpp
@@ -23,7 +23,8 @@ public:
AddHandler({TDqSourceWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleSourceWrap<true, false>));
AddHandler({TDqSourceWideBlockWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleSourceWrap<true, true>));
AddHandler({TDqReadWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleReadWrap));
- AddHandler({TDqReadWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap));
+ AddHandler({TDqReadWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap<false>));
+ AddHandler({TDqReadBlockWideWrap::CallableName()}, Hndl(&TDqsDataSourceTypeAnnotationTransformer::HandleWideReadWrap<true>));
AddHandler({TDqSource::CallableName()}, Hndl(&NDq::AnnotateDqSource));
AddHandler({TDqPhyLength::CallableName()}, Hndl(&NDq::AnnotateDqPhyLength));
@@ -113,6 +114,7 @@ private:
return TStatus::Ok;
}
+ template<bool IsBlock>
TStatus HandleWideReadWrap(const TExprNode::TPtr& input, TExprContext& ctx) {
if (!EnsureMinMaxArgsCount(*input, 1, 3, ctx)) {
return TStatus::Error;
@@ -145,8 +147,15 @@ private:
const auto structType = readerType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
TTypeAnnotationNode::TListType types;
const auto& items = structType->GetItems();
- types.reserve(items.size());
+ types.reserve(items.size() + IsBlock);
std::transform(items.cbegin(), items.cend(), std::back_inserter(types), std::bind(&TItemExprType::GetItemType, std::placeholders::_1));
+ if constexpr (IsBlock) {
+ for (auto& type : types) {
+ type = ctx.MakeType<TBlockExprType>(type);
+ }
+
+ types.push_back(ctx.MakeType<TScalarExprType>(ctx.MakeType<TDataExprType>(EDataSlot::Uint64)));
+ }
input->SetTypeAnn(ctx.MakeType<TFlowExprType>(ctx.MakeType<TMultiExprType>(types)));
return TStatus::Ok;
diff --git a/ydb/library/yql/providers/yt/common/yql_configuration.h b/ydb/library/yql/providers/yt/common/yql_configuration.h
index ca01b761ebd..0b671b7341f 100644
--- a/ydb/library/yql/providers/yt/common/yql_configuration.h
+++ b/ydb/library/yql/providers/yt/common/yql_configuration.h
@@ -56,7 +56,6 @@ constexpr bool DEFAULT_USE_RPC_READER_IN_DQ = false;
constexpr size_t DEFAULT_RPC_READER_INFLIGHT = 1;
constexpr TDuration DEFAULT_RPC_READER_TIMEOUT = TDuration::Seconds(120);
-
constexpr auto DEFAULT_SWITCH_MEMORY_LIMIT = 128_MB;
constexpr ui32 DEFAULT_MAX_INPUT_TABLES = 3000;
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt
index b1d1973fea3..9082a770b58 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.darwin-x86_64.txt
@@ -11,6 +11,9 @@ add_library(yt-comp_nodes-dq)
target_compile_options(yt-comp_nodes-dq PRIVATE
-DUSE_CURRENT_UDF_ABI_VERSION
)
+target_include_directories(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include
+)
target_link_libraries(yt-comp_nodes-dq PUBLIC
contrib-libs-cxxsupp
yutil
@@ -18,10 +21,14 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC
providers-yt-comp_nodes
providers-yt-codec
providers-common-codec
+ core-formats-arrow
cpp-mapreduce-interface
cpp-mapreduce-common
cpp-yson-node
yt-yt-core
+ public-udf-arrow
+ libs-apache-arrow
+ contrib-libs-flatbuffers
)
target_sources(yt-comp_nodes-dq PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt
index 5b2c1550db9..06ae4cb2da3 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-aarch64.txt
@@ -9,8 +9,12 @@
add_library(yt-comp_nodes-dq)
target_compile_options(yt-comp_nodes-dq PRIVATE
+ -Wno-unused-parameter
-DUSE_CURRENT_UDF_ABI_VERSION
)
+target_include_directories(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include
+)
target_link_libraries(yt-comp_nodes-dq PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
@@ -19,16 +23,23 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC
providers-yt-comp_nodes
providers-yt-codec
providers-common-codec
+ core-formats-arrow
cpp-mapreduce-interface
cpp-mapreduce-common
cpp-yson-node
yt-yt-core
+ public-udf-arrow
+ libs-apache-arrow
+ contrib-libs-flatbuffers
yt-yt-client
yt-client-arrow
yt-lib-yt_rpc_helpers
)
target_sources(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_writer.cpp
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt
index 5b2c1550db9..06ae4cb2da3 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.linux-x86_64.txt
@@ -9,8 +9,12 @@
add_library(yt-comp_nodes-dq)
target_compile_options(yt-comp_nodes-dq PRIVATE
+ -Wno-unused-parameter
-DUSE_CURRENT_UDF_ABI_VERSION
)
+target_include_directories(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include
+)
target_link_libraries(yt-comp_nodes-dq PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
@@ -19,16 +23,23 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC
providers-yt-comp_nodes
providers-yt-codec
providers-common-codec
+ core-formats-arrow
cpp-mapreduce-interface
cpp-mapreduce-common
cpp-yson-node
yt-yt-core
+ public-udf-arrow
+ libs-apache-arrow
+ contrib-libs-flatbuffers
yt-yt-client
yt-client-arrow
yt-lib-yt_rpc_helpers
)
target_sources(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_writer.cpp
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt
index b1d1973fea3..9082a770b58 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/CMakeLists.windows-x86_64.txt
@@ -11,6 +11,9 @@ add_library(yt-comp_nodes-dq)
target_compile_options(yt-comp_nodes-dq PRIVATE
-DUSE_CURRENT_UDF_ABI_VERSION
)
+target_include_directories(yt-comp_nodes-dq PRIVATE
+ ${CMAKE_SOURCE_DIR}/contrib/libs/flatbuffers/include
+)
target_link_libraries(yt-comp_nodes-dq PUBLIC
contrib-libs-cxxsupp
yutil
@@ -18,10 +21,14 @@ target_link_libraries(yt-comp_nodes-dq PUBLIC
providers-yt-comp_nodes
providers-yt-codec
providers-common-codec
+ core-formats-arrow
cpp-mapreduce-interface
cpp-mapreduce-common
cpp-yson-node
yt-yt-core
+ public-udf-arrow
+ libs-apache-arrow
+ contrib-libs-flatbuffers
)
target_sources(yt-comp_nodes-dq PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp
new file mode 100644
index 00000000000..923deacc08c
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.cpp
@@ -0,0 +1,584 @@
+#include "dq_yt_block_reader.h"
+#include "stream_decoder.h"
+#include "dq_yt_rpc_helpers.h"
+
+#include <ydb/library/yql/public/udf/arrow/block_builder.h>
+
+#include <ydb/library/yql/providers/yt/codec/yt_codec.h>
+#include <ydb/library/yql/providers/common/codec/yql_codec.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+#include <ydb/library/yql/minikql/mkql_stats_registry.h>
+#include <ydb/library/yql/minikql/mkql_node.h>
+#include <ydb/library/yql/dq/runtime/dq_arrow_helpers.h>
+#include <ydb/core/formats/arrow/dictionary/conversion.h>
+#include <ydb/core/formats/arrow/arrow_helpers.h>
+
+#include <yt/yt/core/concurrency/thread_pool.h>
+#include <yt/cpp/mapreduce/interface/common.h>
+#include <yt/cpp/mapreduce/interface/errors.h>
+#include <yt/cpp/mapreduce/interface/client.h>
+#include <yt/cpp/mapreduce/interface/serialize.h>
+#include <yt/cpp/mapreduce/interface/config.h>
+#include <yt/cpp/mapreduce/common/wait_proxy.h>
+
+#include <arrow/compute/cast.h>
+#include <arrow/compute/api_vector.h>
+#include <arrow/compute/api.h>
+#include <arrow/io/interfaces.h>
+#include <arrow/io/memory.h>
+#include <arrow/ipc/reader.h>
+#include <arrow/array.h>
+#include <arrow/record_batch.h>
+#include <arrow/type.h>
+#include <arrow/result.h>
+#include <arrow/buffer.h>
+
+#include <library/cpp/yson/node/node.h>
+#include <library/cpp/yson/node/node_io.h>
+
+#include <util/generic/size_literals.h>
+#include <util/stream/output.h>
+
+namespace NYql::NDqs {
+
+using namespace NKikimr::NMiniKQL;
+
+namespace {
+struct TResultBatch {
+ using TPtr = std::shared_ptr<TResultBatch>;
+ size_t RowsCnt;
+ std::vector<arrow::Datum> Columns;
+ TResultBatch(int64_t cnt, decltype(Columns)&& columns) : RowsCnt(cnt), Columns(std::move(columns)) {}
+ TResultBatch(std::shared_ptr<arrow::RecordBatch> batch) : RowsCnt(batch->num_rows()), Columns(batch->columns().begin(), batch->columns().end()) {}
+};
+
+template<typename T>
+class TBlockingQueueWithLimit {
+ struct Poison {
+ TString Error;
+ };
+ using TPoisonOr = std::variant<T, Poison>;
+public:
+ TBlockingQueueWithLimit(size_t limit) : Limit_(limit) {}
+ void Push(T&& val) {
+ PushInternal(std::move(val));
+ }
+
+ void PushPoison(const TString& err) {
+ PushInternal(Poison{err});
+ }
+
+ T Get() {
+ auto res = GetInternal();
+ if (std::holds_alternative<Poison>(res)) {
+ throw std::runtime_error(std::get<Poison>(res).Error);
+ }
+ return std::move(std::get<T>(res));
+ }
+private:
+ template<typename X>
+ void PushInternal(X&& val) {
+ NYT::TPromise<void> promise;
+ {
+ std::lock_guard _(Mtx_);
+ if (!Awaiting_.empty()) {
+ Awaiting_.front().Set(std::move(val));
+ Awaiting_.pop();
+ return;
+ }
+ if (Ready_.size() >= Limit_) {
+ promise = NYT::NewPromise<void>();
+ BlockedPushes_.push(promise);
+ } else {
+ Ready_.emplace(std::move(val));
+ return;
+ }
+ }
+ YQL_ENSURE(NYT::NConcurrency::WaitFor(promise.ToFuture()).IsOK());
+ std::lock_guard _(Mtx_);
+ Ready_.emplace(std::move(val));
+ }
+
+ TPoisonOr GetInternal() {
+ NYT::TPromise<TPoisonOr> awaiter;
+ {
+ std::lock_guard _(Mtx_);
+ if (!BlockedPushes_.empty()) {
+ BlockedPushes_.front().Set();
+ BlockedPushes_.pop();
+ }
+ if (!Ready_.empty()) {
+ auto res = std::move(Ready_.front());
+ Ready_.pop();
+ return res;
+ }
+ awaiter = NYT::NewPromise<TPoisonOr>();
+ Awaiting_.push(awaiter);
+ }
+ auto awaitResult = NYT::NConcurrency::WaitFor(awaiter.ToFuture());
+ if (!awaitResult.IsOK()) {
+ throw std::runtime_error(awaitResult.GetMessage());
+ }
+ return std::move(awaitResult.Value());
+ }
+ std::mutex Mtx_;
+ std::queue<TPoisonOr> Ready_;
+ std::queue<NYT::TPromise<void>> BlockedPushes_;
+ std::queue<NYT::TPromise<TPoisonOr>> Awaiting_;
+ size_t Limit_;
+};
+
+class TListener {
+ using TBatchPtr = std::shared_ptr<arrow::RecordBatch>;
+public:
+ using TPromise = NYT::TPromise<TResultBatch>;
+ using TPtr = std::shared_ptr<TListener>;
+ TListener(size_t initLatch, size_t inflight) : Latch_(initLatch), Queue_(inflight) {}
+
+ void OnEOF() {
+ bool excepted = 0;
+ if (GotEOF_.compare_exchange_strong(excepted, 1)) {
+ // block poining to nullptr is marker of EOF
+ HandleResult(nullptr);
+ } else {
+ // can't get EOF more than one time
+ HandleError("EOS already got");
+ }
+ }
+
+ // Handles result
+ void HandleResult(TResultBatch::TPtr&& res) {
+ Queue_.Push(std::move(res));
+ }
+
+ void OnRecordBatchDecoded(TBatchPtr record_batch) {
+ YQL_ENSURE(record_batch);
+ // decode dictionary
+ record_batch = NKikimr::NArrow::DictionaryToArray(record_batch);
+ // and handle result
+ HandleResult(std::make_shared<TResultBatch>(record_batch));
+ }
+
+ TResultBatch::TPtr Get() {
+ return Queue_.Get();
+ }
+
+ void HandleError(const TString& msg) {
+ Queue_.PushPoison(msg);
+ }
+
+ void HandleFallback(TResultBatch::TPtr&& block) {
+ HandleResult(std::move(block));
+ }
+
+ void InputDone() {
+ // EOF comes when all inputs got EOS
+ if (!--Latch_) {
+ OnEOF();
+ }
+ }
+private:
+ std::atomic<size_t> Latch_;
+ std::atomic<bool> GotEOF_;
+ TBlockingQueueWithLimit<TResultBatch::TPtr> Queue_;
+};
+
+class TBlockBuilder {
+public:
+ void Init(std::shared_ptr<std::vector<TType*>> columnTypes, arrow::MemoryPool& pool, const NUdf::IPgBuilder* pgBuilder) {
+ ColumnTypes_ = columnTypes;
+ ColumnBuilders_.reserve(ColumnTypes_->size());
+ size_t maxBlockItemSize = 0;
+ for (auto& type: *ColumnTypes_) {
+ maxBlockItemSize = std::max(maxBlockItemSize, CalcMaxBlockItemSize(type));
+ }
+ size_t maxBlockLen = CalcBlockLen(maxBlockItemSize);
+ for (size_t i = 0; i < ColumnTypes_->size(); ++i) {
+ ColumnBuilders_.push_back(
+ std::move(NUdf::MakeArrayBuilder(
+ TTypeInfoHelper(), ColumnTypes_->at(i),
+ pool,
+ maxBlockLen,
+ pgBuilder
+ ))
+ );
+ }
+ }
+
+ void Add(const NUdf::TUnboxedValue& val) {
+ for (ui32 i = 0; i < ColumnBuilders_.size(); ++i) {
+ auto v = val.GetElement(i);
+ ColumnBuilders_[i]->Add(v);
+ }
+ ++RowsCnt_;
+ }
+
+ std::vector<TResultBatch::TPtr> Build() {
+ std::vector<arrow::Datum> columns;
+ columns.reserve(ColumnBuilders_.size());
+ for (size_t i = 0; i < ColumnBuilders_.size(); ++i) {
+ columns.emplace_back(std::move(ColumnBuilders_[i]->Build(false)));
+ }
+ std::vector<std::shared_ptr<TResultBatch>> blocks;
+ int64_t offset = 0;
+ std::vector<int64_t> currentChunk(columns.size()), inChunkOffset(columns.size());
+ while (RowsCnt_) {
+ int64_t max_curr_len = RowsCnt_;
+ for (size_t i = 0; i < columns.size(); ++i) {
+ if (arrow::Datum::Kind::CHUNKED_ARRAY == columns[i].kind()) {
+ auto& c_arr = columns[i].chunked_array();
+ while (currentChunk[i] < c_arr->num_chunks() && !c_arr->chunk(currentChunk[i])) {
+ ++currentChunk[i];
+ }
+ YQL_ENSURE(currentChunk[i] < c_arr->num_chunks());
+ max_curr_len = std::min(max_curr_len, c_arr->chunk(currentChunk[i])->length() - inChunkOffset[i]);
+ }
+ }
+ RowsCnt_ -= max_curr_len;
+ decltype(columns) result_columns;
+ result_columns.reserve(columns.size());
+ offset += max_curr_len;
+ for (size_t i = 0; i < columns.size(); ++i) {
+ auto& e = columns[i];
+ if (arrow::Datum::Kind::CHUNKED_ARRAY == e.kind()) {
+ result_columns.emplace_back(e.chunked_array()->chunk(currentChunk[i])->Slice(inChunkOffset[i], max_curr_len));
+ if (max_curr_len + inChunkOffset[i] == e.chunked_array()->chunk(currentChunk[i])->length()) {
+ ++currentChunk[i];
+ inChunkOffset[i] = 0;
+ } else {
+ inChunkOffset[i] += max_curr_len;
+ }
+ } else {
+ result_columns.emplace_back(e.array()->Slice(offset - max_curr_len, max_curr_len));
+ }
+ }
+ blocks.emplace_back(std::make_shared<TResultBatch>(max_curr_len, std::move(result_columns)));
+ }
+ return blocks;
+ }
+
+private:
+ int64_t RowsCnt_ = 0;
+ std::vector<std::unique_ptr<NUdf::IArrayBuilder>> ColumnBuilders_;
+ std::shared_ptr<std::vector<TType*>> ColumnTypes_;
+};
+
+class TLocalListener : public arrow::ipc::Listener {
+public:
+ TLocalListener(std::shared_ptr<TListener> consumer)
+ : Consumer_(consumer) {}
+
+ void Init(std::shared_ptr<TLocalListener> self) {
+ Self_ = self;
+ Decoder_ = std::make_shared<arrow::ipc::NDqs::StreamDecoder2>(self, arrow::ipc::IpcReadOptions{.use_threads=false});
+ }
+
+ arrow::Status OnEOS() override {
+ Decoder_->Reset();
+ return arrow::Status::OK();
+ }
+
+ arrow::Status OnRecordBatchDecoded(std::shared_ptr<arrow::RecordBatch> batch) override {
+ Consumer_->OnRecordBatchDecoded(batch);
+ return arrow::Status::OK();
+ }
+
+ void Consume(std::shared_ptr<arrow::Buffer> buff) {
+ ARROW_OK(Decoder_->Consume(buff));
+ }
+
+ void Finish() {
+ Self_ = nullptr;
+ }
+private:
+ std::shared_ptr<TLocalListener> Self_;
+ std::shared_ptr<TListener> Consumer_;
+ std::shared_ptr<arrow::ipc::NDqs::StreamDecoder2> Decoder_;
+};
+
+class TSource : public TNonCopyable {
+public:
+ using TPtr = std::shared_ptr<TSource>;
+ TSource(std::unique_ptr<TSettingsHolder>&& settings,
+ size_t inflight, TType* type, const THolderFactory& holderFactory)
+ : Settings_(std::move(settings))
+ , Inputs_(std::move(Settings_->RawInputs))
+ , Listener_(std::make_shared<TListener>(Inputs_.size(), inflight))
+ , HolderFactory_(holderFactory)
+ {
+ auto structType = AS_TYPE(TStructType, type);
+ std::vector<TType*> columnTypes_(structType->GetMembersCount());
+ for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
+ columnTypes_[i] = structType->GetMemberType(i);
+ }
+ auto ptr = std::make_shared<decltype(columnTypes_)>(std::move(columnTypes_));
+ Inflight_ = std::min(inflight, Inputs_.size());
+
+ LocalListeners_.reserve(Inputs_.size());
+ for (size_t i = 0; i < Inputs_.size(); ++i) {
+ InputsQueue_.emplace(i);
+ LocalListeners_.emplace_back(std::make_shared<TLocalListener>(Listener_));
+ LocalListeners_.back()->Init(LocalListeners_.back());
+ }
+ BlockBuilder_.Init(ptr, *Settings_->Pool, Settings_->PgBuilder);
+ FallbackReader_.SetSpecs(*Settings_->Specs, HolderFactory_);
+ }
+
+ void RunRead() {
+ size_t inputIdx;
+ {
+ std::lock_guard _(Mtx_);
+ if (InputsQueue_.empty()) {
+ return;
+ }
+ inputIdx = InputsQueue_.front();
+ InputsQueue_.pop();
+ }
+ Inputs_[inputIdx]->Read().SubscribeUnique(BIND([inputIdx = inputIdx, self = Self_](NYT::TErrorOr<NYT::TSharedRef>&& res) {
+ self->Pool_->GetInvoker()->Invoke(BIND([inputIdx, self, res = std::move(res)] () mutable {
+ try {
+ self->Accept(inputIdx, std::move(res));
+ self->RunRead();
+ } catch (std::exception& e) {
+ self->Listener_->HandleError(e.what());
+ }
+ }));
+ }));
+ }
+
+ void Accept(size_t inputIdx, NYT::TErrorOr<NYT::TSharedRef>&& res) {
+ if (res.IsOK() && !res.Value()) {
+ // End Of Stream
+ Listener_->InputDone();
+ return;
+ }
+
+ if (!res.IsOK()) {
+ // Propagate error
+ Listener_->HandleError(res.GetMessage());
+ return;
+ }
+
+ NYT::NApi::NRpcProxy::NProto::TRowsetDescriptor descriptor;
+ NYT::NApi::NRpcProxy::NProto::TRowsetStatistics statistics;
+ NYT::TSharedRef currentPayload = NYT::NApi::NRpcProxy::DeserializeRowStreamBlockEnvelope(res.Value(), &descriptor, &statistics);
+ if (descriptor.rowset_format() != NYT::NApi::NRpcProxy::NProto::RF_ARROW) {
+ auto promise = NYT::NewPromise<std::vector<TResultBatch::TPtr>>();
+ MainInvoker_->Invoke(BIND([inputIdx, currentPayload, self = Self_, promise] {
+ try {
+ promise.Set(self->FallbackHandler(inputIdx, currentPayload));
+ } catch (std::exception& e) {
+ promise.Set(NYT::TError(e.what()));
+ }
+ }));
+ auto result = NYT::NConcurrency::WaitFor(promise.ToFuture());
+ if (!result.IsOK()) {
+ Listener_->HandleError(result.GetMessage());
+ return;
+ }
+ for (auto& e: result.Value()) {
+ Listener_->HandleFallback(std::move(e));
+ }
+ InputDone(inputIdx);
+ return;
+ }
+
+ if (!currentPayload.Size()) {
+ // EOS
+ Listener_->InputDone();
+ return;
+ }
+ // TODO(): support row and range indexes
+ auto payload = TMemoryInput(currentPayload.Begin(), currentPayload.Size());
+ arrow::BufferBuilder bb;
+ ARROW_OK(bb.Reserve(currentPayload.Size()));
+ ARROW_OK(bb.Append((const uint8_t*)payload.Buf(), currentPayload.Size()));
+ LocalListeners_[inputIdx]->Consume(*bb.Finish());
+ InputDone(inputIdx);
+ }
+
+ // Return input back to queue
+ void InputDone(auto input) {
+ std::lock_guard _(Mtx_);
+ InputsQueue_.emplace(input);
+ }
+
+ TResultBatch::TPtr Next() {
+ auto result = Listener_->Get();
+ return result;
+ }
+
+ std::vector<TResultBatch::TPtr> FallbackHandler(size_t idx, NYT::TSharedRef payload) {
+ if (!payload.Size()) {
+ return {};
+ }
+ // We're have only one mkql reader, protect it if 2 fallbacks happen at the same time
+ std::lock_guard _(FallbackMtx_);
+ auto currentReader_ = std::make_shared<TPayloadRPCReader>(std::move(payload));
+
+ // TODO(): save and recover row indexes
+ FallbackReader_.SetReader(*currentReader_, 1, 4_MB, ui32(Settings_->OriginalIndexes[idx]), true);
+ // If we don't save the reader, after exiting FallbackHandler it will be destroyed,
+ // but FallbackReader points on it yet.
+ Reader_ = currentReader_;
+ FallbackReader_.Next();
+ while (FallbackReader_.IsValid()) {
+ auto currentRow = std::move(FallbackReader_.GetRow());
+ if (!Settings_->Specs->InputGroups.empty()) {
+ currentRow = std::move(HolderFactory_.CreateVariantHolder(currentRow.Release(), Settings_->Specs->InputGroups.at(Settings_->OriginalIndexes[idx])));
+ }
+ BlockBuilder_.Add(currentRow);
+ FallbackReader_.Next();
+ }
+ return BlockBuilder_.Build();
+ }
+
+ void Finish() {
+ FallbackReader_.Finish();
+ Pool_->Shutdown();
+ for (auto& e: LocalListeners_) {
+ e->Finish();
+ }
+ Self_ = nullptr;
+ }
+
+ void SetSelfAndRun(TPtr self) {
+ MainInvoker_ = NYT::GetCurrentInvoker();
+ Self_ = self;
+ Pool_ = NYT::NConcurrency::CreateThreadPool(Inflight_, "rpc_reader_inflight");
+ // Run Inflight_ reads at the same time
+ for (size_t i = 0; i < Inflight_; ++i) {
+ RunRead();
+ }
+ }
+
+private:
+ NYT::IInvoker* MainInvoker_;
+ NYT::NConcurrency::IThreadPoolPtr Pool_;
+ std::mutex Mtx_;
+ std::mutex FallbackMtx_;
+ std::unique_ptr<TSettingsHolder> Settings_;
+ std::vector<std::shared_ptr<TLocalListener>> LocalListeners_;
+ std::vector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> Inputs_;
+ TMkqlReaderImpl FallbackReader_;
+ TBlockBuilder BlockBuilder_;
+ std::shared_ptr<TPayloadRPCReader> Reader_;
+ std::queue<size_t> InputsQueue_;
+ TListener::TPtr Listener_;
+ TPtr Self_;
+ size_t Inflight_;
+ const THolderFactory& HolderFactory_;
+};
+
+class TState: public TComputationValue<TState> {
+ using TBase = TComputationValue<TState>;
+public:
+ TState(TMemoryUsageInfo* memInfo, TSource::TPtr source, size_t width, TType* type)
+ : TBase(memInfo)
+ , Source_(std::move(source))
+ , Width_(width)
+ , Type_(type)
+ , Types_(width)
+ {
+ for (size_t i = 0; i < Width_; ++i) {
+ Types_[i] = NArrow::GetArrowType(AS_TYPE(TStructType, Type_)->GetMemberType(i));
+ }
+ }
+
+ EFetchResult FetchValues(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) {
+ auto batch = Source_->Next();
+ if (!batch) {
+ Source_->Finish();
+ return EFetchResult::Finish;
+ }
+ for (size_t i = 0; i < Width_; ++i) {
+ if (!output[i]) {
+ continue;
+ }
+ if(!batch->Columns[i].type()->Equals(Types_[i])) {
+ *(output[i]) = ctx.HolderFactory.CreateArrowBlock(ARROW_RESULT(arrow::compute::Cast(batch->Columns[i], Types_[i])));
+ continue;
+ }
+ *(output[i]) = ctx.HolderFactory.CreateArrowBlock(std::move(batch->Columns[i]));
+ }
+ if (output[Width_]) {
+ *(output[Width_]) = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(ui64(batch->RowsCnt)));
+ }
+ return EFetchResult::One;
+ }
+
+private:
+ TSource::TPtr Source_;
+ const size_t Width_;
+ TType* Type_;
+ std::vector<std::shared_ptr<arrow::DataType>> Types_;
+};
+};
+
+class TDqYtReadBlockWrapper : public TStatefulWideFlowComputationNode<TDqYtReadBlockWrapper> {
+using TBaseComputation = TStatefulWideFlowComputationNode<TDqYtReadBlockWrapper>;
+public:
+
+ TDqYtReadBlockWrapper(const TComputationNodeFactoryContext& ctx, const TString& clusterName,
+ const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec,
+ const TVector<ui32>& inputGroups,
+ TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight,
+ size_t timeout) : TBaseComputation(ctx.Mutables, this, EValueRepresentation::Boxed)
+ , Width(AS_TYPE(TStructType, itemType)->GetMembersCount())
+ , CodecCtx(ctx.Env, ctx.FunctionRegistry, &ctx.HolderFactory)
+ , ClusterName(clusterName)
+ , Token(token)
+ , SamplingSpec(samplingSpec)
+ , Tables(std::move(tables))
+ , Inflight(inflight)
+ , Timeout(timeout)
+ , Type(itemType)
+ {
+ // TODO() Enable range indexes
+ Specs.SetUseSkiff("", TMkqlIOSpecs::ESystemField::RowIndex);
+ Specs.Init(CodecCtx, inputSpec, inputGroups, tableNames, itemType, {}, {}, jobStats);
+ }
+
+ void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
+ auto settings = CreateInputStreams(true, Token, ClusterName, Timeout, Inflight > 1, Tables, SamplingSpec);
+ settings->Specs = &Specs;
+ settings->Pool = arrow::default_memory_pool();
+ settings->PgBuilder = &ctx.Builder->GetPgBuilder();
+ auto source = std::make_shared<TSource>(std::move(settings), Inflight, Type, ctx.HolderFactory);
+ source->SetSelfAndRun(source);
+ state = ctx.HolderFactory.Create<TState>(source, Width, Type);
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
+ if (!state.HasValue()) {
+ MakeState(ctx, state);
+ }
+ return static_cast<TState&>(*state.AsBoxed()).FetchValues(ctx, output);
+ }
+
+ void RegisterDependencies() const final {}
+
+ const ui32 Width;
+ NCommon::TCodecContext CodecCtx;
+ TMkqlIOSpecs Specs;
+
+ TString ClusterName;
+ TString Token;
+ NYT::TNode SamplingSpec;
+ TVector<std::pair<NYT::TRichYPath, NYT::TFormat>> Tables;
+ size_t Inflight;
+ size_t Timeout;
+ TType* Type;
+};
+
+IComputationNode* CreateDqYtReadBlockWrapper(const TComputationNodeFactoryContext& ctx, const TString& clusterName,
+ const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec,
+ const TVector<ui32>& inputGroups,
+ TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight,
+ size_t timeout)
+{
+ return new TDqYtReadBlockWrapper(ctx, clusterName, token, inputSpec, samplingSpec, inputGroups, itemType, tableNames, std::move(tables), jobStats, inflight, timeout);
+}
+} \ No newline at end of file
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h
new file mode 100644
index 00000000000..0bcce64b114
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_block_reader.h
@@ -0,0 +1,17 @@
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+#include <ydb/library/yql/minikql/mkql_stats_registry.h>
+#include <ydb/library/yql/minikql/mkql_node.h>
+
+#include <ydb/library/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.h>
+#include <ydb/library/yql/providers/yt/codec/yt_codec.h>
+#include <ydb/library/yql/providers/common/codec/yql_codec.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h>
+
+namespace NYql::NDqs {
+NKikimr::NMiniKQL::IComputationNode* CreateDqYtReadBlockWrapper(const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx, const TString& clusterName,
+ const TString& token, const NYT::TNode& inputSpec, const NYT::TNode& samplingSpec,
+ const TVector<ui32>& inputGroups,
+ NKikimr::NMiniKQL::TType* itemType, const TVector<TString>& tableNames, TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>&& tables, NKikimr::NMiniKQL::IStatsRegistry* jobStats, size_t inflight,
+ size_t timeout);
+}
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp
index c02a693f699..4d88adfc6eb 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_factory.cpp
@@ -9,8 +9,8 @@ using namespace NKikimr::NMiniKQL;
TComputationNodeFactory GetDqYtFactory(NKikimr::NMiniKQL::IStatsRegistry* jobStats) {
return [=] (TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
TStringBuf name = callable.GetType()->GetName();
- if (name == "DqYtRead") {
- return NDqs::WrapDqYtRead(callable, jobStats, ctx);
+ if (name == "DqYtRead" || name == "DqYtBlockRead") {
+ return NDqs::WrapDqYtRead(callable, jobStats, ctx, name == "DqYtBlockRead");
}
if (name == "YtDqRowsWideWrite") {
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
index 8461b8e495b..578c5115079 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.cpp
@@ -1,7 +1,7 @@
#include "dq_yt_reader.h"
#include "dq_yt_reader_impl.h"
-
+#include "dq_yt_block_reader.h"
#include "dq_yt_rpc_reader.h"
namespace NYql::NDqs {
@@ -63,7 +63,7 @@ using TInputType = NYT::TRawTableReaderPtr;
}
};
-IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const TComputationNodeFactoryContext& ctx) {
+IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const TComputationNodeFactoryContext& ctx, bool useBlocks) {
MKQL_ENSURE(callable.GetInputsCount() == 8 || callable.GetInputsCount() == 9, "Expected 8 or 9 arguments.");
TString clusterName(AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef());
@@ -102,15 +102,23 @@ IComputationNode* WrapDqYtRead(TCallable& callable, NKikimr::NMiniKQL::IStatsReg
#ifdef __linux__
size_t inflight(AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<size_t>());
if (inflight) {
- return new TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>(ctx, clusterName, token,
- NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(),
- inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout);
+ if (useBlocks) {
+ return CreateDqYtReadBlockWrapper(ctx, clusterName, token,
+ NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(),
+ inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout);
+ } else {
+ return new TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>(ctx, clusterName, token,
+ NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(),
+ inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout);
+ }
} else {
+ YQL_ENSURE(!useBlocks);
return new TDqYtReadWrapperBase<TDqYtReadWrapperHttp, TFileInputState>(ctx, clusterName, token,
NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(),
inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, inflight, timeout);
}
#else
+ YQL_ENSURE(!useBlocks);
return new TDqYtReadWrapperBase<TDqYtReadWrapperHttp, TFileInputState>(ctx, clusterName, token,
NYT::NodeFromYsonString(inputSpec), samplingSpec ? NYT::NodeFromYsonString(samplingSpec) : NYT::TNode(),
inputGroups, static_cast<TType*>(callable.GetInput(5).GetNode()), tableNames, std::move(tables), jobStats, 0, timeout);
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h
index 2dd5c2deefd..5f1a471a690 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_reader.h
@@ -6,6 +6,6 @@
namespace NYql::NDqs {
-NKikimr::NMiniKQL::IComputationNode* WrapDqYtRead(NKikimr::NMiniKQL::TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx);
+NKikimr::NMiniKQL::IComputationNode* WrapDqYtRead(NKikimr::NMiniKQL::TCallable& callable, NKikimr::NMiniKQL::IStatsRegistry* jobStats, const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx, bool useBlocks);
} // NYql
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp
new file mode 100644
index 00000000000..735bb9ab9b0
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.cpp
@@ -0,0 +1,123 @@
+#include "dq_yt_rpc_helpers.h"
+
+#include <ydb/library/yql/providers/yt/lib/yt_rpc_helpers/yt_convert_helpers.h>
+
+namespace NYql::NDqs {
+
+NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath) {
+ NYT::NYPath::TRichYPath tableYPath(richYPath.Path_);
+ const auto& rngs = richYPath.GetRanges();
+ if (rngs) {
+ TVector<NYT::NChunkClient::TReadRange> ranges;
+ for (const auto& rng: *rngs) {
+ auto& range = ranges.emplace_back();
+ if (rng.LowerLimit_.Offset_) {
+ range.LowerLimit().SetOffset(*rng.LowerLimit_.Offset_);
+ }
+
+ if (rng.LowerLimit_.TabletIndex_) {
+ range.LowerLimit().SetTabletIndex(*rng.LowerLimit_.TabletIndex_);
+ }
+
+ if (rng.LowerLimit_.RowIndex_) {
+ range.LowerLimit().SetRowIndex(*rng.LowerLimit_.RowIndex_);
+ }
+
+ if (rng.UpperLimit_.Offset_) {
+ range.UpperLimit().SetOffset(*rng.UpperLimit_.Offset_);
+ }
+
+ if (rng.UpperLimit_.TabletIndex_) {
+ range.UpperLimit().SetTabletIndex(*rng.UpperLimit_.TabletIndex_);
+ }
+
+ if (rng.UpperLimit_.RowIndex_) {
+ range.UpperLimit().SetRowIndex(*rng.UpperLimit_.RowIndex_);
+ }
+ }
+ tableYPath.SetRanges(std::move(ranges));
+ }
+
+ if (richYPath.Columns_) {
+ tableYPath.SetColumns(richYPath.Columns_->Parts_);
+ }
+
+ return tableYPath;
+}
+
+
+std::unique_ptr<TSettingsHolder> CreateInputStreams(bool isArrow, const TString& token, const TString& clusterName, const ui64 timeout, bool unordered, const TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>& tables, NYT::TNode samplingSpec) {
+ auto connectionConfig = NYT::New<NYT::NApi::NRpcProxy::TConnectionConfig>();
+ connectionConfig->ClusterUrl = clusterName;
+ connectionConfig->DefaultTotalStreamingTimeout = TDuration::MilliSeconds(timeout);
+ auto connection = CreateConnection(connectionConfig);
+ auto clientOptions = NYT::NApi::TClientOptions();
+
+ if (token) {
+ clientOptions.Token = token;
+ }
+
+ auto client = DynamicPointerCast<NYT::NApi::NRpcProxy::TClient>(connection->CreateClient(clientOptions));
+ Y_VERIFY(client);
+ auto apiServiceProxy = client->CreateApiServiceProxy();
+
+ TVector<NYT::TFuture<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>> waitFor;
+
+ size_t inputIdx = 0;
+ TVector<size_t> originalIndexes;
+
+ for (auto [richYPath, format]: tables) {
+ if (richYPath.GetRanges() && richYPath.GetRanges()->empty()) {
+ ++inputIdx;
+ continue;
+ }
+ originalIndexes.emplace_back(inputIdx);
+
+ auto request = apiServiceProxy.ReadTable();
+ client->InitStreamingRequest(*request);
+ request->ClientAttachmentsStreamingParameters().ReadTimeout = TDuration::MilliSeconds(timeout);
+
+ TString ppath;
+ auto tableYPath = ConvertYPathFromOld(richYPath);
+
+ NYT::NYPath::ToProto(&ppath, tableYPath);
+ request->set_path(ppath);
+ request->set_desired_rowset_format(isArrow ? NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_ARROW : NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT);
+ if (isArrow) {
+ request->set_arrow_fallback_rowset_format(NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT);
+ }
+
+ request->set_enable_row_index(true);
+ request->set_enable_table_index(true);
+ // TODO() Enable range indexes
+ request->set_enable_range_index(!isArrow);
+
+ request->set_unordered(unordered);
+
+ // https://a.yandex-team.ru/arcadia/yt/yt_proto/yt/client/api/rpc_proxy/proto/api_service.proto?rev=r11519304#L2338
+ if (!samplingSpec.IsUndefined()) {
+ TStringStream ss;
+ samplingSpec.Save(&ss);
+ request->set_config(ss.Str());
+ }
+
+ ConfigureTransaction(request, richYPath);
+
+ // Get skiff format yson string
+ TStringStream fmt;
+ format.Config.Save(&fmt);
+ request->set_format(fmt.Str());
+
+ waitFor.emplace_back(std::move(CreateRpcClientInputStream(std::move(request)).ApplyUnique(BIND([](NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr&& stream) {
+ // first packet contains meta, skip it
+ return stream->Read().ApplyUnique(BIND([stream = std::move(stream)](NYT::TSharedRef&&) {
+ return std::move(stream);
+ }));
+ }))));
+ }
+ TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> rawInputs;
+ NYT::NConcurrency::WaitFor(NYT::AllSucceeded(waitFor)).ValueOrThrow().swap(rawInputs);
+ return std::make_unique<TSettingsHolder>(std::move(connection), std::move(client), std::move(rawInputs), std::move(originalIndexes));
+}
+
+}
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h
new file mode 100644
index 00000000000..7bab55334a0
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_helpers.h
@@ -0,0 +1,67 @@
+#pragma once
+
+#include <ydb/library/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.h>
+#include <yt/cpp/mapreduce/common/helpers.h>
+#include <yt/cpp/mapreduce/interface/client.h>
+#include <yt/cpp/mapreduce/interface/serialize.h>
+#include <yt/cpp/mapreduce/interface/config.h>
+#include <yt/cpp/mapreduce/interface/common.h>
+
+#include <yt/yt/library/auth/auth.h>
+#include <yt/yt/client/api/client.h>
+#include <yt/yt/client/api/rpc_proxy/client_impl.h>
+#include <yt/yt/client/api/rpc_proxy/config.h>
+#include <yt/yt/client/api/rpc_proxy/connection.h>
+#include <yt/yt/client/api/rpc_proxy/row_stream.h>
+
+namespace NYql::NDqs {
+NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath);
+class TPayloadRPCReader : public NYT::TRawTableReader {
+public:
+ TPayloadRPCReader(NYT::TSharedRef&& payload) : Payload_(std::move(payload)), PayloadStream_(Payload_.Begin(), Payload_.Size()) {}
+
+ bool Retry(const TMaybe<ui32>&, const TMaybe<ui64>&) override {
+ return false;
+ }
+
+ void ResetRetries() override {
+
+ }
+
+ bool HasRangeIndices() const override {
+ return true;
+ };
+
+ size_t DoRead(void* buf, size_t len) override {
+ if (!PayloadStream_.Exhausted()) {
+ return PayloadStream_.Read(buf, len);
+ }
+ return 0;
+ };
+
+ virtual ~TPayloadRPCReader() override {
+ }
+private:
+ NYT::TSharedRef Payload_;
+ TMemoryInput PayloadStream_;
+};
+
+struct TSettingsHolder : public TNonCopyable {
+ TSettingsHolder(NYT::NApi::IConnectionPtr&& connection, NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient>&& client,
+ TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& inputs, TVector<size_t>&& originalIndexes)
+ : Connection(std::move(connection))
+ , Client(std::move(client))
+ , RawInputs(std::move(inputs))
+ , OriginalIndexes(std::move(originalIndexes)) {};
+ NYT::NApi::IConnectionPtr Connection;
+ NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient> Client;
+ const TMkqlIOSpecs* Specs = nullptr;
+ arrow::MemoryPool* Pool = nullptr;
+ const NUdf::IPgBuilder* PgBuilder = nullptr;
+ TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> RawInputs;
+ TVector<size_t> OriginalIndexes;
+};
+
+std::unique_ptr<TSettingsHolder> CreateInputStreams(bool isArrow, const TString& token, const TString& clusterName, const ui64 timeout, bool unordered, const TVector<std::pair<NYT::TRichYPath, NYT::TFormat>>& tables, NYT::TNode samplingSpec);
+
+};
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp
index d9cd70fa492..2bdd1e41a6c 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.cpp
@@ -1,4 +1,5 @@
#include "dq_yt_rpc_reader.h"
+#include "dq_yt_rpc_helpers.h"
#include "yt/cpp/mapreduce/common/helpers.h"
@@ -11,86 +12,12 @@
#include <yt/yt/client/api/rpc_proxy/connection.h>
#include <yt/yt/client/api/rpc_proxy/row_stream.h>
-#include <ydb/library/yql/providers/yt/lib/yt_rpc_helpers/yt_convert_helpers.h>
-
namespace NYql::NDqs {
using namespace NKikimr::NMiniKQL;
namespace {
TStatKey RPCReaderAwaitingStallTime("Job_RPCReaderAwaitingStallTime", true);
-NYT::NYPath::TRichYPath ConvertYPathFromOld(const NYT::TRichYPath& richYPath) {
- NYT::NYPath::TRichYPath tableYPath(richYPath.Path_);
- const auto& rngs = richYPath.GetRanges();
- if (rngs) {
- TVector<NYT::NChunkClient::TReadRange> ranges;
- for (const auto& rng: *rngs) {
- auto& range = ranges.emplace_back();
- if (rng.LowerLimit_.Offset_) {
- range.LowerLimit().SetOffset(*rng.LowerLimit_.Offset_);
- }
-
- if (rng.LowerLimit_.TabletIndex_) {
- range.LowerLimit().SetTabletIndex(*rng.LowerLimit_.TabletIndex_);
- }
-
- if (rng.LowerLimit_.RowIndex_) {
- range.LowerLimit().SetRowIndex(*rng.LowerLimit_.RowIndex_);
- }
-
- if (rng.UpperLimit_.Offset_) {
- range.UpperLimit().SetOffset(*rng.UpperLimit_.Offset_);
- }
-
- if (rng.UpperLimit_.TabletIndex_) {
- range.UpperLimit().SetTabletIndex(*rng.UpperLimit_.TabletIndex_);
- }
-
- if (rng.UpperLimit_.RowIndex_) {
- range.UpperLimit().SetRowIndex(*rng.UpperLimit_.RowIndex_);
- }
- }
- tableYPath.SetRanges(std::move(ranges));
- }
-
- if (richYPath.Columns_) {
- tableYPath.SetColumns(richYPath.Columns_->Parts_);
- }
-
- return tableYPath;
-}
-
-class TFakeRPCReader : public NYT::TRawTableReader {
-public:
- TFakeRPCReader(NYT::TSharedRef&& payload) : Payload_(std::move(payload)), PayloadStream_(Payload_.Begin(), Payload_.Size()) {}
-
- bool Retry(const TMaybe<ui32>& rangeIndex, const TMaybe<ui64>& rowIndex) override {
- Y_UNUSED(rangeIndex);
- Y_UNUSED(rowIndex);
- return false;
- }
-
- void ResetRetries() override {
-
- }
-
- bool HasRangeIndices() const override {
- return true;
- };
-
- size_t DoRead(void* buf, size_t len) override {
- if (!PayloadStream_.Exhausted()) {
- return PayloadStream_.Read(buf, len);
- }
- return 0;
- };
-
- virtual ~TFakeRPCReader() override {
- }
-private:
- NYT::TSharedRef Payload_;
- TMemoryInput PayloadStream_;
-};
}
#ifdef RPC_PRINT_TIME
int cnt = 0;
@@ -102,28 +29,21 @@ void print_add(int x) {
}
#endif
-struct TSettingsHolder {
- NYT::NApi::IConnectionPtr Connection;
- NYT::TIntrusivePtr<NYT::NApi::NRpcProxy::TClient> Client;
-};
-
TParallelFileInputState::TParallelFileInputState(const TMkqlIOSpecs& spec,
const NKikimr::NMiniKQL::THolderFactory& holderFactory,
- TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& rawInputs,
size_t blockSize,
size_t inflight,
- std::unique_ptr<TSettingsHolder>&& settings,
- TVector<size_t>&& originalIndexes)
- : InnerState_(new TInnerState (rawInputs.size()))
- , StateByReader_(rawInputs.size())
+ std::unique_ptr<TSettingsHolder>&& settings)
+ : InnerState_(new TInnerState (settings->RawInputs.size()))
+ , StateByReader_(settings->RawInputs.size())
, Spec_(&spec)
, HolderFactory_(holderFactory)
- , RawInputs_(std::move(rawInputs))
+ , RawInputs_(std::move(settings->RawInputs))
, BlockSize_(blockSize)
, Inflight_(inflight)
- , Settings_(std::move(settings))
, TimerAwaiting_(RPCReaderAwaitingStallTime, 100)
- , OriginalIndexes_(std::move(originalIndexes))
+ , OriginalIndexes_(std::move(settings->OriginalIndexes))
+ , Settings_(std::move(settings))
{
#ifdef RPC_PRINT_TIME
print_add(1);
@@ -277,79 +197,14 @@ bool TParallelFileInputState::NextValue() {
InnerState_->WaitPromise = NYT::NewPromise<void>();
}
CurrentInput_ = result.Input_;
- CurrentReader_ = MakeIntrusive<TFakeRPCReader>(std::move(result.Value_));
+ CurrentReader_ = MakeIntrusive<TPayloadRPCReader>(std::move(result.Value_));
MkqlReader_.SetReader(*CurrentReader_, 1, BlockSize_, ui32(OriginalIndexes_[CurrentInput_]), true, StateByReader_[CurrentInput_].CurrentRow, StateByReader_[CurrentInput_].CurrentRange);
MkqlReader_.Next();
}
}
void TDqYtReadWrapperRPC::MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
- auto connectionConfig = NYT::New<NYT::NApi::NRpcProxy::TConnectionConfig>();
- connectionConfig->ClusterUrl = ClusterName;
- connectionConfig->DefaultTotalStreamingTimeout = TDuration::MilliSeconds(Timeout);
- auto connection = CreateConnection(connectionConfig);
- auto clientOptions = NYT::NApi::TClientOptions();
-
- if (Token) {
- clientOptions.Token = Token;
- }
-
- auto client = DynamicPointerCast<NYT::NApi::NRpcProxy::TClient>(connection->CreateClient(clientOptions));
- Y_VERIFY(client);
- auto apiServiceProxy = client->CreateApiServiceProxy();
-
- TVector<NYT::TFuture<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>> waitFor;
- TVector<size_t> originalIndexes;
- size_t inputIdx = 0;
- for (auto [richYPath, format]: Tables) {
- if (richYPath.GetRanges() && richYPath.GetRanges()->empty()) {
- ++inputIdx;
- continue;
- }
-
- originalIndexes.push_back(inputIdx++);
- auto request = apiServiceProxy.ReadTable();
- client->InitStreamingRequest(*request);
- request->ClientAttachmentsStreamingParameters().ReadTimeout = TDuration::MilliSeconds(Timeout);
-
- TString ppath;
- auto tableYPath = ConvertYPathFromOld(richYPath);
-
- NYT::NYPath::ToProto(&ppath, tableYPath);
- request->set_path(ppath);
- request->set_desired_rowset_format(NYT::NApi::NRpcProxy::NProto::ERowsetFormat::RF_FORMAT);
-
- request->set_enable_table_index(true);
- request->set_enable_range_index(true);
- request->set_enable_row_index(true);
- request->set_unordered(Inflight > 1);
-
- // https://a.yandex-team.ru/arcadia/yt/yt_proto/yt/client/api/rpc_proxy/proto/api_service.proto?rev=r11519304#L2338
- if (!SamplingSpec.IsUndefined()) {
- TStringStream ss;
- SamplingSpec.Save(&ss);
- request->set_config(ss.Str());
- }
- ConfigureTransaction(request, richYPath);
-
- // Get skiff format yson string
- TStringStream fmt;
- format.Config.Save(&fmt);
- request->set_format(fmt.Str());
-
- waitFor.emplace_back(std::move(CreateRpcClientInputStream(std::move(request)).ApplyUnique(BIND([](NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr&& stream) {
- // first packet contains meta, skip it
- return stream->Read().ApplyUnique(BIND([stream = std::move(stream)](NYT::TSharedRef&&) {
- return std::move(stream);
- }));
- }))));
- }
-
- TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr> rawInputs;
- NYT::NConcurrency::WaitFor(NYT::AllSucceeded(waitFor)).ValueOrThrow().swap(rawInputs);
- state = ctx.HolderFactory.Create<TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>::TState>(
- Specs, ctx.HolderFactory, std::move(rawInputs), 4_MB, Inflight,
- std::make_unique<TSettingsHolder>(TSettingsHolder{std::move(connection), std::move(client)}), std::move(originalIndexes)
- );
+ auto settings = CreateInputStreams(false, Token, ClusterName, Timeout, Inflight > 1, Tables, SamplingSpec);
+ state = ctx.HolderFactory.Create<TDqYtReadWrapperBase<TDqYtReadWrapperRPC, TParallelFileInputState>::TState>(Specs, ctx.HolderFactory, 4_MB, Inflight, std::move(settings));
}
}
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h
index 84d7a31b79a..0219ac56450 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/dq_yt_rpc_reader.h
@@ -22,11 +22,9 @@ struct TReaderState {
public:
TParallelFileInputState(const TMkqlIOSpecs& spec,
const NKikimr::NMiniKQL::THolderFactory& holderFactory,
- TVector<NYT::NConcurrency::IAsyncZeroCopyInputStreamPtr>&& rawInputs,
size_t blockSize,
size_t inflight,
- std::unique_ptr<TSettingsHolder>&& client,
- TVector<size_t>&& originalIndexes);
+ std::unique_ptr<TSettingsHolder>&& settings);
size_t GetTableIndex() const;
@@ -71,12 +69,11 @@ private:
size_t CurrentRecord_ = 1;
size_t Inflight_ = 1;
bool Valid_ = true;
- std::unique_ptr<TSettingsHolder> Settings_;
NUdf::TUnboxedValue CurrentValue_;
std::function<void()> OnNextBlockCallback_;
NKikimr::NMiniKQL::TSamplingStatTimer TimerAwaiting_;
TVector<size_t> OriginalIndexes_;
-
+ std::unique_ptr<TSettingsHolder> Settings_;
};
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp
new file mode 100644
index 00000000000..920575a5b84
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.cpp
@@ -0,0 +1,1404 @@
+#include "stream_decoder.h"
+
+#include <algorithm>
+#include <climits>
+#include <cstdint>
+#include <cstring>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <flatbuffers/flatbuffers.h> // IWYU pragma: export
+
+#include <arrow/array.h>
+#include <arrow/buffer.h>
+#include <arrow/extension_type.h>
+#include <arrow/io/caching.h>
+#include <arrow/io/interfaces.h>
+#include <arrow/io/memory.h>
+#include <arrow/ipc/message.h>
+#include <arrow/ipc/metadata_internal.h>
+#include <arrow/ipc/util.h>
+#include <arrow/ipc/writer.h>
+#include <arrow/record_batch.h>
+#include <arrow/sparse_tensor.h>
+#include <arrow/status.h>
+#include <arrow/type.h>
+#include <arrow/type_traits.h>
+#include <arrow/util/bit_util.h>
+#include <arrow/util/bitmap_ops.h>
+#include <arrow/util/checked_cast.h>
+#include <arrow/util/compression.h>
+#include <arrow/util/endian.h>
+#include <arrow/util/key_value_metadata.h>
+#include <arrow/util/parallel.h>
+#include <arrow/util/string.h>
+#include <arrow/util/thread_pool.h>
+#include <arrow/util/ubsan.h>
+#include <arrow/visitor_inline.h>
+
+#include <generated/File.fbs.h> // IWYU pragma: export
+#include <generated/Message.fbs.h>
+#include <generated/Schema.fbs.h>
+#include <generated/SparseTensor.fbs.h>
+
+namespace arrow {
+
+namespace flatbuf = org::apache::arrow::flatbuf;
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::GetByteWidth;
+
+namespace ipc {
+
+using internal::FileBlock;
+using internal::kArrowMagicBytes;
+
+namespace NDqs {
+Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) {
+ if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) {
+ ARROW_ASSIGN_OR_RAISE(*metadata, (*metadata)->CopySlice(0, (*metadata)->size()));
+ }
+ return Status::OK();
+}
+
+Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) {
+ const flatbuf::Message* fb_message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message));
+ *body_length = fb_message->bodyLength();
+ if (*body_length < 0) {
+ return Status::IOError("Invalid IPC message: negative bodyLength");
+ }
+ return Status::OK();
+}
+
+std::string FormatMessageType(MessageType type) {
+ switch (type) {
+ case MessageType::SCHEMA:
+ return "schema";
+ case MessageType::RECORD_BATCH:
+ return "record batch";
+ case MessageType::DICTIONARY_BATCH:
+ return "dictionary";
+ case MessageType::TENSOR:
+ return "tensor";
+ case MessageType::SPARSE_TENSOR:
+ return "sparse tensor";
+ default:
+ break;
+ }
+ return "unknown";
+}
+
+Status WriteMessage(const Buffer& message, const IpcWriteOptions& options,
+ io::OutputStream* file, int32_t* message_length) {
+ const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
+ const int32_t flatbuffer_size = static_cast<int32_t>(message.size());
+
+ int32_t padded_message_length = static_cast<int32_t>(
+ PaddedLength(flatbuffer_size + prefix_size, options.alignment));
+
+ int32_t padding = padded_message_length - flatbuffer_size - prefix_size;
+
+ *message_length = padded_message_length;
+
+ if (!options.write_legacy_ipc_format) {
+ RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t)));
+ }
+
+ int32_t padded_flatbuffer_size =
+ BitUtil::ToLittleEndian(padded_message_length - prefix_size);
+ RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t)));
+
+ RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size));
+ if (padding > 0) {
+ RETURN_NOT_OK(file->Write(kPaddingBytes, padding));
+ }
+
+ return Status::OK();
+}
+
+static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t);
+static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t);
+
+class MessageDecoder2::MessageDecoderImpl {
+ public:
+ explicit MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,
+ State initial_state, int64_t initial_next_required_size,
+ MemoryPool* pool)
+ : listener_(std::move(listener)),
+ pool_(pool),
+ state_(initial_state),
+ next_required_size_(initial_next_required_size),
+ save_initial_size_(initial_next_required_size),
+ chunks_(),
+ buffered_size_(0),
+ metadata_(nullptr) {}
+
+ Status ConsumeData(const uint8_t* data, int64_t size) {
+ if (buffered_size_ == 0) {
+ while (size > 0 && size >= next_required_size_) {
+ auto used_size = next_required_size_;
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_));
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
+ break;
+ case State::METADATA: {
+ auto buffer = std::make_shared<Buffer>(data, next_required_size_);
+ RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
+ } break;
+ case State::BODY: {
+ auto buffer = std::make_shared<Buffer>(data, next_required_size_);
+ RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
+ } break;
+ case State::EOS:
+ return Status::OK();
+ }
+ data += used_size;
+ size -= used_size;
+ }
+ }
+
+ if (size == 0) {
+ return Status::OK();
+ }
+
+ chunks_.push_back(std::make_shared<Buffer>(data, size));
+ buffered_size_ += size;
+ return ConsumeChunks();
+ }
+
+ Status ConsumeBuffer(std::shared_ptr<Buffer> buffer) {
+ if (buffered_size_ == 0) {
+ while (buffer->size() >= next_required_size_) {
+ auto used_size = next_required_size_;
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialBuffer(buffer));
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer));
+ break;
+ case State::METADATA:
+ if (buffer->size() == next_required_size_) {
+ return ConsumeMetadataBuffer(buffer);
+ } else {
+ auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeMetadataBuffer(sliced_buffer));
+ }
+ break;
+ case State::BODY:
+ if (buffer->size() == next_required_size_) {
+ return ConsumeBodyBuffer(buffer);
+ } else {
+ auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeBodyBuffer(sliced_buffer));
+ }
+ break;
+ case State::EOS:
+ return Status::OK();
+ }
+ if (buffer->size() == used_size) {
+ return Status::OK();
+ }
+ buffer = SliceBuffer(buffer, used_size);
+ }
+ }
+
+ if (buffer->size() == 0) {
+ return Status::OK();
+ }
+
+ buffered_size_ += buffer->size();
+ chunks_.push_back(std::move(buffer));
+ return ConsumeChunks();
+ }
+
+ int64_t next_required_size() const { return next_required_size_ - buffered_size_; }
+
+ MessageDecoder2::State state() const { return state_; }
+ void Reset() {
+ state_ = State::INITIAL;
+ next_required_size_ = save_initial_size_;
+ chunks_.clear();
+ buffered_size_ = 0;
+ metadata_ = nullptr;
+ }
+
+ private:
+ Status ConsumeChunks() {
+ while (state_ != State::EOS) {
+ if (buffered_size_ < next_required_size_) {
+ return Status::OK();
+ }
+
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialChunks());
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthChunks());
+ break;
+ case State::METADATA:
+ RETURN_NOT_OK(ConsumeMetadataChunks());
+ break;
+ case State::BODY:
+ RETURN_NOT_OK(ConsumeBodyChunks());
+ break;
+ case State::EOS:
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status ConsumeInitialData(const uint8_t* data, int64_t) {
+ return ConsumeInitial(BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
+ }
+
+ Status ConsumeInitialBuffer(const std::shared_ptr<Buffer>& buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer));
+ return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
+ }
+
+ Status ConsumeInitialChunks() {
+ int32_t continuation = 0;
+ RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation));
+ return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
+ }
+
+ Status ConsumeInitial(int32_t continuation) {
+ if (continuation == internal::kIpcContinuationToken) {
+ state_ = State::METADATA_LENGTH;
+ next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength;
+ RETURN_NOT_OK(listener_->OnMetadataLength());
+ // Valid IPC message, read the message length now
+ return Status::OK();
+ } else if (continuation == 0) {
+ state_ = State::EOS;
+ next_required_size_ = 0;
+ RETURN_NOT_OK(listener_->OnEOS());
+ return Status::OK();
+ } else if (continuation > 0) {
+ state_ = State::METADATA;
+ next_required_size_ = continuation;
+ RETURN_NOT_OK(listener_->OnMetadata());
+ return Status::OK();
+ } else {
+ return Status::IOError("Invalid IPC stream: negative continuation token");
+ }
+ }
+
+ Status ConsumeMetadataLengthData(const uint8_t* data, int64_t) {
+ return ConsumeMetadataLength(
+ BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
+ }
+
+ Status ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer>& buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer));
+ return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
+ }
+
+ Status ConsumeMetadataLengthChunks() {
+ int32_t metadata_length = 0;
+ RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length));
+ return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
+ }
+
+ Status ConsumeMetadataLength(int32_t metadata_length) {
+ if (metadata_length == 0) {
+ state_ = State::EOS;
+ next_required_size_ = 0;
+ RETURN_NOT_OK(listener_->OnEOS());
+ return Status::OK();
+ } else if (metadata_length > 0) {
+ state_ = State::METADATA;
+ next_required_size_ = metadata_length;
+ RETURN_NOT_OK(listener_->OnMetadata());
+ return Status::OK();
+ } else {
+ return Status::IOError("Invalid IPC message: negative metadata length");
+ }
+ }
+
+ Status ConsumeMetadataBuffer(const std::shared_ptr<Buffer>& buffer) {
+ if (buffer->is_cpu()) {
+ metadata_ = buffer;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(metadata_,
+ Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
+ }
+ return ConsumeMetadata();
+ }
+
+ Status ConsumeMetadataChunks() {
+ if (chunks_[0]->size() >= next_required_size_) {
+ if (chunks_[0]->size() == next_required_size_) {
+ if (chunks_[0]->is_cpu()) {
+ metadata_ = std::move(chunks_[0]);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ metadata_,
+ Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_)));
+ }
+ chunks_.erase(chunks_.begin());
+ } else {
+ metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_);
+ if (!chunks_[0]->is_cpu()) {
+ ARROW_ASSIGN_OR_RAISE(
+ metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_)));
+ }
+ chunks_[0] = SliceBuffer(chunks_[0], next_required_size_);
+ }
+ buffered_size_ -= next_required_size_;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
+ metadata_ = std::shared_ptr<Buffer>(metadata.release());
+ RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
+ }
+ return ConsumeMetadata();
+ }
+
+ Status ConsumeMetadata() {
+ RETURN_NOT_OK(MaybeAlignMetadata(&metadata_));
+ int64_t body_length = -1;
+ RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length));
+
+ state_ = State::BODY;
+ next_required_size_ = body_length;
+ RETURN_NOT_OK(listener_->OnBody());
+ if (next_required_size_ == 0) {
+ ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
+ std::shared_ptr<Buffer> shared_body(body.release());
+ return ConsumeBody(&shared_body);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer) {
+ return ConsumeBody(&buffer);
+ }
+
+ Status ConsumeBodyChunks() {
+ if (chunks_[0]->size() >= next_required_size_) {
+ auto used_size = next_required_size_;
+ if (chunks_[0]->size() == next_required_size_) {
+ RETURN_NOT_OK(ConsumeBody(&chunks_[0]));
+ chunks_.erase(chunks_.begin());
+ } else {
+ auto body = SliceBuffer(chunks_[0], 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeBody(&body));
+ chunks_[0] = SliceBuffer(chunks_[0], used_size);
+ }
+ buffered_size_ -= used_size;
+ return Status::OK();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
+ RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
+ std::shared_ptr<Buffer> shared_body(body.release());
+ return ConsumeBody(&shared_body);
+ }
+ }
+
+ Status ConsumeBody(std::shared_ptr<Buffer>* buffer) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
+ Message::Open(metadata_, *buffer));
+
+ RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message)));
+ state_ = State::INITIAL;
+ next_required_size_ = kMessageDecoderNextRequiredSizeInitial;
+ RETURN_NOT_OK(listener_->OnInitial());
+ return Status::OK();
+ }
+
+ Result<int32_t> ConsumeDataBufferInt32(const std::shared_ptr<Buffer>& buffer) {
+ if (buffer->is_cpu()) {
+ return util::SafeLoadAs<int32_t>(buffer->data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto cpu_buffer,
+ Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
+ return util::SafeLoadAs<int32_t>(cpu_buffer->data());
+ }
+ }
+
+ Status ConsumeDataChunks(int64_t nbytes, void* out) {
+ size_t offset = 0;
+ size_t n_used_chunks = 0;
+ auto required_size = nbytes;
+ std::shared_ptr<Buffer> last_chunk;
+ for (auto& chunk : chunks_) {
+ if (!chunk->is_cpu()) {
+ ARROW_ASSIGN_OR_RAISE(
+ chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_)));
+ }
+ auto data = chunk->data();
+ auto data_size = chunk->size();
+ auto copy_size = std::min(required_size, data_size);
+ memcpy(static_cast<uint8_t*>(out) + offset, data, copy_size);
+ n_used_chunks++;
+ offset += copy_size;
+ required_size -= copy_size;
+ if (required_size == 0) {
+ if (data_size != copy_size) {
+ last_chunk = SliceBuffer(chunk, copy_size);
+ }
+ break;
+ }
+ }
+ chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks);
+ if (last_chunk.get() != nullptr) {
+ chunks_.insert(chunks_.begin(), std::move(last_chunk));
+ }
+ buffered_size_ -= offset;
+ return Status::OK();
+ }
+
+ std::shared_ptr<MessageDecoderListener> listener_;
+ MemoryPool* pool_;
+ State state_;
+ int64_t next_required_size_, save_initial_size_;
+ std::vector<std::shared_ptr<Buffer>> chunks_;
+ int64_t buffered_size_;
+ std::shared_ptr<Buffer> metadata_; // Must be CPU buffer
+};
+
+MessageDecoder2::MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener,
+ MemoryPool* pool) {
+ impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL,
+ kMessageDecoderNextRequiredSizeInitial, pool));
+}
+
+MessageDecoder2::MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener,
+ State initial_state, int64_t initial_next_required_size,
+ MemoryPool* pool) {
+ impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state,
+ initial_next_required_size, pool));
+}
+
+MessageDecoder2::~MessageDecoder2() {}
+
+Status MessageDecoder2::Consume(const uint8_t* data, int64_t size) {
+ return impl_->ConsumeData(data, size);
+}
+
+void MessageDecoder2::Reset() {
+ impl_->Reset();
+}
+
+Status MessageDecoder2::Consume(std::shared_ptr<Buffer> buffer) {
+ return impl_->ConsumeBuffer(buffer);
+}
+
+int64_t MessageDecoder2::next_required_size() const { return impl_->next_required_size(); }
+
+MessageDecoder2::State MessageDecoder2::state() const { return impl_->state(); }
+
+enum class DictionaryKind { New, Delta, Replacement };
+
+Status InvalidMessageType(MessageType expected, MessageType actual) {
+ return Status::IOError("Expected IPC message of type ", ::arrow::ipc::FormatMessageType(expected),
+ " but got ", ::arrow::ipc::FormatMessageType(actual));
+}
+
+#define CHECK_MESSAGE_TYPE(expected, actual) \
+ do { \
+ if ((actual) != (expected)) { \
+ return InvalidMessageType((expected), (actual)); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_BODY(message) \
+ do { \
+ if ((message).body() == nullptr) { \
+ return Status::IOError("Expected body in IPC message of type ", \
+ ::arrow::ipc::FormatMessageType((message).type())); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_NO_BODY(message) \
+ do { \
+ if ((message).body_length() != 0) { \
+ return Status::IOError("Unexpected body in IPC message of type ", \
+ ::arrow::ipc::FormatMessageType((message).type())); \
+ } \
+ } while (0)
+struct IpcReadContext {
+ IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap,
+ MetadataVersion version = MetadataVersion::V5,
+ Compression::type kind = Compression::UNCOMPRESSED)
+ : dictionary_memo(memo),
+ options(option),
+ metadata_version(version),
+ compression(kind),
+ swap_endian(swap) {}
+
+ DictionaryMemo* dictionary_memo;
+
+ const IpcReadOptions& options;
+
+ MetadataVersion metadata_version;
+
+ Compression::type compression;
+
+ const bool swap_endian;
+};
+
+
+
+Result<std::shared_ptr<Buffer>> DecompressBuffer(const std::shared_ptr<Buffer>& buf,
+ const IpcReadOptions& options,
+ util::Codec* codec) {
+ if (buf == nullptr || buf->size() == 0) {
+ return buf;
+ }
+
+ if (buf->size() < 8) {
+ return Status::Invalid(
+ "Likely corrupted message, compressed buffers "
+ "are larger than 8 bytes by construction");
+ }
+
+ const uint8_t* data = buf->data();
+ int64_t compressed_size = buf->size() - sizeof(int64_t);
+ int64_t uncompressed_size = BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data));
+
+ ARROW_ASSIGN_OR_RAISE(auto uncompressed,
+ AllocateBuffer(uncompressed_size, options.memory_pool));
+
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t actual_decompressed,
+ codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size,
+ uncompressed->mutable_data()));
+ if (actual_decompressed != uncompressed_size) {
+ return Status::Invalid("Failed to fully decompress buffer, expected ",
+ uncompressed_size, " bytes but decompressed ",
+ actual_decompressed);
+ }
+
+ return std::move(uncompressed);
+}
+
+Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options,
+ ArrayDataVector* fields) {
+ struct BufferAccumulator {
+ using BufferPtrVector = std::vector<std::shared_ptr<Buffer>*>;
+
+ void AppendFrom(const ArrayDataVector& fields) {
+ for (const auto& field : fields) {
+ for (auto& buffer : field->buffers) {
+ buffers_.push_back(&buffer);
+ }
+ AppendFrom(field->child_data);
+ }
+ }
+
+ BufferPtrVector Get(const ArrayDataVector& fields) && {
+ AppendFrom(fields);
+ return std::move(buffers_);
+ }
+
+ BufferPtrVector buffers_;
+ };
+
+ auto buffers = BufferAccumulator{}.Get(*fields);
+
+ std::unique_ptr<util::Codec> codec;
+ ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression));
+
+ return ::arrow::internal::OptionalParallelFor(
+ options.use_threads, static_cast<int>(buffers.size()), [&](int i) {
+ ARROW_ASSIGN_OR_RAISE(*buffers[i],
+ DecompressBuffer(*buffers[i], options, codec.get()));
+ return Status::OK();
+ });
+}
+class ArrayLoader {
+ public:
+ explicit ArrayLoader(const flatbuf::RecordBatch* metadata,
+ MetadataVersion metadata_version, const IpcReadOptions& options,
+ io::RandomAccessFile* file)
+ : metadata_(metadata),
+ metadata_version_(metadata_version),
+ file_(file),
+ max_recursion_depth_(options.max_recursion_depth) {}
+
+ Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) {
+ if (skip_io_) {
+ return Status::OK();
+ }
+ if (offset < 0) {
+ return Status::Invalid("Negative offset for reading buffer ", buffer_index_);
+ }
+ if (length < 0) {
+ return Status::Invalid("Negative length for reading buffer ", buffer_index_);
+ }
+ if (!BitUtil::IsMultipleOf8(offset)) {
+ return Status::Invalid("Buffer ", buffer_index_,
+ " did not start on 8-byte aligned offset: ", offset);
+ }
+ return file_->ReadAt(offset, length).Value(out);
+ }
+
+ Status LoadType(const DataType& type) { return VisitTypeInline(type, this); }
+
+ Status Load(const Field* field, ArrayData* out) {
+ if (max_recursion_depth_ <= 0) {
+ return Status::Invalid("Max recursion depth reached");
+ }
+
+ field_ = field;
+ out_ = out;
+ out_->type = field_->type();
+ return LoadType(*field_->type());
+ }
+
+ Status SkipField(const Field* field) {
+ ArrayData dummy;
+ skip_io_ = true;
+ Status status = Load(field, &dummy);
+ skip_io_ = false;
+ return status;
+ }
+
+ Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
+ auto buffers = metadata_->buffers();
+ CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers");
+ if (buffer_index >= static_cast<int>(buffers->size())) {
+ return Status::IOError("buffer_index out of range.");
+ }
+ const flatbuf::Buffer* buffer = buffers->Get(buffer_index);
+ if (buffer->length() == 0) {
+ // Should never return a null buffer here.
+ // (zero-sized buffer allocations are cheap)
+ return AllocateBuffer(0).Value(out);
+ } else {
+ return ReadBuffer(buffer->offset(), buffer->length(), out);
+ }
+ }
+
+ Status GetFieldMetadata(int field_index, ArrayData* out) {
+ auto nodes = metadata_->nodes();
+ CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes");
+ if (field_index >= static_cast<int>(nodes->size())) {
+ return Status::Invalid("Ran out of field metadata, likely malformed");
+ }
+ const flatbuf::FieldNode* node = nodes->Get(field_index);
+
+ out->length = node->length();
+ out->null_count = node->null_count();
+ out->offset = 0;
+ return Status::OK();
+ }
+
+ Status LoadCommon(Type::type type_id) {
+ RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_));
+
+ if (internal::HasValidityBitmap(type_id, metadata_version_)) {
+ // Extract null_bitmap which is common to all arrays except for unions
+ // and nulls.
+ if (out_->null_count != 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0]));
+ }
+ buffer_index_++;
+ }
+ return Status::OK();
+ }
+
+ template <typename TYPE>
+ Status LoadPrimitive(Type::type type_id) {
+ out_->buffers.resize(2);
+
+ RETURN_NOT_OK(LoadCommon(type_id));
+ if (out_->length > 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+ } else {
+ buffer_index_++;
+ out_->buffers[1].reset(new Buffer(nullptr, 0));
+ }
+ return Status::OK();
+ }
+
+ template <typename TYPE>
+ Status LoadBinary(Type::type type_id) {
+ out_->buffers.resize(3);
+
+ RETURN_NOT_OK(LoadCommon(type_id));
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+ return GetBuffer(buffer_index_++, &out_->buffers[2]);
+ }
+
+ template <typename TYPE>
+ Status LoadList(const TYPE& type) {
+ out_->buffers.resize(2);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+
+ const int num_children = type.num_fields();
+ if (num_children != 1) {
+ return Status::Invalid("Wrong number of children: ", num_children);
+ }
+
+ return LoadChildren(type.fields());
+ }
+
+ Status LoadChildren(const std::vector<std::shared_ptr<Field>>& child_fields) {
+ ArrayData* parent = out_;
+
+ parent->child_data.resize(child_fields.size());
+ for (int i = 0; i < static_cast<int>(child_fields.size()); ++i) {
+ parent->child_data[i] = std::make_shared<ArrayData>();
+ --max_recursion_depth_;
+ RETURN_NOT_OK(Load(child_fields[i].get(), parent->child_data[i].get()));
+ ++max_recursion_depth_;
+ }
+ out_ = parent;
+ return Status::OK();
+ }
+
+ Status Visit(const NullType&) {
+ out_->buffers.resize(1);
+
+ return GetFieldMetadata(field_index_++, out_);
+ }
+
+ template <typename T>
+ enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
+ !std::is_base_of<FixedSizeBinaryType, T>::value &&
+ !std::is_base_of<DictionaryType, T>::value,
+ Status>
+ Visit(const T& type) {
+ return LoadPrimitive<T>(type.id());
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T& type) {
+ return LoadBinary<T>(type.id());
+ }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ out_->buffers.resize(2);
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ return GetBuffer(buffer_index_++, &out_->buffers[1]);
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T& type) {
+ return LoadList(type);
+ }
+
+ Status Visit(const MapType& type) {
+ RETURN_NOT_OK(LoadList(type));
+ return MapArray::ValidateChildData(out_->child_data);
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ out_->buffers.resize(1);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+
+ const int num_children = type.num_fields();
+ if (num_children != 1) {
+ return Status::Invalid("Wrong number of children: ", num_children);
+ }
+
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const StructType& type) {
+ out_->buffers.resize(1);
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const UnionType& type) {
+ int n_buffers = type.mode() == UnionMode::SPARSE ? 2 : 3;
+ out_->buffers.resize(n_buffers);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+
+ if (out_->null_count != 0 && out_->buffers[0] != nullptr) {
+ return Status::Invalid(
+ "Cannot read pre-1.0.0 Union array with top-level validity bitmap");
+ }
+ out_->buffers[0] = nullptr;
+ out_->null_count = 0;
+
+ if (out_->length > 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1]));
+ if (type.mode() == UnionMode::DENSE) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2]));
+ }
+ }
+ buffer_index_ += n_buffers - 1;
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const DictionaryType& type) {
+ // out_->dictionary will be filled later in ResolveDictionaries()
+ return LoadType(*type.index_type());
+ }
+
+ Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); }
+
+ private:
+ const flatbuf::RecordBatch* metadata_;
+ const MetadataVersion metadata_version_;
+ io::RandomAccessFile* file_;
+ int max_recursion_depth_;
+ int buffer_index_ = 0;
+ int field_index_ = 0;
+ bool skip_io_ = false;
+
+ const Field* field_;
+ ArrayData* out_;
+};
+
+Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
+ const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>* inclusion_mask, const IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ ArrayLoader loader(metadata, context.metadata_version, context.options, file);
+
+ ArrayDataVector columns(schema->num_fields());
+ ArrayDataVector filtered_columns;
+ FieldVector filtered_fields;
+ std::shared_ptr<Schema> filtered_schema;
+
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const Field& field = *schema->field(i);
+ if (!inclusion_mask || (*inclusion_mask)[i]) {
+ // Read field
+ auto column = std::make_shared<ArrayData>();
+ RETURN_NOT_OK(loader.Load(&field, column.get()));
+ if (metadata->length() != column->length) {
+ return Status::IOError("Array length did not match record batch length");
+ }
+ columns[i] = std::move(column);
+ if (inclusion_mask) {
+ filtered_columns.push_back(columns[i]);
+ filtered_fields.push_back(schema->field(i));
+ }
+ } else {
+ RETURN_NOT_OK(loader.SkipField(&field));
+ }
+ }
+
+ RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
+ context.options.memory_pool));
+
+ if (inclusion_mask) {
+ filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
+ columns.clear();
+ } else {
+ filtered_schema = schema;
+ filtered_columns = std::move(columns);
+ }
+ if (context.compression != Compression::UNCOMPRESSED) {
+ RETURN_NOT_OK(
+ DecompressBuffers(context.compression, context.options, &filtered_columns));
+ }
+
+ if (context.swap_endian) {
+ for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(filtered_columns[i],
+ arrow::internal::SwapEndianArrayData(filtered_columns[i]));
+ }
+ }
+ return RecordBatch::Make(std::move(filtered_schema), metadata->length(),
+ std::move(filtered_columns));
+}
+
+Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
+ const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>& inclusion_mask, const IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ if (inclusion_mask.size() > 0) {
+ return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file);
+ } else {
+ return LoadRecordBatchSubset(metadata, schema, /*param_name=*/nullptr, context, file);
+ }
+}
+
+Status GetCompression(const flatbuf::RecordBatch* batch, Compression::type* out) {
+ *out = Compression::UNCOMPRESSED;
+ const flatbuf::BodyCompression* compression = batch->compression();
+ if (compression != nullptr) {
+ if (compression->method() != flatbuf::BodyCompressionMethod::BUFFER) {
+ return Status::Invalid("This library only supports BUFFER compression method");
+ }
+
+ if (compression->codec() == flatbuf::CompressionType::LZ4_FRAME) {
+ *out = Compression::LZ4_FRAME;
+ } else if (compression->codec() == flatbuf::CompressionType::ZSTD) {
+ *out = Compression::ZSTD;
+ } else {
+ return Status::Invalid("Unsupported codec in RecordBatch::compression metadata");
+ }
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status GetCompressionExperimental(const flatbuf::Message* message,
+ Compression::type* out) {
+ *out = Compression::UNCOMPRESSED;
+ if (message->custom_metadata() != nullptr) {
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata));
+ int index = metadata->FindKey("ARROW:experimental_compression");
+ if (index != -1) {
+ auto name = arrow::internal::AsciiToLower(metadata->value(index));
+ ARROW_ASSIGN_OR_RAISE(*out, util::Codec::GetCompressionType(name));
+ }
+ return internal::CheckCompressionSupported(*out);
+ }
+ return Status::OK();
+}
+
+static Status ReadContiguousPayload(io::InputStream* file,
+ std::unique_ptr<Message>* message) {
+ ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file));
+ if (*message == nullptr) {
+ return Status::Invalid("Unable to read metadata at offset");
+ }
+ return Status::OK();
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
+ const IpcReadOptions& options, io::InputStream* file) {
+ std::unique_ptr<Message> message;
+ RETURN_NOT_OK(ReadContiguousPayload(file, &message));
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options,
+ reader.get());
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Message& message, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) {
+ CHECK_MESSAGE_TYPE(MessageType::RECORD_BATCH, message.type());
+ CHECK_HAS_BODY(message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
+ return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options,
+ reader.get());
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
+ const Buffer& metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>& inclusion_mask, IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ auto batch = message->header_as_RecordBatch();
+ if (batch == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not RecordBatch.");
+ }
+
+ Compression::type compression;
+ RETURN_NOT_OK(GetCompression(batch, &compression));
+ if (context.compression == Compression::UNCOMPRESSED &&
+ message->version() == flatbuf::MetadataVersion::V4) {
+ RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
+ }
+ context.compression = compression;
+ context.metadata_version = internal::GetMetadataVersion(message->version());
+ return LoadRecordBatch(batch, schema, inclusion_mask, context, file);
+}
+
+Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema,
+ const std::vector<int>& included_indices,
+ std::vector<bool>* inclusion_mask,
+ std::shared_ptr<Schema>* out_schema) {
+ inclusion_mask->clear();
+ if (included_indices.empty()) {
+ *out_schema = full_schema;
+ return Status::OK();
+ }
+
+ inclusion_mask->resize(full_schema->num_fields(), false);
+
+ auto included_indices_sorted = included_indices;
+ std::sort(included_indices_sorted.begin(), included_indices_sorted.end());
+
+ FieldVector included_fields;
+ for (int i : included_indices_sorted) {
+ if (i < 0 || i >= full_schema->num_fields()) {
+ return Status::Invalid("Out of bounds field index: ", i);
+ }
+
+ if (inclusion_mask->at(i)) continue;
+
+ inclusion_mask->at(i) = true;
+ included_fields.push_back(full_schema->field(i));
+ }
+
+ *out_schema = schema(std::move(included_fields), full_schema->endianness(),
+ full_schema->metadata());
+ return Status::OK();
+}
+
+Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options,
+ DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* schema,
+ std::shared_ptr<Schema>* out_schema,
+ std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
+ RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
+
+ RETURN_NOT_OK(GetInclusionMaskAndOutSchema(*schema, options.included_fields,
+ field_inclusion_mask, out_schema));
+ *swap_endian = options.ensure_native_endian && !out_schema->get()->is_native_endian();
+ if (*swap_endian) {
+ *schema = schema->get()->WithEndianness(Endianness::Native);
+ *out_schema = out_schema->get()->WithEndianness(Endianness::Native);
+ }
+ return Status::OK();
+}
+
+Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options,
+ DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* schema,
+ std::shared_ptr<Schema>* out_schema,
+ std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
+ CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message.type());
+ CHECK_HAS_NO_BODY(message);
+
+ return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema,
+ out_schema, field_inclusion_mask, swap_endian);
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Buffer& metadata, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
+ io::RandomAccessFile* file) {
+ std::shared_ptr<Schema> out_schema;
+ // Empty means do not use
+ std::vector<bool> inclusion_mask;
+ IpcReadContext context(const_cast<DictionaryMemo*>(dictionary_memo), options, false);
+ RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
+ &inclusion_mask, &out_schema));
+ return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file);
+}
+
+Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
+ DictionaryKind* kind, io::RandomAccessFile* file) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ const auto dictionary_batch = message->header_as_DictionaryBatch();
+ if (dictionary_batch == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not DictionaryBatch.");
+ }
+
+ // The dictionary is embedded in a record batch with a single column
+ const auto batch_meta = dictionary_batch->data();
+
+ CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
+
+ Compression::type compression;
+ RETURN_NOT_OK(GetCompression(batch_meta, &compression));
+ if (compression == Compression::UNCOMPRESSED &&
+ message->version() == flatbuf::MetadataVersion::V4) {
+ RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
+ }
+
+ const int64_t id = dictionary_batch->id();
+
+ ARROW_ASSIGN_OR_RAISE(auto value_type, context.dictionary_memo->GetDictionaryType(id));
+
+ ArrayLoader loader(batch_meta, internal::GetMetadataVersion(message->version()),
+ context.options, file);
+ auto dict_data = std::make_shared<ArrayData>();
+ const Field dummy_field("", value_type);
+ RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get()));
+
+ if (compression != Compression::UNCOMPRESSED) {
+ ArrayDataVector dict_fields{dict_data};
+ RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields));
+ }
+
+ if (context.swap_endian) {
+ ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(dict_data));
+ }
+
+ if (dictionary_batch->isDelta()) {
+ if (kind != nullptr) {
+ *kind = DictionaryKind::Delta;
+ }
+ return context.dictionary_memo->AddDictionaryDelta(id, dict_data);
+ }
+ ARROW_ASSIGN_OR_RAISE(bool inserted,
+ context.dictionary_memo->AddOrReplaceDictionary(id, dict_data));
+ if (kind != nullptr) {
+ *kind = inserted ? DictionaryKind::New : DictionaryKind::Replacement;
+ }
+ return Status::OK();
+}
+
+Status ReadDictionary(const Message& message, const IpcReadContext& context,
+ DictionaryKind* kind) {
+ DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH);
+ CHECK_HAS_BODY(message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
+ return ReadDictionary(*message.metadata(), context, kind, reader.get());
+}
+
+
+class StreamDecoder2::StreamDecoder2Impl : public MessageDecoderListener {
+ private:
+ enum State {
+ SCHEMA,
+ INITIAL_DICTIONARIES,
+ RECORD_BATCHES,
+ EOS,
+ };
+
+ public:
+ explicit StreamDecoder2Impl(std::shared_ptr<Listener> listener, IpcReadOptions options)
+ : listener_(std::move(listener)),
+ options_(std::move(options)),
+ state_(State::SCHEMA),
+ message_decoder_(std::shared_ptr<StreamDecoder2Impl>(this, [](void*) {}),
+ options_.memory_pool),
+ n_required_dictionaries_(0),
+ dictionary_memo_(std::make_unique<DictionaryMemo>()) {}
+
+ void Reset() {
+ state_ = State::SCHEMA;
+ field_inclusion_mask_.clear();
+ n_required_dictionaries_ = 0;
+ dictionary_memo_ = std::make_unique<DictionaryMemo>();
+ schema_ = out_schema_ = nullptr;
+ message_decoder_.Reset();
+ }
+
+ Status OnMessageDecoded(std::unique_ptr<Message> message) override {
+ ++stats_.num_messages;
+ switch (state_) {
+ case State::SCHEMA:
+ ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message)));
+ break;
+ case State::INITIAL_DICTIONARIES:
+ ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message)));
+ break;
+ case State::RECORD_BATCHES:
+ ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message)));
+ break;
+ case State::EOS:
+ break;
+ }
+ return Status::OK();
+ }
+
+ Status OnEOS() override {
+ state_ = State::EOS;
+ return listener_->OnEOS();
+ }
+ Status Consume(const uint8_t* data, int64_t size) {
+ return message_decoder_.Consume(data, size);
+ }
+
+ Status Consume(std::shared_ptr<Buffer> buffer) {
+ return message_decoder_.Consume(std::move(buffer));
+ }
+
+ std::shared_ptr<Schema> schema() const { return out_schema_; }
+
+ int64_t next_required_size() const { return message_decoder_.next_required_size(); }
+
+ ReadStats stats() const { return stats_; }
+
+ private:
+ Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) {
+ RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, dictionary_memo_.get(), &schema_,
+ &out_schema_, &field_inclusion_mask_,
+ &swap_endian_));
+
+ n_required_dictionaries_ = dictionary_memo_->fields().num_fields();
+ if (n_required_dictionaries_ == 0) {
+ state_ = State::RECORD_BATCHES;
+ RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
+ } else {
+ state_ = State::INITIAL_DICTIONARIES;
+ }
+ return Status::OK();
+ }
+
+ Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) {
+ if (message->type() != MessageType::DICTIONARY_BATCH) {
+ return Status::Invalid("IPC stream did not have the expected number (",
+ dictionary_memo_->fields().num_fields(),
+ ") of dictionaries at the start of the stream");
+ }
+ RETURN_NOT_OK(ReadDictionary(*message));
+ n_required_dictionaries_--;
+ if (n_required_dictionaries_ == 0) {
+ state_ = State::RECORD_BATCHES;
+ ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
+ }
+ return Status::OK();
+ }
+
+ Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
+ IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_);
+ if (message->type() == MessageType::DICTIONARY_BATCH) {
+ return ReadDictionary(*message);
+ } else {
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch,
+ ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
+ context, reader.get()));
+ ++stats_.num_record_batches;
+ return listener_->OnRecordBatchDecoded(std::move(batch));
+ }
+ }
+
+ Status ReadDictionary(const Message& message) {
+ DictionaryKind kind;
+ IpcReadContext context(dictionary_memo_.get(), options_, swap_endian_);
+ RETURN_NOT_OK(::arrow::ipc::NDqs::ReadDictionary(message, context, &kind));
+ ++stats_.num_dictionary_batches;
+ switch (kind) {
+ case DictionaryKind::New:
+ break;
+ case DictionaryKind::Delta:
+ ++stats_.num_dictionary_deltas;
+ break;
+ case DictionaryKind::Replacement:
+ ++stats_.num_replaced_dictionaries;
+ break;
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<Listener> listener_;
+ const IpcReadOptions options_;
+ State state_;
+ MessageDecoder2 message_decoder_;
+ std::vector<bool> field_inclusion_mask_;
+ int n_required_dictionaries_;
+ std::unique_ptr<DictionaryMemo> dictionary_memo_;
+ std::shared_ptr<Schema> schema_, out_schema_;
+ ReadStats stats_;
+ bool swap_endian_;
+};
+
+Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
+ DictionaryMemo* dictionary_memo) {
+ std::unique_ptr<MessageReader> reader = MessageReader::Open(stream);
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage());
+ if (!message) {
+ return Status::Invalid("Tried reading schema message, was null or length 0");
+ }
+ CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message->type());
+ return ReadSchema(*message, dictionary_memo);
+}
+
+Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
+ DictionaryMemo* dictionary_memo) {
+ std::shared_ptr<Schema> result;
+ RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result));
+ return result;
+}
+
+Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) {
+ std::unique_ptr<Message> message;
+ RETURN_NOT_OK(ReadContiguousPayload(file, &message));
+ return ReadTensor(*message);
+}
+
+Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) {
+ std::shared_ptr<DataType> type;
+ std::vector<int64_t> shape;
+ std::vector<int64_t> strides;
+ std::vector<std::string> dim_names;
+ CHECK_HAS_BODY(message);
+ RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
+ &dim_names));
+ return Tensor::Make(type, message.body(), shape, strides, dim_names);
+}
+
+
+StreamDecoder2::StreamDecoder2(std::shared_ptr<Listener> listener, IpcReadOptions options) {
+ impl_.reset(new StreamDecoder2::StreamDecoder2Impl(std::move(listener), options));
+}
+
+StreamDecoder2::~StreamDecoder2() {}
+
+Status StreamDecoder2::Consume(const uint8_t* data, int64_t size) {
+ return impl_->Consume(data, size);
+}
+
+void StreamDecoder2::Reset() {
+ impl_->Reset();
+}
+
+Status StreamDecoder2::Consume(std::shared_ptr<Buffer> buffer) {
+ return impl_->Consume(std::move(buffer));
+}
+
+std::shared_ptr<Schema> StreamDecoder2::schema() const { return impl_->schema(); }
+
+int64_t StreamDecoder2::next_required_size() const { return impl_->next_required_size(); }
+
+ReadStats StreamDecoder2::stats() const { return impl_->stats(); }
+
+class InputStreamMessageReader : public MessageReader, public MessageDecoderListener {
+ public:
+ explicit InputStreamMessageReader(io::InputStream* stream)
+ : stream_(stream),
+ owned_stream_(),
+ message_(),
+ decoder_(std::shared_ptr<InputStreamMessageReader>(this, [](void*) {})) {}
+
+ explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream)
+ : InputStreamMessageReader(owned_stream.get()) {
+ owned_stream_ = owned_stream;
+ }
+
+ ~InputStreamMessageReader() {}
+
+ Status OnMessageDecoded(std::unique_ptr<Message> message) override {
+ message_ = std::move(message);
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<Message>> ReadNextMessage() override {
+ ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_));
+ return std::move(message_);
+ }
+
+ private:
+ io::InputStream* stream_;
+ std::shared_ptr<io::InputStream> owned_stream_;
+ std::unique_ptr<Message> message_;
+ MessageDecoder decoder_;
+};
+
+
+std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) {
+ return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream));
+}
+
+std::unique_ptr<MessageReader> MessageReader::Open(
+ const std::shared_ptr<io::InputStream>& owned_stream) {
+ return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream));
+}
+
+}
+
+} // namespace ipc::NDqs
+} // namespace arrow
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h
new file mode 100644
index 00000000000..f5ba1be7fda
--- /dev/null
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/stream_decoder.h
@@ -0,0 +1,85 @@
+// almost copy of reader.cc + message.cc without comments
+// TODO(): Remove when .Reset() will be added in contrib version
+#pragma once
+
+#include <arrow/ipc/reader.h>
+#include <arrow/ipc/message.h>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <arrow/io/caching.h>
+#include <arrow/io/type_fwd.h>
+#include <arrow/ipc/message.h>
+#include <arrow/ipc/options.h>
+#include <arrow/record_batch.h>
+#include <arrow/result.h>
+#include <arrow/type_fwd.h>
+#include <arrow/util/async_generator.h>
+#include <arrow/util/macros.h>
+#include <arrow/util/visibility.h>
+
+namespace arrow {
+namespace ipc::NDqs {
+
+class ARROW_EXPORT MessageDecoder2 {
+ public:
+ enum State {
+ INITIAL,
+ METADATA_LENGTH,
+ METADATA,
+ BODY,
+ EOS,
+ };
+
+ explicit MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener,
+ MemoryPool* pool = default_memory_pool());
+
+ MessageDecoder2(std::shared_ptr<MessageDecoderListener> listener, State initial_state,
+ int64_t initial_next_required_size,
+ MemoryPool* pool = default_memory_pool());
+
+ virtual ~MessageDecoder2();
+ Status Consume(const uint8_t* data, int64_t size);
+ Status Consume(std::shared_ptr<Buffer> buffer);
+ int64_t next_required_size() const;
+ State state() const;
+ void Reset();
+
+ private:
+ class MessageDecoderImpl;
+ std::unique_ptr<MessageDecoderImpl> impl_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(MessageDecoder2);
+};
+
+class ARROW_EXPORT MessageReader {
+ public:
+ virtual ~MessageReader() = default;
+ static std::unique_ptr<MessageReader> Open(io::InputStream* stream);
+ static std::unique_ptr<MessageReader> Open(
+ const std::shared_ptr<io::InputStream>& owned_stream);
+ virtual Result<std::unique_ptr<Message>> ReadNextMessage() = 0;
+};
+
+class ARROW_EXPORT StreamDecoder2 {
+ public:
+ StreamDecoder2(std::shared_ptr<Listener> listener,
+ IpcReadOptions options = IpcReadOptions::Defaults());
+
+ virtual ~StreamDecoder2();
+ Status Consume(const uint8_t* data, int64_t size);
+ Status Consume(std::shared_ptr<Buffer> buffer);
+ std::shared_ptr<Schema> schema() const;
+ int64_t next_required_size() const;
+ ReadStats stats() const;
+ void Reset();
+
+ private:
+ class StreamDecoder2Impl;
+ std::unique_ptr<StreamDecoder2Impl> impl_;
+};
+} // namespace ipc::NDqs
+} // namespace arrow
diff --git a/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make b/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make
index 6f96e18fe33..48ae5d6ea10 100644
--- a/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make
+++ b/ydb/library/yql/providers/yt/comp_nodes/dq/ya.make
@@ -5,10 +5,18 @@ PEERDIR(
ydb/library/yql/providers/yt/comp_nodes
ydb/library/yql/providers/yt/codec
ydb/library/yql/providers/common/codec
+ ydb/core/formats/arrow
yt/cpp/mapreduce/interface
yt/cpp/mapreduce/common
library/cpp/yson/node
yt/yt/core
+ ydb/library/yql/public/udf/arrow
+ contrib/libs/apache/arrow
+ contrib/libs/flatbuffers
+)
+
+ADDINCL(
+ contrib/libs/flatbuffers/include
)
IF(LINUX)
@@ -19,7 +27,13 @@ IF(LINUX)
)
SRCS(
+ stream_decoder.cpp
dq_yt_rpc_reader.cpp
+ dq_yt_rpc_helpers.cpp
+ dq_yt_block_reader.cpp
+ )
+ CFLAGS(
+ -Wno-unused-parameter
)
ENDIF()
@@ -29,7 +43,6 @@ SRCS(
dq_yt_writer.cpp
)
-
YQL_LAST_ABI_VERSION()
diff --git a/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp b/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp
index 41d652354cf..a29a5878630 100644
--- a/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp
+++ b/ydb/library/yql/providers/yt/mkql_dq/yql_yt_dq_transform.cpp
@@ -31,7 +31,7 @@ public:
}
NMiniKQL::TCallableVisitFunc operator()(NMiniKQL::TInternName name) {
- if (TaskParams.contains("yt") && name == "DqYtRead") {
+ if (TaskParams.contains("yt") && (name == "DqYtRead" || name == "DqYtBlockRead")) {
return [this](NMiniKQL::TCallable& callable, const NMiniKQL::TTypeEnvironment& env) {
using namespace NMiniKQL;
diff --git a/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp b/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp
index 8f54076f94b..c83a855c270 100644
--- a/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp
+++ b/ydb/library/yql/providers/yt/provider/yql_yt_dq_integration.cpp
@@ -344,6 +344,33 @@ public:
return false;
}
+ bool CanBlockRead(const NNodes::TExprBase& node, TExprContext&, TTypeAnnotationContext&) override {
+ auto wrap = node.Cast<TDqReadWideWrap>();
+ auto maybeRead = wrap.Input().Maybe<TYtReadTable>();
+ if (!maybeRead) {
+ return false;
+ }
+
+
+ if (!State_->Configuration->UseRPCReaderInDQ.Get(maybeRead.Cast().DataSource().Cluster().StringValue()).GetOrElse(DEFAULT_USE_RPC_READER_IN_DQ)) {
+ return false;
+ }
+
+ const auto structType = GetSeqItemType(maybeRead.Raw()->GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back())->Cast<TStructExprType>();
+ if (!CanBlockReadTypes(structType)) {
+ return false;
+ }
+
+ const TYtSectionList& sectionList = wrap.Input().Cast<TYtReadTable>().Input();
+ for (size_t i = 0; i < sectionList.Size(); ++i) {
+ auto section = sectionList.Item(i);
+ if (!NYql::GetSettingAsColumnList(section.Settings().Ref(), EYtSettingType::SysColumns).empty()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
TMaybe<TOptimizerStatistics> ReadStatistics(const TExprNode::TPtr& read, TExprContext& ctx) override {
Y_UNUSED(ctx);
TOptimizerStatistics stat(0, 0);
diff --git a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp
index 3a5c0733650..e4d97ea43be 100644
--- a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp
+++ b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp
@@ -288,7 +288,8 @@ TRuntimeNode BuildDqYtInputCall(
const TYtState::TPtr& state,
NCommon::TMkqlBuildContext& ctx,
size_t inflight,
- size_t timeout)
+ size_t timeout,
+ bool enableBlockReader)
{
NYT::TNode specNode = NYT::TNode::CreateMap();
NYT::TNode& tablesNode = specNode[YqlIOSpecTables];
@@ -363,12 +364,12 @@ TRuntimeNode BuildDqYtInputCall(
auto res = uniqSpecs.emplace(NYT::NodeToCanonicalYsonString(specNode), refName);
if (res.second) {
registryNode[refName] = specNode;
- }
- else {
+ } else {
refName = res.first->second;
}
tablesNode.Add(refName);
- auto skiffNode = SingleTableSpecToInputSkiff(specNode, structColumns, true, true, false);
+ // TODO() Enable range indexes
+ auto skiffNode = SingleTableSpecToInputSkiff(specNode, structColumns, true, !enableBlockReader, false);
const auto tmpFolder = GetTablesTmpFolder(*state->Configuration);
auto tableName = pathInfo.Table->Name;
if (pathInfo.Table->IsAnonymous && !TYtTableInfo::HasSubstAnonymousLabel(pathInfo.Table->FromNode.Cast())) {
@@ -402,7 +403,7 @@ TRuntimeNode BuildDqYtInputCall(
auto server = state->Gateway->GetClusterServer(clusterName);
YQL_ENSURE(server, "Invalid YT cluster: " << clusterName);
- TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), "DqYtRead", outputType);
+ TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), enableBlockReader ? "DqYtBlockRead" : "DqYtRead", outputType);
call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(server));
call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(tokenName));
@@ -475,6 +476,35 @@ void RegisterYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) {
void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, const TYtState::TPtr& state) {
+ compiler.ChainCallable(TDqReadBlockWideWrap::CallableName(),
+ [state](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) {
+ if (const auto& wrapper = TDqReadBlockWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) {
+ const auto ytRead = wrapper.Input().Cast<TYtReadTable>();
+ const auto readType = ytRead.Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back();
+ const auto inputItemType = NCommon::BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType), ctx.ProgramBuilder);
+ const auto cluster = ytRead.DataSource().Cluster().StringValue();
+ size_t inflight = state->Configuration->UseRPCReaderInDQ.Get(cluster).GetOrElse(DEFAULT_USE_RPC_READER_IN_DQ) ? state->Configuration->DQRPCReaderInflight.Get(cluster).GetOrElse(DEFAULT_RPC_READER_INFLIGHT) : 0;
+ size_t timeout = state->Configuration->DQRPCReaderTimeout.Get(cluster).GetOrElse(DEFAULT_RPC_READER_TIMEOUT).MilliSeconds();
+ const auto outputType = NCommon::BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn(), ctx.ProgramBuilder);
+ TString tokenName;
+ if (auto secureParams = wrapper.Token()) {
+ tokenName = secureParams.Cast().Name().StringValue();
+ }
+
+ bool solid = false;
+ for (const auto& flag : wrapper.Flags())
+ if (solid = flag.Value() == "Solid")
+ break;
+
+ if (solid)
+ return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
+ else
+ return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
+ }
+
+ return TRuntimeNode();
+ });
+
compiler.ChainCallable(TDqReadWideWrap::CallableName(),
[state](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) {
if (const auto& wrapper = TDqReadWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) {
@@ -496,9 +526,9 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con
break;
if (solid)
- return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout);
+ return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout, false);
else
- return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout);
+ return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, isRPC, timeout, false);
}
return TRuntimeNode();