aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp')
-rw-r--r--contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp293
1 files changed, 293 insertions, 0 deletions
diff --git a/contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp b/contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp
new file mode 100644
index 0000000000..b6b21bde11
--- /dev/null
+++ b/contrib/libs/llvm14/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp
@@ -0,0 +1,293 @@
+//===------- SimpleEPCServer.cpp - EPC over simple abstract channel -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.h"
+
+#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Host.h"
+#include "llvm/Support/Process.h"
+
+#include "OrcRTBootstrap.h"
+
+#define DEBUG_TYPE "orc"
+
+using namespace llvm::orc::shared;
+
+namespace llvm {
+namespace orc {
+
+ExecutorBootstrapService::~ExecutorBootstrapService() {}
+
+SimpleRemoteEPCServer::Dispatcher::~Dispatcher() {}
+
+#if LLVM_ENABLE_THREADS
+void SimpleRemoteEPCServer::ThreadDispatcher::dispatch(
+ unique_function<void()> Work) {
+ {
+ std::lock_guard<std::mutex> Lock(DispatchMutex);
+ if (!Running)
+ return;
+ ++Outstanding;
+ }
+
+ std::thread([this, Work = std::move(Work)]() mutable {
+ Work();
+ std::lock_guard<std::mutex> Lock(DispatchMutex);
+ --Outstanding;
+ OutstandingCV.notify_all();
+ }).detach();
+}
+
+void SimpleRemoteEPCServer::ThreadDispatcher::shutdown() {
+ std::unique_lock<std::mutex> Lock(DispatchMutex);
+ Running = false;
+ OutstandingCV.wait(Lock, [this]() { return Outstanding == 0; });
+}
+#endif
+
+StringMap<ExecutorAddr> SimpleRemoteEPCServer::defaultBootstrapSymbols() {
+ StringMap<ExecutorAddr> DBS;
+ rt_bootstrap::addTo(DBS);
+ return DBS;
+}
+
+Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
+SimpleRemoteEPCServer::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
+ ExecutorAddr TagAddr,
+ SimpleRemoteEPCArgBytesVector ArgBytes) {
+
+ LLVM_DEBUG({
+ dbgs() << "SimpleRemoteEPCServer::handleMessage: opc = ";
+ switch (OpC) {
+ case SimpleRemoteEPCOpcode::Setup:
+ dbgs() << "Setup";
+ assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
+ break;
+ case SimpleRemoteEPCOpcode::Hangup:
+ dbgs() << "Hangup";
+ assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
+ break;
+ case SimpleRemoteEPCOpcode::Result:
+ dbgs() << "Result";
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
+ break;
+ case SimpleRemoteEPCOpcode::CallWrapper:
+ dbgs() << "CallWrapper";
+ break;
+ }
+ dbgs() << ", seqno = " << SeqNo
+ << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
+ << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
+ << " bytes\n";
+ });
+
+ using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
+ if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
+ return make_error<StringError>("Unexpected opcode",
+ inconvertibleErrorCode());
+
+ // TODO: Clean detach message?
+ switch (OpC) {
+ case SimpleRemoteEPCOpcode::Setup:
+ return make_error<StringError>("Unexpected Setup opcode",
+ inconvertibleErrorCode());
+ case SimpleRemoteEPCOpcode::Hangup:
+ return SimpleRemoteEPCTransportClient::EndSession;
+ case SimpleRemoteEPCOpcode::Result:
+ if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
+ return std::move(Err);
+ break;
+ case SimpleRemoteEPCOpcode::CallWrapper:
+ handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
+ break;
+ }
+ return ContinueSession;
+}
+
+Error SimpleRemoteEPCServer::waitForDisconnect() {
+ std::unique_lock<std::mutex> Lock(ServerStateMutex);
+ ShutdownCV.wait(Lock, [this]() { return RunState == ServerShutDown; });
+ return std::move(ShutdownErr);
+}
+
+void SimpleRemoteEPCServer::handleDisconnect(Error Err) {
+ PendingJITDispatchResultsMap TmpPending;
+
+ {
+ std::lock_guard<std::mutex> Lock(ServerStateMutex);
+ std::swap(TmpPending, PendingJITDispatchResults);
+ RunState = ServerShuttingDown;
+ }
+
+ // Send out-of-band errors to any waiting threads.
+ for (auto &KV : TmpPending)
+ KV.second->set_value(
+ shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
+
+ // Wait for dispatcher to clear.
+ D->shutdown();
+
+ // Shut down services.
+ while (!Services.empty()) {
+ ShutdownErr =
+ joinErrors(std::move(ShutdownErr), Services.back()->shutdown());
+ Services.pop_back();
+ }
+
+ std::lock_guard<std::mutex> Lock(ServerStateMutex);
+ ShutdownErr = joinErrors(std::move(ShutdownErr), std::move(Err));
+ RunState = ServerShutDown;
+ ShutdownCV.notify_all();
+}
+
+Error SimpleRemoteEPCServer::sendMessage(SimpleRemoteEPCOpcode OpC,
+ uint64_t SeqNo, ExecutorAddr TagAddr,
+ ArrayRef<char> ArgBytes) {
+
+ LLVM_DEBUG({
+ dbgs() << "SimpleRemoteEPCServer::sendMessage: opc = ";
+ switch (OpC) {
+ case SimpleRemoteEPCOpcode::Setup:
+ dbgs() << "Setup";
+ assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
+ break;
+ case SimpleRemoteEPCOpcode::Hangup:
+ dbgs() << "Hangup";
+ assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
+ break;
+ case SimpleRemoteEPCOpcode::Result:
+ dbgs() << "Result";
+ assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
+ break;
+ case SimpleRemoteEPCOpcode::CallWrapper:
+ dbgs() << "CallWrapper";
+ break;
+ }
+ dbgs() << ", seqno = " << SeqNo
+ << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
+ << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
+ << " bytes\n";
+ });
+ auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
+ LLVM_DEBUG({
+ if (Err)
+ dbgs() << " \\--> SimpleRemoteEPC::sendMessage failed\n";
+ });
+ return Err;
+}
+
+Error SimpleRemoteEPCServer::sendSetupMessage(
+ StringMap<ExecutorAddr> BootstrapSymbols) {
+
+ using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
+
+ std::vector<char> SetupPacket;
+ SimpleRemoteEPCExecutorInfo EI;
+ EI.TargetTriple = sys::getProcessTriple();
+ if (auto PageSize = sys::Process::getPageSize())
+ EI.PageSize = *PageSize;
+ else
+ return PageSize.takeError();
+ EI.BootstrapSymbols = std::move(BootstrapSymbols);
+
+ assert(!EI.BootstrapSymbols.count(ExecutorSessionObjectName) &&
+ "Dispatch context name should not be set");
+ assert(!EI.BootstrapSymbols.count(DispatchFnName) &&
+ "Dispatch function name should not be set");
+ EI.BootstrapSymbols[ExecutorSessionObjectName] = ExecutorAddr::fromPtr(this);
+ EI.BootstrapSymbols[DispatchFnName] = ExecutorAddr::fromPtr(jitDispatchEntry);
+
+ using SPSSerialize =
+ shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
+ auto SetupPacketBytes =
+ shared::WrapperFunctionResult::allocate(SPSSerialize::size(EI));
+ shared::SPSOutputBuffer OB(SetupPacketBytes.data(), SetupPacketBytes.size());
+ if (!SPSSerialize::serialize(OB, EI))
+ return make_error<StringError>("Could not send setup packet",
+ inconvertibleErrorCode());
+
+ return sendMessage(SimpleRemoteEPCOpcode::Setup, 0, ExecutorAddr(),
+ {SetupPacketBytes.data(), SetupPacketBytes.size()});
+}
+
+Error SimpleRemoteEPCServer::handleResult(
+ uint64_t SeqNo, ExecutorAddr TagAddr,
+ SimpleRemoteEPCArgBytesVector ArgBytes) {
+ std::promise<shared::WrapperFunctionResult> *P = nullptr;
+ {
+ std::lock_guard<std::mutex> Lock(ServerStateMutex);
+ auto I = PendingJITDispatchResults.find(SeqNo);
+ if (I == PendingJITDispatchResults.end())
+ return make_error<StringError>("No call for sequence number " +
+ Twine(SeqNo),
+ inconvertibleErrorCode());
+ P = I->second;
+ PendingJITDispatchResults.erase(I);
+ releaseSeqNo(SeqNo);
+ }
+ auto R = shared::WrapperFunctionResult::allocate(ArgBytes.size());
+ memcpy(R.data(), ArgBytes.data(), ArgBytes.size());
+ P->set_value(std::move(R));
+ return Error::success();
+}
+
+void SimpleRemoteEPCServer::handleCallWrapper(
+ uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
+ SimpleRemoteEPCArgBytesVector ArgBytes) {
+ D->dispatch([this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
+ using WrapperFnTy =
+ shared::CWrapperFunctionResult (*)(const char *, size_t);
+ auto *Fn = TagAddr.toPtr<WrapperFnTy>();
+ shared::WrapperFunctionResult ResultBytes(
+ Fn(ArgBytes.data(), ArgBytes.size()));
+ if (auto Err = sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
+ ExecutorAddr(),
+ {ResultBytes.data(), ResultBytes.size()}))
+ ReportError(std::move(Err));
+ });
+}
+
+shared::WrapperFunctionResult
+SimpleRemoteEPCServer::doJITDispatch(const void *FnTag, const char *ArgData,
+ size_t ArgSize) {
+ uint64_t SeqNo;
+ std::promise<shared::WrapperFunctionResult> ResultP;
+ auto ResultF = ResultP.get_future();
+ {
+ std::lock_guard<std::mutex> Lock(ServerStateMutex);
+ if (RunState != ServerRunning)
+ return shared::WrapperFunctionResult::createOutOfBandError(
+ "jit_dispatch not available (EPC server shut down)");
+
+ SeqNo = getNextSeqNo();
+ assert(!PendingJITDispatchResults.count(SeqNo) && "SeqNo already in use");
+ PendingJITDispatchResults[SeqNo] = &ResultP;
+ }
+
+ if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
+ ExecutorAddr::fromPtr(FnTag), {ArgData, ArgSize}))
+ ReportError(std::move(Err));
+
+ return ResultF.get();
+}
+
+shared::CWrapperFunctionResult
+SimpleRemoteEPCServer::jitDispatchEntry(void *DispatchCtx, const void *FnTag,
+ const char *ArgData, size_t ArgSize) {
+ return reinterpret_cast<SimpleRemoteEPCServer *>(DispatchCtx)
+ ->doJITDispatch(FnTag, ArgData, ArgSize)
+ .release();
+}
+
+} // end namespace orc
+} // end namespace llvm