diff options
author | aozeritsky <aozeritsky@ydb.tech> | 2023-11-22 17:23:26 +0300 |
---|---|---|
committer | aozeritsky <aozeritsky@ydb.tech> | 2023-11-22 20:38:10 +0300 |
commit | 39f63fbcf0250f7db43e8c6b440ff2acae9d56cc (patch) | |
tree | 36f03ea845666bcc35d27f6291fea7cb43e2e6b8 | |
parent | 2a8db7ea366cd6757bd39bebbda36ef00cc19844 (diff) | |
download | ydb-39f63fbcf0250f7db43e8c6b440ff2acae9d56cc.tar.gz |
Fix race: create map keys in constructor
4 files changed, 45 insertions, 39 deletions
diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp b/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp index 1188231a1b..bb1bb560eb 100644 --- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp +++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_local.cpp @@ -141,8 +141,8 @@ public: return new TLocalOutputChannel(Runner->GetOutputChannel(channelId), Task.GetId(), Task.GetStageId(), &QueryStat); } - IDqAsyncInputBuffer::TPtr GetSource(ui64 index) override { - return Runner->GetSource(index); + IDqAsyncInputBuffer* GetSource(ui64 index) override { + return Runner->GetSource(index).Get(); } IDqAsyncOutputBuffer::TPtr GetSink(ui64 index) override { diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp index 35f36fc55a..2df5b962bd 100644 --- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp +++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp @@ -538,16 +538,17 @@ public: , ChannelId(channelId) , Input(input) , Output(output) - , ProtocolVersion(taskRunner->GetProtocolVersion()) + , TaskRunner(taskRunner) , FreeSpace(channelBufferSize) { } i64 GetFreeSpace() override { - if (ProtocolVersion <= 1) { + int protocolVersion = TaskRunner->GetProtocolVersion(); + if (protocolVersion <= 1) { return std::numeric_limits<i64>::max(); } - if (ProtocolVersion < 6) { + if (protocolVersion < 6) { NDqProto::TCommandHeader header; header.SetVersion(2); header.SetCommand(NDqProto::TCommandHeader::GET_FREE_SPACE); @@ -584,7 +585,7 @@ public: written = countingOutput.Counter(); } - if (ProtocolVersion >= 6) { + if (TaskRunner->GetProtocolVersion() >= 6) { // estimate free space FreeSpace -= written; } @@ -606,7 +607,8 @@ private: IInputStream& Input; IOutputStream& Output; - i32 ProtocolVersion; + ITaskRunner::TPtr TaskRunner; + i64 FreeSpace; }; @@ -740,13 +742,12 @@ public: , InputType(inputType) , BufferSize(channelBufferSize) , FreeSpace(channelBufferSize) - , ProtocolVersion(TaskRunner->GetProtocolVersion()) { PushStats.InputIndex = inputIndex; } i64 GetFreeSpace() const override { - if (ProtocolVersion < 6) { + if (TaskRunner->GetProtocolVersion() < 6) { NDqProto::TCommandHeader header; header.SetVersion(4); header.SetCommand(NDqProto::TCommandHeader::GET_FREE_SPACE_SOURCE); @@ -767,7 +768,7 @@ public: } ui64 GetStoredBytes() const override { - if (ProtocolVersion < 6) { + if (TaskRunner->GetProtocolVersion() < 6) { NDqProto::TCommandHeader header; header.SetVersion(4); header.SetCommand(NDqProto::TCommandHeader::GET_STORED_BYTES_SOURCE); @@ -813,7 +814,7 @@ public: SaveRopeToPipe(Output, serialized.Payload); } - if (ProtocolVersion >= 6) { + if (TaskRunner->GetProtocolVersion() >= 6) { FreeSpace -= space; } } @@ -884,7 +885,6 @@ private: TDqInputStats PopStats; i64 BufferSize; i64 FreeSpace; - i32 ProtocolVersion; }; /*______________________________________________________________________________________________*/ @@ -1284,6 +1284,7 @@ public: Alloc->Release(); StderrReader->Start(); InitTaskMeta(); + InitChannels(); } ~TTaskRunner() { @@ -1354,35 +1355,31 @@ public: response.Load(&Input); if (GetProtocolVersion() >= 6) { for (auto& space : response.GetChannelFreeSpace()) { - auto* channel = static_cast<TInputChannel*>(GetInputChannel(space.GetId()).Get()); - channel->SetFreeSpace(space.GetSpace()); + auto channel = InputChannels.find(space.GetId()); YQL_ENSURE(channel != InputChannels.end()); + channel->second->SetFreeSpace(space.GetSpace()); } for (auto& space : response.GetSourceFreeSpace()) { - auto* source = static_cast<TDqSource*>(GetSource(space.GetId()).Get()); - source->SetFreeSpace(space.GetSpace()); + auto source = Sources.find(space.GetId()); YQL_ENSURE(source != Sources.end()); + source->second->SetFreeSpace(space.GetSpace()); } } return response; } IInputChannel::TPtr GetInputChannel(ui64 channelId) override { - auto& channel = InputChannels[channelId]; - if (channel == nullptr) { - channel = new TInputChannel(this, Task.GetId(), channelId, Input, Output, ChannelBufferSize); - } - return channel; + auto channel = InputChannels.find(channelId); + YQL_ENSURE(channel != InputChannels.end()); + return channel->second; } IOutputChannel::TPtr GetOutputChannel(ui64 channelId) override { return new TOutputChannel(Task.GetId(), channelId, Input, Output); } - IDqAsyncInputBuffer::TPtr GetSource(ui64 index) override { - auto& source = Sources[index]; - if (source == nullptr) { - source = new TDqSource(Task.GetId(), index, InputTypes.at(index), ChannelBufferSize, this); - } - return source; + IDqAsyncInputBuffer* GetSource(ui64 index) override { + auto source = Sources.find(index); + YQL_ENSURE(source != Sources.end()); + return source->second.Get(); } TDqSink::TPtr GetSink(ui64 index) override { @@ -1543,14 +1540,28 @@ private: } } + void InitChannels() { + for (ui32 i = 0; i < Task.InputsSize(); ++i) { + auto& inputDesc = Task.GetInputs(i); + if (inputDesc.HasSource()) { + Sources[i] = new TDqSource(Task.GetId(), i, InputTypes.at(i), ChannelBufferSize, this); + } else { + for (auto& inputChannelDesc : inputDesc.GetChannels()) { + ui64 channelId = inputChannelDesc.GetId(); + InputChannels[channelId] = new TInputChannel(this, Task.GetId(), channelId, Input, Output, ChannelBufferSize); + } + } + } + } + private: const TString TraceId; NDqProto::TDqTask Task; THashMap<TString, TString> SecureParams; THashMap<TString, TString> TaskParams; TVector<TString> ReadRanges; - THashMap<ui64, IInputChannel::TPtr> InputChannels; - THashMap<ui64, IDqAsyncInputBuffer::TPtr> Sources; + THashMap<ui64, TIntrusivePtr<TInputChannel>> InputChannels; + THashMap<ui64, TIntrusivePtr<TDqSource>> Sources; i64 ChannelBufferSize = 0; std::shared_ptr <NKikimr::NMiniKQL::TScopedAlloc> Alloc; @@ -1644,11 +1655,7 @@ public: } IDqAsyncInputBuffer::TPtr GetSource(ui64 inputIndex) override { - auto& source = Sources[inputIndex]; - if (!source) { - source = static_cast<TDqSource*>(Delegate->GetSource(inputIndex).Get()); - } - return source; + return Delegate->GetSource(inputIndex); } IDqOutputChannel::TPtr GetOutputChannel(ui64 channelId) override @@ -1812,7 +1819,6 @@ private: mutable TDqMeteringStats MeteringStats; mutable THashMap<ui64, TIntrusivePtr<TDqInputChannel>> InputChannels; - THashMap<ui64, TIntrusivePtr<TDqSource>> Sources; mutable THashMap<ui64, TIntrusivePtr<TDqOutputChannel>> OutputChannels; THashMap<ui64, TIntrusivePtr<TDqSink>> Sinks; }; diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h b/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h index c99c3f197f..011d9e196e 100644 --- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h +++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_proxy.h @@ -54,7 +54,7 @@ public: virtual IInputChannel::TPtr GetInputChannel(ui64 channelId) = 0; virtual IOutputChannel::TPtr GetOutputChannel(ui64 channelId) = 0; - virtual NDq::IDqAsyncInputBuffer::TPtr GetSource(ui64 index) = 0; + virtual NDq::IDqAsyncInputBuffer* GetSource(ui64 index) = 0; virtual NDq::IDqAsyncOutputBuffer::TPtr GetSink(ui64 index) = 0; virtual const THashMap<TString,TString>& GetTaskParams() const = 0; diff --git a/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp b/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp index c16c670f53..6bb207b8a1 100644 --- a/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp +++ b/ydb/library/yql/providers/dq/task_runner_actor/task_runner_actor.cpp @@ -259,14 +259,14 @@ private: YQL_ENSURE(!batch.IsWide()); - auto source = TaskRunner->GetSource(index); + auto* source = TaskRunner->GetSource(index); TDqDataSerializer dataSerializer(TaskRunner->GetTypeEnv(), TaskRunner->GetHolderFactory(), DataTransportVersion); TDqSerializedBatch serialized = dataSerializer.Serialize(batch, source->GetInputType()); - Invoker->Invoke([serialized=std::move(serialized),taskRunner=TaskRunner, actorSystem, selfId, cookie, parentId=ParentId, space, finish, index, settings=Settings, stageId=StageId]() mutable { + Invoker->Invoke([serialized=std::move(serialized), taskRunner=TaskRunner, actorSystem, selfId, cookie, parentId=ParentId, space, finish, index, settings=Settings, stageId=StageId]() mutable { try { // auto guard = taskRunner->BindAllocator(); // only for local mode - auto source = taskRunner->GetSource(index); + auto* source = taskRunner->GetSource(index); source->Push(std::move(serialized), space); if (finish) { source->Finish(); |