aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraozeritsky <aozeritsky@ydb.tech>2023-11-22 17:23:26 +0300
committeraozeritsky <aozeritsky@ydb.tech>2023-11-22 20:38:10 +0300
commit39f63fbcf0250f7db43e8c6b440ff2acae9d56cc (patch)
tree36f03ea845666bcc35d27f6291fea7cb43e2e6b8
parent2a8db7ea366cd6757bd39bebbda36ef00cc19844 (diff)
downloadydb-39f63fbcf0250f7db43e8c6b440ff2acae9d56cc.tar.gz
Fix race: create map keys in constructor
-rw-r--r--ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp4
-rw-r--r--ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp72
-rw-r--r--ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h2
-rw-r--r--ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp6
4 files changed, 45 insertions, 39 deletions
diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp b/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp
index 1188231a1b..bb1bb560eb 100644
--- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp
+++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp
@@ -141,8 +141,8 @@ public:
return new TLocalOutputChannel(Runner->GetOutputChannel(channelId), Task.GetId(), Task.GetStageId(), &QueryStat);
}
- IDqAsyncInputBuffer::TPtr GetSource(ui64 index) override {
- return Runner->GetSource(index);
+ IDqAsyncInputBuffer* GetSource(ui64 index) override {
+ return Runner->GetSource(index).Get();
}
IDqAsyncOutputBuffer::TPtr GetSink(ui64 index) override {
diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp
index 35f36fc55a..2df5b962bd 100644
--- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp
+++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp
@@ -538,16 +538,17 @@ public:
, ChannelId(channelId)
, Input(input)
, Output(output)
- , ProtocolVersion(taskRunner->GetProtocolVersion())
+ , TaskRunner(taskRunner)
, FreeSpace(channelBufferSize)
{ }
i64 GetFreeSpace() override {
- if (ProtocolVersion <= 1) {
+ int protocolVersion = TaskRunner->GetProtocolVersion();
+ if (protocolVersion <= 1) {
return std::numeric_limits<i64>::max();
}
- if (ProtocolVersion < 6) {
+ if (protocolVersion < 6) {
NDqProto::TCommandHeader header;
header.SetVersion(2);
header.SetCommand(NDqProto::TCommandHeader::GET_FREE_SPACE);
@@ -584,7 +585,7 @@ public:
written = countingOutput.Counter();
}
- if (ProtocolVersion >= 6) {
+ if (TaskRunner->GetProtocolVersion() >= 6) {
// estimate free space
FreeSpace -= written;
}
@@ -606,7 +607,8 @@ private:
IInputStream& Input;
IOutputStream& Output;
- i32 ProtocolVersion;
+ ITaskRunner::TPtr TaskRunner;
+
i64 FreeSpace;
};
@@ -740,13 +742,12 @@ public:
, InputType(inputType)
, BufferSize(channelBufferSize)
, FreeSpace(channelBufferSize)
- , ProtocolVersion(TaskRunner->GetProtocolVersion())
{
PushStats.InputIndex = inputIndex;
}
i64 GetFreeSpace() const override {
- if (ProtocolVersion < 6) {
+ if (TaskRunner->GetProtocolVersion() < 6) {
NDqProto::TCommandHeader header;
header.SetVersion(4);
header.SetCommand(NDqProto::TCommandHeader::GET_FREE_SPACE_SOURCE);
@@ -767,7 +768,7 @@ public:
}
ui64 GetStoredBytes() const override {
- if (ProtocolVersion < 6) {
+ if (TaskRunner->GetProtocolVersion() < 6) {
NDqProto::TCommandHeader header;
header.SetVersion(4);
header.SetCommand(NDqProto::TCommandHeader::GET_STORED_BYTES_SOURCE);
@@ -813,7 +814,7 @@ public:
SaveRopeToPipe(Output, serialized.Payload);
}
- if (ProtocolVersion >= 6) {
+ if (TaskRunner->GetProtocolVersion() >= 6) {
FreeSpace -= space;
}
}
@@ -884,7 +885,6 @@ private:
TDqInputStats PopStats;
i64 BufferSize;
i64 FreeSpace;
- i32 ProtocolVersion;
};
/*______________________________________________________________________________________________*/
@@ -1284,6 +1284,7 @@ public:
Alloc->Release();
StderrReader->Start();
InitTaskMeta();
+ InitChannels();
}
~TTaskRunner() {
@@ -1354,35 +1355,31 @@ public:
response.Load(&Input);
if (GetProtocolVersion() >= 6) {
for (auto& space : response.GetChannelFreeSpace()) {
- auto* channel = static_cast<TInputChannel*>(GetInputChannel(space.GetId()).Get());
- channel->SetFreeSpace(space.GetSpace());
+ auto channel = InputChannels.find(space.GetId()); YQL_ENSURE(channel != InputChannels.end());
+ channel->second->SetFreeSpace(space.GetSpace());
}
for (auto& space : response.GetSourceFreeSpace()) {
- auto* source = static_cast<TDqSource*>(GetSource(space.GetId()).Get());
- source->SetFreeSpace(space.GetSpace());
+ auto source = Sources.find(space.GetId()); YQL_ENSURE(source != Sources.end());
+ source->second->SetFreeSpace(space.GetSpace());
}
}
return response;
}
IInputChannel::TPtr GetInputChannel(ui64 channelId) override {
- auto& channel = InputChannels[channelId];
- if (channel == nullptr) {
- channel = new TInputChannel(this, Task.GetId(), channelId, Input, Output, ChannelBufferSize);
- }
- return channel;
+ auto channel = InputChannels.find(channelId);
+ YQL_ENSURE(channel != InputChannels.end());
+ return channel->second;
}
IOutputChannel::TPtr GetOutputChannel(ui64 channelId) override {
return new TOutputChannel(Task.GetId(), channelId, Input, Output);
}
- IDqAsyncInputBuffer::TPtr GetSource(ui64 index) override {
- auto& source = Sources[index];
- if (source == nullptr) {
- source = new TDqSource(Task.GetId(), index, InputTypes.at(index), ChannelBufferSize, this);
- }
- return source;
+ IDqAsyncInputBuffer* GetSource(ui64 index) override {
+ auto source = Sources.find(index);
+ YQL_ENSURE(source != Sources.end());
+ return source->second.Get();
}
TDqSink::TPtr GetSink(ui64 index) override {
@@ -1543,14 +1540,28 @@ private:
}
}
+ void InitChannels() {
+ for (ui32 i = 0; i < Task.InputsSize(); ++i) {
+ auto& inputDesc = Task.GetInputs(i);
+ if (inputDesc.HasSource()) {
+ Sources[i] = new TDqSource(Task.GetId(), i, InputTypes.at(i), ChannelBufferSize, this);
+ } else {
+ for (auto& inputChannelDesc : inputDesc.GetChannels()) {
+ ui64 channelId = inputChannelDesc.GetId();
+ InputChannels[channelId] = new TInputChannel(this, Task.GetId(), channelId, Input, Output, ChannelBufferSize);
+ }
+ }
+ }
+ }
+
private:
const TString TraceId;
NDqProto::TDqTask Task;
THashMap<TString, TString> SecureParams;
THashMap<TString, TString> TaskParams;
TVector<TString> ReadRanges;
- THashMap<ui64, IInputChannel::TPtr> InputChannels;
- THashMap<ui64, IDqAsyncInputBuffer::TPtr> Sources;
+ THashMap<ui64, TIntrusivePtr<TInputChannel>> InputChannels;
+ THashMap<ui64, TIntrusivePtr<TDqSource>> Sources;
i64 ChannelBufferSize = 0;
std::shared_ptr <NKikimr::NMiniKQL::TScopedAlloc> Alloc;
@@ -1644,11 +1655,7 @@ public:
}
IDqAsyncInputBuffer::TPtr GetSource(ui64 inputIndex) override {
- auto& source = Sources[inputIndex];
- if (!source) {
- source = static_cast<TDqSource*>(Delegate->GetSource(inputIndex).Get());
- }
- return source;
+ return Delegate->GetSource(inputIndex);
}
IDqOutputChannel::TPtr GetOutputChannel(ui64 channelId) override
@@ -1812,7 +1819,6 @@ private:
mutable TDqMeteringStats MeteringStats;
mutable THashMap<ui64, TIntrusivePtr<TDqInputChannel>> InputChannels;
- THashMap<ui64, TIntrusivePtr<TDqSource>> Sources;
mutable THashMap<ui64, TIntrusivePtr<TDqOutputChannel>> OutputChannels;
THashMap<ui64, TIntrusivePtr<TDqSink>> Sinks;
};
diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h b/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h
index c99c3f197f..011d9e196e 100644
--- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h
+++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h
@@ -54,7 +54,7 @@ public:
virtual IInputChannel::TPtr GetInputChannel(ui64 channelId) = 0;
virtual IOutputChannel::TPtr GetOutputChannel(ui64 channelId) = 0;
- virtual NDq::IDqAsyncInputBuffer::TPtr GetSource(ui64 index) = 0;
+ virtual NDq::IDqAsyncInputBuffer* GetSource(ui64 index) = 0;
virtual NDq::IDqAsyncOutputBuffer::TPtr GetSink(ui64 index) = 0;
virtual const THashMap<TString,TString>& GetTaskParams() const = 0;
diff --git a/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp b/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp
index c16c670f53..6bb207b8a1 100644
--- a/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp
+++ b/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp
@@ -259,14 +259,14 @@ private:
YQL_ENSURE(!batch.IsWide());
- auto source = TaskRunner->GetSource(index);
+ auto* source = TaskRunner->GetSource(index);
TDqDataSerializer dataSerializer(TaskRunner->GetTypeEnv(), TaskRunner->GetHolderFactory(), DataTransportVersion);
TDqSerializedBatch serialized = dataSerializer.Serialize(batch, source->GetInputType());
- Invoker->Invoke([serialized=std::move(serialized),taskRunner=TaskRunner, actorSystem, selfId, cookie, parentId=ParentId, space, finish, index, settings=Settings, stageId=StageId]() mutable {
+ Invoker->Invoke([serialized=std::move(serialized), taskRunner=TaskRunner, actorSystem, selfId, cookie, parentId=ParentId, space, finish, index, settings=Settings, stageId=StageId]() mutable {
try {
// auto guard = taskRunner->BindAllocator(); // only for local mode
- auto source = taskRunner->GetSource(index);
+ auto* source = taskRunner->GetSource(index);
source->Push(std::move(serialized), space);
if (finish) {
source->Finish();