#include "remote_server_session.h"
#include "remote_connection.h"
#include "remote_server_connection.h"
#include <library/cpp/messagebus/actor/temp_tls_vector.h>
#include <util/generic/cast.h>
#include <util/stream/output.h>
#include <util/system/yassert.h>
#include <typeinfo>
using namespace NActor;
using namespace NBus;
using namespace NBus::NPrivate;
TRemoteServerSession::TRemoteServerSession(TBusMessageQueue* queue,
TBusProtocol* proto, IBusServerHandler* handler,
const TBusServerSessionConfig& config, const TString& name)
: TBusSessionImpl(false, queue, proto, handler, config, name)
, ServerOwnedMessages(config.MaxInFlight, config.MaxInFlightBySize, "ServerOwnedMessages")
, ServerHandler(handler)
{
if (config.PerConnectionMaxInFlightBySize > 0) {
if (config.PerConnectionMaxInFlightBySize < config.MaxMessageSize)
ythrow yexception()
<< "too low PerConnectionMaxInFlightBySize value";
}
}
namespace NBus {
namespace NPrivate {
class TInvokeOnMessage: public IWorkItem {
private:
TRemoteServerSession* RemoteServerSession;
TBusMessagePtrAndHeader Request;
TIntrusivePtr<TRemoteServerConnection> Connection;
public:
TInvokeOnMessage(TRemoteServerSession* session, TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& connection)
: RemoteServerSession(session)
{
Y_ASSERT(!!connection);
Connection.Swap(connection);
Request.Swap(request);
}
void DoWork() override {
THolder<TInvokeOnMessage> holder(this);
RemoteServerSession->InvokeOnMessage(Request, Connection);
// TODO: TRemoteServerSessionSemaphore should be enough
RemoteServerSession->JobCount.Decrement();
}
};
}
}
void TRemoteServerSession::OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& messages) {
AcquireInWorkRequests(messages);
bool executeInPool = Config.ExecuteOnMessageInWorkerPool;
TTempTlsVector< ::IWorkItem*> workQueueTemp;
if (executeInPool) {
workQueueTemp.GetVector()->reserve(messages.size());
}
for (auto& message : messages) {
// TODO: incref once
TIntrusivePtr<TRemoteServerConnection> connection(CheckedCast<TRemoteServerConnection*>(c));
if (executeInPool) {
workQueueTemp.GetVector()->push_back(new TInvokeOnMessage(this, message, connection));
} else {
InvokeOnMessage(message, connection);
}
}
if (executeInPool) {
JobCount.Add(workQueueTemp.GetVector()->size());
Queue->EnqueueWork(*workQueueTemp.GetVector());
}
}
void TRemoteServerSession::InvokeOnMessage(TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& conn) {
if (Y_UNLIKELY(AtomicGet(Down))) {
ReleaseInWorkRequests(*conn.Get(), request.MessagePtr.Get());
InvokeOnError(request.MessagePtr.Release(), MESSAGE_SHUTDOWN);
} else {
TWhatThreadDoesPushPop pp("OnMessage");
TBusIdentity ident;
ident.Connection.Swap(conn);
request.MessagePtr->GetIdentity(ident);
Y_ASSERT(request.MessagePtr->LocalFlags & MESSAGE_IN_WORK);
DoSwap(request.MessagePtr->LocalFlags, ident.LocalFlags);
ident.RecvTime = request.MessagePtr->RecvTime;
#ifndef NDEBUG
auto& message = *request.MessagePtr;
ident.SetMessageType(typeid(message));
#endif
TOnMessageContext context(request.MessagePtr.Release(), ident, this);
ServerHandler->OnMessage(context);
}
}
EMessageStatus TRemoteServerSession::ForgetRequest(const TBusIdentity& ident) {
ReleaseInWork(const_cast<TBusIdentity&>(ident));
return MESSAGE_OK;
}
EMessageStatus TRemoteServerSession::SendReply(const TBusIdentity& ident, TBusMessage* reply) {
reply->CheckClean();
ConvertInWork(const_cast<TBusIdentity&>(ident), reply);
reply->RecvTime = ident.RecvTime;
ident.Connection->Send(reply);
return MESSAGE_OK;
}
int TRemoteServerSession::GetInFlight() const noexcept {
return ServerOwnedMessages.GetCurrentCount();
}
void TRemoteServerSession::FillStatus() {
TBusSessionImpl::FillStatus();
// TODO: weird
StatusData.Status.InFlightCount = ServerOwnedMessages.GetCurrentCount();
StatusData.Status.InFlightSize = ServerOwnedMessages.GetCurrentSize();
StatusData.Status.InputPaused = ServerOwnedMessages.IsLocked();
}
void TRemoteServerSession::AcquireInWorkRequests(TArrayRef<const TBusMessagePtrAndHeader> messages) {
TAtomicBase size = 0;
for (auto message = messages.begin(); message != messages.end(); ++message) {
Y_ASSERT(!(message->MessagePtr->LocalFlags & MESSAGE_IN_WORK));
message->MessagePtr->LocalFlags |= MESSAGE_IN_WORK;
size += message->MessagePtr->GetHeader()->Size;
}
ServerOwnedMessages.IncrementMultiple(messages.size(), size);
}
void TRemoteServerSession::ReleaseInWorkResponses(TArrayRef<const TBusMessagePtrAndHeader> responses) {
TAtomicBase size = 0;
for (auto response = responses.begin(); response != responses.end(); ++response) {
Y_ASSERT((response->MessagePtr->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT));
response->MessagePtr->LocalFlags &= ~MESSAGE_REPLY_IS_BEGING_SENT;
size += response->MessagePtr->RequestSize;
}
ServerOwnedMessages.ReleaseMultiple(responses.size(), size);
}
void TRemoteServerSession::ReleaseInWorkRequests(TRemoteConnection& con, TBusMessage* request) {
Y_ASSERT((request->LocalFlags & MESSAGE_IN_WORK));
request->LocalFlags &= ~MESSAGE_IN_WORK;
const size_t size = request->GetHeader()->Size;
con.QuotaReturnAside(1, size);
ServerOwnedMessages.ReleaseMultiple(1, size);
}
void TRemoteServerSession::ReleaseInWork(TBusIdentity& ident) {
ident.SetInWork(false);
ident.Connection->QuotaReturnAside(1, ident.Size);
ServerOwnedMessages.ReleaseMultiple(1, ident.Size);
}
void TRemoteServerSession::ConvertInWork(TBusIdentity& req, TBusMessage* reply) {
reply->SetIdentity(req);
req.SetInWork(false);
Y_ASSERT(!(reply->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT));
reply->LocalFlags |= MESSAGE_REPLY_IS_BEGING_SENT;
reply->RequestSize = req.Size;
}
void TRemoteServerSession::Shutdown() {
ServerOwnedMessages.Stop();
TBusSessionImpl::Shutdown();
}
void TRemoteServerSession::PauseInput(bool pause) {
ServerOwnedMessages.PauseByUsed(pause);
}
unsigned TRemoteServerSession::GetActualListenPort() {
Y_ABORT_UNLESS(Config.ListenPort > 0, "state check");
return Config.ListenPort;
}