#include "task_scheduler.h"
#include <util/system/thread.h>
#include <util/string/cast.h>
#include <util/stream/output.h>
TTaskScheduler::ITask::~ITask() {}
TTaskScheduler::IRepeatedTask::~IRepeatedTask() {}
class TTaskScheduler::TWorkerThread
: public ISimpleThread
{
public:
TWorkerThread(TTaskScheduler& state)
: Scheduler_(state)
{
}
TString DebugState = "?";
TString DebugId = "";
private:
void* ThreadProc() noexcept override {
Scheduler_.WorkerFunc(this);
return nullptr;
}
private:
TTaskScheduler& Scheduler_;
};
TTaskScheduler::TTaskScheduler(size_t threadCount, size_t maxTaskCount)
: MaxTaskCount_(maxTaskCount)
{
for (size_t i = 0; i < threadCount; ++i) {
Workers_.push_back(new TWorkerThread(*this));
Workers_.back()->DebugId = ToString(i);
}
}
TTaskScheduler::~TTaskScheduler() {
try {
Stop();
} catch (...) {
Cdbg << "task scheduled destruction error: " << CurrentExceptionMessage();
}
}
void TTaskScheduler::Start() {
for (auto& w : Workers_) {
w->Start();
}
}
void TTaskScheduler::Stop() {
with_lock (Lock_) {
IsStopped_ = true;
CondVar_.BroadCast();
}
for (auto& w: Workers_) {
w->Join();
}
Workers_.clear();
Queue_.clear();
}
size_t TTaskScheduler::GetTaskCount() const {
return static_cast<size_t>(AtomicGet(TaskCounter_));
}
namespace {
class TTaskWrapper
: public TTaskScheduler::ITask
, TNonCopyable
{
public:
TTaskWrapper(TTaskScheduler::ITaskRef task, TAtomic& counter)
: Task_(task)
, Counter_(counter)
{
AtomicIncrement(Counter_);
}
~TTaskWrapper() override {
AtomicDecrement(Counter_);
}
private:
TInstant Process() override {
return Task_->Process();
}
private:
TTaskScheduler::ITaskRef Task_;
TAtomic& Counter_;
};
}
bool TTaskScheduler::Add(ITaskRef task, TInstant expire) {
with_lock (Lock_) {
if (!IsStopped_ && Workers_.size() > 0 && GetTaskCount() + 1 <= MaxTaskCount_) {
ITaskRef newTask = new TTaskWrapper(task, TaskCounter_);
Queue_.insert(std::make_pair(expire, TTaskHolder(newTask)));
if (!Queue_.begin()->second.WaitingWorker) {
CondVar_.Signal();
}
return true;
}
}
return false;
}
namespace {
class TRepeatedTask
: public TTaskScheduler::ITask
{
public:
TRepeatedTask(TTaskScheduler::IRepeatedTaskRef task, TDuration period, TInstant deadline)
: Task_(task)
, Period_(period)
, Deadline_(deadline)
{
}
private:
TInstant Process() final {
Deadline_ += Period_;
if (Task_->Process()) {
return Deadline_;
} else {
return TInstant::Max();
}
}
private:
TTaskScheduler::IRepeatedTaskRef Task_;
TDuration Period_;
TInstant Deadline_;
};
}
bool TTaskScheduler::Add(IRepeatedTaskRef task, TDuration period) {
const TInstant deadline = Now() + period;
ITaskRef t = new TRepeatedTask(task, period, deadline);
return Add(t, deadline);
}
const bool debugOutput = false;
void TTaskScheduler::ChangeDebugState(TWorkerThread* thread, const TString& state) {
if (!debugOutput) {
Y_UNUSED(thread);
Y_UNUSED(state);
return;
}
thread->DebugState = state;
TStringStream ss;
ss << Now() << " " << thread->DebugId << ":\t";
for (auto& w : Workers_) {
ss << w->DebugState << " ";
}
ss << " [" << Queue_.size() << "] [" << TaskCounter_ << "]" << Endl;
Cerr << ss.Str();
}
bool TTaskScheduler::Wait(TWorkerThread* thread, TQueueIterator& toWait) {
ChangeDebugState(thread, "w");
toWait->second.WaitingWorker = thread;
return !CondVar_.WaitD(Lock_, toWait->first);
}
void TTaskScheduler::ChooseFromQueue(TQueueIterator& toWait) {
for (TQueueIterator it = Queue_.begin(); it != Queue_.end(); ++it) {
if (!it->second.WaitingWorker) {
if (toWait == Queue_.end()) {
toWait = it;
} else if (it->first < toWait->first) {
toWait->second.WaitingWorker = nullptr;
toWait = it;
}
break;
}
}
}
void TTaskScheduler::WorkerFunc(TWorkerThread* thread) {
TThread::SetCurrentThreadName("TaskSchedWorker");
TQueueIterator toWait = Queue_.end();
ITaskRef toDo;
for (;;) {
TInstant repeat = TInstant::Max();
if (!!toDo) {
try {
repeat = toDo->Process();
} catch (...) {
Cdbg << "task scheduler error: " << CurrentExceptionMessage();
}
}
with_lock (Lock_) {
ChangeDebugState(thread, "f");
if (IsStopped_) {
ChangeDebugState(thread, "s");
return ;
}
if (!!toDo) {
if (repeat < TInstant::Max()) {
Queue_.insert(std::make_pair(repeat, TTaskHolder(toDo)));
}
}
toDo = nullptr;
ChooseFromQueue(toWait);
if (toWait != Queue_.end()) {
if (toWait->first <= Now() || Wait(thread, toWait)) {
toDo = toWait->second.Task;
Queue_.erase(toWait);
toWait = Queue_.end();
if (!Queue_.empty() && !Queue_.begin()->second.WaitingWorker && Workers_.size() > 1) {
CondVar_.Signal();
}
ChangeDebugState(thread, "p");
}
} else {
ChangeDebugState(thread, "e");
CondVar_.WaitI(Lock_);
}
}
}
}