diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:30 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:30 +0300 |
commit | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (patch) | |
tree | 012bb94d777798f1f56ac1cec429509766d05181 /contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc | |
parent | 6751af0b0c1b952fede40b19b71da8025b5d8bcf (diff) | |
download | ydb-2598ef1d0aee359b4b6d5fdd1758916d5907d04f.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc')
32 files changed, 5795 insertions, 5795 deletions
diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h index aa58b3d35f..abfc9333d8 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -30,7 +30,7 @@ #include "llvm/ExecutionEngine/Orc/Layer.h" #include "llvm/ExecutionEngine/Orc/LazyReexports.h" #include "llvm/ExecutionEngine/Orc/Speculation.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Constant.h" @@ -101,8 +101,8 @@ public: /// Emits the given module. This should not be called by clients: it will be /// called by the JIT when a definition added via the add method is requested. - void emit(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM) override; + void emit(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM) override; private: struct PerDylibResources { @@ -126,8 +126,8 @@ private: void expandPartition(GlobalValueSet &Partition); - void emitPartition(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM, + void emitPartition(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM, IRMaterializationUnit::SymbolNameToDefinitionMap Defs); mutable std::mutex CODLayerMutex; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Core.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Core.h index 3a51e885ae..08ada986f3 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Core.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Core.h @@ -23,14 +23,14 @@ #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/FunctionExtras.h" -#include "llvm/ADT/IntrusiveRefCntPtr.h" -#include "llvm/ExecutionEngine/JITLink/JITLinkDylib.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkDylib.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" #include "llvm/ExecutionEngine/OrcV1Deprecation.h" #include "llvm/Support/Debug.h" -#include <atomic> +#include <atomic> #include <memory> #include <vector> @@ -43,68 +43,68 @@ class ExecutionSession; class MaterializationUnit; class MaterializationResponsibility; class JITDylib; -class ResourceTracker; -class InProgressLookupState; - +class ResourceTracker; +class InProgressLookupState; + enum class SymbolState : uint8_t; -using ResourceTrackerSP = IntrusiveRefCntPtr<ResourceTracker>; -using JITDylibSP = IntrusiveRefCntPtr<JITDylib>; - -using ResourceKey = uintptr_t; - -/// API to remove / transfer ownership of JIT resources. -class ResourceTracker : public ThreadSafeRefCountedBase<ResourceTracker> { -private: - friend class ExecutionSession; - friend class JITDylib; - friend class MaterializationResponsibility; - -public: - ResourceTracker(const ResourceTracker &) = delete; - ResourceTracker &operator=(const ResourceTracker &) = delete; - ResourceTracker(ResourceTracker &&) = delete; - ResourceTracker &operator=(ResourceTracker &&) = delete; - - ~ResourceTracker(); - - /// Return the JITDylib targeted by this tracker. - JITDylib &getJITDylib() const { - return *reinterpret_cast<JITDylib *>(JDAndFlag.load() & - ~static_cast<uintptr_t>(1)); - } - - /// Remove all resources associated with this key. - Error remove(); - - /// Transfer all resources associated with this key to the given - /// tracker, which must target the same JITDylib as this one. - void transferTo(ResourceTracker &DstRT); - - /// Return true if this tracker has become defunct. - bool isDefunct() const { return JDAndFlag.load() & 0x1; } - - /// Returns the key associated with this tracker. - /// This method should not be used except for debug logging: there is no - /// guarantee that the returned value will remain valid. - ResourceKey getKeyUnsafe() const { return reinterpret_cast<uintptr_t>(this); } - -private: - ResourceTracker(JITDylibSP JD); - - void makeDefunct(); - - std::atomic_uintptr_t JDAndFlag; -}; - -/// Listens for ResourceTracker operations. -class ResourceManager { -public: - virtual ~ResourceManager(); - virtual Error handleRemoveResources(ResourceKey K) = 0; - virtual void handleTransferResources(ResourceKey DstK, ResourceKey SrcK) = 0; -}; - +using ResourceTrackerSP = IntrusiveRefCntPtr<ResourceTracker>; +using JITDylibSP = IntrusiveRefCntPtr<JITDylib>; + +using ResourceKey = uintptr_t; + +/// API to remove / transfer ownership of JIT resources. +class ResourceTracker : public ThreadSafeRefCountedBase<ResourceTracker> { +private: + friend class ExecutionSession; + friend class JITDylib; + friend class MaterializationResponsibility; + +public: + ResourceTracker(const ResourceTracker &) = delete; + ResourceTracker &operator=(const ResourceTracker &) = delete; + ResourceTracker(ResourceTracker &&) = delete; + ResourceTracker &operator=(ResourceTracker &&) = delete; + + ~ResourceTracker(); + + /// Return the JITDylib targeted by this tracker. + JITDylib &getJITDylib() const { + return *reinterpret_cast<JITDylib *>(JDAndFlag.load() & + ~static_cast<uintptr_t>(1)); + } + + /// Remove all resources associated with this key. + Error remove(); + + /// Transfer all resources associated with this key to the given + /// tracker, which must target the same JITDylib as this one. + void transferTo(ResourceTracker &DstRT); + + /// Return true if this tracker has become defunct. + bool isDefunct() const { return JDAndFlag.load() & 0x1; } + + /// Returns the key associated with this tracker. + /// This method should not be used except for debug logging: there is no + /// guarantee that the returned value will remain valid. + ResourceKey getKeyUnsafe() const { return reinterpret_cast<uintptr_t>(this); } + +private: + ResourceTracker(JITDylibSP JD); + + void makeDefunct(); + + std::atomic_uintptr_t JDAndFlag; +}; + +/// Listens for ResourceTracker operations. +class ResourceManager { +public: + virtual ~ResourceManager(); + virtual Error handleRemoveResources(ResourceKey K) = 0; + virtual void handleTransferResources(ResourceKey DstK, ResourceKey SrcK) = 0; +}; + /// A set of symbol names (represented by SymbolStringPtrs for // efficiency). using SymbolNameSet = DenseSet<SymbolStringPtr>; @@ -224,21 +224,21 @@ public: /// Add an element to the set. The client is responsible for checking that /// duplicates are not added. - SymbolLookupSet & - add(SymbolStringPtr Name, - SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) { + SymbolLookupSet & + add(SymbolStringPtr Name, + SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) { Symbols.push_back(std::make_pair(std::move(Name), Flags)); - return *this; - } - - /// Quickly append one lookup set to another. - SymbolLookupSet &append(SymbolLookupSet Other) { - Symbols.reserve(Symbols.size() + Other.size()); - for (auto &KV : Other) - Symbols.push_back(std::move(KV)); - return *this; + return *this; } + /// Quickly append one lookup set to another. + SymbolLookupSet &append(SymbolLookupSet Other) { + Symbols.reserve(Symbols.size() + Other.size()); + for (auto &KV : Other) + Symbols.push_back(std::move(KV)); + return *this; + } + bool empty() const { return Symbols.empty(); } UnderlyingVector::size_type size() const { return Symbols.size(); } iterator begin() { return Symbols.begin(); } @@ -363,7 +363,7 @@ public: for (UnderlyingVector::size_type I = 1; I != Symbols.size(); ++I) if (Symbols[I].first == Symbols[I - 1].first) return true; - return false; + return false; } #endif @@ -394,18 +394,18 @@ using RegisterDependenciesFunction = /// are no dependants to register with. extern RegisterDependenciesFunction NoDependenciesToRegister; -class ResourceTrackerDefunct : public ErrorInfo<ResourceTrackerDefunct> { -public: - static char ID; - - ResourceTrackerDefunct(ResourceTrackerSP RT); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - -private: - ResourceTrackerSP RT; -}; - +class ResourceTrackerDefunct : public ErrorInfo<ResourceTrackerDefunct> { +public: + static char ID; + + ResourceTrackerDefunct(ResourceTrackerSP RT); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + +private: + ResourceTrackerSP RT; +}; + /// Used to notify a JITDylib that the given set of symbols failed to /// materialize. class FailedToMaterialize : public ErrorInfo<FailedToMaterialize> { @@ -496,10 +496,10 @@ private: /// emit symbols, or abandon materialization by notifying any unmaterialized /// symbols of an error. class MaterializationResponsibility { - friend class ExecutionSession; - + friend class ExecutionSession; + public: - MaterializationResponsibility(MaterializationResponsibility &&) = delete; + MaterializationResponsibility(MaterializationResponsibility &&) = delete; MaterializationResponsibility & operator=(MaterializationResponsibility &&) = delete; @@ -508,15 +508,15 @@ public: /// emitted or notified of an error. ~MaterializationResponsibility(); - /// Returns the ResourceTracker for this instance. - template <typename Func> Error withResourceKeyDo(Func &&F) const; - + /// Returns the ResourceTracker for this instance. + template <typename Func> Error withResourceKeyDo(Func &&F) const; + /// Returns the target JITDylib that these symbols are being materialized /// into. JITDylib &getTargetJITDylib() const { return *JD; } - /// Returns the ExecutionSession for this instance. - ExecutionSession &getExecutionSession(); + /// Returns the ExecutionSession for this instance. + ExecutionSession &getExecutionSession(); /// Returns the symbol flags map for this responsibility instance. /// Note: The returned flags may have transient flags (Lazy, Materializing) @@ -601,13 +601,13 @@ public: /// materializers to break up work based on run-time information (e.g. /// by introspecting which symbols have actually been looked up and /// materializing only those). - Error replace(std::unique_ptr<MaterializationUnit> MU); + Error replace(std::unique_ptr<MaterializationUnit> MU); /// Delegates responsibility for the given symbols to the returned /// materialization responsibility. Useful for breaking up work between /// threads, or different kinds of materialization processes. - Expected<std::unique_ptr<MaterializationResponsibility>> - delegate(const SymbolNameSet &Symbols); + Expected<std::unique_ptr<MaterializationResponsibility>> + delegate(const SymbolNameSet &Symbols); void addDependencies(const SymbolStringPtr &Name, const SymbolDependenceMap &Dependencies); @@ -618,15 +618,15 @@ public: private: /// Create a MaterializationResponsibility for the given JITDylib and /// initial symbols. - MaterializationResponsibility(JITDylibSP JD, SymbolFlagsMap SymbolFlags, - SymbolStringPtr InitSymbol) + MaterializationResponsibility(JITDylibSP JD, SymbolFlagsMap SymbolFlags, + SymbolStringPtr InitSymbol) : JD(std::move(JD)), SymbolFlags(std::move(SymbolFlags)), - InitSymbol(std::move(InitSymbol)) { - assert(this->JD && "Cannot initialize with null JITDylib"); + InitSymbol(std::move(InitSymbol)) { + assert(this->JD && "Cannot initialize with null JITDylib"); assert(!this->SymbolFlags.empty() && "Materializing nothing?"); } - JITDylibSP JD; + JITDylibSP JD; SymbolFlagsMap SymbolFlags; SymbolStringPtr InitSymbol; }; @@ -645,9 +645,9 @@ class MaterializationUnit { public: MaterializationUnit(SymbolFlagsMap InitalSymbolFlags, - SymbolStringPtr InitSymbol) + SymbolStringPtr InitSymbol) : SymbolFlags(std::move(InitalSymbolFlags)), - InitSymbol(std::move(InitSymbol)) { + InitSymbol(std::move(InitSymbol)) { assert((!this->InitSymbol || this->SymbolFlags.count(this->InitSymbol)) && "If set, InitSymbol should appear in InitialSymbolFlags map"); } @@ -667,8 +667,8 @@ public: /// Implementations of this method should materialize all symbols /// in the materialzation unit, except for those that have been /// previously discarded. - virtual void - materialize(std::unique_ptr<MaterializationResponsibility> R) = 0; + virtual void + materialize(std::unique_ptr<MaterializationResponsibility> R) = 0; /// Called by JITDylibs to notify MaterializationUnits that the given symbol /// has been overridden. @@ -697,12 +697,12 @@ private: /// materialized. class AbsoluteSymbolsMaterializationUnit : public MaterializationUnit { public: - AbsoluteSymbolsMaterializationUnit(SymbolMap Symbols); + AbsoluteSymbolsMaterializationUnit(SymbolMap Symbols); StringRef getName() const override; private: - void materialize(std::unique_ptr<MaterializationResponsibility> R) override; + void materialize(std::unique_ptr<MaterializationResponsibility> R) override; void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; static SymbolFlagsMap extractFlags(const SymbolMap &Symbols); @@ -720,9 +720,9 @@ private: /// \endcode /// inline std::unique_ptr<AbsoluteSymbolsMaterializationUnit> -absoluteSymbols(SymbolMap Symbols) { +absoluteSymbols(SymbolMap Symbols) { return std::make_unique<AbsoluteSymbolsMaterializationUnit>( - std::move(Symbols)); + std::move(Symbols)); } /// A materialization unit for symbol aliases. Allows existing symbols to be @@ -739,12 +739,12 @@ public: /// resolved. ReExportsMaterializationUnit(JITDylib *SourceJD, JITDylibLookupFlags SourceJDLookupFlags, - SymbolAliasMap Aliases); + SymbolAliasMap Aliases); StringRef getName() const override; private: - void materialize(std::unique_ptr<MaterializationResponsibility> R) override; + void materialize(std::unique_ptr<MaterializationResponsibility> R) override; void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; static SymbolFlagsMap extractFlags(const SymbolAliasMap &Aliases); @@ -765,9 +765,9 @@ private: /// return Err; /// \endcode inline std::unique_ptr<ReExportsMaterializationUnit> -symbolAliases(SymbolAliasMap Aliases) { +symbolAliases(SymbolAliasMap Aliases) { return std::make_unique<ReExportsMaterializationUnit>( - nullptr, JITDylibLookupFlags::MatchAllSymbols, std::move(Aliases)); + nullptr, JITDylibLookupFlags::MatchAllSymbols, std::move(Aliases)); } /// Create a materialization unit for re-exporting symbols from another JITDylib @@ -776,9 +776,9 @@ symbolAliases(SymbolAliasMap Aliases) { inline std::unique_ptr<ReExportsMaterializationUnit> reexports(JITDylib &SourceJD, SymbolAliasMap Aliases, JITDylibLookupFlags SourceJDLookupFlags = - JITDylibLookupFlags::MatchExportedSymbolsOnly) { + JITDylibLookupFlags::MatchExportedSymbolsOnly) { return std::make_unique<ReExportsMaterializationUnit>( - &SourceJD, SourceJDLookupFlags, std::move(Aliases)); + &SourceJD, SourceJDLookupFlags, std::move(Aliases)); } /// Build a SymbolAliasMap for the common case where you want to re-export @@ -802,10 +802,10 @@ enum class SymbolState : uint8_t { /// makes a callback when all symbols are available. class AsynchronousSymbolQuery { friend class ExecutionSession; - friend class InProgressFullLookupState; + friend class InProgressFullLookupState; friend class JITDylib; friend class JITSymbolResolverAdapter; - friend class MaterializationResponsibility; + friend class MaterializationResponsibility; public: /// Create a query for the given symbols. The NotifyComplete @@ -849,57 +849,57 @@ private: SymbolState RequiredState; }; -/// Wraps state for a lookup-in-progress. -/// DefinitionGenerators can optionally take ownership of a LookupState object -/// to suspend a lookup-in-progress while they search for definitions. -class LookupState { - friend class OrcV2CAPIHelper; - friend class ExecutionSession; - -public: - LookupState(); - LookupState(LookupState &&); - LookupState &operator=(LookupState &&); - ~LookupState(); - - /// Continue the lookup. This can be called by DefinitionGenerators - /// to re-start a captured query-application operation. - void continueLookup(Error Err); - -private: - LookupState(std::unique_ptr<InProgressLookupState> IPLS); - - // For C API. - void reset(InProgressLookupState *IPLS); - - std::unique_ptr<InProgressLookupState> IPLS; -}; - -/// Definition generators can be attached to JITDylibs to generate new -/// definitions for otherwise unresolved symbols during lookup. -class DefinitionGenerator { -public: - virtual ~DefinitionGenerator(); - - /// DefinitionGenerators should override this method to insert new - /// definitions into the parent JITDylib. K specifies the kind of this - /// lookup. JD specifies the target JITDylib being searched, and - /// JDLookupFlags specifies whether the search should match against - /// hidden symbols. Finally, Symbols describes the set of unresolved - /// symbols and their associated lookup flags. - virtual Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, - JITDylibLookupFlags JDLookupFlags, - const SymbolLookupSet &LookupSet) = 0; -}; - +/// Wraps state for a lookup-in-progress. +/// DefinitionGenerators can optionally take ownership of a LookupState object +/// to suspend a lookup-in-progress while they search for definitions. +class LookupState { + friend class OrcV2CAPIHelper; + friend class ExecutionSession; + +public: + LookupState(); + LookupState(LookupState &&); + LookupState &operator=(LookupState &&); + ~LookupState(); + + /// Continue the lookup. This can be called by DefinitionGenerators + /// to re-start a captured query-application operation. + void continueLookup(Error Err); + +private: + LookupState(std::unique_ptr<InProgressLookupState> IPLS); + + // For C API. + void reset(InProgressLookupState *IPLS); + + std::unique_ptr<InProgressLookupState> IPLS; +}; + +/// Definition generators can be attached to JITDylibs to generate new +/// definitions for otherwise unresolved symbols during lookup. +class DefinitionGenerator { +public: + virtual ~DefinitionGenerator(); + + /// DefinitionGenerators should override this method to insert new + /// definitions into the parent JITDylib. K specifies the kind of this + /// lookup. JD specifies the target JITDylib being searched, and + /// JDLookupFlags specifies whether the search should match against + /// hidden symbols. Finally, Symbols describes the set of unresolved + /// symbols and their associated lookup flags. + virtual Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, + JITDylibLookupFlags JDLookupFlags, + const SymbolLookupSet &LookupSet) = 0; +}; + /// A symbol table that supports asynchoronous symbol queries. /// /// Represents a virtual shared object. Instances can not be copied or moved, so /// their addresses may be used as keys for resource management. /// JITDylib state changes must be made via an ExecutionSession to guarantee /// that they are synchronized with respect to other JITDylib operations. -class JITDylib : public ThreadSafeRefCountedBase<JITDylib>, - public jitlink::JITLinkDylib { +class JITDylib : public ThreadSafeRefCountedBase<JITDylib>, + public jitlink::JITLinkDylib { friend class AsynchronousSymbolQuery; friend class ExecutionSession; friend class Platform; @@ -920,21 +920,21 @@ public: /// Get a reference to the ExecutionSession for this JITDylib. ExecutionSession &getExecutionSession() const { return ES; } - /// Calls remove on all trackers currently associated with this JITDylib. - /// Does not run static deinits. - /// - /// Note that removal happens outside the session lock, so new code may be - /// added concurrently while the clear is underway, and the newly added - /// code will *not* be cleared. Adding new code concurrently with a clear - /// is usually a bug and should be avoided. - Error clear(); - - /// Get the default resource tracker for this JITDylib. - ResourceTrackerSP getDefaultResourceTracker(); - - /// Create a resource tracker for this JITDylib. - ResourceTrackerSP createResourceTracker(); - + /// Calls remove on all trackers currently associated with this JITDylib. + /// Does not run static deinits. + /// + /// Note that removal happens outside the session lock, so new code may be + /// added concurrently while the clear is underway, and the newly added + /// code will *not* be cleared. Adding new code concurrently with a clear + /// is usually a bug and should be avoided. + Error clear(); + + /// Get the default resource tracker for this JITDylib. + ResourceTrackerSP getDefaultResourceTracker(); + + /// Create a resource tracker for this JITDylib. + ResourceTrackerSP createResourceTracker(); + /// Adds a definition generator to this JITDylib and returns a referenece to /// it. /// @@ -995,13 +995,13 @@ public: /// Define all symbols provided by the materialization unit to be part of this /// JITDylib. /// - /// If RT is not specified then the default resource tracker will be used. - /// + /// If RT is not specified then the default resource tracker will be used. + /// /// This overload always takes ownership of the MaterializationUnit. If any /// errors occur, the MaterializationUnit consumed. template <typename MaterializationUnitType> - Error define(std::unique_ptr<MaterializationUnitType> &&MU, - ResourceTrackerSP RT = nullptr); + Error define(std::unique_ptr<MaterializationUnitType> &&MU, + ResourceTrackerSP RT = nullptr); /// Define all symbols provided by the materialization unit to be part of this /// JITDylib. @@ -1011,8 +1011,8 @@ public: /// may allow the caller to modify the MaterializationUnit to correct the /// issue, then re-call define. template <typename MaterializationUnitType> - Error define(std::unique_ptr<MaterializationUnitType> &MU, - ResourceTrackerSP RT = nullptr); + Error define(std::unique_ptr<MaterializationUnitType> &MU, + ResourceTrackerSP RT = nullptr); /// Tries to remove the given symbols. /// @@ -1029,44 +1029,44 @@ public: /// Dump current JITDylib state to OS. void dump(raw_ostream &OS); - /// Returns the given JITDylibs and all of their transitive dependencies in - /// DFS order (based on linkage relationships). Each JITDylib will appear - /// only once. - static std::vector<JITDylibSP> getDFSLinkOrder(ArrayRef<JITDylibSP> JDs); - - /// Returns the given JITDylibs and all of their transitive dependensies in - /// reverse DFS order (based on linkage relationships). Each JITDylib will - /// appear only once. - static std::vector<JITDylibSP> - getReverseDFSLinkOrder(ArrayRef<JITDylibSP> JDs); - - /// Return this JITDylib and its transitive dependencies in DFS order - /// based on linkage relationships. - std::vector<JITDylibSP> getDFSLinkOrder(); - - /// Rteurn this JITDylib and its transitive dependencies in reverse DFS order - /// based on linkage relationships. - std::vector<JITDylibSP> getReverseDFSLinkOrder(); - + /// Returns the given JITDylibs and all of their transitive dependencies in + /// DFS order (based on linkage relationships). Each JITDylib will appear + /// only once. + static std::vector<JITDylibSP> getDFSLinkOrder(ArrayRef<JITDylibSP> JDs); + + /// Returns the given JITDylibs and all of their transitive dependensies in + /// reverse DFS order (based on linkage relationships). Each JITDylib will + /// appear only once. + static std::vector<JITDylibSP> + getReverseDFSLinkOrder(ArrayRef<JITDylibSP> JDs); + + /// Return this JITDylib and its transitive dependencies in DFS order + /// based on linkage relationships. + std::vector<JITDylibSP> getDFSLinkOrder(); + + /// Rteurn this JITDylib and its transitive dependencies in reverse DFS order + /// based on linkage relationships. + std::vector<JITDylibSP> getReverseDFSLinkOrder(); + private: using AsynchronousSymbolQueryList = std::vector<std::shared_ptr<AsynchronousSymbolQuery>>; struct UnmaterializedInfo { - UnmaterializedInfo(std::unique_ptr<MaterializationUnit> MU, - ResourceTracker *RT) - : MU(std::move(MU)), RT(RT) {} + UnmaterializedInfo(std::unique_ptr<MaterializationUnit> MU, + ResourceTracker *RT) + : MU(std::move(MU)), RT(RT) {} std::unique_ptr<MaterializationUnit> MU; - ResourceTracker *RT; + ResourceTracker *RT; }; using UnmaterializedInfosMap = DenseMap<SymbolStringPtr, std::shared_ptr<UnmaterializedInfo>>; - using UnmaterializedInfosList = - std::vector<std::shared_ptr<UnmaterializedInfo>>; - + using UnmaterializedInfosList = + std::vector<std::shared_ptr<UnmaterializedInfo>>; + struct MaterializingInfo { SymbolDependenceMap Dependants; SymbolDependenceMap UnemittedDependencies; @@ -1133,16 +1133,16 @@ private: JITDylib(ExecutionSession &ES, std::string Name); - ResourceTrackerSP getTracker(MaterializationResponsibility &MR); - std::pair<AsynchronousSymbolQuerySet, std::shared_ptr<SymbolDependenceMap>> - removeTracker(ResourceTracker &RT); + ResourceTrackerSP getTracker(MaterializationResponsibility &MR); + std::pair<AsynchronousSymbolQuerySet, std::shared_ptr<SymbolDependenceMap>> + removeTracker(ResourceTracker &RT); - void transferTracker(ResourceTracker &DstRT, ResourceTracker &SrcRT); + void transferTracker(ResourceTracker &DstRT, ResourceTracker &SrcRT); - Error defineImpl(MaterializationUnit &MU); + Error defineImpl(MaterializationUnit &MU); - void installMaterializationUnit(std::unique_ptr<MaterializationUnit> MU, - ResourceTracker &RT); + void installMaterializationUnit(std::unique_ptr<MaterializationUnit> MU, + ResourceTracker &RT); void detachQueryHelper(AsynchronousSymbolQuery &Q, const SymbolNameSet &QuerySymbols); @@ -1153,45 +1153,45 @@ private: Expected<SymbolFlagsMap> defineMaterializing(SymbolFlagsMap SymbolFlags); - Error replace(MaterializationResponsibility &FromMR, - std::unique_ptr<MaterializationUnit> MU); - - Expected<std::unique_ptr<MaterializationResponsibility>> - delegate(MaterializationResponsibility &FromMR, SymbolFlagsMap SymbolFlags, - SymbolStringPtr InitSymbol); + Error replace(MaterializationResponsibility &FromMR, + std::unique_ptr<MaterializationUnit> MU); + Expected<std::unique_ptr<MaterializationResponsibility>> + delegate(MaterializationResponsibility &FromMR, SymbolFlagsMap SymbolFlags, + SymbolStringPtr InitSymbol); + SymbolNameSet getRequestedSymbols(const SymbolFlagsMap &SymbolFlags) const; void addDependencies(const SymbolStringPtr &Name, const SymbolDependenceMap &Dependants); - Error resolve(MaterializationResponsibility &MR, const SymbolMap &Resolved); - - Error emit(MaterializationResponsibility &MR, const SymbolFlagsMap &Emitted); + Error resolve(MaterializationResponsibility &MR, const SymbolMap &Resolved); - void unlinkMaterializationResponsibility(MaterializationResponsibility &MR); + Error emit(MaterializationResponsibility &MR, const SymbolFlagsMap &Emitted); + void unlinkMaterializationResponsibility(MaterializationResponsibility &MR); + using FailedSymbolsWorklist = std::vector<std::pair<JITDylib *, SymbolStringPtr>>; - static std::pair<AsynchronousSymbolQuerySet, - std::shared_ptr<SymbolDependenceMap>> - failSymbols(FailedSymbolsWorklist); - + static std::pair<AsynchronousSymbolQuerySet, + std::shared_ptr<SymbolDependenceMap>> + failSymbols(FailedSymbolsWorklist); + ExecutionSession &ES; std::string JITDylibName; - std::mutex GeneratorsMutex; + std::mutex GeneratorsMutex; bool Open = true; SymbolTable Symbols; UnmaterializedInfosMap UnmaterializedInfos; MaterializingInfosMap MaterializingInfos; - std::vector<std::shared_ptr<DefinitionGenerator>> DefGenerators; + std::vector<std::shared_ptr<DefinitionGenerator>> DefGenerators; JITDylibSearchOrder LinkOrder; - ResourceTrackerSP DefaultTracker; - - // Map trackers to sets of symbols tracked. - DenseMap<ResourceTracker *, SymbolNameVector> TrackerSymbols; - DenseMap<MaterializationResponsibility *, ResourceTracker *> MRTrackers; + ResourceTrackerSP DefaultTracker; + + // Map trackers to sets of symbols tracked. + DenseMap<ResourceTracker *, SymbolNameVector> TrackerSymbols; + DenseMap<MaterializationResponsibility *, ResourceTracker *> MRTrackers; }; /// Platforms set up standard symbols and mediate interactions between dynamic @@ -1210,12 +1210,12 @@ public: /// This method will be called under the ExecutionSession lock each time a /// MaterializationUnit is added to a JITDylib. - virtual Error notifyAdding(ResourceTracker &RT, - const MaterializationUnit &MU) = 0; + virtual Error notifyAdding(ResourceTracker &RT, + const MaterializationUnit &MU) = 0; /// This method will be called under the ExecutionSession lock when a - /// ResourceTracker is removed. - virtual Error notifyRemoving(ResourceTracker &RT) = 0; + /// ResourceTracker is removed. + virtual Error notifyRemoving(ResourceTracker &RT) = 0; /// A utility function for looking up initializer symbols. Performs a blocking /// lookup for the given symbols in each of the given JITDylibs. @@ -1226,12 +1226,12 @@ public: /// An ExecutionSession represents a running JIT program. class ExecutionSession { - friend class InProgressLookupFlagsState; - friend class InProgressFullLookupState; + friend class InProgressLookupFlagsState; + friend class InProgressFullLookupState; friend class JITDylib; - friend class LookupState; - friend class MaterializationResponsibility; - friend class ResourceTracker; + friend class LookupState; + friend class MaterializationResponsibility; + friend class ResourceTracker; public: /// For reporting errors. @@ -1240,16 +1240,16 @@ public: /// For dispatching MaterializationUnit::materialize calls. using DispatchMaterializationFunction = std::function<void(std::unique_ptr<MaterializationUnit> MU, - std::unique_ptr<MaterializationResponsibility> MR)>; + std::unique_ptr<MaterializationResponsibility> MR)>; /// Construct an ExecutionSession. /// /// SymbolStringPools may be shared between ExecutionSessions. ExecutionSession(std::shared_ptr<SymbolStringPool> SSP = nullptr); - /// End the session. Closes all JITDylibs. - Error endSession(); - + /// End the session. Closes all JITDylibs. + Error endSession(); + /// Add a symbol name to the SymbolStringPool and return a pointer to it. SymbolStringPtr intern(StringRef SymName) { return SSP->intern(SymName); } @@ -1269,14 +1269,14 @@ public: return F(); } - /// Register the given ResourceManager with this ExecutionSession. - /// Managers will be notified of events in reverse order of registration. - void registerResourceManager(ResourceManager &RM); - - /// Deregister the given ResourceManager with this ExecutionSession. - /// Manager must have been previously registered. - void deregisterResourceManager(ResourceManager &RM); - + /// Register the given ResourceManager with this ExecutionSession. + /// Managers will be notified of events in reverse order of registration. + void registerResourceManager(ResourceManager &RM); + + /// Deregister the given ResourceManager with this ExecutionSession. + /// Manager must have been previously registered. + void deregisterResourceManager(ResourceManager &RM); + /// Return a pointer to the "name" JITDylib. /// Ownership of JITDylib remains within Execution Session JITDylib *getJITDylibByName(StringRef Name); @@ -1320,18 +1320,18 @@ public: return *this; } - /// Search the given JITDylibs to find the flags associated with each of the - /// given symbols. - void lookupFlags(LookupKind K, JITDylibSearchOrder SearchOrder, - SymbolLookupSet Symbols, - unique_function<void(Expected<SymbolFlagsMap>)> OnComplete); + /// Search the given JITDylibs to find the flags associated with each of the + /// given symbols. + void lookupFlags(LookupKind K, JITDylibSearchOrder SearchOrder, + SymbolLookupSet Symbols, + unique_function<void(Expected<SymbolFlagsMap>)> OnComplete); - /// Blocking version of lookupFlags. - Expected<SymbolFlagsMap> lookupFlags(LookupKind K, - JITDylibSearchOrder SearchOrder, - SymbolLookupSet Symbols); + /// Blocking version of lookupFlags. + Expected<SymbolFlagsMap> lookupFlags(LookupKind K, + JITDylibSearchOrder SearchOrder, + SymbolLookupSet Symbols); - /// Search the given JITDylibs for the given symbols. + /// Search the given JITDylibs for the given symbols. /// /// SearchOrder lists the JITDylibs to search. For each dylib, the associated /// boolean indicates whether the search should match against non-exported @@ -1391,11 +1391,11 @@ public: SymbolState RequiredState = SymbolState::Ready); /// Materialize the given unit. - void - dispatchMaterialization(std::unique_ptr<MaterializationUnit> MU, - std::unique_ptr<MaterializationResponsibility> MR) { + void + dispatchMaterialization(std::unique_ptr<MaterializationUnit> MU, + std::unique_ptr<MaterializationResponsibility> MR) { assert(MU && "MU must be non-null"); - DEBUG_WITH_TYPE("orc", dumpDispatchInfo(MR->getTargetJITDylib(), *MU)); + DEBUG_WITH_TYPE("orc", dumpDispatchInfo(MR->getTargetJITDylib(), *MU)); DispatchMaterialization(std::move(MU), std::move(MR)); } @@ -1407,124 +1407,124 @@ private: logAllUnhandledErrors(std::move(Err), errs(), "JIT session error: "); } - static void materializeOnCurrentThread( - std::unique_ptr<MaterializationUnit> MU, - std::unique_ptr<MaterializationResponsibility> MR) { + static void materializeOnCurrentThread( + std::unique_ptr<MaterializationUnit> MU, + std::unique_ptr<MaterializationResponsibility> MR) { MU->materialize(std::move(MR)); } - void dispatchOutstandingMUs(); - - static std::unique_ptr<MaterializationResponsibility> - createMaterializationResponsibility(ResourceTracker &RT, - SymbolFlagsMap Symbols, - SymbolStringPtr InitSymbol) { - auto &JD = RT.getJITDylib(); - std::unique_ptr<MaterializationResponsibility> MR( - new MaterializationResponsibility(&JD, std::move(Symbols), - std::move(InitSymbol))); - JD.MRTrackers[MR.get()] = &RT; - return MR; - } - - Error removeResourceTracker(ResourceTracker &RT); - void transferResourceTracker(ResourceTracker &DstRT, ResourceTracker &SrcRT); - void destroyResourceTracker(ResourceTracker &RT); - - // State machine functions for query application.. - - /// IL_updateCandidatesFor is called to remove already-defined symbols that - /// match a given query from the set of candidate symbols to generate - /// definitions for (no need to generate a definition if one already exists). - Error IL_updateCandidatesFor(JITDylib &JD, JITDylibLookupFlags JDLookupFlags, - SymbolLookupSet &Candidates, - SymbolLookupSet *NonCandidates); - - /// OL_applyQueryPhase1 is an optionally re-startable loop for triggering - /// definition generation. It is called when a lookup is performed, and again - /// each time that LookupState::continueLookup is called. - void OL_applyQueryPhase1(std::unique_ptr<InProgressLookupState> IPLS, - Error Err); - - /// OL_completeLookup is run once phase 1 successfully completes for a lookup - /// call. It attempts to attach the symbol to all symbol table entries and - /// collect all MaterializationUnits to dispatch. If this method fails then - /// all MaterializationUnits will be left un-materialized. - void OL_completeLookup(std::unique_ptr<InProgressLookupState> IPLS, - std::shared_ptr<AsynchronousSymbolQuery> Q, - RegisterDependenciesFunction RegisterDependencies); - - /// OL_completeLookupFlags is run once phase 1 successfully completes for a - /// lookupFlags call. - void OL_completeLookupFlags( - std::unique_ptr<InProgressLookupState> IPLS, - unique_function<void(Expected<SymbolFlagsMap>)> OnComplete); - - // State machine functions for MaterializationResponsibility. - void OL_destroyMaterializationResponsibility( - MaterializationResponsibility &MR); - SymbolNameSet OL_getRequestedSymbols(const MaterializationResponsibility &MR); - Error OL_notifyResolved(MaterializationResponsibility &MR, - const SymbolMap &Symbols); - Error OL_notifyEmitted(MaterializationResponsibility &MR); - Error OL_defineMaterializing(MaterializationResponsibility &MR, - SymbolFlagsMap SymbolFlags); - void OL_notifyFailed(MaterializationResponsibility &MR); - Error OL_replace(MaterializationResponsibility &MR, - std::unique_ptr<MaterializationUnit> MU); - Expected<std::unique_ptr<MaterializationResponsibility>> - OL_delegate(MaterializationResponsibility &MR, const SymbolNameSet &Symbols); - void OL_addDependencies(MaterializationResponsibility &MR, - const SymbolStringPtr &Name, - const SymbolDependenceMap &Dependencies); - void OL_addDependenciesForAll(MaterializationResponsibility &MR, - const SymbolDependenceMap &Dependencies); - + void dispatchOutstandingMUs(); + + static std::unique_ptr<MaterializationResponsibility> + createMaterializationResponsibility(ResourceTracker &RT, + SymbolFlagsMap Symbols, + SymbolStringPtr InitSymbol) { + auto &JD = RT.getJITDylib(); + std::unique_ptr<MaterializationResponsibility> MR( + new MaterializationResponsibility(&JD, std::move(Symbols), + std::move(InitSymbol))); + JD.MRTrackers[MR.get()] = &RT; + return MR; + } + + Error removeResourceTracker(ResourceTracker &RT); + void transferResourceTracker(ResourceTracker &DstRT, ResourceTracker &SrcRT); + void destroyResourceTracker(ResourceTracker &RT); + + // State machine functions for query application.. + + /// IL_updateCandidatesFor is called to remove already-defined symbols that + /// match a given query from the set of candidate symbols to generate + /// definitions for (no need to generate a definition if one already exists). + Error IL_updateCandidatesFor(JITDylib &JD, JITDylibLookupFlags JDLookupFlags, + SymbolLookupSet &Candidates, + SymbolLookupSet *NonCandidates); + + /// OL_applyQueryPhase1 is an optionally re-startable loop for triggering + /// definition generation. It is called when a lookup is performed, and again + /// each time that LookupState::continueLookup is called. + void OL_applyQueryPhase1(std::unique_ptr<InProgressLookupState> IPLS, + Error Err); + + /// OL_completeLookup is run once phase 1 successfully completes for a lookup + /// call. It attempts to attach the symbol to all symbol table entries and + /// collect all MaterializationUnits to dispatch. If this method fails then + /// all MaterializationUnits will be left un-materialized. + void OL_completeLookup(std::unique_ptr<InProgressLookupState> IPLS, + std::shared_ptr<AsynchronousSymbolQuery> Q, + RegisterDependenciesFunction RegisterDependencies); + + /// OL_completeLookupFlags is run once phase 1 successfully completes for a + /// lookupFlags call. + void OL_completeLookupFlags( + std::unique_ptr<InProgressLookupState> IPLS, + unique_function<void(Expected<SymbolFlagsMap>)> OnComplete); + + // State machine functions for MaterializationResponsibility. + void OL_destroyMaterializationResponsibility( + MaterializationResponsibility &MR); + SymbolNameSet OL_getRequestedSymbols(const MaterializationResponsibility &MR); + Error OL_notifyResolved(MaterializationResponsibility &MR, + const SymbolMap &Symbols); + Error OL_notifyEmitted(MaterializationResponsibility &MR); + Error OL_defineMaterializing(MaterializationResponsibility &MR, + SymbolFlagsMap SymbolFlags); + void OL_notifyFailed(MaterializationResponsibility &MR); + Error OL_replace(MaterializationResponsibility &MR, + std::unique_ptr<MaterializationUnit> MU); + Expected<std::unique_ptr<MaterializationResponsibility>> + OL_delegate(MaterializationResponsibility &MR, const SymbolNameSet &Symbols); + void OL_addDependencies(MaterializationResponsibility &MR, + const SymbolStringPtr &Name, + const SymbolDependenceMap &Dependencies); + void OL_addDependenciesForAll(MaterializationResponsibility &MR, + const SymbolDependenceMap &Dependencies); + #ifndef NDEBUG void dumpDispatchInfo(JITDylib &JD, MaterializationUnit &MU); #endif // NDEBUG mutable std::recursive_mutex SessionMutex; - bool SessionOpen = true; + bool SessionOpen = true; std::shared_ptr<SymbolStringPool> SSP; std::unique_ptr<Platform> P; ErrorReporter ReportError = logErrorsToStdErr; DispatchMaterializationFunction DispatchMaterialization = materializeOnCurrentThread; - std::vector<ResourceManager *> ResourceManagers; - - std::vector<JITDylibSP> JDs; + std::vector<ResourceManager *> ResourceManagers; + std::vector<JITDylibSP> JDs; + // FIXME: Remove this (and runOutstandingMUs) once the linking layer works // with callbacks from asynchronous queries. mutable std::recursive_mutex OutstandingMUsMutex; std::vector<std::pair<std::unique_ptr<MaterializationUnit>, - std::unique_ptr<MaterializationResponsibility>>> + std::unique_ptr<MaterializationResponsibility>>> OutstandingMUs; }; -inline ExecutionSession &MaterializationResponsibility::getExecutionSession() { - return JD->getExecutionSession(); -} - -template <typename Func> -Error MaterializationResponsibility::withResourceKeyDo(Func &&F) const { - return JD->getExecutionSession().runSessionLocked([&]() -> Error { - auto I = JD->MRTrackers.find(this); - assert(I != JD->MRTrackers.end() && "No tracker for this MR"); - if (I->second->isDefunct()) - return make_error<ResourceTrackerDefunct>(I->second); - F(I->second->getKeyUnsafe()); - return Error::success(); - }); -} - +inline ExecutionSession &MaterializationResponsibility::getExecutionSession() { + return JD->getExecutionSession(); +} + +template <typename Func> +Error MaterializationResponsibility::withResourceKeyDo(Func &&F) const { + return JD->getExecutionSession().runSessionLocked([&]() -> Error { + auto I = JD->MRTrackers.find(this); + assert(I != JD->MRTrackers.end() && "No tracker for this MR"); + if (I->second->isDefunct()) + return make_error<ResourceTrackerDefunct>(I->second); + F(I->second->getKeyUnsafe()); + return Error::success(); + }); +} + template <typename GeneratorT> GeneratorT &JITDylib::addGenerator(std::unique_ptr<GeneratorT> DefGenerator) { auto &G = *DefGenerator; - std::lock_guard<std::mutex> Lock(GeneratorsMutex); - DefGenerators.push_back(std::move(DefGenerator)); + std::lock_guard<std::mutex> Lock(GeneratorsMutex); + DefGenerators.push_back(std::move(DefGenerator)); return G; } @@ -1535,8 +1535,8 @@ auto JITDylib::withLinkOrderDo(Func &&F) } template <typename MaterializationUnitType> -Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &&MU, - ResourceTrackerSP RT) { +Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &&MU, + ResourceTrackerSP RT) { assert(MU && "Can not define with a null MU"); if (MU->getSymbols().empty()) { @@ -1548,36 +1548,36 @@ Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &&MU, return Error::success(); } else DEBUG_WITH_TYPE("orc", { - dbgs() << "Defining MU " << MU->getName() << " for " << getName() - << " (tracker: "; - if (RT == getDefaultResourceTracker()) - dbgs() << "default)"; - else if (RT) - dbgs() << RT.get() << ")\n"; - else - dbgs() << "0x0, default will be used)\n"; + dbgs() << "Defining MU " << MU->getName() << " for " << getName() + << " (tracker: "; + if (RT == getDefaultResourceTracker()) + dbgs() << "default)"; + else if (RT) + dbgs() << RT.get() << ")\n"; + else + dbgs() << "0x0, default will be used)\n"; }); return ES.runSessionLocked([&, this]() -> Error { if (auto Err = defineImpl(*MU)) return Err; - if (!RT) - RT = getDefaultResourceTracker(); - + if (!RT) + RT = getDefaultResourceTracker(); + if (auto *P = ES.getPlatform()) { - if (auto Err = P->notifyAdding(*RT, *MU)) + if (auto Err = P->notifyAdding(*RT, *MU)) return Err; } - installMaterializationUnit(std::move(MU), *RT); + installMaterializationUnit(std::move(MU), *RT); return Error::success(); }); } template <typename MaterializationUnitType> -Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &MU, - ResourceTrackerSP RT) { +Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &MU, + ResourceTrackerSP RT) { assert(MU && "Can not define with a null MU"); if (MU->getSymbols().empty()) { @@ -1589,36 +1589,36 @@ Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &MU, return Error::success(); } else DEBUG_WITH_TYPE("orc", { - dbgs() << "Defining MU " << MU->getName() << " for " << getName() - << " (tracker: "; - if (RT == getDefaultResourceTracker()) - dbgs() << "default)"; - else if (RT) - dbgs() << RT.get() << ")\n"; - else - dbgs() << "0x0, default will be used)\n"; + dbgs() << "Defining MU " << MU->getName() << " for " << getName() + << " (tracker: "; + if (RT == getDefaultResourceTracker()) + dbgs() << "default)"; + else if (RT) + dbgs() << RT.get() << ")\n"; + else + dbgs() << "0x0, default will be used)\n"; }); return ES.runSessionLocked([&, this]() -> Error { if (auto Err = defineImpl(*MU)) return Err; - if (!RT) - RT = getDefaultResourceTracker(); - + if (!RT) + RT = getDefaultResourceTracker(); + if (auto *P = ES.getPlatform()) { - if (auto Err = P->notifyAdding(*RT, *MU)) + if (auto Err = P->notifyAdding(*RT, *MU)) return Err; } - installMaterializationUnit(std::move(MU), *RT); + installMaterializationUnit(std::move(MU), *RT); return Error::success(); }); } /// ReexportsGenerator can be used with JITDylib::addGenerator to automatically /// re-export a subset of the source JITDylib's symbols in the target. -class ReexportsGenerator : public DefinitionGenerator { +class ReexportsGenerator : public DefinitionGenerator { public: using SymbolPredicate = std::function<bool(SymbolStringPtr)>; @@ -1629,7 +1629,7 @@ public: JITDylibLookupFlags SourceJDLookupFlags, SymbolPredicate Allow = SymbolPredicate()); - Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, + Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, JITDylibLookupFlags JDLookupFlags, const SymbolLookupSet &LookupSet) override; @@ -1639,57 +1639,57 @@ private: SymbolPredicate Allow; }; -// --------------- IMPLEMENTATION -------------- -// Implementations for inline functions/methods. -// --------------------------------------------- - -inline MaterializationResponsibility::~MaterializationResponsibility() { - JD->getExecutionSession().OL_destroyMaterializationResponsibility(*this); -} - -inline SymbolNameSet MaterializationResponsibility::getRequestedSymbols() const { - return JD->getExecutionSession().OL_getRequestedSymbols(*this); -} - -inline Error MaterializationResponsibility::notifyResolved( - const SymbolMap &Symbols) { - return JD->getExecutionSession().OL_notifyResolved(*this, Symbols); -} - -inline Error MaterializationResponsibility::notifyEmitted() { - return JD->getExecutionSession().OL_notifyEmitted(*this); -} - -inline Error MaterializationResponsibility::defineMaterializing( - SymbolFlagsMap SymbolFlags) { - return JD->getExecutionSession().OL_defineMaterializing( - *this, std::move(SymbolFlags)); -} - -inline void MaterializationResponsibility::failMaterialization() { - JD->getExecutionSession().OL_notifyFailed(*this); -} - -inline Error MaterializationResponsibility::replace( - std::unique_ptr<MaterializationUnit> MU) { - return JD->getExecutionSession().OL_replace(*this, std::move(MU)); -} - -inline Expected<std::unique_ptr<MaterializationResponsibility>> -MaterializationResponsibility::delegate(const SymbolNameSet &Symbols) { - return JD->getExecutionSession().OL_delegate(*this, Symbols); -} - -inline void MaterializationResponsibility::addDependencies( - const SymbolStringPtr &Name, const SymbolDependenceMap &Dependencies) { - JD->getExecutionSession().OL_addDependencies(*this, Name, Dependencies); -} - -inline void MaterializationResponsibility::addDependenciesForAll( - const SymbolDependenceMap &Dependencies) { - JD->getExecutionSession().OL_addDependenciesForAll(*this, Dependencies); -} - +// --------------- IMPLEMENTATION -------------- +// Implementations for inline functions/methods. +// --------------------------------------------- + +inline MaterializationResponsibility::~MaterializationResponsibility() { + JD->getExecutionSession().OL_destroyMaterializationResponsibility(*this); +} + +inline SymbolNameSet MaterializationResponsibility::getRequestedSymbols() const { + return JD->getExecutionSession().OL_getRequestedSymbols(*this); +} + +inline Error MaterializationResponsibility::notifyResolved( + const SymbolMap &Symbols) { + return JD->getExecutionSession().OL_notifyResolved(*this, Symbols); +} + +inline Error MaterializationResponsibility::notifyEmitted() { + return JD->getExecutionSession().OL_notifyEmitted(*this); +} + +inline Error MaterializationResponsibility::defineMaterializing( + SymbolFlagsMap SymbolFlags) { + return JD->getExecutionSession().OL_defineMaterializing( + *this, std::move(SymbolFlags)); +} + +inline void MaterializationResponsibility::failMaterialization() { + JD->getExecutionSession().OL_notifyFailed(*this); +} + +inline Error MaterializationResponsibility::replace( + std::unique_ptr<MaterializationUnit> MU) { + return JD->getExecutionSession().OL_replace(*this, std::move(MU)); +} + +inline Expected<std::unique_ptr<MaterializationResponsibility>> +MaterializationResponsibility::delegate(const SymbolNameSet &Symbols) { + return JD->getExecutionSession().OL_delegate(*this, Symbols); +} + +inline void MaterializationResponsibility::addDependencies( + const SymbolStringPtr &Name, const SymbolDependenceMap &Dependencies) { + JD->getExecutionSession().OL_addDependencies(*this, Name, Dependencies); +} + +inline void MaterializationResponsibility::addDependenciesForAll( + const SymbolDependenceMap &Dependencies) { + JD->getExecutionSession().OL_addDependenciesForAll(*this, Dependencies); +} + } // End namespace orc } // End namespace llvm diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h index 1c94201394..334677e783 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h @@ -25,7 +25,7 @@ #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/Object/Archive.h" #include "llvm/Support/DynamicLibrary.h" @@ -222,7 +222,7 @@ private: /// If an instance of this class is attached to a JITDylib as a fallback /// definition generator, then any symbol found in the given DynamicLibrary that /// passes the 'Allow' predicate will be added to the JITDylib. -class DynamicLibrarySearchGenerator : public DefinitionGenerator { +class DynamicLibrarySearchGenerator : public DefinitionGenerator { public: using SymbolPredicate = std::function<bool(const SymbolStringPtr &)>; @@ -250,7 +250,7 @@ public: return Load(nullptr, GlobalPrefix, std::move(Allow)); } - Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, + Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, JITDylibLookupFlags JDLookupFlags, const SymbolLookupSet &Symbols) override; @@ -265,7 +265,7 @@ private: /// If an instance of this class is attached to a JITDylib as a fallback /// definition generator, then any symbol found in the archive will result in /// the containing object being added to the JITDylib. -class StaticLibraryDefinitionGenerator : public DefinitionGenerator { +class StaticLibraryDefinitionGenerator : public DefinitionGenerator { public: /// Try to create a StaticLibraryDefinitionGenerator from the given path. /// @@ -288,7 +288,7 @@ public: static Expected<std::unique_ptr<StaticLibraryDefinitionGenerator>> Create(ObjectLayer &L, std::unique_ptr<MemoryBuffer> ArchiveBuffer); - Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, + Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, JITDylibLookupFlags JDLookupFlags, const SymbolLookupSet &Symbols) override; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h index fd729bf7e5..b4c4565f5b 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h @@ -52,8 +52,8 @@ public: IRSymbolMapper::ManglingOptions MO; }; - using NotifyCompiledFunction = std::function<void( - MaterializationResponsibility &R, ThreadSafeModule TSM)>; + using NotifyCompiledFunction = std::function<void( + MaterializationResponsibility &R, ThreadSafeModule TSM)>; IRCompileLayer(ExecutionSession &ES, ObjectLayer &BaseLayer, std::unique_ptr<IRCompiler> Compile); @@ -62,8 +62,8 @@ public: void setNotifyCompiled(NotifyCompiledFunction NotifyCompiled); - void emit(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM) override; + void emit(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM) override; private: mutable std::mutex IRLayerMutex; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h index 7ea7ee2a62..03a35bb841 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h @@ -20,7 +20,7 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_IRTRANSFORMLAYER_H #define LLVM_EXECUTIONENGINE_ORC_IRTRANSFORMLAYER_H -#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/Layer.h" #include <memory> @@ -35,7 +35,7 @@ namespace orc { /// before operating on the module. class IRTransformLayer : public IRLayer { public: - using TransformFunction = unique_function<Expected<ThreadSafeModule>( + using TransformFunction = unique_function<Expected<ThreadSafeModule>( ThreadSafeModule, MaterializationResponsibility &R)>; IRTransformLayer(ExecutionSession &ES, IRLayer &BaseLayer, @@ -45,8 +45,8 @@ public: this->Transform = std::move(Transform); } - void emit(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM) override; + void emit(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM) override; static ThreadSafeModule identityTransform(ThreadSafeModule TSM, MaterializationResponsibility &R) { diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h index 1ba1a2945d..f262d603f4 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h @@ -69,33 +69,33 @@ public: JITTargetAddress TrampolineAddr, NotifyLandingResolvedFunction OnLandingResolved) const>; - virtual ~TrampolinePool(); + virtual ~TrampolinePool(); /// Get an available trampoline address. /// Returns an error if no trampoline can be created. - Expected<JITTargetAddress> getTrampoline() { - std::lock_guard<std::mutex> Lock(TPMutex); - if (AvailableTrampolines.empty()) { - if (auto Err = grow()) - return std::move(Err); - } - assert(!AvailableTrampolines.empty() && "Failed to grow trampoline pool"); - auto TrampolineAddr = AvailableTrampolines.back(); - AvailableTrampolines.pop_back(); - return TrampolineAddr; - } - - /// Returns the given trampoline to the pool for re-use. - void releaseTrampoline(JITTargetAddress TrampolineAddr) { - std::lock_guard<std::mutex> Lock(TPMutex); - AvailableTrampolines.push_back(TrampolineAddr); - } - -protected: - virtual Error grow() = 0; - - std::mutex TPMutex; - std::vector<JITTargetAddress> AvailableTrampolines; + Expected<JITTargetAddress> getTrampoline() { + std::lock_guard<std::mutex> Lock(TPMutex); + if (AvailableTrampolines.empty()) { + if (auto Err = grow()) + return std::move(Err); + } + assert(!AvailableTrampolines.empty() && "Failed to grow trampoline pool"); + auto TrampolineAddr = AvailableTrampolines.back(); + AvailableTrampolines.pop_back(); + return TrampolineAddr; + } + + /// Returns the given trampoline to the pool for re-use. + void releaseTrampoline(JITTargetAddress TrampolineAddr) { + std::lock_guard<std::mutex> Lock(TPMutex); + AvailableTrampolines.push_back(TrampolineAddr); + } + +protected: + virtual Error grow() = 0; + + std::mutex TPMutex; + std::vector<JITTargetAddress> AvailableTrampolines; }; /// A trampoline pool for trampolines within the current process. @@ -160,8 +160,8 @@ private: } } - Error grow() override { - assert(AvailableTrampolines.empty() && "Growing prematurely?"); + Error grow() override { + assert(AvailableTrampolines.empty() && "Growing prematurely?"); std::error_code EC; auto TrampolineBlock = @@ -181,7 +181,7 @@ private: pointerToJITTargetAddress(ResolverBlock.base()), NumTrampolines); for (unsigned I = 0; I < NumTrampolines; ++I) - AvailableTrampolines.push_back(pointerToJITTargetAddress( + AvailableTrampolines.push_back(pointerToJITTargetAddress( TrampolineMem + (I * ORCABI::TrampolineSize))); if (auto EC = sys::Memory::protectMappedMemory( diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LLJIT.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LLJIT.h index d27b2d3036..565340c68e 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LLJIT.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LLJIT.h @@ -35,8 +35,8 @@ namespace orc { class LLJITBuilderState; class LLLazyJITBuilderState; -class ObjectTransformLayer; -class TargetProcessControl; +class ObjectTransformLayer; +class TargetProcessControl; /// A pre-fabricated ORC JIT stack that can serve as an alternative to MCJIT. /// @@ -93,8 +93,8 @@ public: return ES->createJITDylib(std::move(Name)); } - /// Adds an IR module with the given ResourceTracker. - Error addIRModule(ResourceTrackerSP RT, ThreadSafeModule TSM); + /// Adds an IR module with the given ResourceTracker. + Error addIRModule(ResourceTrackerSP RT, ThreadSafeModule TSM); /// Adds an IR module to the given JITDylib. Error addIRModule(JITDylib &JD, ThreadSafeModule TSM); @@ -105,9 +105,9 @@ public: } /// Adds an object file to the given JITDylib. - Error addObjectFile(ResourceTrackerSP RT, std::unique_ptr<MemoryBuffer> Obj); - - /// Adds an object file to the given JITDylib. + Error addObjectFile(ResourceTrackerSP RT, std::unique_ptr<MemoryBuffer> Obj); + + /// Adds an object file to the given JITDylib. Error addObjectFile(JITDylib &JD, std::unique_ptr<MemoryBuffer> Obj); /// Adds an object file to the given JITDylib. @@ -176,7 +176,7 @@ public: ObjectLayer &getObjLinkingLayer() { return *ObjLinkingLayer; } /// Returns a reference to the object transform layer. - ObjectTransformLayer &getObjTransformLayer() { return *ObjTransformLayer; } + ObjectTransformLayer &getObjTransformLayer() { return *ObjTransformLayer; } /// Returns a reference to the IR transform layer. IRTransformLayer &getIRTransformLayer() { return *TransformLayer; } @@ -193,7 +193,7 @@ public: } protected: - static Expected<std::unique_ptr<ObjectLayer>> + static Expected<std::unique_ptr<ObjectLayer>> createObjectLinkingLayer(LLJITBuilderState &S, ExecutionSession &ES); static Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> @@ -216,7 +216,7 @@ protected: std::unique_ptr<ThreadPool> CompileThreads; std::unique_ptr<ObjectLayer> ObjLinkingLayer; - std::unique_ptr<ObjectTransformLayer> ObjTransformLayer; + std::unique_ptr<ObjectTransformLayer> ObjTransformLayer; std::unique_ptr<IRCompileLayer> CompileLayer; std::unique_ptr<IRTransformLayer> TransformLayer; std::unique_ptr<IRTransformLayer> InitHelperTransformLayer; @@ -235,9 +235,9 @@ public: CODLayer->setPartitionFunction(std::move(Partition)); } - /// Returns a reference to the on-demand layer. - CompileOnDemandLayer &getCompileOnDemandLayer() { return *CODLayer; } - + /// Returns a reference to the on-demand layer. + CompileOnDemandLayer &getCompileOnDemandLayer() { return *CODLayer; } + /// Add a module to be lazily compiled to JITDylib JD. Error addLazyIRModule(JITDylib &JD, ThreadSafeModule M); @@ -257,9 +257,9 @@ private: class LLJITBuilderState { public: - using ObjectLinkingLayerCreator = - std::function<Expected<std::unique_ptr<ObjectLayer>>(ExecutionSession &, - const Triple &)>; + using ObjectLinkingLayerCreator = + std::function<Expected<std::unique_ptr<ObjectLayer>>(ExecutionSession &, + const Triple &)>; using CompileFunctionCreator = std::function<Expected<std::unique_ptr<IRCompileLayer::IRCompiler>>( @@ -274,7 +274,7 @@ public: CompileFunctionCreator CreateCompileFunction; PlatformSetupFunction SetUpPlatform; unsigned NumCompileThreads = 0; - TargetProcessControl *TPC = nullptr; + TargetProcessControl *TPC = nullptr; /// Called prior to JIT class construcion to fix up defaults. Error prepareForConstruction(); @@ -357,17 +357,17 @@ public: return impl(); } - /// Set a TargetProcessControl object. - /// - /// If the platform uses ObjectLinkingLayer by default and no - /// ObjectLinkingLayerCreator has been set then the TargetProcessControl - /// object will be used to supply the memory manager for the - /// ObjectLinkingLayer. - SetterImpl &setTargetProcessControl(TargetProcessControl &TPC) { - impl().TPC = &TPC; - return impl(); - } - + /// Set a TargetProcessControl object. + /// + /// If the platform uses ObjectLinkingLayer by default and no + /// ObjectLinkingLayerCreator has been set then the TargetProcessControl + /// object will be used to supply the memory manager for the + /// ObjectLinkingLayer. + SetterImpl &setTargetProcessControl(TargetProcessControl &TPC) { + impl().TPC = &TPC; + return impl(); + } + /// Create an instance of the JIT. Expected<std::unique_ptr<JITType>> create() { if (auto Err = impl().prepareForConstruction()) diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Layer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Layer.h index 3d315e52f4..7fe794009d 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Layer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Layer.h @@ -41,15 +41,15 @@ public: /// SymbolFlags and SymbolToDefinition maps. IRMaterializationUnit(ExecutionSession &ES, const IRSymbolMapper::ManglingOptions &MO, - ThreadSafeModule TSM); + ThreadSafeModule TSM); /// Create an IRMaterializationLayer from a module, and pre-existing /// SymbolFlags and SymbolToDefinition maps. The maps must provide /// entries for each definition in M. /// This constructor is useful for delegating work from one /// IRMaterializationUnit to another. - IRMaterializationUnit(ThreadSafeModule TSM, SymbolFlagsMap SymbolFlags, - SymbolStringPtr InitSymbol, + IRMaterializationUnit(ThreadSafeModule TSM, SymbolFlagsMap SymbolFlags, + SymbolStringPtr InitSymbol, SymbolNameToDefinitionMap SymbolToDefinition); /// Return the ModuleIdentifier as the name for this MaterializationUnit. @@ -101,19 +101,19 @@ public: /// Returns the current value of the CloneToNewContextOnEmit flag. bool getCloneToNewContextOnEmit() const { return CloneToNewContextOnEmit; } - /// Add a MaterializatinoUnit representing the given IR to the JITDylib - /// targeted by the given tracker. - virtual Error add(ResourceTrackerSP RT, ThreadSafeModule TSM); - + /// Add a MaterializatinoUnit representing the given IR to the JITDylib + /// targeted by the given tracker. + virtual Error add(ResourceTrackerSP RT, ThreadSafeModule TSM); + /// Adds a MaterializationUnit representing the given IR to the given - /// JITDylib. If RT is not specif - Error add(JITDylib &JD, ThreadSafeModule TSM) { - return add(JD.getDefaultResourceTracker(), std::move(TSM)); - } + /// JITDylib. If RT is not specif + Error add(JITDylib &JD, ThreadSafeModule TSM) { + return add(JD.getDefaultResourceTracker(), std::move(TSM)); + } /// Emit should materialize the given IR. - virtual void emit(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM) = 0; + virtual void emit(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM) = 0; private: bool CloneToNewContextOnEmit = false; @@ -127,10 +127,10 @@ class BasicIRLayerMaterializationUnit : public IRMaterializationUnit { public: BasicIRLayerMaterializationUnit(IRLayer &L, const IRSymbolMapper::ManglingOptions &MO, - ThreadSafeModule TSM); + ThreadSafeModule TSM); private: - void materialize(std::unique_ptr<MaterializationResponsibility> R) override; + void materialize(std::unique_ptr<MaterializationResponsibility> R) override; IRLayer &L; }; @@ -146,14 +146,14 @@ public: /// Adds a MaterializationUnit representing the given IR to the given /// JITDylib. - virtual Error add(ResourceTrackerSP RT, std::unique_ptr<MemoryBuffer> O); - - Error add(JITDylib &JD, std::unique_ptr<MemoryBuffer> O) { - return add(JD.getDefaultResourceTracker(), std::move(O)); - } + virtual Error add(ResourceTrackerSP RT, std::unique_ptr<MemoryBuffer> O); + Error add(JITDylib &JD, std::unique_ptr<MemoryBuffer> O) { + return add(JD.getDefaultResourceTracker(), std::move(O)); + } + /// Emit should materialize the given IR. - virtual void emit(std::unique_ptr<MaterializationResponsibility> R, + virtual void emit(std::unique_ptr<MaterializationResponsibility> R, std::unique_ptr<MemoryBuffer> O) = 0; private: @@ -165,9 +165,9 @@ private: class BasicObjectLayerMaterializationUnit : public MaterializationUnit { public: static Expected<std::unique_ptr<BasicObjectLayerMaterializationUnit>> - Create(ObjectLayer &L, std::unique_ptr<MemoryBuffer> O); + Create(ObjectLayer &L, std::unique_ptr<MemoryBuffer> O); - BasicObjectLayerMaterializationUnit(ObjectLayer &L, + BasicObjectLayerMaterializationUnit(ObjectLayer &L, std::unique_ptr<MemoryBuffer> O, SymbolFlagsMap SymbolFlags, SymbolStringPtr InitSymbol); @@ -176,7 +176,7 @@ public: StringRef getName() const override; private: - void materialize(std::unique_ptr<MaterializationResponsibility> R) override; + void materialize(std::unique_ptr<MaterializationResponsibility> R) override; void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; ObjectLayer &L; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LazyReexports.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LazyReexports.h index 71b5831d32..4f4d089463 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LazyReexports.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/LazyReexports.h @@ -47,9 +47,9 @@ public: using NotifyResolvedFunction = unique_function<Error(JITTargetAddress ResolvedAddr)>; - LazyCallThroughManager(ExecutionSession &ES, - JITTargetAddress ErrorHandlerAddr, TrampolinePool *TP); - + LazyCallThroughManager(ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddr, TrampolinePool *TP); + // Return a free call-through trampoline and bind it to look up and call // through to the given symbol. Expected<JITTargetAddress> @@ -151,12 +151,12 @@ public: IndirectStubsManager &ISManager, JITDylib &SourceJD, SymbolAliasMap CallableAliases, - ImplSymbolMap *SrcJDLoc); + ImplSymbolMap *SrcJDLoc); StringRef getName() const override; private: - void materialize(std::unique_ptr<MaterializationResponsibility> R) override; + void materialize(std::unique_ptr<MaterializationResponsibility> R) override; void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; static SymbolFlagsMap extractFlags(const SymbolAliasMap &Aliases); @@ -173,10 +173,10 @@ private: inline std::unique_ptr<LazyReexportsMaterializationUnit> lazyReexports(LazyCallThroughManager &LCTManager, IndirectStubsManager &ISManager, JITDylib &SourceJD, - SymbolAliasMap CallableAliases, - ImplSymbolMap *SrcJDLoc = nullptr) { + SymbolAliasMap CallableAliases, + ImplSymbolMap *SrcJDLoc = nullptr) { return std::make_unique<LazyReexportsMaterializationUnit>( - LCTManager, ISManager, SourceJD, std::move(CallableAliases), SrcJDLoc); + LCTManager, ISManager, SourceJD, std::move(CallableAliases), SrcJDLoc); } } // End namespace orc diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/MachOPlatform.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/MachOPlatform.h index 6e83f86621..997a0aca6e 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/MachOPlatform.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/MachOPlatform.h @@ -105,9 +105,9 @@ public: ExecutionSession &getExecutionSession() const { return ES; } Error setupJITDylib(JITDylib &JD) override; - Error notifyAdding(ResourceTracker &RT, - const MaterializationUnit &MU) override; - Error notifyRemoving(ResourceTracker &RT) override; + Error notifyAdding(ResourceTracker &RT, + const MaterializationUnit &MU) override; + Error notifyRemoving(ResourceTracker &RT) override; Expected<InitializerSequence> getInitializerSequence(JITDylib &JD); @@ -127,19 +127,19 @@ private: LocalDependenciesMap getSyntheticSymbolLocalDependencies( MaterializationResponsibility &MR) override; - // FIXME: We should be tentatively tracking scraped sections and discarding - // if the MR fails. - Error notifyFailed(MaterializationResponsibility &MR) override { - return Error::success(); - } - - Error notifyRemovingResources(ResourceKey K) override { - return Error::success(); - } - - void notifyTransferringResources(ResourceKey DstKey, - ResourceKey SrcKey) override {} - + // FIXME: We should be tentatively tracking scraped sections and discarding + // if the MR fails. + Error notifyFailed(MaterializationResponsibility &MR) override { + return Error::success(); + } + + Error notifyRemovingResources(ResourceKey K) override { + return Error::success(); + } + + void notifyTransferringResources(ResourceKey DstKey, + ResourceKey SrcKey) override {} + private: using InitSymbolDepMap = DenseMap<MaterializationResponsibility *, JITLinkSymbolVector>; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h index f9aa582b31..0c16ece95a 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h @@ -42,7 +42,7 @@ namespace llvm { namespace jitlink { class EHFrameRegistrar; -class LinkGraph; +class LinkGraph; class Symbol; } // namespace jitlink @@ -59,7 +59,7 @@ class ObjectLinkingLayerJITLinkContext; /// Clients can use this class to add relocatable object files to an /// ExecutionSession, and it typically serves as the base layer (underneath /// a compiling layer like IRCompileLayer) for the rest of the JIT. -class ObjectLinkingLayer : public ObjectLayer, private ResourceManager { +class ObjectLinkingLayer : public ObjectLayer, private ResourceManager { friend class ObjectLinkingLayerJITLinkContext; public: @@ -80,10 +80,10 @@ public: virtual Error notifyEmitted(MaterializationResponsibility &MR) { return Error::success(); } - virtual Error notifyFailed(MaterializationResponsibility &MR) = 0; - virtual Error notifyRemovingResources(ResourceKey K) = 0; - virtual void notifyTransferringResources(ResourceKey DstKey, - ResourceKey SrcKey) = 0; + virtual Error notifyFailed(MaterializationResponsibility &MR) = 0; + virtual Error notifyRemovingResources(ResourceKey K) = 0; + virtual void notifyTransferringResources(ResourceKey DstKey, + ResourceKey SrcKey) = 0; /// Return any dependencies that synthetic symbols (e.g. init symbols) /// have on locally scoped jitlink::Symbols. This is used by the @@ -98,15 +98,15 @@ public: using ReturnObjectBufferFunction = std::function<void(std::unique_ptr<MemoryBuffer>)>; - /// Construct an ObjectLinkingLayer. - ObjectLinkingLayer(ExecutionSession &ES, - jitlink::JITLinkMemoryManager &MemMgr); - - /// Construct an ObjectLinkingLayer. Takes ownership of the given - /// JITLinkMemoryManager. This method is a temporary hack to simplify - /// co-existence with RTDyldObjectLinkingLayer (which also owns its - /// allocators). + /// Construct an ObjectLinkingLayer. ObjectLinkingLayer(ExecutionSession &ES, + jitlink::JITLinkMemoryManager &MemMgr); + + /// Construct an ObjectLinkingLayer. Takes ownership of the given + /// JITLinkMemoryManager. This method is a temporary hack to simplify + /// co-existence with RTDyldObjectLinkingLayer (which also owns its + /// allocators). + ObjectLinkingLayer(ExecutionSession &ES, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr); /// Destruct an ObjectLinkingLayer. @@ -126,14 +126,14 @@ public: return *this; } - /// Emit an object file. - void emit(std::unique_ptr<MaterializationResponsibility> R, + /// Emit an object file. + void emit(std::unique_ptr<MaterializationResponsibility> R, std::unique_ptr<MemoryBuffer> O) override; - /// Emit a LinkGraph. - void emit(std::unique_ptr<MaterializationResponsibility> R, - std::unique_ptr<jitlink::LinkGraph> G); - + /// Emit a LinkGraph. + void emit(std::unique_ptr<MaterializationResponsibility> R, + std::unique_ptr<jitlink::LinkGraph> G); + /// Instructs this ObjectLinkingLayer instance to override the symbol flags /// found in the AtomGraph with the flags supplied by the /// MaterializationResponsibility instance. This is a workaround to support @@ -173,31 +173,31 @@ private: void notifyLoaded(MaterializationResponsibility &MR); Error notifyEmitted(MaterializationResponsibility &MR, AllocPtr Alloc); - Error handleRemoveResources(ResourceKey K) override; - void handleTransferResources(ResourceKey DstKey, ResourceKey SrcKey) override; + Error handleRemoveResources(ResourceKey K) override; + void handleTransferResources(ResourceKey DstKey, ResourceKey SrcKey) override; mutable std::mutex LayerMutex; - jitlink::JITLinkMemoryManager &MemMgr; - std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgrOwnership; + jitlink::JITLinkMemoryManager &MemMgr; + std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgrOwnership; bool OverrideObjectFlags = false; bool AutoClaimObjectSymbols = false; ReturnObjectBufferFunction ReturnObjectBuffer; - DenseMap<ResourceKey, std::vector<AllocPtr>> Allocs; + DenseMap<ResourceKey, std::vector<AllocPtr>> Allocs; std::vector<std::unique_ptr<Plugin>> Plugins; }; class EHFrameRegistrationPlugin : public ObjectLinkingLayer::Plugin { public: - EHFrameRegistrationPlugin( - ExecutionSession &ES, - std::unique_ptr<jitlink::EHFrameRegistrar> Registrar); + EHFrameRegistrationPlugin( + ExecutionSession &ES, + std::unique_ptr<jitlink::EHFrameRegistrar> Registrar); void modifyPassConfig(MaterializationResponsibility &MR, const Triple &TT, jitlink::PassConfiguration &PassConfig) override; - Error notifyEmitted(MaterializationResponsibility &MR) override; - Error notifyFailed(MaterializationResponsibility &MR) override; - Error notifyRemovingResources(ResourceKey K) override; - void notifyTransferringResources(ResourceKey DstKey, - ResourceKey SrcKey) override; + Error notifyEmitted(MaterializationResponsibility &MR) override; + Error notifyFailed(MaterializationResponsibility &MR) override; + Error notifyRemovingResources(ResourceKey K) override; + void notifyTransferringResources(ResourceKey DstKey, + ResourceKey SrcKey) override; private: @@ -207,10 +207,10 @@ private: }; std::mutex EHFramePluginMutex; - ExecutionSession &ES; - std::unique_ptr<jitlink::EHFrameRegistrar> Registrar; + ExecutionSession &ES; + std::unique_ptr<jitlink::EHFrameRegistrar> Registrar; DenseMap<MaterializationResponsibility *, EHFrameRange> InProcessLinks; - DenseMap<ResourceKey, std::vector<EHFrameRange>> EHFrameRanges; + DenseMap<ResourceKey, std::vector<EHFrameRange>> EHFrameRanges; }; } // end namespace orc diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h index d37ff22757..e1ec55031b 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h @@ -38,7 +38,7 @@ public: ObjectTransformLayer(ExecutionSession &ES, ObjectLayer &BaseLayer, TransformFunction Transform = TransformFunction()); - void emit(std::unique_ptr<MaterializationResponsibility> R, + void emit(std::unique_ptr<MaterializationResponsibility> R, std::unique_ptr<MemoryBuffer> O) override; void setTransform(TransformFunction Transform) { diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h index 1b6578db1c..80142adcdf 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h @@ -1,426 +1,426 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===--- OrcRPCTargetProcessControl.h - Remote target control ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for interacting with target processes. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H -#define LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H - -#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" -#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" -#include "llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h" -#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" -#include "llvm/Support/MSVCErrorWorkarounds.h" - -namespace llvm { -namespace orc { - -/// JITLinkMemoryManager implementation for a process connected via an ORC RPC -/// endpoint. -template <typename OrcRPCTPCImplT> -class OrcRPCTPCJITLinkMemoryManager : public jitlink::JITLinkMemoryManager { -private: - struct HostAlloc { - std::unique_ptr<char[]> Mem; - uint64_t Size; - }; - - struct TargetAlloc { - JITTargetAddress Address = 0; - uint64_t AllocatedSize = 0; - }; - - using HostAllocMap = DenseMap<int, HostAlloc>; - using TargetAllocMap = DenseMap<int, TargetAlloc>; - -public: - class OrcRPCAllocation : public Allocation { - public: - OrcRPCAllocation(OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent, - HostAllocMap HostAllocs, TargetAllocMap TargetAllocs) - : Parent(Parent), HostAllocs(std::move(HostAllocs)), - TargetAllocs(std::move(TargetAllocs)) { - assert(HostAllocs.size() == TargetAllocs.size() && - "HostAllocs size should match TargetAllocs"); - } - - ~OrcRPCAllocation() override { - assert(TargetAllocs.empty() && "failed to deallocate"); - } - - MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override { - auto I = HostAllocs.find(Seg); - assert(I != HostAllocs.end() && "No host allocation for segment"); - auto &HA = I->second; - return {HA.Mem.get(), static_cast<size_t>(HA.Size)}; - } - - JITTargetAddress getTargetMemory(ProtectionFlags Seg) override { - auto I = TargetAllocs.find(Seg); - assert(I != TargetAllocs.end() && "No target allocation for segment"); - return I->second.Address; - } - - void finalizeAsync(FinalizeContinuation OnFinalize) override { - - std::vector<tpctypes::BufferWrite> BufferWrites; - orcrpctpc::ReleaseOrFinalizeMemRequest FMR; - - for (auto &KV : HostAllocs) { - assert(TargetAllocs.count(KV.first) && - "No target allocation for buffer"); - auto &HA = KV.second; - auto &TA = TargetAllocs[KV.first]; - BufferWrites.push_back({TA.Address, StringRef(HA.Mem.get(), HA.Size)}); - FMR.push_back({orcrpctpc::toWireProtectionFlags( - static_cast<sys::Memory::ProtectionFlags>(KV.first)), - TA.Address, TA.AllocatedSize}); - } - - DEBUG_WITH_TYPE("orc", { - dbgs() << "finalizeAsync " << (void *)this << ":\n"; - auto FMRI = FMR.begin(); - for (auto &B : BufferWrites) { - auto Prot = FMRI->Prot; - ++FMRI; - dbgs() << " Writing " << formatv("{0:x16}", B.Buffer.size()) - << " bytes to " << ((Prot & orcrpctpc::WPF_Read) ? 'R' : '-') - << ((Prot & orcrpctpc::WPF_Write) ? 'W' : '-') - << ((Prot & orcrpctpc::WPF_Exec) ? 'X' : '-') - << " segment: local " << (const void *)B.Buffer.data() - << " -> target " << formatv("{0:x16}", B.Address) << "\n"; - } - }); - if (auto Err = - Parent.Parent.getMemoryAccess().writeBuffers(BufferWrites)) { - OnFinalize(std::move(Err)); - return; - } - - DEBUG_WITH_TYPE("orc", dbgs() << " Applying permissions...\n"); - if (auto Err = - Parent.getEndpoint().template callAsync<orcrpctpc::FinalizeMem>( - [OF = std::move(OnFinalize)](Error Err2) { - // FIXME: Dispatch to work queue. - std::thread([OF = std::move(OF), - Err3 = std::move(Err2)]() mutable { - DEBUG_WITH_TYPE( - "orc", { dbgs() << " finalizeAsync complete\n"; }); - OF(std::move(Err3)); - }).detach(); - return Error::success(); - }, - FMR)) { - DEBUG_WITH_TYPE("orc", dbgs() << " failed.\n"); - Parent.getEndpoint().abandonPendingResponses(); - Parent.reportError(std::move(Err)); - } - DEBUG_WITH_TYPE("orc", { - dbgs() << "Leaving finalizeAsync (finalization may continue in " - "background)\n"; - }); - } - - Error deallocate() override { - orcrpctpc::ReleaseOrFinalizeMemRequest RMR; - for (auto &KV : TargetAllocs) - RMR.push_back({orcrpctpc::toWireProtectionFlags( - static_cast<sys::Memory::ProtectionFlags>(KV.first)), - KV.second.Address, KV.second.AllocatedSize}); - TargetAllocs.clear(); - - return Parent.getEndpoint().template callB<orcrpctpc::ReleaseMem>(RMR); - } - - private: - OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent; - HostAllocMap HostAllocs; - TargetAllocMap TargetAllocs; - }; - - OrcRPCTPCJITLinkMemoryManager(OrcRPCTPCImplT &Parent) : Parent(Parent) {} - - Expected<std::unique_ptr<Allocation>> - allocate(const jitlink::JITLinkDylib *JD, - const SegmentsRequestMap &Request) override { - orcrpctpc::ReserveMemRequest RMR; - HostAllocMap HostAllocs; - - for (auto &KV : Request) { - assert(KV.second.getContentSize() <= std::numeric_limits<size_t>::max() && - "Content size is out-of-range for host"); - - RMR.push_back({orcrpctpc::toWireProtectionFlags( - static_cast<sys::Memory::ProtectionFlags>(KV.first)), - KV.second.getContentSize() + KV.second.getZeroFillSize(), - KV.second.getAlignment()}); - HostAllocs[KV.first] = { - std::make_unique<char[]>(KV.second.getContentSize()), - KV.second.getContentSize()}; - } - - DEBUG_WITH_TYPE("orc", { - dbgs() << "Orc remote memmgr got request:\n"; - for (auto &KV : Request) - dbgs() << " permissions: " - << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') - << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') - << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') - << ", content size: " - << formatv("{0:x16}", KV.second.getContentSize()) - << " + zero-fill-size: " - << formatv("{0:x16}", KV.second.getZeroFillSize()) - << ", align: " << KV.second.getAlignment() << "\n"; - }); - - // FIXME: LLVM RPC needs to be fixed to support alt - // serialization/deserialization on return types. For now just - // translate from std::map to DenseMap manually. - auto TmpTargetAllocs = - Parent.getEndpoint().template callB<orcrpctpc::ReserveMem>(RMR); - if (!TmpTargetAllocs) - return TmpTargetAllocs.takeError(); - - if (TmpTargetAllocs->size() != RMR.size()) - return make_error<StringError>( - "Number of target allocations does not match request", - inconvertibleErrorCode()); - - TargetAllocMap TargetAllocs; - for (auto &E : *TmpTargetAllocs) - TargetAllocs[orcrpctpc::fromWireProtectionFlags(E.Prot)] = { - E.Address, E.AllocatedSize}; - - DEBUG_WITH_TYPE("orc", { - auto HAI = HostAllocs.begin(); - for (auto &KV : TargetAllocs) - dbgs() << " permissions: " - << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') - << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') - << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') - << " assigned local " << (void *)HAI->second.Mem.get() - << ", target " << formatv("{0:x16}", KV.second.Address) << "\n"; - }); - - return std::make_unique<OrcRPCAllocation>(*this, std::move(HostAllocs), - std::move(TargetAllocs)); - } - -private: - void reportError(Error Err) { Parent.reportError(std::move(Err)); } - - decltype(std::declval<OrcRPCTPCImplT>().getEndpoint()) getEndpoint() { - return Parent.getEndpoint(); - } - - OrcRPCTPCImplT &Parent; -}; - -/// TargetProcessControl::MemoryAccess implementation for a process connected -/// via an ORC RPC endpoint. -template <typename OrcRPCTPCImplT> -class OrcRPCTPCMemoryAccess : public TargetProcessControl::MemoryAccess { -public: - OrcRPCTPCMemoryAccess(OrcRPCTPCImplT &Parent) : Parent(Parent) {} - - void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, - WriteResultFn OnWriteComplete) override { - writeViaRPC<orcrpctpc::WriteUInt8s>(Ws, std::move(OnWriteComplete)); - } - - void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, - WriteResultFn OnWriteComplete) override { - writeViaRPC<orcrpctpc::WriteUInt16s>(Ws, std::move(OnWriteComplete)); - } - - void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, - WriteResultFn OnWriteComplete) override { - writeViaRPC<orcrpctpc::WriteUInt32s>(Ws, std::move(OnWriteComplete)); - } - - void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, - WriteResultFn OnWriteComplete) override { - writeViaRPC<orcrpctpc::WriteUInt64s>(Ws, std::move(OnWriteComplete)); - } - - void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, - WriteResultFn OnWriteComplete) override { - writeViaRPC<orcrpctpc::WriteBuffers>(Ws, std::move(OnWriteComplete)); - } - -private: - template <typename WriteRPCFunction, typename WriteElementT> - void writeViaRPC(ArrayRef<WriteElementT> Ws, WriteResultFn OnWriteComplete) { - if (auto Err = Parent.getEndpoint().template callAsync<WriteRPCFunction>( - [OWC = std::move(OnWriteComplete)](Error Err2) mutable -> Error { - OWC(std::move(Err2)); - return Error::success(); - }, - Ws)) { - Parent.reportError(std::move(Err)); - Parent.getEndpoint().abandonPendingResponses(); - } - } - - OrcRPCTPCImplT &Parent; -}; - -// TargetProcessControl for a process connected via an ORC RPC Endpoint. -template <typename RPCEndpointT> -class OrcRPCTargetProcessControlBase : public TargetProcessControl { -public: - using ErrorReporter = unique_function<void(Error)>; - - using OnCloseConnectionFunction = unique_function<Error(Error)>; - - OrcRPCTargetProcessControlBase(std::shared_ptr<SymbolStringPool> SSP, - RPCEndpointT &EP, ErrorReporter ReportError) - : TargetProcessControl(std::move(SSP)), - ReportError(std::move(ReportError)), EP(EP) {} - - void reportError(Error Err) { ReportError(std::move(Err)); } - - RPCEndpointT &getEndpoint() { return EP; } - - Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override { - DEBUG_WITH_TYPE("orc", { - dbgs() << "Loading dylib \"" << (DylibPath ? DylibPath : "") << "\" "; - if (!DylibPath) - dbgs() << "(process symbols)"; - dbgs() << "\n"; - }); - if (!DylibPath) - DylibPath = ""; - auto H = EP.template callB<orcrpctpc::LoadDylib>(DylibPath); - DEBUG_WITH_TYPE("orc", { - if (H) - dbgs() << " got handle " << formatv("{0:x16}", *H) << "\n"; - else - dbgs() << " error, unable to load\n"; - }); - return H; - } - - Expected<std::vector<tpctypes::LookupResult>> - lookupSymbols(ArrayRef<LookupRequest> Request) override { - std::vector<orcrpctpc::RemoteLookupRequest> RR; - for (auto &E : Request) { - RR.push_back({}); - RR.back().first = E.Handle; - for (auto &KV : E.Symbols) - RR.back().second.push_back( - {(*KV.first).str(), - KV.second == SymbolLookupFlags::WeaklyReferencedSymbol}); - } - DEBUG_WITH_TYPE("orc", { - dbgs() << "Compound lookup:\n"; - for (auto &R : Request) { - dbgs() << " In " << formatv("{0:x16}", R.Handle) << ": {"; - bool First = true; - for (auto &KV : R.Symbols) { - dbgs() << (First ? "" : ",") << " " << *KV.first; - First = false; - } - dbgs() << " }\n"; - } - }); - return EP.template callB<orcrpctpc::LookupSymbols>(RR); - } - - Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, - ArrayRef<std::string> Args) override { - DEBUG_WITH_TYPE("orc", { - dbgs() << "Running as main: " << formatv("{0:x16}", MainFnAddr) - << ", args = ["; - for (unsigned I = 0; I != Args.size(); ++I) - dbgs() << (I ? "," : "") << " \"" << Args[I] << "\""; - dbgs() << "]\n"; - }); - auto Result = EP.template callB<orcrpctpc::RunMain>(MainFnAddr, Args); - DEBUG_WITH_TYPE("orc", { - dbgs() << " call to " << formatv("{0:x16}", MainFnAddr); - if (Result) - dbgs() << " returned result " << *Result << "\n"; - else - dbgs() << " failed\n"; - }); - return Result; - } - - Expected<tpctypes::WrapperFunctionResult> - runWrapper(JITTargetAddress WrapperFnAddr, - ArrayRef<uint8_t> ArgBuffer) override { - DEBUG_WITH_TYPE("orc", { - dbgs() << "Running as wrapper function " - << formatv("{0:x16}", WrapperFnAddr) << " with " - << formatv("{0:x16}", ArgBuffer.size()) << " argument buffer\n"; - }); - auto Result = - EP.template callB<orcrpctpc::RunWrapper>(WrapperFnAddr, ArgBuffer); - // dbgs() << "Returned from runWrapper...\n"; - return Result; - } - - Error closeConnection(OnCloseConnectionFunction OnCloseConnection) { - DEBUG_WITH_TYPE("orc", dbgs() << "Closing connection to remote\n"); - return EP.template callAsync<orcrpctpc::CloseConnection>( - std::move(OnCloseConnection)); - } - - Error closeConnectionAndWait() { - std::promise<MSVCPError> P; - auto F = P.get_future(); - if (auto Err = closeConnection([&](Error Err2) -> Error { - P.set_value(std::move(Err2)); - return Error::success(); - })) { - EP.abandonAllPendingResponses(); - return joinErrors(std::move(Err), F.get()); - } - return F.get(); - } - -protected: - /// Subclasses must call this during construction to initialize the - /// TargetTriple and PageSize members. - Error initializeORCRPCTPCBase() { - if (auto TripleOrErr = EP.template callB<orcrpctpc::GetTargetTriple>()) - TargetTriple = Triple(*TripleOrErr); - else - return TripleOrErr.takeError(); - - if (auto PageSizeOrErr = EP.template callB<orcrpctpc::GetPageSize>()) - PageSize = *PageSizeOrErr; - else - return PageSizeOrErr.takeError(); - - return Error::success(); - } - -private: - ErrorReporter ReportError; - RPCEndpointT &EP; -}; - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===--- OrcRPCTargetProcessControl.h - Remote target control ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for interacting with target processes. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H +#define LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H + +#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" +#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +namespace llvm { +namespace orc { + +/// JITLinkMemoryManager implementation for a process connected via an ORC RPC +/// endpoint. +template <typename OrcRPCTPCImplT> +class OrcRPCTPCJITLinkMemoryManager : public jitlink::JITLinkMemoryManager { +private: + struct HostAlloc { + std::unique_ptr<char[]> Mem; + uint64_t Size; + }; + + struct TargetAlloc { + JITTargetAddress Address = 0; + uint64_t AllocatedSize = 0; + }; + + using HostAllocMap = DenseMap<int, HostAlloc>; + using TargetAllocMap = DenseMap<int, TargetAlloc>; + +public: + class OrcRPCAllocation : public Allocation { + public: + OrcRPCAllocation(OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent, + HostAllocMap HostAllocs, TargetAllocMap TargetAllocs) + : Parent(Parent), HostAllocs(std::move(HostAllocs)), + TargetAllocs(std::move(TargetAllocs)) { + assert(HostAllocs.size() == TargetAllocs.size() && + "HostAllocs size should match TargetAllocs"); + } + + ~OrcRPCAllocation() override { + assert(TargetAllocs.empty() && "failed to deallocate"); + } + + MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override { + auto I = HostAllocs.find(Seg); + assert(I != HostAllocs.end() && "No host allocation for segment"); + auto &HA = I->second; + return {HA.Mem.get(), static_cast<size_t>(HA.Size)}; + } + + JITTargetAddress getTargetMemory(ProtectionFlags Seg) override { + auto I = TargetAllocs.find(Seg); + assert(I != TargetAllocs.end() && "No target allocation for segment"); + return I->second.Address; + } + + void finalizeAsync(FinalizeContinuation OnFinalize) override { + + std::vector<tpctypes::BufferWrite> BufferWrites; + orcrpctpc::ReleaseOrFinalizeMemRequest FMR; + + for (auto &KV : HostAllocs) { + assert(TargetAllocs.count(KV.first) && + "No target allocation for buffer"); + auto &HA = KV.second; + auto &TA = TargetAllocs[KV.first]; + BufferWrites.push_back({TA.Address, StringRef(HA.Mem.get(), HA.Size)}); + FMR.push_back({orcrpctpc::toWireProtectionFlags( + static_cast<sys::Memory::ProtectionFlags>(KV.first)), + TA.Address, TA.AllocatedSize}); + } + + DEBUG_WITH_TYPE("orc", { + dbgs() << "finalizeAsync " << (void *)this << ":\n"; + auto FMRI = FMR.begin(); + for (auto &B : BufferWrites) { + auto Prot = FMRI->Prot; + ++FMRI; + dbgs() << " Writing " << formatv("{0:x16}", B.Buffer.size()) + << " bytes to " << ((Prot & orcrpctpc::WPF_Read) ? 'R' : '-') + << ((Prot & orcrpctpc::WPF_Write) ? 'W' : '-') + << ((Prot & orcrpctpc::WPF_Exec) ? 'X' : '-') + << " segment: local " << (const void *)B.Buffer.data() + << " -> target " << formatv("{0:x16}", B.Address) << "\n"; + } + }); + if (auto Err = + Parent.Parent.getMemoryAccess().writeBuffers(BufferWrites)) { + OnFinalize(std::move(Err)); + return; + } + + DEBUG_WITH_TYPE("orc", dbgs() << " Applying permissions...\n"); + if (auto Err = + Parent.getEndpoint().template callAsync<orcrpctpc::FinalizeMem>( + [OF = std::move(OnFinalize)](Error Err2) { + // FIXME: Dispatch to work queue. + std::thread([OF = std::move(OF), + Err3 = std::move(Err2)]() mutable { + DEBUG_WITH_TYPE( + "orc", { dbgs() << " finalizeAsync complete\n"; }); + OF(std::move(Err3)); + }).detach(); + return Error::success(); + }, + FMR)) { + DEBUG_WITH_TYPE("orc", dbgs() << " failed.\n"); + Parent.getEndpoint().abandonPendingResponses(); + Parent.reportError(std::move(Err)); + } + DEBUG_WITH_TYPE("orc", { + dbgs() << "Leaving finalizeAsync (finalization may continue in " + "background)\n"; + }); + } + + Error deallocate() override { + orcrpctpc::ReleaseOrFinalizeMemRequest RMR; + for (auto &KV : TargetAllocs) + RMR.push_back({orcrpctpc::toWireProtectionFlags( + static_cast<sys::Memory::ProtectionFlags>(KV.first)), + KV.second.Address, KV.second.AllocatedSize}); + TargetAllocs.clear(); + + return Parent.getEndpoint().template callB<orcrpctpc::ReleaseMem>(RMR); + } + + private: + OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent; + HostAllocMap HostAllocs; + TargetAllocMap TargetAllocs; + }; + + OrcRPCTPCJITLinkMemoryManager(OrcRPCTPCImplT &Parent) : Parent(Parent) {} + + Expected<std::unique_ptr<Allocation>> + allocate(const jitlink::JITLinkDylib *JD, + const SegmentsRequestMap &Request) override { + orcrpctpc::ReserveMemRequest RMR; + HostAllocMap HostAllocs; + + for (auto &KV : Request) { + assert(KV.second.getContentSize() <= std::numeric_limits<size_t>::max() && + "Content size is out-of-range for host"); + + RMR.push_back({orcrpctpc::toWireProtectionFlags( + static_cast<sys::Memory::ProtectionFlags>(KV.first)), + KV.second.getContentSize() + KV.second.getZeroFillSize(), + KV.second.getAlignment()}); + HostAllocs[KV.first] = { + std::make_unique<char[]>(KV.second.getContentSize()), + KV.second.getContentSize()}; + } + + DEBUG_WITH_TYPE("orc", { + dbgs() << "Orc remote memmgr got request:\n"; + for (auto &KV : Request) + dbgs() << " permissions: " + << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') + << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') + << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') + << ", content size: " + << formatv("{0:x16}", KV.second.getContentSize()) + << " + zero-fill-size: " + << formatv("{0:x16}", KV.second.getZeroFillSize()) + << ", align: " << KV.second.getAlignment() << "\n"; + }); + + // FIXME: LLVM RPC needs to be fixed to support alt + // serialization/deserialization on return types. For now just + // translate from std::map to DenseMap manually. + auto TmpTargetAllocs = + Parent.getEndpoint().template callB<orcrpctpc::ReserveMem>(RMR); + if (!TmpTargetAllocs) + return TmpTargetAllocs.takeError(); + + if (TmpTargetAllocs->size() != RMR.size()) + return make_error<StringError>( + "Number of target allocations does not match request", + inconvertibleErrorCode()); + + TargetAllocMap TargetAllocs; + for (auto &E : *TmpTargetAllocs) + TargetAllocs[orcrpctpc::fromWireProtectionFlags(E.Prot)] = { + E.Address, E.AllocatedSize}; + + DEBUG_WITH_TYPE("orc", { + auto HAI = HostAllocs.begin(); + for (auto &KV : TargetAllocs) + dbgs() << " permissions: " + << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') + << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') + << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') + << " assigned local " << (void *)HAI->second.Mem.get() + << ", target " << formatv("{0:x16}", KV.second.Address) << "\n"; + }); + + return std::make_unique<OrcRPCAllocation>(*this, std::move(HostAllocs), + std::move(TargetAllocs)); + } + +private: + void reportError(Error Err) { Parent.reportError(std::move(Err)); } + + decltype(std::declval<OrcRPCTPCImplT>().getEndpoint()) getEndpoint() { + return Parent.getEndpoint(); + } + + OrcRPCTPCImplT &Parent; +}; + +/// TargetProcessControl::MemoryAccess implementation for a process connected +/// via an ORC RPC endpoint. +template <typename OrcRPCTPCImplT> +class OrcRPCTPCMemoryAccess : public TargetProcessControl::MemoryAccess { +public: + OrcRPCTPCMemoryAccess(OrcRPCTPCImplT &Parent) : Parent(Parent) {} + + void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, + WriteResultFn OnWriteComplete) override { + writeViaRPC<orcrpctpc::WriteUInt8s>(Ws, std::move(OnWriteComplete)); + } + + void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, + WriteResultFn OnWriteComplete) override { + writeViaRPC<orcrpctpc::WriteUInt16s>(Ws, std::move(OnWriteComplete)); + } + + void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, + WriteResultFn OnWriteComplete) override { + writeViaRPC<orcrpctpc::WriteUInt32s>(Ws, std::move(OnWriteComplete)); + } + + void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, + WriteResultFn OnWriteComplete) override { + writeViaRPC<orcrpctpc::WriteUInt64s>(Ws, std::move(OnWriteComplete)); + } + + void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, + WriteResultFn OnWriteComplete) override { + writeViaRPC<orcrpctpc::WriteBuffers>(Ws, std::move(OnWriteComplete)); + } + +private: + template <typename WriteRPCFunction, typename WriteElementT> + void writeViaRPC(ArrayRef<WriteElementT> Ws, WriteResultFn OnWriteComplete) { + if (auto Err = Parent.getEndpoint().template callAsync<WriteRPCFunction>( + [OWC = std::move(OnWriteComplete)](Error Err2) mutable -> Error { + OWC(std::move(Err2)); + return Error::success(); + }, + Ws)) { + Parent.reportError(std::move(Err)); + Parent.getEndpoint().abandonPendingResponses(); + } + } + + OrcRPCTPCImplT &Parent; +}; + +// TargetProcessControl for a process connected via an ORC RPC Endpoint. +template <typename RPCEndpointT> +class OrcRPCTargetProcessControlBase : public TargetProcessControl { +public: + using ErrorReporter = unique_function<void(Error)>; + + using OnCloseConnectionFunction = unique_function<Error(Error)>; + + OrcRPCTargetProcessControlBase(std::shared_ptr<SymbolStringPool> SSP, + RPCEndpointT &EP, ErrorReporter ReportError) + : TargetProcessControl(std::move(SSP)), + ReportError(std::move(ReportError)), EP(EP) {} + + void reportError(Error Err) { ReportError(std::move(Err)); } + + RPCEndpointT &getEndpoint() { return EP; } + + Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override { + DEBUG_WITH_TYPE("orc", { + dbgs() << "Loading dylib \"" << (DylibPath ? DylibPath : "") << "\" "; + if (!DylibPath) + dbgs() << "(process symbols)"; + dbgs() << "\n"; + }); + if (!DylibPath) + DylibPath = ""; + auto H = EP.template callB<orcrpctpc::LoadDylib>(DylibPath); + DEBUG_WITH_TYPE("orc", { + if (H) + dbgs() << " got handle " << formatv("{0:x16}", *H) << "\n"; + else + dbgs() << " error, unable to load\n"; + }); + return H; + } + + Expected<std::vector<tpctypes::LookupResult>> + lookupSymbols(ArrayRef<LookupRequest> Request) override { + std::vector<orcrpctpc::RemoteLookupRequest> RR; + for (auto &E : Request) { + RR.push_back({}); + RR.back().first = E.Handle; + for (auto &KV : E.Symbols) + RR.back().second.push_back( + {(*KV.first).str(), + KV.second == SymbolLookupFlags::WeaklyReferencedSymbol}); + } + DEBUG_WITH_TYPE("orc", { + dbgs() << "Compound lookup:\n"; + for (auto &R : Request) { + dbgs() << " In " << formatv("{0:x16}", R.Handle) << ": {"; + bool First = true; + for (auto &KV : R.Symbols) { + dbgs() << (First ? "" : ",") << " " << *KV.first; + First = false; + } + dbgs() << " }\n"; + } + }); + return EP.template callB<orcrpctpc::LookupSymbols>(RR); + } + + Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, + ArrayRef<std::string> Args) override { + DEBUG_WITH_TYPE("orc", { + dbgs() << "Running as main: " << formatv("{0:x16}", MainFnAddr) + << ", args = ["; + for (unsigned I = 0; I != Args.size(); ++I) + dbgs() << (I ? "," : "") << " \"" << Args[I] << "\""; + dbgs() << "]\n"; + }); + auto Result = EP.template callB<orcrpctpc::RunMain>(MainFnAddr, Args); + DEBUG_WITH_TYPE("orc", { + dbgs() << " call to " << formatv("{0:x16}", MainFnAddr); + if (Result) + dbgs() << " returned result " << *Result << "\n"; + else + dbgs() << " failed\n"; + }); + return Result; + } + + Expected<tpctypes::WrapperFunctionResult> + runWrapper(JITTargetAddress WrapperFnAddr, + ArrayRef<uint8_t> ArgBuffer) override { + DEBUG_WITH_TYPE("orc", { + dbgs() << "Running as wrapper function " + << formatv("{0:x16}", WrapperFnAddr) << " with " + << formatv("{0:x16}", ArgBuffer.size()) << " argument buffer\n"; + }); + auto Result = + EP.template callB<orcrpctpc::RunWrapper>(WrapperFnAddr, ArgBuffer); + // dbgs() << "Returned from runWrapper...\n"; + return Result; + } + + Error closeConnection(OnCloseConnectionFunction OnCloseConnection) { + DEBUG_WITH_TYPE("orc", dbgs() << "Closing connection to remote\n"); + return EP.template callAsync<orcrpctpc::CloseConnection>( + std::move(OnCloseConnection)); + } + + Error closeConnectionAndWait() { + std::promise<MSVCPError> P; + auto F = P.get_future(); + if (auto Err = closeConnection([&](Error Err2) -> Error { + P.set_value(std::move(Err2)); + return Error::success(); + })) { + EP.abandonAllPendingResponses(); + return joinErrors(std::move(Err), F.get()); + } + return F.get(); + } + +protected: + /// Subclasses must call this during construction to initialize the + /// TargetTriple and PageSize members. + Error initializeORCRPCTPCBase() { + if (auto TripleOrErr = EP.template callB<orcrpctpc::GetTargetTriple>()) + TargetTriple = Triple(*TripleOrErr); + else + return TripleOrErr.takeError(); + + if (auto PageSizeOrErr = EP.template callB<orcrpctpc::GetPageSize>()) + PageSize = *PageSizeOrErr; + else + return PageSizeOrErr.takeError(); + + return Error::success(); + } + +private: + ErrorReporter ReportError; + RPCEndpointT &EP; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index e639958887..508d7b92da 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -27,7 +27,7 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/JITSymbol.h" -#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" @@ -61,7 +61,7 @@ namespace remote { /// OrcRemoteTargetServer class) via an RPC system (see RPCUtils.h) to carry out /// its actions. class OrcRemoteTargetClient - : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { + : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { public: /// Remote-mapped RuntimeDyld-compatible memory manager. class RemoteRTDyldMemoryManager : public RuntimeDyld::MemoryManager { @@ -337,221 +337,221 @@ public: std::vector<EHFrame> RegisteredEHFrames; }; - class RPCMMAlloc : public jitlink::JITLinkMemoryManager::Allocation { - using AllocationMap = DenseMap<unsigned, sys::MemoryBlock>; - using FinalizeContinuation = - jitlink::JITLinkMemoryManager::Allocation::FinalizeContinuation; - using ProtectionFlags = sys::Memory::ProtectionFlags; - using SegmentsRequestMap = - DenseMap<unsigned, jitlink::JITLinkMemoryManager::SegmentRequest>; - - RPCMMAlloc(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id) - : Client(Client), Id(Id) {} - - public: - static Expected<std::unique_ptr<RPCMMAlloc>> - Create(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id, - const SegmentsRequestMap &Request) { - auto *MM = new RPCMMAlloc(Client, Id); - - if (Error Err = MM->allocateHostBlocks(Request)) - return std::move(Err); - - if (Error Err = MM->allocateTargetBlocks()) - return std::move(Err); - - return std::unique_ptr<RPCMMAlloc>(MM); - } - - MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override { - assert(HostSegBlocks.count(Seg) && "No allocation for segment"); - return {static_cast<char *>(HostSegBlocks[Seg].base()), - HostSegBlocks[Seg].allocatedSize()}; - } - - JITTargetAddress getTargetMemory(ProtectionFlags Seg) override { - assert(TargetSegBlocks.count(Seg) && "No allocation for segment"); - return pointerToJITTargetAddress(TargetSegBlocks[Seg].base()); - } - - void finalizeAsync(FinalizeContinuation OnFinalize) override { - // Host allocations (working memory) remain ReadWrite. - OnFinalize(copyAndProtect()); - } - - Error deallocate() override { - // TODO: Cannot release target allocation. RPCAPI has no function - // symmetric to reserveMem(). Add RPC call like freeMem()? - return errorCodeToError(sys::Memory::releaseMappedMemory(HostAllocation)); - } - - private: - OrcRemoteTargetClient &Client; - ResourceIdMgr::ResourceId Id; - AllocationMap HostSegBlocks; - AllocationMap TargetSegBlocks; - JITTargetAddress TargetSegmentAddr; - sys::MemoryBlock HostAllocation; - - Error allocateHostBlocks(const SegmentsRequestMap &Request) { - unsigned TargetPageSize = Client.getPageSize(); - - if (!isPowerOf2_64(static_cast<uint64_t>(TargetPageSize))) - return make_error<StringError>("Host page size is not a power of 2", - inconvertibleErrorCode()); - - auto TotalSize = calcTotalAllocSize(Request, TargetPageSize); - if (!TotalSize) - return TotalSize.takeError(); - - // Allocate one slab to cover all the segments. - const sys::Memory::ProtectionFlags ReadWrite = - static_cast<sys::Memory::ProtectionFlags>(sys::Memory::MF_READ | - sys::Memory::MF_WRITE); - std::error_code EC; - HostAllocation = - sys::Memory::allocateMappedMemory(*TotalSize, nullptr, ReadWrite, EC); - if (EC) - return errorCodeToError(EC); - - char *SlabAddr = static_cast<char *>(HostAllocation.base()); -#ifndef NDEBUG - char *SlabAddrEnd = SlabAddr + HostAllocation.allocatedSize(); -#endif - - // Allocate segment memory from the slab. - for (auto &KV : Request) { - const auto &Seg = KV.second; - - uint64_t SegmentSize = Seg.getContentSize() + Seg.getZeroFillSize(); - uint64_t AlignedSegmentSize = alignTo(SegmentSize, TargetPageSize); - - // Zero out zero-fill memory. - char *ZeroFillBegin = SlabAddr + Seg.getContentSize(); - memset(ZeroFillBegin, 0, Seg.getZeroFillSize()); - - // Record the block for this segment. - HostSegBlocks[KV.first] = - sys::MemoryBlock(SlabAddr, AlignedSegmentSize); - - SlabAddr += AlignedSegmentSize; - assert(SlabAddr <= SlabAddrEnd && "Out of range"); - } - - return Error::success(); - } - - Error allocateTargetBlocks() { - // Reserve memory for all blocks on the target. We need as much space on - // the target as we allocated on the host. - TargetSegmentAddr = Client.reserveMem(Id, HostAllocation.allocatedSize(), - Client.getPageSize()); - if (!TargetSegmentAddr) - return make_error<StringError>("Failed to reserve memory on the target", - inconvertibleErrorCode()); - - // Map memory blocks into the allocation, that match the host allocation. - JITTargetAddress TargetAllocAddr = TargetSegmentAddr; - for (const auto &KV : HostSegBlocks) { - size_t TargetAllocSize = KV.second.allocatedSize(); - - TargetSegBlocks[KV.first] = - sys::MemoryBlock(jitTargetAddressToPointer<void *>(TargetAllocAddr), - TargetAllocSize); - - TargetAllocAddr += TargetAllocSize; - assert(TargetAllocAddr - TargetSegmentAddr <= - HostAllocation.allocatedSize() && - "Out of range on target"); - } - - return Error::success(); - } - - Error copyAndProtect() { - unsigned Permissions = 0u; - - // Copy segments one by one. - for (auto &KV : TargetSegBlocks) { - Permissions |= KV.first; - - const sys::MemoryBlock &TargetBlock = KV.second; - const sys::MemoryBlock &HostBlock = HostSegBlocks.lookup(KV.first); - - size_t TargetAllocSize = TargetBlock.allocatedSize(); - auto TargetAllocAddr = pointerToJITTargetAddress(TargetBlock.base()); - auto *HostAllocBegin = static_cast<const char *>(HostBlock.base()); - - bool CopyErr = - Client.writeMem(TargetAllocAddr, HostAllocBegin, TargetAllocSize); - if (CopyErr) - return createStringError(inconvertibleErrorCode(), - "Failed to copy %d segment to the target", - KV.first); - } - - // Set permission flags for all segments at once. - bool ProtectErr = - Client.setProtections(Id, TargetSegmentAddr, Permissions); - if (ProtectErr) - return createStringError(inconvertibleErrorCode(), - "Failed to apply permissions for %d segment " - "on the target", - Permissions); - return Error::success(); - } - - static Expected<size_t> - calcTotalAllocSize(const SegmentsRequestMap &Request, - unsigned TargetPageSize) { - size_t TotalSize = 0; - for (const auto &KV : Request) { - const auto &Seg = KV.second; - - if (Seg.getAlignment() > TargetPageSize) - return make_error<StringError>("Cannot request alignment higher than " - "page alignment on target", - inconvertibleErrorCode()); - - TotalSize = alignTo(TotalSize, TargetPageSize); - TotalSize += Seg.getContentSize(); - TotalSize += Seg.getZeroFillSize(); - } - - return TotalSize; - } - }; - - class RemoteJITLinkMemoryManager : public jitlink::JITLinkMemoryManager { - public: - RemoteJITLinkMemoryManager(OrcRemoteTargetClient &Client, - ResourceIdMgr::ResourceId Id) - : Client(Client), Id(Id) {} - - RemoteJITLinkMemoryManager(const RemoteJITLinkMemoryManager &) = delete; - RemoteJITLinkMemoryManager(RemoteJITLinkMemoryManager &&) = default; - - RemoteJITLinkMemoryManager & - operator=(const RemoteJITLinkMemoryManager &) = delete; - RemoteJITLinkMemoryManager & - operator=(RemoteJITLinkMemoryManager &&) = delete; - - ~RemoteJITLinkMemoryManager() { - Client.destroyRemoteAllocator(Id); - LLVM_DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n"); - } - - Expected<std::unique_ptr<Allocation>> - allocate(const jitlink::JITLinkDylib *JD, - const SegmentsRequestMap &Request) override { - return RPCMMAlloc::Create(Client, Id, Request); - } - - private: - OrcRemoteTargetClient &Client; - ResourceIdMgr::ResourceId Id; - }; - + class RPCMMAlloc : public jitlink::JITLinkMemoryManager::Allocation { + using AllocationMap = DenseMap<unsigned, sys::MemoryBlock>; + using FinalizeContinuation = + jitlink::JITLinkMemoryManager::Allocation::FinalizeContinuation; + using ProtectionFlags = sys::Memory::ProtectionFlags; + using SegmentsRequestMap = + DenseMap<unsigned, jitlink::JITLinkMemoryManager::SegmentRequest>; + + RPCMMAlloc(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id) + : Client(Client), Id(Id) {} + + public: + static Expected<std::unique_ptr<RPCMMAlloc>> + Create(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id, + const SegmentsRequestMap &Request) { + auto *MM = new RPCMMAlloc(Client, Id); + + if (Error Err = MM->allocateHostBlocks(Request)) + return std::move(Err); + + if (Error Err = MM->allocateTargetBlocks()) + return std::move(Err); + + return std::unique_ptr<RPCMMAlloc>(MM); + } + + MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override { + assert(HostSegBlocks.count(Seg) && "No allocation for segment"); + return {static_cast<char *>(HostSegBlocks[Seg].base()), + HostSegBlocks[Seg].allocatedSize()}; + } + + JITTargetAddress getTargetMemory(ProtectionFlags Seg) override { + assert(TargetSegBlocks.count(Seg) && "No allocation for segment"); + return pointerToJITTargetAddress(TargetSegBlocks[Seg].base()); + } + + void finalizeAsync(FinalizeContinuation OnFinalize) override { + // Host allocations (working memory) remain ReadWrite. + OnFinalize(copyAndProtect()); + } + + Error deallocate() override { + // TODO: Cannot release target allocation. RPCAPI has no function + // symmetric to reserveMem(). Add RPC call like freeMem()? + return errorCodeToError(sys::Memory::releaseMappedMemory(HostAllocation)); + } + + private: + OrcRemoteTargetClient &Client; + ResourceIdMgr::ResourceId Id; + AllocationMap HostSegBlocks; + AllocationMap TargetSegBlocks; + JITTargetAddress TargetSegmentAddr; + sys::MemoryBlock HostAllocation; + + Error allocateHostBlocks(const SegmentsRequestMap &Request) { + unsigned TargetPageSize = Client.getPageSize(); + + if (!isPowerOf2_64(static_cast<uint64_t>(TargetPageSize))) + return make_error<StringError>("Host page size is not a power of 2", + inconvertibleErrorCode()); + + auto TotalSize = calcTotalAllocSize(Request, TargetPageSize); + if (!TotalSize) + return TotalSize.takeError(); + + // Allocate one slab to cover all the segments. + const sys::Memory::ProtectionFlags ReadWrite = + static_cast<sys::Memory::ProtectionFlags>(sys::Memory::MF_READ | + sys::Memory::MF_WRITE); + std::error_code EC; + HostAllocation = + sys::Memory::allocateMappedMemory(*TotalSize, nullptr, ReadWrite, EC); + if (EC) + return errorCodeToError(EC); + + char *SlabAddr = static_cast<char *>(HostAllocation.base()); +#ifndef NDEBUG + char *SlabAddrEnd = SlabAddr + HostAllocation.allocatedSize(); +#endif + + // Allocate segment memory from the slab. + for (auto &KV : Request) { + const auto &Seg = KV.second; + + uint64_t SegmentSize = Seg.getContentSize() + Seg.getZeroFillSize(); + uint64_t AlignedSegmentSize = alignTo(SegmentSize, TargetPageSize); + + // Zero out zero-fill memory. + char *ZeroFillBegin = SlabAddr + Seg.getContentSize(); + memset(ZeroFillBegin, 0, Seg.getZeroFillSize()); + + // Record the block for this segment. + HostSegBlocks[KV.first] = + sys::MemoryBlock(SlabAddr, AlignedSegmentSize); + + SlabAddr += AlignedSegmentSize; + assert(SlabAddr <= SlabAddrEnd && "Out of range"); + } + + return Error::success(); + } + + Error allocateTargetBlocks() { + // Reserve memory for all blocks on the target. We need as much space on + // the target as we allocated on the host. + TargetSegmentAddr = Client.reserveMem(Id, HostAllocation.allocatedSize(), + Client.getPageSize()); + if (!TargetSegmentAddr) + return make_error<StringError>("Failed to reserve memory on the target", + inconvertibleErrorCode()); + + // Map memory blocks into the allocation, that match the host allocation. + JITTargetAddress TargetAllocAddr = TargetSegmentAddr; + for (const auto &KV : HostSegBlocks) { + size_t TargetAllocSize = KV.second.allocatedSize(); + + TargetSegBlocks[KV.first] = + sys::MemoryBlock(jitTargetAddressToPointer<void *>(TargetAllocAddr), + TargetAllocSize); + + TargetAllocAddr += TargetAllocSize; + assert(TargetAllocAddr - TargetSegmentAddr <= + HostAllocation.allocatedSize() && + "Out of range on target"); + } + + return Error::success(); + } + + Error copyAndProtect() { + unsigned Permissions = 0u; + + // Copy segments one by one. + for (auto &KV : TargetSegBlocks) { + Permissions |= KV.first; + + const sys::MemoryBlock &TargetBlock = KV.second; + const sys::MemoryBlock &HostBlock = HostSegBlocks.lookup(KV.first); + + size_t TargetAllocSize = TargetBlock.allocatedSize(); + auto TargetAllocAddr = pointerToJITTargetAddress(TargetBlock.base()); + auto *HostAllocBegin = static_cast<const char *>(HostBlock.base()); + + bool CopyErr = + Client.writeMem(TargetAllocAddr, HostAllocBegin, TargetAllocSize); + if (CopyErr) + return createStringError(inconvertibleErrorCode(), + "Failed to copy %d segment to the target", + KV.first); + } + + // Set permission flags for all segments at once. + bool ProtectErr = + Client.setProtections(Id, TargetSegmentAddr, Permissions); + if (ProtectErr) + return createStringError(inconvertibleErrorCode(), + "Failed to apply permissions for %d segment " + "on the target", + Permissions); + return Error::success(); + } + + static Expected<size_t> + calcTotalAllocSize(const SegmentsRequestMap &Request, + unsigned TargetPageSize) { + size_t TotalSize = 0; + for (const auto &KV : Request) { + const auto &Seg = KV.second; + + if (Seg.getAlignment() > TargetPageSize) + return make_error<StringError>("Cannot request alignment higher than " + "page alignment on target", + inconvertibleErrorCode()); + + TotalSize = alignTo(TotalSize, TargetPageSize); + TotalSize += Seg.getContentSize(); + TotalSize += Seg.getZeroFillSize(); + } + + return TotalSize; + } + }; + + class RemoteJITLinkMemoryManager : public jitlink::JITLinkMemoryManager { + public: + RemoteJITLinkMemoryManager(OrcRemoteTargetClient &Client, + ResourceIdMgr::ResourceId Id) + : Client(Client), Id(Id) {} + + RemoteJITLinkMemoryManager(const RemoteJITLinkMemoryManager &) = delete; + RemoteJITLinkMemoryManager(RemoteJITLinkMemoryManager &&) = default; + + RemoteJITLinkMemoryManager & + operator=(const RemoteJITLinkMemoryManager &) = delete; + RemoteJITLinkMemoryManager & + operator=(RemoteJITLinkMemoryManager &&) = delete; + + ~RemoteJITLinkMemoryManager() { + Client.destroyRemoteAllocator(Id); + LLVM_DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n"); + } + + Expected<std::unique_ptr<Allocation>> + allocate(const jitlink::JITLinkDylib *JD, + const SegmentsRequestMap &Request) override { + return RPCMMAlloc::Create(Client, Id, Request); + } + + private: + OrcRemoteTargetClient &Client; + ResourceIdMgr::ResourceId Id; + }; + /// Remote indirect stubs manager. class RemoteIndirectStubsManager : public IndirectStubsManager { public: @@ -677,7 +677,7 @@ public: RemoteTrampolinePool(OrcRemoteTargetClient &Client) : Client(Client) {} private: - Error grow() override { + Error grow() override { JITTargetAddress BlockAddr = 0; uint32_t NumTrampolines = 0; if (auto TrampolineInfoOrErr = Client.emitTrampolineBlock()) @@ -687,7 +687,7 @@ public: uint32_t TrampolineSize = Client.getTrampolineSize(); for (unsigned I = 0; I < NumTrampolines; ++I) - AvailableTrampolines.push_back(BlockAddr + (I * TrampolineSize)); + AvailableTrampolines.push_back(BlockAddr + (I * TrampolineSize)); return Error::success(); } @@ -710,7 +710,7 @@ public: /// Channel is the ChannelT instance to communicate on. It is assumed that /// the channel is ready to be read from and written to. static Expected<std::unique_ptr<OrcRemoteTargetClient>> - Create(shared::RawByteChannel &Channel, ExecutionSession &ES) { + Create(shared::RawByteChannel &Channel, ExecutionSession &ES) { Error Err = Error::success(); auto Client = std::unique_ptr<OrcRemoteTargetClient>( new OrcRemoteTargetClient(Channel, ES, Err)); @@ -727,14 +727,14 @@ public: return callB<exec::CallIntVoid>(Addr); } - /// Call the int(int) function at the given address in the target and return - /// its result. - Expected<int> callIntInt(JITTargetAddress Addr, int Arg) { - LLVM_DEBUG(dbgs() << "Calling int(*)(int) " << format("0x%016" PRIx64, Addr) - << "\n"); - return callB<exec::CallIntInt>(Addr, Arg); - } - + /// Call the int(int) function at the given address in the target and return + /// its result. + Expected<int> callIntInt(JITTargetAddress Addr, int Arg) { + LLVM_DEBUG(dbgs() << "Calling int(*)(int) " << format("0x%016" PRIx64, Addr) + << "\n"); + return callB<exec::CallIntInt>(Addr, Arg); + } + /// Call the int(int, char*[]) function at the given address in the target and /// return its result. Expected<int> callMain(JITTargetAddress Addr, @@ -763,18 +763,18 @@ public: new RemoteRTDyldMemoryManager(*this, Id)); } - /// Create a JITLink-compatible memory manager which will allocate working - /// memory on the host and target memory on the remote target. - Expected<std::unique_ptr<RemoteJITLinkMemoryManager>> - createRemoteJITLinkMemoryManager() { - auto Id = AllocatorIds.getNext(); - if (auto Err = callB<mem::CreateRemoteAllocator>(Id)) - return std::move(Err); - LLVM_DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); - return std::unique_ptr<RemoteJITLinkMemoryManager>( - new RemoteJITLinkMemoryManager(*this, Id)); - } - + /// Create a JITLink-compatible memory manager which will allocate working + /// memory on the host and target memory on the remote target. + Expected<std::unique_ptr<RemoteJITLinkMemoryManager>> + createRemoteJITLinkMemoryManager() { + auto Id = AllocatorIds.getNext(); + if (auto Err = callB<mem::CreateRemoteAllocator>(Id)) + return std::move(Err); + LLVM_DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); + return std::unique_ptr<RemoteJITLinkMemoryManager>( + new RemoteJITLinkMemoryManager(*this, Id)); + } + /// Create an RCIndirectStubsManager that will allocate stubs on the remote /// target. Expected<std::unique_ptr<RemoteIndirectStubsManager>> @@ -812,10 +812,10 @@ public: Error terminateSession() { return callB<utils::TerminateSession>(); } private: - OrcRemoteTargetClient(shared::RawByteChannel &Channel, ExecutionSession &ES, + OrcRemoteTargetClient(shared::RawByteChannel &Channel, ExecutionSession &ES, Error &Err) - : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(Channel, - true), + : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(Channel, + true), ES(ES) { ErrorAsOutParameter EAO(&Err); diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index ff0dc7d33f..95f7205c06 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -23,8 +23,8 @@ #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H #include "llvm/ExecutionEngine/JITSymbol.h" -#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" -#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" +#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" +#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" namespace llvm { namespace orc { @@ -80,9 +80,9 @@ private: } // end namespace remote -namespace shared { +namespace shared { -template <> class SerializationTypeName<JITSymbolFlags> { +template <> class SerializationTypeName<JITSymbolFlags> { public: static const char *getName() { return "JITSymbolFlags"; } }; @@ -106,7 +106,7 @@ public: } }; -template <> class SerializationTypeName<remote::DirectBufferWriter> { +template <> class SerializationTypeName<remote::DirectBufferWriter> { public: static const char *getName() { return "DirectBufferWriter"; } }; @@ -139,7 +139,7 @@ public: } }; -} // end namespace shared +} // end namespace shared namespace remote { @@ -173,20 +173,20 @@ private: namespace eh { /// Registers EH frames on the remote. -class RegisterEHFrames - : public shared::RPCFunction<RegisterEHFrames, - void(JITTargetAddress Addr, uint32_t Size)> { -public: - static const char *getName() { return "RegisterEHFrames"; } -}; +class RegisterEHFrames + : public shared::RPCFunction<RegisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { +public: + static const char *getName() { return "RegisterEHFrames"; } +}; /// Deregisters EH frames on the remote. -class DeregisterEHFrames - : public shared::RPCFunction<DeregisterEHFrames, - void(JITTargetAddress Addr, uint32_t Size)> { -public: - static const char *getName() { return "DeregisterEHFrames"; } -}; +class DeregisterEHFrames + : public shared::RPCFunction<DeregisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { +public: + static const char *getName() { return "DeregisterEHFrames"; } +}; } // end namespace eh @@ -195,38 +195,38 @@ namespace exec { /// Call an 'int32_t()'-type function on the remote, returns the called /// function's return value. -class CallIntVoid - : public shared::RPCFunction<CallIntVoid, int32_t(JITTargetAddress Addr)> { -public: - static const char *getName() { return "CallIntVoid"; } -}; - - /// Call an 'int32_t(int32_t)'-type function on the remote, returns the called - /// function's return value. -class CallIntInt - : public shared::RPCFunction<CallIntInt, - int32_t(JITTargetAddress Addr, int)> { -public: - static const char *getName() { return "CallIntInt"; } -}; - +class CallIntVoid + : public shared::RPCFunction<CallIntVoid, int32_t(JITTargetAddress Addr)> { +public: + static const char *getName() { return "CallIntVoid"; } +}; + + /// Call an 'int32_t(int32_t)'-type function on the remote, returns the called + /// function's return value. +class CallIntInt + : public shared::RPCFunction<CallIntInt, + int32_t(JITTargetAddress Addr, int)> { +public: + static const char *getName() { return "CallIntInt"; } +}; + /// Call an 'int32_t(int32_t, char**)'-type function on the remote, returns the /// called function's return value. -class CallMain - : public shared::RPCFunction<CallMain, - int32_t(JITTargetAddress Addr, - std::vector<std::string> Args)> { -public: - static const char *getName() { return "CallMain"; } -}; +class CallMain + : public shared::RPCFunction<CallMain, + int32_t(JITTargetAddress Addr, + std::vector<std::string> Args)> { +public: + static const char *getName() { return "CallMain"; } +}; /// Calls a 'void()'-type function on the remote, returns when the called /// function completes. -class CallVoidVoid - : public shared::RPCFunction<CallVoidVoid, void(JITTargetAddress FnAddr)> { -public: - static const char *getName() { return "CallVoidVoid"; } -}; +class CallVoidVoid + : public shared::RPCFunction<CallVoidVoid, void(JITTargetAddress FnAddr)> { +public: + static const char *getName() { return "CallVoidVoid"; } +}; } // end namespace exec @@ -234,62 +234,62 @@ public: namespace mem { /// Creates a memory allocator on the remote. -class CreateRemoteAllocator - : public shared::RPCFunction<CreateRemoteAllocator, - void(ResourceIdMgr::ResourceId AllocatorID)> { -public: - static const char *getName() { return "CreateRemoteAllocator"; } -}; +class CreateRemoteAllocator + : public shared::RPCFunction<CreateRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { +public: + static const char *getName() { return "CreateRemoteAllocator"; } +}; /// Destroys a remote allocator, freeing any memory allocated by it. -class DestroyRemoteAllocator - : public shared::RPCFunction<DestroyRemoteAllocator, - void(ResourceIdMgr::ResourceId AllocatorID)> { -public: - static const char *getName() { return "DestroyRemoteAllocator"; } -}; +class DestroyRemoteAllocator + : public shared::RPCFunction<DestroyRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { +public: + static const char *getName() { return "DestroyRemoteAllocator"; } +}; /// Read a remote memory block. -class ReadMem - : public shared::RPCFunction< - ReadMem, std::vector<uint8_t>(JITTargetAddress Src, uint64_t Size)> { -public: - static const char *getName() { return "ReadMem"; } -}; +class ReadMem + : public shared::RPCFunction< + ReadMem, std::vector<uint8_t>(JITTargetAddress Src, uint64_t Size)> { +public: + static const char *getName() { return "ReadMem"; } +}; /// Reserve a block of memory on the remote via the given allocator. -class ReserveMem - : public shared::RPCFunction< - ReserveMem, JITTargetAddress(ResourceIdMgr::ResourceId AllocID, - uint64_t Size, uint32_t Align)> { -public: - static const char *getName() { return "ReserveMem"; } -}; +class ReserveMem + : public shared::RPCFunction< + ReserveMem, JITTargetAddress(ResourceIdMgr::ResourceId AllocID, + uint64_t Size, uint32_t Align)> { +public: + static const char *getName() { return "ReserveMem"; } +}; /// Set the memory protection on a memory block. -class SetProtections - : public shared::RPCFunction< - SetProtections, void(ResourceIdMgr::ResourceId AllocID, - JITTargetAddress Dst, uint32_t ProtFlags)> { -public: - static const char *getName() { return "SetProtections"; } -}; +class SetProtections + : public shared::RPCFunction< + SetProtections, void(ResourceIdMgr::ResourceId AllocID, + JITTargetAddress Dst, uint32_t ProtFlags)> { +public: + static const char *getName() { return "SetProtections"; } +}; /// Write to a remote memory block. -class WriteMem - : public shared::RPCFunction<WriteMem, - void(remote::DirectBufferWriter DB)> { -public: - static const char *getName() { return "WriteMem"; } -}; +class WriteMem + : public shared::RPCFunction<WriteMem, + void(remote::DirectBufferWriter DB)> { +public: + static const char *getName() { return "WriteMem"; } +}; /// Write to a remote pointer. -class WritePtr - : public shared::RPCFunction<WritePtr, void(JITTargetAddress Dst, - JITTargetAddress Val)> { -public: - static const char *getName() { return "WritePtr"; } -}; +class WritePtr + : public shared::RPCFunction<WritePtr, void(JITTargetAddress Dst, + JITTargetAddress Val)> { +public: + static const char *getName() { return "WritePtr"; } +}; } // end namespace mem @@ -297,46 +297,46 @@ public: namespace stubs { /// Creates an indirect stub owner on the remote. -class CreateIndirectStubsOwner - : public shared::RPCFunction<CreateIndirectStubsOwner, - void(ResourceIdMgr::ResourceId StubOwnerID)> { -public: - static const char *getName() { return "CreateIndirectStubsOwner"; } -}; +class CreateIndirectStubsOwner + : public shared::RPCFunction<CreateIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubOwnerID)> { +public: + static const char *getName() { return "CreateIndirectStubsOwner"; } +}; /// RPC function for destroying an indirect stubs owner. -class DestroyIndirectStubsOwner - : public shared::RPCFunction<DestroyIndirectStubsOwner, - void(ResourceIdMgr::ResourceId StubsOwnerID)> { -public: - static const char *getName() { return "DestroyIndirectStubsOwner"; } -}; +class DestroyIndirectStubsOwner + : public shared::RPCFunction<DestroyIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubsOwnerID)> { +public: + static const char *getName() { return "DestroyIndirectStubsOwner"; } +}; /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). -class EmitIndirectStubs - : public shared::RPCFunction< - EmitIndirectStubs, - std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>( - ResourceIdMgr::ResourceId StubsOwnerID, - uint32_t NumStubsRequired)> { -public: - static const char *getName() { return "EmitIndirectStubs"; } -}; +class EmitIndirectStubs + : public shared::RPCFunction< + EmitIndirectStubs, + std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> { +public: + static const char *getName() { return "EmitIndirectStubs"; } +}; /// RPC function to emit the resolver block and return its address. -class EmitResolverBlock - : public shared::RPCFunction<EmitResolverBlock, void()> { -public: - static const char *getName() { return "EmitResolverBlock"; } -}; +class EmitResolverBlock + : public shared::RPCFunction<EmitResolverBlock, void()> { +public: + static const char *getName() { return "EmitResolverBlock"; } +}; /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). -class EmitTrampolineBlock - : public shared::RPCFunction<EmitTrampolineBlock, - std::tuple<JITTargetAddress, uint32_t>()> { -public: - static const char *getName() { return "EmitTrampolineBlock"; } -}; +class EmitTrampolineBlock + : public shared::RPCFunction<EmitTrampolineBlock, + std::tuple<JITTargetAddress, uint32_t>()> { +public: + static const char *getName() { return "EmitTrampolineBlock"; } +}; } // end namespace stubs @@ -345,44 +345,44 @@ namespace utils { /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, /// IndirectStubsSize). -class GetRemoteInfo - : public shared::RPCFunction< - GetRemoteInfo, - std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>()> { -public: - static const char *getName() { return "GetRemoteInfo"; } -}; +class GetRemoteInfo + : public shared::RPCFunction< + GetRemoteInfo, + std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>()> { +public: + static const char *getName() { return "GetRemoteInfo"; } +}; /// Get the address of a remote symbol. -class GetSymbolAddress - : public shared::RPCFunction<GetSymbolAddress, - JITTargetAddress(std::string SymbolName)> { -public: - static const char *getName() { return "GetSymbolAddress"; } -}; +class GetSymbolAddress + : public shared::RPCFunction<GetSymbolAddress, + JITTargetAddress(std::string SymbolName)> { +public: + static const char *getName() { return "GetSymbolAddress"; } +}; /// Request that the host execute a compile callback. -class RequestCompile - : public shared::RPCFunction< - RequestCompile, JITTargetAddress(JITTargetAddress TrampolineAddr)> { -public: - static const char *getName() { return "RequestCompile"; } -}; +class RequestCompile + : public shared::RPCFunction< + RequestCompile, JITTargetAddress(JITTargetAddress TrampolineAddr)> { +public: + static const char *getName() { return "RequestCompile"; } +}; /// Notify the remote and terminate the session. -class TerminateSession : public shared::RPCFunction<TerminateSession, void()> { -public: - static const char *getName() { return "TerminateSession"; } -}; +class TerminateSession : public shared::RPCFunction<TerminateSession, void()> { +public: + static const char *getName() { return "TerminateSession"; } +}; } // namespace utils class OrcRemoteTargetRPCAPI - : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { + : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { public: // FIXME: Remove constructors once MSVC supports synthesizing move-ops. - OrcRemoteTargetRPCAPI(shared::RawByteChannel &C) - : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(C, true) {} + OrcRemoteTargetRPCAPI(shared::RawByteChannel &C) + : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(C, true) {} }; } // end namespace remote diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index d7bb4f591a..1002264934 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -24,7 +24,7 @@ #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Error.h" #include "llvm/Support/Format.h" @@ -53,7 +53,7 @@ namespace remote { template <typename ChannelT, typename TargetT> class OrcRemoteTargetServer - : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { + : public shared::SingleThreadedRPCEndpoint<shared::RawByteChannel> { public: using SymbolLookupFtor = std::function<JITTargetAddress(const std::string &Name)>; @@ -64,14 +64,14 @@ public: OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup, EHFrameRegistrationFtor EHFramesRegister, EHFrameRegistrationFtor EHFramesDeregister) - : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(Channel, - true), + : shared::SingleThreadedRPCEndpoint<shared::RawByteChannel>(Channel, + true), SymbolLookup(std::move(SymbolLookup)), EHFramesRegister(std::move(EHFramesRegister)), EHFramesDeregister(std::move(EHFramesDeregister)) { using ThisT = std::remove_reference_t<decltype(*this)>; addHandler<exec::CallIntVoid>(*this, &ThisT::handleCallIntVoid); - addHandler<exec::CallIntInt>(*this, &ThisT::handleCallIntInt); + addHandler<exec::CallIntInt>(*this, &ThisT::handleCallIntInt); addHandler<exec::CallMain>(*this, &ThisT::handleCallMain); addHandler<exec::CallVoidVoid>(*this, &ThisT::handleCallVoidVoid); addHandler<mem::CreateRemoteAllocator>(*this, @@ -177,19 +177,19 @@ private: return Result; } - Expected<int32_t> handleCallIntInt(JITTargetAddress Addr, int Arg) { - using IntIntFnTy = int (*)(int); - - IntIntFnTy Fn = reinterpret_cast<IntIntFnTy>(static_cast<uintptr_t>(Addr)); - - LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) - << " with argument " << Arg << "\n"); - int Result = Fn(Arg); - LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); - - return Result; - } - + Expected<int32_t> handleCallIntInt(JITTargetAddress Addr, int Arg) { + using IntIntFnTy = int (*)(int); + + IntIntFnTy Fn = reinterpret_cast<IntIntFnTy>(static_cast<uintptr_t>(Addr)); + + LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) + << " with argument " << Arg << "\n"); + int Result = Fn(Arg); + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + + return Result; + } + Expected<int32_t> handleCallMain(JITTargetAddress Addr, std::vector<std::string> Args) { using MainFnTy = int (*)(int, const char *[]); diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h index cb1755e21e..d7531d1087 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h @@ -42,16 +42,16 @@ namespace llvm { namespace orc { -class RTDyldObjectLinkingLayer : public ObjectLayer, private ResourceManager { +class RTDyldObjectLinkingLayer : public ObjectLayer, private ResourceManager { public: /// Functor for receiving object-loaded notifications. - using NotifyLoadedFunction = std::function<void( - MaterializationResponsibility &R, const object::ObjectFile &Obj, - const RuntimeDyld::LoadedObjectInfo &)>; + using NotifyLoadedFunction = std::function<void( + MaterializationResponsibility &R, const object::ObjectFile &Obj, + const RuntimeDyld::LoadedObjectInfo &)>; /// Functor for receiving finalization notifications. - using NotifyEmittedFunction = std::function<void( - MaterializationResponsibility &R, std::unique_ptr<MemoryBuffer>)>; + using NotifyEmittedFunction = std::function<void( + MaterializationResponsibility &R, std::unique_ptr<MemoryBuffer>)>; using GetMemoryManagerFunction = std::function<std::unique_ptr<RuntimeDyld::MemoryManager>()>; @@ -64,7 +64,7 @@ public: ~RTDyldObjectLinkingLayer(); /// Emit the object. - void emit(std::unique_ptr<MaterializationResponsibility> R, + void emit(std::unique_ptr<MaterializationResponsibility> R, std::unique_ptr<MemoryBuffer> O) override; /// Set the NotifyLoaded callback. @@ -129,24 +129,24 @@ public: void unregisterJITEventListener(JITEventListener &L); private: - using MemoryManagerUP = std::unique_ptr<RuntimeDyld::MemoryManager>; - - Error onObjLoad(MaterializationResponsibility &R, + using MemoryManagerUP = std::unique_ptr<RuntimeDyld::MemoryManager>; + + Error onObjLoad(MaterializationResponsibility &R, const object::ObjectFile &Obj, - RuntimeDyld::MemoryManager &MemMgr, - RuntimeDyld::LoadedObjectInfo &LoadedObjInfo, + RuntimeDyld::MemoryManager &MemMgr, + RuntimeDyld::LoadedObjectInfo &LoadedObjInfo, std::map<StringRef, JITEvaluatedSymbol> Resolved, std::set<StringRef> &InternalSymbols); - void onObjEmit(MaterializationResponsibility &R, + void onObjEmit(MaterializationResponsibility &R, object::OwningBinary<object::ObjectFile> O, - std::unique_ptr<RuntimeDyld::MemoryManager> MemMgr, - std::unique_ptr<RuntimeDyld::LoadedObjectInfo> LoadedObjInfo, - Error Err); - - Error handleRemoveResources(ResourceKey K) override; - void handleTransferResources(ResourceKey DstKey, ResourceKey SrcKey) override; + std::unique_ptr<RuntimeDyld::MemoryManager> MemMgr, + std::unique_ptr<RuntimeDyld::LoadedObjectInfo> LoadedObjInfo, + Error Err); + Error handleRemoveResources(ResourceKey K) override; + void handleTransferResources(ResourceKey DstKey, ResourceKey SrcKey) override; + mutable std::mutex RTDyldLayerMutex; GetMemoryManagerFunction GetMemoryManager; NotifyLoadedFunction NotifyLoaded; @@ -154,7 +154,7 @@ private: bool ProcessAllSections = false; bool OverrideObjectFlags = false; bool AutoClaimObjectSymbols = false; - DenseMap<ResourceKey, std::vector<MemoryManagerUP>> MemMgrs; + DenseMap<ResourceKey, std::vector<MemoryManagerUP>> MemMgrs; std::vector<JITEventListener *> EventListeners; }; diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h index e1a376bc6b..d4aa712442 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h @@ -1,90 +1,90 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- FDRawByteChannel.h - File descriptor based byte-channel -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// File descriptor based RawByteChannel. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H - -#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" - -#if !defined(_MSC_VER) && !defined(__MINGW32__) -#include <unistd.h> -#else -#include <io.h> -#endif - -namespace llvm { -namespace orc { -namespace shared { - -/// Serialization channel that reads from and writes from file descriptors. -class FDRawByteChannel final : public RawByteChannel { -public: - FDRawByteChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} - - llvm::Error readBytes(char *Dst, unsigned Size) override { - assert(Dst && "Attempt to read into null."); - ssize_t Completed = 0; - while (Completed < static_cast<ssize_t>(Size)) { - ssize_t Read = ::read(InFD, Dst + Completed, Size - Completed); - if (Read <= 0) { - auto ErrNo = errno; - if (ErrNo == EAGAIN || ErrNo == EINTR) - continue; - else - return llvm::errorCodeToError( - std::error_code(errno, std::generic_category())); - } - Completed += Read; - } - return llvm::Error::success(); - } - - llvm::Error appendBytes(const char *Src, unsigned Size) override { - assert(Src && "Attempt to append from null."); - ssize_t Completed = 0; - while (Completed < static_cast<ssize_t>(Size)) { - ssize_t Written = ::write(OutFD, Src + Completed, Size - Completed); - if (Written < 0) { - auto ErrNo = errno; - if (ErrNo == EAGAIN || ErrNo == EINTR) - continue; - else - return llvm::errorCodeToError( - std::error_code(errno, std::generic_category())); - } - Completed += Written; - } - return llvm::Error::success(); - } - - llvm::Error send() override { return llvm::Error::success(); } - -private: - int InFD, OutFD; -}; - -} // namespace shared -} // namespace orc -} // namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- FDRawByteChannel.h - File descriptor based byte-channel -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// File descriptor based RawByteChannel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H + +#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" + +#if !defined(_MSC_VER) && !defined(__MINGW32__) +#include <unistd.h> +#else +#include <io.h> +#endif + +namespace llvm { +namespace orc { +namespace shared { + +/// Serialization channel that reads from and writes from file descriptors. +class FDRawByteChannel final : public RawByteChannel { +public: + FDRawByteChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + + llvm::Error readBytes(char *Dst, unsigned Size) override { + assert(Dst && "Attempt to read into null."); + ssize_t Completed = 0; + while (Completed < static_cast<ssize_t>(Size)) { + ssize_t Read = ::read(InFD, Dst + Completed, Size - Completed); + if (Read <= 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Read; + } + return llvm::Error::success(); + } + + llvm::Error appendBytes(const char *Src, unsigned Size) override { + assert(Src && "Attempt to append from null."); + ssize_t Completed = 0; + while (Completed < static_cast<ssize_t>(Size)) { + ssize_t Written = ::write(OutFD, Src + Completed, Size - Completed); + if (Written < 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Written; + } + return llvm::Error::success(); + } + + llvm::Error send() override { return llvm::Error::success(); } + +private: + int InFD, OutFD; +}; + +} // namespace shared +} // namespace orc +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h index 2dde3afdce..172c35a221 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h @@ -1,85 +1,85 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===------ OrcError.h - Reject symbol lookup requests ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Define an error category, error codes, and helper utilities for Orc. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H -#define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H - -#include "llvm/Support/Error.h" -#include "llvm/Support/raw_ostream.h" -#include <string> -#include <system_error> - -namespace llvm { -namespace orc { - -enum class OrcErrorCode : int { - // RPC Errors - UnknownORCError = 1, - DuplicateDefinition, - JITSymbolNotFound, - RemoteAllocatorDoesNotExist, - RemoteAllocatorIdAlreadyInUse, - RemoteMProtectAddrUnrecognized, - RemoteIndirectStubsOwnerDoesNotExist, - RemoteIndirectStubsOwnerIdAlreadyInUse, - RPCConnectionClosed, - RPCCouldNotNegotiateFunction, - RPCResponseAbandoned, - UnexpectedRPCCall, - UnexpectedRPCResponse, - UnknownErrorCodeFromRemote, - UnknownResourceHandle, - MissingSymbolDefinitions, - UnexpectedSymbolDefinitions, -}; - -std::error_code orcError(OrcErrorCode ErrCode); - -class DuplicateDefinition : public ErrorInfo<DuplicateDefinition> { -public: - static char ID; - - DuplicateDefinition(std::string SymbolName); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSymbolName() const; -private: - std::string SymbolName; -}; - -class JITSymbolNotFound : public ErrorInfo<JITSymbolNotFound> { -public: - static char ID; - - JITSymbolNotFound(std::string SymbolName); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSymbolName() const; -private: - std::string SymbolName; -}; - -} // End namespace orc. -} // End namespace llvm. - -#endif // LLVM_EXECUTIONENGINE_ORC_ORCERROR_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===------ OrcError.h - Reject symbol lookup requests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Define an error category, error codes, and helper utilities for Orc. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H +#define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H + +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <string> +#include <system_error> + +namespace llvm { +namespace orc { + +enum class OrcErrorCode : int { + // RPC Errors + UnknownORCError = 1, + DuplicateDefinition, + JITSymbolNotFound, + RemoteAllocatorDoesNotExist, + RemoteAllocatorIdAlreadyInUse, + RemoteMProtectAddrUnrecognized, + RemoteIndirectStubsOwnerDoesNotExist, + RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCConnectionClosed, + RPCCouldNotNegotiateFunction, + RPCResponseAbandoned, + UnexpectedRPCCall, + UnexpectedRPCResponse, + UnknownErrorCodeFromRemote, + UnknownResourceHandle, + MissingSymbolDefinitions, + UnexpectedSymbolDefinitions, +}; + +std::error_code orcError(OrcErrorCode ErrCode); + +class DuplicateDefinition : public ErrorInfo<DuplicateDefinition> { +public: + static char ID; + + DuplicateDefinition(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +class JITSymbolNotFound : public ErrorInfo<JITSymbolNotFound> { +public: + static char ID; + + JITSymbolNotFound(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +} // End namespace orc. +} // End namespace llvm. + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCERROR_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h index 4bc6d3577b..26b64ee2db 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h @@ -1,1668 +1,1668 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities to support construction of simple RPC APIs. -// -// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ -// programmers, high performance, low memory overhead, and efficient use of the -// communications channel. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H - -#include <map> -#include <thread> -#include <vector> - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" -#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" -#include "llvm/Support/MSVCErrorWorkarounds.h" - -#include <future> - -namespace llvm { -namespace orc { -namespace shared { - -/// Base class of all fatal RPC errors (those that necessarily result in the -/// termination of the RPC session). -class RPCFatalError : public ErrorInfo<RPCFatalError> { -public: - static char ID; -}; - -/// RPCConnectionClosed is returned from RPC operations if the RPC connection -/// has already been closed due to either an error or graceful disconnection. -class ConnectionClosed : public ErrorInfo<ConnectionClosed> { -public: - static char ID; - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; -}; - -/// BadFunctionCall is returned from handleOne when the remote makes a call with -/// an unrecognized function id. -/// -/// This error is fatal because Orc RPC needs to know how to parse a function -/// call to know where the next call starts, and if it doesn't recognize the -/// function id it cannot parse the call. -template <typename FnIdT, typename SeqNoT> -class BadFunctionCall - : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { -public: - static char ID; - - BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) - : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} - - std::error_code convertToErrorCode() const override { - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - - void log(raw_ostream &OS) const override { - OS << "Call to invalid RPC function id '" << FnId - << "' with " - "sequence number " - << SeqNo; - } - -private: - FnIdT FnId; - SeqNoT SeqNo; -}; - -template <typename FnIdT, typename SeqNoT> -char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; - -/// InvalidSequenceNumberForResponse is returned from handleOne when a response -/// call arrives with a sequence number that doesn't correspond to any in-flight -/// function call. -/// -/// This error is fatal because Orc RPC needs to know how to parse the rest of -/// the response call to know where the next call starts, and if it doesn't have -/// a result parser for this sequence number it can't do that. -template <typename SeqNoT> -class InvalidSequenceNumberForResponse - : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, - RPCFatalError> { -public: - static char ID; - - InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} - - std::error_code convertToErrorCode() const override { - return orcError(OrcErrorCode::UnexpectedRPCCall); - }; - - void log(raw_ostream &OS) const override { - OS << "Response has unknown sequence number " << SeqNo; - } - -private: - SeqNoT SeqNo; -}; - -template <typename SeqNoT> -char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; - -/// This non-fatal error will be passed to asynchronous result handlers in place -/// of a result if the connection goes down before a result returns, or if the -/// function to be called cannot be negotiated with the remote. -class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { -public: - static char ID; - - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; -}; - -/// This error is returned if the remote does not have a handler installed for -/// the given RPC function. -class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { -public: - static char ID; - - CouldNotNegotiate(std::string Signature); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSignature() const { return Signature; } - -private: - std::string Signature; -}; - -template <typename DerivedFunc, typename FnT> class RPCFunction; - -// RPC Function class. -// DerivedFunc should be a user defined class with a static 'getName()' method -// returning a const char* representing the function's name. -template <typename DerivedFunc, typename RetT, typename... ArgTs> -class RPCFunction<DerivedFunc, RetT(ArgTs...)> { -public: - /// User defined function type. - using Type = RetT(ArgTs...); - - /// Return type. - using ReturnType = RetT; - - /// Returns the full function prototype as a string. - static const char *getPrototype() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << SerializationTypeName<RetT>::getName() << " " - << DerivedFunc::getName() << "(" - << SerializationTypeNameSequence<ArgTs...>() << ")"; - return Name; - }(); - return Name.data(); - } -}; - -/// Allocates RPC function ids during autonegotiation. -/// Specializations of this class must provide four members: -/// -/// static T getInvalidId(): -/// Should return a reserved id that will be used to represent missing -/// functions during autonegotiation. -/// -/// static T getResponseId(): -/// Should return a reserved id that will be used to send function responses -/// (return values). -/// -/// static T getNegotiateId(): -/// Should return a reserved id for the negotiate function, which will be used -/// to negotiate ids for user defined functions. -/// -/// template <typename Func> T allocate(): -/// Allocate a unique id for function Func. -template <typename T, typename = void> class RPCFunctionIdAllocator; - -/// This specialization of RPCFunctionIdAllocator provides a default -/// implementation for integral types. -template <typename T> -class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> { -public: - static T getInvalidId() { return T(0); } - static T getResponseId() { return T(1); } - static T getNegotiateId() { return T(2); } - - template <typename Func> T allocate() { return NextId++; } - -private: - T NextId = 3; -}; - -namespace detail { - -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class RPCFunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class RPCFunctionArgsTuple<RetT(ArgTs...)> { -public: - using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>; -}; - -// ResultTraits provides typedefs and utilities specific to the return type -// of functions. -template <typename RetT> class ResultTraits { -public: - // The return type wrapped in llvm::Expected. - using ErrorReturnType = Expected<RetT>; - -#ifdef _MSC_VER - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<MSVCPExpected<RetT>>; -#else - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<ErrorReturnType>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<ErrorReturnType>; -#endif - - // Create a 'blank' value of the ErrorReturnType, ready and safe to - // overwrite. - static ErrorReturnType createBlankErrorReturnValue() { - return ErrorReturnType(RetT()); - } - - // Consume an abandoned ErrorReturnType. - static void consumeAbandoned(ErrorReturnType RetOrErr) { - consumeError(RetOrErr.takeError()); - } -}; - -// ResultTraits specialization for void functions. -template <> class ResultTraits<void> { -public: - // For void functions, ErrorReturnType is llvm::Error. - using ErrorReturnType = Error; - -#ifdef _MSC_VER - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<MSVCPError>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<MSVCPError>; -#else - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<ErrorReturnType>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<ErrorReturnType>; -#endif - - // Create a 'blank' value of the ErrorReturnType, ready and safe to - // overwrite. - static ErrorReturnType createBlankErrorReturnValue() { - return ErrorReturnType::success(); - } - - // Consume an abandoned ErrorReturnType. - static void consumeAbandoned(ErrorReturnType Err) { - consumeError(std::move(Err)); - } -}; - -// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows -// handlers for void RPC functions to return either void (in which case they -// implicitly succeed) or Error (in which case their error return is -// propagated). See usage in HandlerTraits::runHandlerHelper. -template <> class ResultTraits<Error> : public ResultTraits<void> {}; - -// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows -// handlers for RPC functions returning a T to return either a T (in which -// case they implicitly succeed) or Expected<T> (in which case their error -// return is propagated). See usage in HandlerTraits::runHandlerHelper. -template <typename RetT> -class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; - -// Determines whether an RPC function's defined error return type supports -// error return value. -template <typename T> class SupportsErrorReturn { -public: - static const bool value = false; -}; - -template <> class SupportsErrorReturn<Error> { -public: - static const bool value = true; -}; - -template <typename T> class SupportsErrorReturn<Expected<T>> { -public: - static const bool value = true; -}; - -// RespondHelper packages return values based on whether or not the declared -// RPC function return type supports error returns. -template <bool FuncSupportsErrorReturn> class RespondHelper; - -// RespondHelper specialization for functions that support error returns. -template <> class RespondHelper<true> { -public: - // Send Expected<T>. - template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) - return ResultOrErr.takeError(); - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>:: - serialize(C, std::move(ResultOrErr))) - return Err; - - // Close the response message. - if (auto Err = C.endSendMessage()) - return Err; - return C.send(); - } - - template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err && Err.isA<RPCFatalError>()) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - if (auto Err2 = serializeSeq(C, std::move(Err))) - return Err2; - if (auto Err2 = C.endSendMessage()) - return Err2; - return C.send(); - } -}; - -// RespondHelper specialization for functions that do not support error returns. -template <> class RespondHelper<false> { -public: - template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - if (auto Err = ResultOrErr.takeError()) - return Err; - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( - C, *ResultOrErr)) - return Err; - - // End the response message. - if (auto Err = C.endSendMessage()) - return Err; - - return C.send(); - } - - template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - if (auto Err2 = C.endSendMessage()) - return Err2; - return C.send(); - } -}; - -// Send a response of the given wire return type (WireRetT) over the -// channel, with the given sequence number. -template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> -Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: - template sendResult<WireRetT>(C, ResponseId, SeqNo, - std::move(ResultOrErr)); -} - -// Send an empty response message on the given channel to indicate that -// the handler ran. -template <typename WireRetT, typename ChannelT, typename FunctionIdT, - typename SequenceNumberT> -Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, - Error Err) { - return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult( - C, ResponseId, SeqNo, std::move(Err)); -} - -// Converts a given type to the equivalent error return type. -template <typename T> class WrappedHandlerReturn { -public: - using Type = Expected<T>; -}; - -template <typename T> class WrappedHandlerReturn<Expected<T>> { -public: - using Type = Expected<T>; -}; - -template <> class WrappedHandlerReturn<void> { -public: - using Type = Error; -}; - -template <> class WrappedHandlerReturn<Error> { -public: - using Type = Error; -}; - -template <> class WrappedHandlerReturn<ErrorSuccess> { -public: - using Type = Error; -}; - -// Traits class that strips the response function from the list of handler -// arguments. -template <typename FnT> class AsyncHandlerTraits; - -template <typename ResultT, typename... ArgTs> -class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, - ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Expected<ResultT>; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename ResponseHandlerT, typename... ArgTs> -class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> - : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>, - ArgTs...)> {}; - -// This template class provides utilities related to RPC function handlers. -// The base case applies to non-function types (the template class is -// specialized for function types) and inherits from the appropriate -// speciilization for the given non-function type's call operator. -template <typename HandlerT> -class HandlerTraits - : public HandlerTraits< - decltype(&std::remove_reference<HandlerT>::type::operator())> {}; - -// Traits for handlers with a given function type. -template <typename RetT, typename... ArgTs> -class HandlerTraits<RetT(ArgTs...)> { -public: - // Function type of the handler. - using Type = RetT(ArgTs...); - - // Return type of the handler. - using ReturnType = RetT; - - // Call the given handler with the given arguments. - template <typename HandlerT, typename... TArgTs> - static typename WrappedHandlerReturn<RetT>::Type - unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { - return unpackAndRunHelper(Handler, Args, - std::index_sequence_for<TArgTs...>()); - } - - // Call the given handler with the given arguments. - template <typename HandlerT, typename ResponderT, typename... TArgTs> - static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, - std::tuple<TArgTs...> &Args) { - return unpackAndRunAsyncHelper(Handler, Responder, Args, - std::index_sequence_for<TArgTs...>()); - } - - // Call the given handler with the given arguments. - template <typename HandlerT> - static std::enable_if_t< - std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error> - run(HandlerT &Handler, ArgTs &&...Args) { - Handler(std::move(Args)...); - return Error::success(); - } - - template <typename HandlerT, typename... TArgTs> - static std::enable_if_t< - !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, - typename HandlerTraits<HandlerT>::ReturnType> - run(HandlerT &Handler, TArgTs... Args) { - return Handler(std::move(Args)...); - } - - // Serialize arguments to the channel. - template <typename ChannelT, typename... CArgTs> - static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); - } - - // Deserialize arguments from the channel. - template <typename ChannelT, typename... CArgTs> - static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { - return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); - } - -private: - template <typename ChannelT, typename... CArgTs, size_t... Indexes> - static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, - std::index_sequence<Indexes...> _) { - return SequenceSerialization<ChannelT, ArgTs...>::deserialize( - C, std::get<Indexes>(Args)...); - } - - template <typename HandlerT, typename ArgTuple, size_t... Indexes> - static typename WrappedHandlerReturn< - typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, - std::index_sequence<Indexes...>) { - return run(Handler, std::move(std::get<Indexes>(Args))...); - } - - template <typename HandlerT, typename ResponderT, typename ArgTuple, - size_t... Indexes> - static typename WrappedHandlerReturn< - typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, - ArgTuple &Args, std::index_sequence<Indexes...>) { - return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); - } -}; - -// Handler traits for free functions. -template <typename RetT, typename... ArgTs> -class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> { -}; - -// Handler traits for class methods (especially call operators for lambdas). -template <typename Class, typename RetT, typename... ArgTs> -class HandlerTraits<RetT (Class::*)(ArgTs...)> - : public HandlerTraits<RetT(ArgTs...)> {}; - -// Handler traits for const class methods (especially call operators for -// lambdas). -template <typename Class, typename RetT, typename... ArgTs> -class HandlerTraits<RetT (Class::*)(ArgTs...) const> - : public HandlerTraits<RetT(ArgTs...)> {}; - -// Utility to peel the Expected wrapper off a response handler error type. -template <typename HandlerT> class ResponseHandlerArg; - -template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { -public: - using ArgType = Expected<ArgT>; - using UnwrappedArgType = ArgT; -}; - -template <typename ArgT> -class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { -public: - using ArgType = Expected<ArgT>; - using UnwrappedArgType = ArgT; -}; - -template <> class ResponseHandlerArg<Error(Error)> { -public: - using ArgType = Error; -}; - -template <> class ResponseHandlerArg<ErrorSuccess(Error)> { -public: - using ArgType = Error; -}; - -// ResponseHandler represents a handler for a not-yet-received function call -// result. -template <typename ChannelT> class ResponseHandler { -public: - virtual ~ResponseHandler() {} - - // Reads the function result off the wire and acts on it. The meaning of - // "act" will depend on how this method is implemented in any given - // ResponseHandler subclass but could, for example, mean running a - // user-specified handler or setting a promise value. - virtual Error handleResponse(ChannelT &C) = 0; - - // Abandons this outstanding result. - virtual void abandon() = 0; - - // Create an error instance representing an abandoned response. - static Error createAbandonedResponseError() { - return make_error<ResponseAbandoned>(); - } -}; - -// ResponseHandler subclass for RPC functions with non-void returns. -template <typename ChannelT, typename FuncRetT, typename HandlerT> -class ResponseHandlerImpl : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - using UnwrappedArgType = typename ResponseHandlerArg< - typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; - UnwrappedArgType Result; - if (auto Err = - SerializationTraits<ChannelT, FuncRetT, - UnwrappedArgType>::deserialize(C, Result)) - return Err; - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -// ResponseHandler subclass for RPC functions with void returns. -template <typename ChannelT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, void, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result (no actual value, just a notification that the function - // has completed on the remote end) by calling the user-defined handler with - // Error::success(). - Error handleResponse(ChannelT &C) override { - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(Error::success()); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -template <typename ChannelT, typename FuncRetT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - using HandlerArgType = typename ResponseHandlerArg< - typename HandlerTraits<HandlerT>::Type>::ArgType; - HandlerArgType Result((typename HandlerArgType::value_type())); - - if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>, - HandlerArgType>::deserialize(C, Result)) - return Err; - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -template <typename ChannelT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, Error, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - Error Result = Error::success(); - if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( - C, Result)) { - consumeError(std::move(Result)); - return Err; - } - if (auto Err = C.endReceiveMessage()) { - consumeError(std::move(Result)); - return Err; - } - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -// Create a ResponseHandler from a given user handler. -template <typename ChannelT, typename FuncRetT, typename HandlerT> -std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { - return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( - std::move(H)); -} - -// Helper for wrapping member functions up as functors. This is useful for -// installing methods as result handlers. -template <typename ClassT, typename RetT, typename... ArgTs> -class MemberFnWrapper { -public: - using MethodT = RetT (ClassT::*)(ArgTs...); - MemberFnWrapper(ClassT &Instance, MethodT Method) - : Instance(Instance), Method(Method) {} - RetT operator()(ArgTs &&...Args) { - return (Instance.*Method)(std::move(Args)...); - } - -private: - ClassT &Instance; - MethodT Method; -}; - -// Helper that provides a Functor for deserializing arguments. -template <typename... ArgTs> class ReadArgs { -public: - Error operator()() { return Error::success(); } -}; - -template <typename ArgT, typename... ArgTs> -class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { -public: - ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} - - Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { - this->Arg = std::move(ArgVal); - return ReadArgs<ArgTs...>::operator()(ArgVals...); - } - -private: - ArgT &Arg; -}; - -// Manage sequence numbers. -template <typename SequenceNumberT> class SequenceNumberManager { -public: - // Reset, making all sequence numbers available. - void reset() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - NextSequenceNumber = 0; - FreeSequenceNumbers.clear(); - } - - // Get the next available sequence number. Will re-use numbers that have - // been released. - SequenceNumberT getSequenceNumber() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - if (FreeSequenceNumbers.empty()) - return NextSequenceNumber++; - auto SequenceNumber = FreeSequenceNumbers.back(); - FreeSequenceNumbers.pop_back(); - return SequenceNumber; - } - - // Release a sequence number, making it available for re-use. - void releaseSequenceNumber(SequenceNumberT SequenceNumber) { - std::lock_guard<std::mutex> Lock(SeqNoLock); - FreeSequenceNumbers.push_back(SequenceNumber); - } - -private: - std::mutex SeqNoLock; - SequenceNumberT NextSequenceNumber = 0; - std::vector<SequenceNumberT> FreeSequenceNumbers; -}; - -// Checks that predicate P holds for each corresponding pair of type arguments -// from T1 and T2 tuple. -template <template <class, class> class P, typename T1Tuple, typename T2Tuple> -class RPCArgTypeCheckHelper; - -template <template <class, class> class P> -class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { -public: - static const bool value = true; -}; - -template <template <class, class> class P, typename T, typename... Ts, - typename U, typename... Us> -class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { -public: - static const bool value = - P<T, U>::value && - RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; -}; - -template <template <class, class> class P, typename T1Sig, typename T2Sig> -class RPCArgTypeCheck { -public: - using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type; - using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type; - - static_assert(std::tuple_size<T1Tuple>::value >= - std::tuple_size<T2Tuple>::value, - "Too many arguments to RPC call"); - static_assert(std::tuple_size<T1Tuple>::value <= - std::tuple_size<T2Tuple>::value, - "Too few arguments to RPC call"); - - static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; -}; - -template <typename ChannelT, typename WireT, typename ConcreteT> -class CanSerialize { -private: - using S = SerializationTraits<ChannelT, WireT, ConcreteT>; - - template <typename T> - static std::true_type check( - std::enable_if_t<std::is_same<decltype(T::serialize( - std::declval<ChannelT &>(), - std::declval<const ConcreteT &>())), - Error>::value, - void *>); - - template <typename> static std::false_type check(...); - -public: - static const bool value = decltype(check<S>(0))::value; -}; - -template <typename ChannelT, typename WireT, typename ConcreteT> -class CanDeserialize { -private: - using S = SerializationTraits<ChannelT, WireT, ConcreteT>; - - template <typename T> - static std::true_type - check(std::enable_if_t< - std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), - std::declval<ConcreteT &>())), - Error>::value, - void *>); - - template <typename> static std::false_type check(...); - -public: - static const bool value = decltype(check<S>(0))::value; -}; - -/// Contains primitive utilities for defining, calling and handling calls to -/// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure -/// identifier type that must be serializable on ChannelT, and SequenceNumberT -/// is an integral type that will be used to number in-flight function calls. -/// -/// These utilities support the construction of very primitive RPC utilities. -/// Their intent is to ensure correct serialization and deserialization of -/// procedure arguments, and to keep the client and server's view of the API in -/// sync. -template <typename ImplT, typename ChannelT, typename FunctionIdT, - typename SequenceNumberT> -class RPCEndpointBase { -protected: - class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> { - public: - static const char *getName() { return "__orc_rpc$invalid"; } - }; - - class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> { - public: - static const char *getName() { return "__orc_rpc$response"; } - }; - - class OrcRPCNegotiate - : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> { - public: - static const char *getName() { return "__orc_rpc$negotiate"; } - }; - - // Helper predicate for testing for the presence of SerializeTraits - // serializers. - template <typename WireT, typename ConcreteT> - class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { - public: - using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; - - static_assert(value, "Missing serializer for argument (Can't serialize the " - "first template type argument of CanSerializeCheck " - "from the second)"); - }; - - // Helper predicate for testing for the presence of SerializeTraits - // deserializers. - template <typename WireT, typename ConcreteT> - class CanDeserializeCheck - : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { - public: - using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; - - static_assert(value, "Missing deserializer for argument (Can't deserialize " - "the second template type argument of " - "CanDeserializeCheck from the first)"); - }; - -public: - /// Construct an RPC instance on a channel. - RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) - : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { - // Hold ResponseId in a special variable, since we expect Response to be - // called relatively frequently, and want to avoid the map lookup. - ResponseId = FnIdAllocator.getResponseId(); - RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; - - // Register the negotiate function id and handler. - auto NegotiateId = FnIdAllocator.getNegotiateId(); - RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; - Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( - [this](const std::string &Name) { return handleNegotiate(Name); }); - } - - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - return getRemoteFunctionId<Func>(true, Retry).takeError(); - } - - /// Append a call Func, does not call send on the channel. - /// The first argument specifies a user-defined handler to be run when the - /// function returns. The handler should take an Expected<Func::ReturnType>, - /// or an Error (if Func::ReturnType is void). The handler will be called - /// with an error if the return value is abandoned due to a channel error. - template <typename Func, typename HandlerT, typename... ArgTs> - Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) { - - static_assert( - detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, - void(ArgTs...)>::value, - ""); - - // Look up the function ID. - FunctionIdT FnId; - if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) - FnId = *FnIdOrErr; - else { - // Negotiation failed. Notify the handler then return the negotiate-failed - // error. - cantFail(Handler(make_error<ResponseAbandoned>())); - return FnIdOrErr.takeError(); - } - - SequenceNumberT SeqNo; // initialized in locked scope below. - { - // Lock the pending responses map and sequence number manager. - std::lock_guard<std::mutex> Lock(ResponsesMutex); - - // Allocate a sequence number. - SeqNo = SequenceNumberMgr.getSequenceNumber(); - assert(!PendingResponses.count(SeqNo) && - "Sequence number already allocated"); - - // Install the user handler. - PendingResponses[SeqNo] = - detail::createResponseHandler<ChannelT, typename Func::ReturnType>( - std::move(Handler)); - } - - // Open the function call message. - if (auto Err = C.startSendMessage(FnId, SeqNo)) { - abandonPendingResponses(); - return Err; - } - - // Serialize the call arguments. - if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( - C, Args...)) { - abandonPendingResponses(); - return Err; - } - - // Close the function call messagee. - if (auto Err = C.endSendMessage()) { - abandonPendingResponses(); - return Err; - } - - return Error::success(); - } - - Error sendAppendedCalls() { return C.send(); }; - - template <typename Func, typename HandlerT, typename... ArgTs> - Error callAsync(HandlerT Handler, const ArgTs &...Args) { - if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) - return Err; - return C.send(); - } - - /// Handle one incoming call. - Error handleOne() { - FunctionIdT FnId; - SequenceNumberT SeqNo; - if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { - abandonPendingResponses(); - return Err; - } - if (FnId == ResponseId) - return handleResponse(SeqNo); - auto I = Handlers.find(FnId); - if (I != Handlers.end()) - return I->second(C, SeqNo); - - // else: No handler found. Report error to client? - return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, - SeqNo); - } - - /// Helper for handling setter procedures - this method returns a functor that - /// sets the variables referred to by Args... to values deserialized from the - /// channel. - /// E.g. - /// - /// typedef Function<0, bool, int> Func1; - /// - /// ... - /// bool B; - /// int I; - /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) - /// /* Handle Args */ ; - /// - template <typename... ArgTs> - static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) { - return detail::ReadArgs<ArgTs...>(Args...); - } - - /// Abandon all outstanding result handlers. - /// - /// This will call all currently registered result handlers to receive an - /// "abandoned" error as their argument. This is used internally by the RPC - /// in error situations, but can also be called directly by clients who are - /// disconnecting from the remote and don't or can't expect responses to their - /// outstanding calls. (Especially for outstanding blocking calls, calling - /// this function may be necessary to avoid dead threads). - void abandonPendingResponses() { - // Lock the pending responses map and sequence number manager. - std::lock_guard<std::mutex> Lock(ResponsesMutex); - - for (auto &KV : PendingResponses) - KV.second->abandon(); - PendingResponses.clear(); - SequenceNumberMgr.reset(); - } - - /// Remove the handler for the given function. - /// A handler must currently be registered for this function. - template <typename Func> void removeHandler() { - auto IdItr = LocalFunctionIds.find(Func::getPrototype()); - assert(IdItr != LocalFunctionIds.end() && - "Function does not have a registered handler"); - auto HandlerItr = Handlers.find(IdItr->second); - assert(HandlerItr != Handlers.end() && - "Function does not have a registered handler"); - Handlers.erase(HandlerItr); - } - - /// Clear all handlers. - void clearHandlers() { Handlers.clear(); } - -protected: - FunctionIdT getInvalidFunctionId() const { - return FnIdAllocator.getInvalidId(); - } - - /// Add the given handler to the handler map and make it available for - /// autonegotiation and execution. - template <typename Func, typename HandlerT> - void addHandlerImpl(HandlerT Handler) { - - static_assert(detail::RPCArgTypeCheck< - CanDeserializeCheck, typename Func::Type, - typename detail::HandlerTraits<HandlerT>::Type>::value, - ""); - - FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); - LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandlerImpl(HandlerT Handler) { - - static_assert( - detail::RPCArgTypeCheck< - CanDeserializeCheck, typename Func::Type, - typename detail::AsyncHandlerTraits< - typename detail::HandlerTraits<HandlerT>::Type>::Type>::value, - ""); - - FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); - LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); - } - - Error handleResponse(SequenceNumberT SeqNo) { - using Handler = typename decltype(PendingResponses)::mapped_type; - Handler PRHandler; - - { - // Lock the pending responses map and sequence number manager. - std::unique_lock<std::mutex> Lock(ResponsesMutex); - auto I = PendingResponses.find(SeqNo); - - if (I != PendingResponses.end()) { - PRHandler = std::move(I->second); - PendingResponses.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); - } else { - // Unlock the pending results map to prevent recursive lock. - Lock.unlock(); - abandonPendingResponses(); - return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>( - SeqNo); - } - } - - assert(PRHandler && - "If we didn't find a response handler we should have bailed out"); - - if (auto Err = PRHandler->handleResponse(C)) { - abandonPendingResponses(); - return Err; - } - - return Error::success(); - } - - FunctionIdT handleNegotiate(const std::string &Name) { - auto I = LocalFunctionIds.find(Name); - if (I == LocalFunctionIds.end()) - return getInvalidFunctionId(); - return I->second; - } - - // Find the remote FunctionId for the given function. - template <typename Func> - Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, - bool NegotiateIfInvalid) { - bool DoNegotiate; - - // Check if we already have a function id... - auto I = RemoteFunctionIds.find(Func::getPrototype()); - if (I != RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != getInvalidFunctionId()) - return I->second; - DoNegotiate = NegotiateIfInvalid; - } else - DoNegotiate = NegotiateIfNotInMap; - - // We don't have a function id for Func yet, but we're allowed to try to - // negotiate one. - if (DoNegotiate) { - auto &Impl = static_cast<ImplT &>(*this); - if (auto RemoteIdOrErr = - Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { - RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == getInvalidFunctionId()) - return make_error<CouldNotNegotiate>(Func::getPrototype()); - return *RemoteIdOrErr; - } else - return RemoteIdOrErr.takeError(); - } - - // No key was available in the map and we weren't allowed to try to - // negotiate one, so return an unknown function error. - return make_error<CouldNotNegotiate>(Func::getPrototype()); - } - - using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; - - // Wrap the given user handler in the necessary argument-deserialization code, - // result-serialization code, and call to the launch policy (if present). - template <typename Func, typename HandlerT> - WrappedHandlerFn wrapHandler(HandlerT Handler) { - return [this, Handler](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { - // Start by deserializing the arguments. - using ArgsTuple = typename detail::RPCFunctionArgsTuple< - typename detail::HandlerTraits<HandlerT>::Type>::Type; - auto Args = std::make_shared<ArgsTuple>(); - - if (auto Err = - detail::HandlerTraits<typename Func::Type>::deserializeArgs( - Channel, *Args)) - return Err; - - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)Args; - - // End receieve message, unlocking the channel for reading. - if (auto Err = Channel.endReceiveMessage()) - return Err; - - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, - HTraits::unpackAndRun(Handler, *Args)); - }; - } - - // Wrap the given user handler in the necessary argument-deserialization code, - // result-serialization code, and call to the launch policy (if present). - template <typename Func, typename HandlerT> - WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { - return [this, Handler](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { - // Start by deserializing the arguments. - using AHTraits = detail::AsyncHandlerTraits< - typename detail::HandlerTraits<HandlerT>::Type>; - using ArgsTuple = - typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type; - auto Args = std::make_shared<ArgsTuple>(); - - if (auto Err = - detail::HandlerTraits<typename Func::Type>::deserializeArgs( - Channel, *Args)) - return Err; - - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)Args; - - // End receieve message, unlocking the channel for reading. - if (auto Err = Channel.endReceiveMessage()) - return Err; - - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - auto Responder = [this, - SeqNo](typename AHTraits::ResultType RetVal) -> Error { - return detail::respond<FuncReturn>(C, ResponseId, SeqNo, - std::move(RetVal)); - }; - - return HTraits::unpackAndRunAsync(Handler, Responder, *Args); - }; - } - - ChannelT &C; - - bool LazyAutoNegotiation; - - RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; - - FunctionIdT ResponseId; - std::map<std::string, FunctionIdT> LocalFunctionIds; - std::map<const char *, FunctionIdT> RemoteFunctionIds; - - std::map<FunctionIdT, WrappedHandlerFn> Handlers; - - std::mutex ResponsesMutex; - detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; - std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> - PendingResponses; -}; - -} // end namespace detail - -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint32_t> -class MultiThreadedRPCEndpoint - : public detail::RPCEndpointBase< - MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT> { -private: - using BaseClass = detail::RPCEndpointBase< - MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; - -public: - MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) - : BaseClass(C, LazyAutoNegotiation) {} - - /// Add a handler for the given RPC function. - /// This installs the given handler functor for the given RPCFunction, and - /// makes the RPC function available for negotiation/calling from the remote. - template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandler(HandlerT Handler) { - return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addAsyncHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - /// Return type for non-blocking call primitives. - template <typename Func> - using NonBlockingCallResult = typename detail::ResultTraits< - typename Func::ReturnType>::ReturnFutureType; - - /// Call Func on Channel C. Does not block, does not call send. Returns a pair - /// of a future result and the sequence number assigned to the result. - /// - /// This utility function is primarily used for single-threaded mode support, - /// where the sequence number can be used to wait for the corresponding - /// result. In multi-threaded mode the appendCallNB method, which does not - /// return the sequence numeber, should be preferred. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) { - using RTraits = detail::ResultTraits<typename Func::ReturnType>; - using ErrorReturn = typename RTraits::ErrorReturnType; - using ErrorReturnPromise = typename RTraits::ReturnPromiseType; - - ErrorReturnPromise Promise; - auto FutureResult = Promise.get_future(); - - if (auto Err = this->template appendCallAsync<Func>( - [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { - Promise.set_value(std::move(RetOrErr)); - return Error::success(); - }, - Args...)) { - RTraits::consumeAbandoned(FutureResult.get()); - return std::move(Err); - } - return std::move(FutureResult); - } - - /// The same as appendCallNBWithSeq, except that it calls C.send() to - /// flush the channel after serializing the call. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) { - auto Result = appendCallNB<Func>(Args...); - if (!Result) - return Result; - if (auto Err = this->C.send()) { - this->abandonPendingResponses(); - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result->get())); - return std::move(Err); - } - return Result; - } - - /// Call Func on Channel C. Blocks waiting for a result. Returns an Error - /// for void functions or an Expected<T> for functions returning a T. - /// - /// This function is for use in threaded code where another thread is - /// handling responses and incoming calls. - template <typename Func, typename... ArgTs, - typename AltRetT = typename Func::ReturnType> - typename detail::ResultTraits<AltRetT>::ErrorReturnType - callB(const ArgTs &...Args) { - if (auto FutureResOrErr = callNB<Func>(Args...)) - return FutureResOrErr->get(); - else - return FutureResOrErr.takeError(); - } - - /// Handle incoming RPC calls. - Error handlerLoop() { - while (true) - if (auto Err = this->handleOne()) - return Err; - return Error::success(); - } -}; - -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint32_t> -class SingleThreadedRPCEndpoint - : public detail::RPCEndpointBase< - SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT> { -private: - using BaseClass = detail::RPCEndpointBase< - SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; - -public: - SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) - : BaseClass(C, LazyAutoNegotiation) {} - - template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler)); - } - - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandler(HandlerT Handler) { - return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addAsyncHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename... ArgTs, - typename AltRetT = typename Func::ReturnType> - typename detail::ResultTraits<AltRetT>::ErrorReturnType - callB(const ArgTs &...Args) { - bool ReceivedResponse = false; - using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; - auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); - - // We have to 'Check' result (which we know is in a success state at this - // point) so that it can be overwritten in the async handler. - (void)!!Result; - - if (auto Err = this->template appendCallAsync<Func>( - [&](ResultType R) { - Result = std::move(R); - ReceivedResponse = true; - return Error::success(); - }, - Args...)) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - - if (auto Err = this->C.send()) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - - while (!ReceivedResponse) { - if (auto Err = this->handleOne()) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - } - - return Result; - } -}; - -/// Asynchronous dispatch for a function on an RPC endpoint. -template <typename RPCClass, typename Func> class RPCAsyncDispatch { -public: - RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} - - template <typename HandlerT, typename... ArgTs> - Error operator()(HandlerT Handler, const ArgTs &...Args) const { - return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); - } - -private: - RPCClass &Endpoint; -}; - -/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. -template <typename Func, typename RPCEndpointT> -RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { - return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); -} - -/// Allows a set of asynchrounous calls to be dispatched, and then -/// waited on as a group. -class ParallelCallGroup { -public: - ParallelCallGroup() = default; - ParallelCallGroup(const ParallelCallGroup &) = delete; - ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; - - /// Make as asynchronous call. - template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> - Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, - const ArgTs &...Args) { - // Increment the count of outstanding calls. This has to happen before - // we invoke the call, as the handler may (depending on scheduling) - // be run immediately on another thread, and we don't want the decrement - // in the wrapped handler below to run before the increment. - { - std::unique_lock<std::mutex> Lock(M); - ++NumOutstandingCalls; - } - - // Wrap the user handler in a lambda that will decrement the - // outstanding calls count, then poke the condition variable. - using ArgType = typename detail::ResponseHandlerArg< - typename detail::HandlerTraits<HandlerT>::Type>::ArgType; - auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { - auto Err = Handler(std::move(Arg)); - std::unique_lock<std::mutex> Lock(M); - --NumOutstandingCalls; - CV.notify_all(); - return Err; - }; - - return AsyncDispatch(std::move(WrappedHandler), Args...); - } - - /// Blocks until all calls have been completed and their return value - /// handlers run. - void wait() { - std::unique_lock<std::mutex> Lock(M); - while (NumOutstandingCalls > 0) - CV.wait(Lock); - } - -private: - std::mutex M; - std::condition_variable CV; - uint32_t NumOutstandingCalls = 0; -}; - -/// Convenience class for grouping RPCFunctions into APIs that can be -/// negotiated as a block. -/// -template <typename... Funcs> class APICalls { -public: - /// Test whether this API contains Function F. - template <typename F> class Contains { - public: - static const bool value = false; - }; - - /// Negotiate all functions in this API. - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - return Error::success(); - } -}; - -template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> { -public: - template <typename F> class Contains { - public: - static const bool value = std::is_same<F, Func>::value | - APICalls<Funcs...>::template Contains<F>::value; - }; - - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - if (auto Err = R.template negotiateFunction<Func>()) - return Err; - return APICalls<Funcs...>::negotiate(R); - } -}; - -template <typename... InnerFuncs, typename... Funcs> -class APICalls<APICalls<InnerFuncs...>, Funcs...> { -public: - template <typename F> class Contains { - public: - static const bool value = - APICalls<InnerFuncs...>::template Contains<F>::value | - APICalls<Funcs...>::template Contains<F>::value; - }; - - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) - return Err; - return APICalls<Funcs...>::negotiate(R); - } -}; - -} // end namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H + +#include <map> +#include <thread> +#include <vector> + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include <future> + +namespace llvm { +namespace orc { +namespace shared { + +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall + : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId + << "' with " + "sequence number " + << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse + : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, + RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } + +private: + SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } + +private: + std::string Signature; +}; + +template <typename DerivedFunc, typename FnT> class RPCFunction; + +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template <typename DerivedFunc, typename RetT, typename... ArgTs> +class RPCFunction<DerivedFunc, RetT(ArgTs...)> { +public: + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << SerializationTypeName<RetT>::getName() << " " + << DerivedFunc::getName() << "(" + << SerializationTypeNameSequence<ArgTs...>() << ")"; + return Name; + }(); + return Name.data(); + } +}; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template <typename Func> T allocate(): +/// Allocate a unique id for function Func. +template <typename T, typename = void> class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template <typename T> +class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> { +public: + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template <typename Func> T allocate() { return NextId++; } + +private: + T NextId = 3; +}; + +namespace detail { + +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class RPCFunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class RPCFunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>; +}; + +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template <typename RetT> class ResultTraits { +public: + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected<RetT>; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPExpected<RetT>>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> class ResultTraits<void> { +public: + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPError>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPError>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; + +// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> class ResultTraits<Error> : public ResultTraits<void> {}; + +// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected<T> (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template <typename RetT> +class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; + +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> class SupportsErrorReturn<Error> { +public: + static const bool value = true; +}; + +template <typename T> class SupportsErrorReturn<Expected<T>> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> class RespondHelper<true> { +public: + // Send Expected<T>. + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>:: + serialize(C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + if (auto Err = C.endSendMessage()) + return Err; + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA<RPCFatalError>()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> class RespondHelper<false> { +public: + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( + C, *ResultOrErr)) + return Err; + + // End the response message. + if (auto Err = C.endSendMessage()) + return Err; + + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } +}; + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + template sendResult<WireRetT>(C, ResponseId, SeqNo, + std::move(ResultOrErr)); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template <typename WireRetT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult( + C, ResponseId, SeqNo, std::move(Err)); +} + +// Converts a given type to the equivalent error return type. +template <typename T> class WrappedHandlerReturn { +public: + using Type = Expected<T>; +}; + +template <typename T> class WrappedHandlerReturn<Expected<T>> { +public: + using Type = Expected<T>; +}; + +template <> class WrappedHandlerReturn<void> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<Error> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<ErrorSuccess> { +public: + using Type = Error; +}; + +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, + ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> + : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>, + ArgTs...)> {}; + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template <typename HandlerT> +class HandlerTraits + : public HandlerTraits< + decltype(&std::remove_reference<HandlerT>::type::operator())> {}; + +// Traits for handlers with a given function type. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(ArgTs...)> { +public: + // Function type of the handler. + using Type = RetT(ArgTs...); + + // Return type of the handler. + using ReturnType = RetT; + + // Call the given handler with the given arguments. + template <typename HandlerT, typename... TArgTs> + static typename WrappedHandlerReturn<RetT>::Type + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { + return unpackAndRunHelper(Handler, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT> + static std::enable_if_t< + std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error> + run(HandlerT &Handler, ArgTs &&...Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template <typename HandlerT, typename... TArgTs> + static std::enable_if_t< + !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + typename HandlerTraits<HandlerT>::ReturnType> + run(HandlerT &Handler, TArgTs... Args) { + return Handler(std::move(Args)...); + } + + // Serialize arguments to the channel. + template <typename ChannelT, typename... CArgTs> + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + } + + // Deserialize arguments from the channel. + template <typename ChannelT, typename... CArgTs> + static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { + return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); + } + +private: + template <typename ChannelT, typename... CArgTs, size_t... Indexes> + static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, + std::index_sequence<Indexes...> _) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize( + C, std::get<Indexes>(Args)...); + } + + template <typename HandlerT, typename ArgTuple, size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, + std::index_sequence<Indexes...>) { + return run(Handler, std::move(std::get<Indexes>(Args))...); + } + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, std::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } +}; + +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> { +}; + +// Handler traits for class methods (especially call operators for lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Handler traits for const class methods (especially call operators for +// lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...) const> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Utility to peel the Expected wrapper off a response handler error type. +template <typename HandlerT> class ResponseHandlerArg; + +template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <typename ArgT> +class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg<Error(Error)> { +public: + using ArgType = Error; +}; + +template <> class ResponseHandlerArg<ErrorSuccess(Error)> { +public: + using ArgType = Error; +}; + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template <typename ChannelT> class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error<ResponseAbandoned>(); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; + UnwrappedArgType Result; + if (auto Err = + SerializationTraits<ChannelT, FuncRetT, + UnwrappedArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// ResponseHandler subclass for RPC functions with void returns. +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, void, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( + C, Result)) { + consumeError(std::move(Result)); + return Err; + } + if (auto Err = C.endReceiveMessage()) { + consumeError(std::move(Result)); + return Err; + } + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// Create a ResponseHandler from a given user handler. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { + return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( + std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template <typename ClassT, typename RetT, typename... ArgTs> +class MemberFnWrapper { +public: + using MethodT = RetT (ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&...Args) { + return (Instance.*Method)(std::move(Args)...); + } + +private: + ClassT &Instance; + MethodT Method; +}; + +// Helper that provides a Functor for deserializing arguments. +template <typename... ArgTs> class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; + +template <typename ArgT, typename... ArgTs> +class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { +public: + ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} + + Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs<ArgTs...>::operator()(ArgVals...); + } + +private: + ArgT &Arg; +}; + +// Manage sequence numbers. +template <typename SequenceNumberT> class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; +}; + +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template <template <class, class> class P, typename T1Tuple, typename T2Tuple> +class RPCArgTypeCheckHelper; + +template <template <class, class> class P> +class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { +public: + static const bool value = true; +}; + +template <template <class, class> class P, typename T, typename... Ts, + typename U, typename... Us> +class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { +public: + static const bool value = + P<T, U>::value && + RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; +}; + +template <template <class, class> class P, typename T1Sig, typename T2Sig> +class RPCArgTypeCheck { +public: + using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type; + using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type; + + static_assert(std::tuple_size<T1Tuple>::value >= + std::tuple_size<T2Tuple>::value, + "Too many arguments to RPC call"); + static_assert(std::tuple_size<T1Tuple>::value <= + std::tuple_size<T2Tuple>::value, + "Too few arguments to RPC call"); + + static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanSerialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type check( + std::enable_if_t<std::is_same<decltype(T::serialize( + std::declval<ChannelT &>(), + std::declval<const ConcreteT &>())), + Error>::value, + void *>); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanDeserialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(std::enable_if_t< + std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), + std::declval<ConcreteT &>())), + Error>::value, + void *>); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +/// Contains primitive utilities for defining, calling and handling calls to +/// remote procedures. ChannelT is a bidirectional stream conforming to the +/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure +/// identifier type that must be serializable on ChannelT, and SequenceNumberT +/// is an integral type that will be used to number in-flight function calls. +/// +/// These utilities support the construction of very primitive RPC utilities. +/// Their intent is to ensure correct serialization and deserialization of +/// procedure arguments, and to keep the client and server's view of the API in +/// sync. +template <typename ImplT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +class RPCEndpointBase { +protected: + class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> { + public: + static const char *getName() { return "__orc_rpc$invalid"; } + }; + + class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> { + public: + static const char *getName() { return "__orc_rpc$response"; } + }; + + class OrcRPCNegotiate + : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> { + public: + static const char *getName() { return "__orc_rpc$negotiate"; } + }; + + // Helper predicate for testing for the presence of SerializeTraits + // serializers. + template <typename WireT, typename ConcreteT> + class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing serializer for argument (Can't serialize the " + "first template type argument of CanSerializeCheck " + "from the second)"); + }; + + // Helper predicate for testing for the presence of SerializeTraits + // deserializers. + template <typename WireT, typename ConcreteT> + class CanDeserializeCheck + : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing deserializer for argument (Can't deserialize " + "the second template type argument of " + "CanDeserializeCheck from the first)"); + }; + +public: + /// Construct an RPC instance on a channel. + RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) + : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { + // Hold ResponseId in a special variable, since we expect Response to be + // called relatively frequently, and want to avoid the map lookup. + ResponseId = FnIdAllocator.getResponseId(); + RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; + + // Register the negotiate function id and handler. + auto NegotiateId = FnIdAllocator.getNegotiateId(); + RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; + Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( + [this](const std::string &Name) { return handleNegotiate(Name); }); + } + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> Error negotiateFunction(bool Retry = false) { + return getRemoteFunctionId<Func>(true, Retry).takeError(); + } + + /// Append a call Func, does not call send on the channel. + /// The first argument specifies a user-defined handler to be run when the + /// function returns. The handler should take an Expected<Func::ReturnType>, + /// or an Error (if Func::ReturnType is void). The handler will be called + /// with an error if the return value is abandoned due to a channel error. + template <typename Func, typename HandlerT, typename... ArgTs> + Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) { + + static_assert( + detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, + void(ArgTs...)>::value, + ""); + + // Look up the function ID. + FunctionIdT FnId; + if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) + FnId = *FnIdOrErr; + else { + // Negotiation failed. Notify the handler then return the negotiate-failed + // error. + cantFail(Handler(make_error<ResponseAbandoned>())); + return FnIdOrErr.takeError(); + } + + SequenceNumberT SeqNo; // initialized in locked scope below. + { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + // Allocate a sequence number. + SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); + + // Install the user handler. + PendingResponses[SeqNo] = + detail::createResponseHandler<ChannelT, typename Func::ReturnType>( + std::move(Handler)); + } + + // Open the function call message. + if (auto Err = C.startSendMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + + // Serialize the call arguments. + if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( + C, Args...)) { + abandonPendingResponses(); + return Err; + } + + // Close the function call messagee. + if (auto Err = C.endSendMessage()) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + Error sendAppendedCalls() { return C.send(); }; + + template <typename Func, typename HandlerT, typename... ArgTs> + Error callAsync(HandlerT Handler, const ArgTs &...Args) { + if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) + return Err; + return C.send(); + } + + /// Handle one incoming call. + Error handleOne() { + FunctionIdT FnId; + SequenceNumberT SeqNo; + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + if (FnId == ResponseId) + return handleResponse(SeqNo); + auto I = Handlers.find(FnId); + if (I != Handlers.end()) + return I->second(C, SeqNo); + + // else: No handler found. Report error to client? + return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, + SeqNo); + } + + /// Helper for handling setter procedures - this method returns a functor that + /// sets the variables referred to by Args... to values deserialized from the + /// channel. + /// E.g. + /// + /// typedef Function<0, bool, int> Func1; + /// + /// ... + /// bool B; + /// int I; + /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) + /// /* Handle Args */ ; + /// + template <typename... ArgTs> + static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) { + return detail::ReadArgs<ArgTs...>(Args...); + } + + /// Abandon all outstanding result handlers. + /// + /// This will call all currently registered result handlers to receive an + /// "abandoned" error as their argument. This is used internally by the RPC + /// in error situations, but can also be called directly by clients who are + /// disconnecting from the remote and don't or can't expect responses to their + /// outstanding calls. (Especially for outstanding blocking calls, calling + /// this function may be necessary to avoid dead threads). + void abandonPendingResponses() { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } + + /// Remove the handler for the given function. + /// A handler must currently be registered for this function. + template <typename Func> void removeHandler() { + auto IdItr = LocalFunctionIds.find(Func::getPrototype()); + assert(IdItr != LocalFunctionIds.end() && + "Function does not have a registered handler"); + auto HandlerItr = Handlers.find(IdItr->second); + assert(HandlerItr != Handlers.end() && + "Function does not have a registered handler"); + Handlers.erase(HandlerItr); + } + + /// Clear all handlers. + void clearHandlers() { Handlers.clear(); } + +protected: + FunctionIdT getInvalidFunctionId() const { + return FnIdAllocator.getInvalidId(); + } + + /// Add the given handler to the handler map and make it available for + /// autonegotiation and execution. + template <typename Func, typename HandlerT> + void addHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::HandlerTraits<HandlerT>::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert( + detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); + } + + Error handleResponse(SequenceNumberT SeqNo) { + using Handler = typename decltype(PendingResponses)::mapped_type; + Handler PRHandler; + + { + // Lock the pending responses map and sequence number manager. + std::unique_lock<std::mutex> Lock(ResponsesMutex); + auto I = PendingResponses.find(SeqNo); + + if (I != PendingResponses.end()) { + PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + } else { + // Unlock the pending results map to prevent recursive lock. + Lock.unlock(); + abandonPendingResponses(); + return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>( + SeqNo); + } + } + + assert(PRHandler && + "If we didn't find a response handler we should have bailed out"); + + if (auto Err = PRHandler->handleResponse(C)) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + FunctionIdT handleNegotiate(const std::string &Name) { + auto I = LocalFunctionIds.find(Name); + if (I == LocalFunctionIds.end()) + return getInvalidFunctionId(); + return I->second; + } + + // Find the remote FunctionId for the given function. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, + bool NegotiateIfInvalid) { + bool DoNegotiate; + + // Check if we already have a function id... + auto I = RemoteFunctionIds.find(Func::getPrototype()); + if (I != RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != getInvalidFunctionId()) + return I->second; + DoNegotiate = NegotiateIfInvalid; + } else + DoNegotiate = NegotiateIfNotInMap; + + // We don't have a function id for Func yet, but we're allowed to try to + // negotiate one. + if (DoNegotiate) { + auto &Impl = static_cast<ImplT &>(*this); + if (auto RemoteIdOrErr = + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == getInvalidFunctionId()) + return make_error<CouldNotNegotiate>(Func::getPrototype()); + return *RemoteIdOrErr; + } else + return RemoteIdOrErr.takeError(); + } + + // No key was available in the map and we weren't allowed to try to + // negotiate one, so return an unknown function error. + return make_error<CouldNotNegotiate>(Func::getPrototype()); + } + + using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using ArgsTuple = typename detail::RPCFunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = [this, + SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); + }; + } + + ChannelT &C; + + bool LazyAutoNegotiation; + + RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; + + FunctionIdT ResponseId; + std::map<std::string, FunctionIdT> LocalFunctionIds; + std::map<const char *, FunctionIdT> RemoteFunctionIds; + + std::map<FunctionIdT, WrappedHandlerFn> Handlers; + + std::mutex ResponsesMutex; + detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> + PendingResponses; +}; + +} // end namespace detail + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class MultiThreadedRPCEndpoint + : public detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + /// Add a handler for the given RPC function. + /// This installs the given handler functor for the given RPCFunction, and + /// makes the RPC function available for negotiation/calling from the remote. + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + /// Return type for non-blocking call primitives. + template <typename Func> + using NonBlockingCallResult = typename detail::ResultTraits< + typename Func::ReturnType>::ReturnFutureType; + + /// Call Func on Channel C. Does not block, does not call send. Returns a pair + /// of a future result and the sequence number assigned to the result. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallNB method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) { + using RTraits = detail::ResultTraits<typename Func::ReturnType>; + using ErrorReturn = typename RTraits::ErrorReturnType; + using ErrorReturnPromise = typename RTraits::ReturnPromiseType; + + ErrorReturnPromise Promise; + auto FutureResult = Promise.get_future(); + + if (auto Err = this->template appendCallAsync<Func>( + [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { + Promise.set_value(std::move(RetOrErr)); + return Error::success(); + }, + Args...)) { + RTraits::consumeAbandoned(FutureResult.get()); + return std::move(Err); + } + return std::move(FutureResult); + } + + /// The same as appendCallNBWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) { + auto Result = appendCallNB<Func>(Args...); + if (!Result) + return Result; + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result->get())); + return std::move(Err); + } + return Result; + } + + /// Call Func on Channel C. Blocks waiting for a result. Returns an Error + /// for void functions or an Expected<T> for functions returning a T. + /// + /// This function is for use in threaded code where another thread is + /// handling responses and incoming calls. + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &...Args) { + if (auto FutureResOrErr = callNB<Func>(Args...)) + return FutureResOrErr->get(); + else + return FutureResOrErr.takeError(); + } + + /// Handle incoming RPC calls. + Error handlerLoop() { + while (true) + if (auto Err = this->handleOne()) + return Err; + return Error::success(); + } +}; + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class SingleThreadedRPCEndpoint + : public detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &...Args) { + bool ReceivedResponse = false; + using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; + auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); + + // We have to 'Check' result (which we know is in a success state at this + // point) so that it can be overwritten in the async handler. + (void)!!Result; + + if (auto Err = this->template appendCallAsync<Func>( + [&](ResultType R) { + Result = std::move(R); + ReceivedResponse = true; + return Error::success(); + }, + Args...)) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + if (auto Err = this->C.send()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + while (!ReceivedResponse) { + if (auto Err = this->handleOne()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + } + + return Result; + } +}; + +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> class RPCAsyncDispatch { +public: + RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + + template <typename HandlerT, typename... ArgTs> + Error operator()(HandlerT Handler, const ArgTs &...Args) const { + return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); + } + +private: + RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { + return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} + +/// Allows a set of asynchrounous calls to be dispatched, and then +/// waited on as a group. +class ParallelCallGroup { +public: + ParallelCallGroup() = default; + ParallelCallGroup(const ParallelCallGroup &) = delete; + ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; + + /// Make as asynchronous call. + template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> + Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, + const ArgTs &...Args) { + // Increment the count of outstanding calls. This has to happen before + // we invoke the call, as the handler may (depending on scheduling) + // be run immediately on another thread, and we don't want the decrement + // in the wrapped handler below to run before the increment. + { + std::unique_lock<std::mutex> Lock(M); + ++NumOutstandingCalls; + } + + // Wrap the user handler in a lambda that will decrement the + // outstanding calls count, then poke the condition variable. + using ArgType = typename detail::ResponseHandlerArg< + typename detail::HandlerTraits<HandlerT>::Type>::ArgType; + auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { + auto Err = Handler(std::move(Arg)); + std::unique_lock<std::mutex> Lock(M); + --NumOutstandingCalls; + CV.notify_all(); + return Err; + }; + + return AsyncDispatch(std::move(WrappedHandler), Args...); + } + + /// Blocks until all calls have been completed and their return value + /// handlers run. + void wait() { + std::unique_lock<std::mutex> Lock(M); + while (NumOutstandingCalls > 0) + CV.wait(Lock); + } + +private: + std::mutex M; + std::condition_variable CV; + uint32_t NumOutstandingCalls = 0; +}; + +/// Convenience class for grouping RPCFunctions into APIs that can be +/// negotiated as a block. +/// +template <typename... Funcs> class APICalls { +public: + /// Test whether this API contains Function F. + template <typename F> class Contains { + public: + static const bool value = false; + }; + + /// Negotiate all functions in this API. + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + return Error::success(); + } +}; + +template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> { +public: + template <typename F> class Contains { + public: + static const bool value = std::is_same<F, Func>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + if (auto Err = R.template negotiateFunction<Func>()) + return Err; + return APICalls<Funcs...>::negotiate(R); + } +}; + +template <typename... InnerFuncs, typename... Funcs> +class APICalls<APICalls<InnerFuncs...>, Funcs...> { +public: + template <typename F> class Contains { + public: + static const bool value = + APICalls<InnerFuncs...>::template Contains<F>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) + return Err; + return APICalls<Funcs...>::negotiate(R); + } +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h index 94bb6c7739..4f6175af33 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h @@ -1,194 +1,194 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- RawByteChannel.h -----------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H - -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" -#include "llvm/Support/Endian.h" -#include "llvm/Support/Error.h" -#include <cstdint> -#include <mutex> -#include <string> -#include <type_traits> - -namespace llvm { -namespace orc { -namespace shared { - -/// Interface for byte-streams to be used with ORC Serialization. -class RawByteChannel { -public: - virtual ~RawByteChannel() = default; - - /// Read Size bytes from the stream into *Dst. - virtual Error readBytes(char *Dst, unsigned Size) = 0; - - /// Read size bytes from *Src and append them to the stream. - virtual Error appendBytes(const char *Src, unsigned Size) = 0; - - /// Flush the stream if possible. - virtual Error send() = 0; - - /// Notify the channel that we're starting a message send. - /// Locks the channel for writing. - template <typename FunctionIdT, typename SequenceIdT> - Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { - writeLock.lock(); - if (auto Err = serializeSeq(*this, FnId, SeqNo)) { - writeLock.unlock(); - return Err; - } - return Error::success(); - } - - /// Notify the channel that we're ending a message send. - /// Unlocks the channel for writing. - Error endSendMessage() { - writeLock.unlock(); - return Error::success(); - } - - /// Notify the channel that we're starting a message receive. - /// Locks the channel for reading. - template <typename FunctionIdT, typename SequenceNumberT> - Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { - readLock.lock(); - if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { - readLock.unlock(); - return Err; - } - return Error::success(); - } - - /// Notify the channel that we're ending a message receive. - /// Unlocks the channel for reading. - Error endReceiveMessage() { - readLock.unlock(); - return Error::success(); - } - - /// Get the lock for stream reading. - std::mutex &getReadLock() { return readLock; } - - /// Get the lock for stream writing. - std::mutex &getWriteLock() { return writeLock; } - -private: - std::mutex readLock, writeLock; -}; - -template <typename ChannelT, typename T> -class SerializationTraits< - ChannelT, T, T, - std::enable_if_t< - std::is_base_of<RawByteChannel, ChannelT>::value && - (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || - std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || - std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || - std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || - std::is_same<T, char>::value)>> { -public: - static Error serialize(ChannelT &C, T V) { - support::endian::byte_swap<T, support::big>(V); - return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); - }; - - static Error deserialize(ChannelT &C, T &V) { - if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) - return Err; - support::endian::byte_swap<T, support::big>(V); - return Error::success(); - }; -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, bool, bool, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - static Error serialize(ChannelT &C, bool V) { - uint8_t Tmp = V ? 1 : 0; - if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) - return Err; - return Error::success(); - } - - static Error deserialize(ChannelT &C, bool &V) { - uint8_t Tmp = 0; - if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) - return Err; - V = Tmp != 0; - return Error::success(); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, std::string, StringRef, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - /// Serialization channel serialization for std::strings. - static Error serialize(RawByteChannel &C, StringRef S) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) - return Err; - return C.appendBytes((const char *)S.data(), S.size()); - } -}; - -template <typename ChannelT, typename T> -class SerializationTraits< - ChannelT, std::string, T, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && - (std::is_same<T, const char *>::value || - std::is_same<T, char *>::value)>> { -public: - static Error serialize(RawByteChannel &C, const char *S) { - return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, - S); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, std::string, std::string, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - /// Serialization channel serialization for std::strings. - static Error serialize(RawByteChannel &C, const std::string &S) { - return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, - S); - } - - /// Serialization channel deserialization for std::strings. - static Error deserialize(RawByteChannel &C, std::string &S) { - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - S.resize(Count); - return C.readBytes(&S[0], Count); - } -}; - -} // end namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- RawByteChannel.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include <cstdint> +#include <mutex> +#include <string> +#include <type_traits> + +namespace llvm { +namespace orc { +namespace shared { + +/// Interface for byte-streams to be used with ORC Serialization. +class RawByteChannel { +public: + virtual ~RawByteChannel() = default; + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template <typename FunctionIdT, typename SequenceIdT> + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + writeLock.lock(); + if (auto Err = serializeSeq(*this, FnId, SeqNo)) { + writeLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template <typename FunctionIdT, typename SequenceNumberT> + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { + readLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, T, T, + std::enable_if_t< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || + std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || + std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || + std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || + std::is_same<T, char>::value)>> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap<T, support::big>(V); + return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; + support::endian::byte_swap<T, support::big>(V); + return Error::success(); + }; +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, bool, bool, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + static Error serialize(ChannelT &C, bool V) { + uint8_t Tmp = V ? 1 : 0; + if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) + return Err; + return Error::success(); + } + + static Error deserialize(ChannelT &C, bool &V) { + uint8_t Tmp = 0; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) + return Err; + V = Tmp != 0; + return Error::success(); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, std::string, StringRef, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + /// Serialization channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, std::string, T, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, const char *>::value || + std::is_same<T, char *>::value)>> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, std::string, std::string, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + /// Serialization channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } + + /// Serialization channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h index 5f4e2767f0..fa48a7af43 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h @@ -1,780 +1,780 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- Serialization.h ------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" -#include "llvm/Support/thread.h" -#include <map> -#include <mutex> -#include <set> -#include <sstream> -#include <string> -#include <vector> - -namespace llvm { -namespace orc { -namespace shared { - -template <typename T> class SerializationTypeName; - -/// TypeNameSequence is a utility for rendering sequences of types to a string -/// by rendering each type, separated by ", ". -template <typename... ArgTs> class SerializationTypeNameSequence {}; - -/// Render an empty TypeNameSequence to an ostream. -template <typename OStream> -OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<> &V) { - return OS; -} - -/// Render a TypeNameSequence of a single type to an ostream. -template <typename OStream, typename ArgT> -OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<ArgT> &V) { - OS << SerializationTypeName<ArgT>::getName(); - return OS; -} - -/// Render a TypeNameSequence of more than one type to an ostream. -template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> -OStream & -operator<<(OStream &OS, - const SerializationTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { - OS << SerializationTypeName<ArgT1>::getName() << ", " - << SerializationTypeNameSequence<ArgT2, ArgTs...>(); - return OS; -} - -template <> class SerializationTypeName<void> { -public: - static const char *getName() { return "void"; } -}; - -template <> class SerializationTypeName<int8_t> { -public: - static const char *getName() { return "int8_t"; } -}; - -template <> class SerializationTypeName<uint8_t> { -public: - static const char *getName() { return "uint8_t"; } -}; - -template <> class SerializationTypeName<int16_t> { -public: - static const char *getName() { return "int16_t"; } -}; - -template <> class SerializationTypeName<uint16_t> { -public: - static const char *getName() { return "uint16_t"; } -}; - -template <> class SerializationTypeName<int32_t> { -public: - static const char *getName() { return "int32_t"; } -}; - -template <> class SerializationTypeName<uint32_t> { -public: - static const char *getName() { return "uint32_t"; } -}; - -template <> class SerializationTypeName<int64_t> { -public: - static const char *getName() { return "int64_t"; } -}; - -template <> class SerializationTypeName<uint64_t> { -public: - static const char *getName() { return "uint64_t"; } -}; - -template <> class SerializationTypeName<bool> { -public: - static const char *getName() { return "bool"; } -}; - -template <> class SerializationTypeName<std::string> { -public: - static const char *getName() { return "std::string"; } -}; - -template <> class SerializationTypeName<Error> { -public: - static const char *getName() { return "Error"; } -}; - -template <typename T> class SerializationTypeName<Expected<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "Expected<" << SerializationTypeNameSequence<T>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T1, typename T2> -class SerializationTypeName<std::pair<T1, T2>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::pair<" << SerializationTypeNameSequence<T1, T2>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename... ArgTs> class SerializationTypeName<std::tuple<ArgTs...>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::tuple<" << SerializationTypeNameSequence<ArgTs...>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<Optional<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "Optional<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<std::vector<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::vector<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<std::set<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::set<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename K, typename V> class SerializationTypeName<std::map<K, V>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::map<" << SerializationTypeNameSequence<K, V>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -/// The SerializationTraits<ChannelT, T> class describes how to serialize and -/// deserialize an instance of type T to/from an abstract channel of type -/// ChannelT. It also provides a representation of the type's name via the -/// getName method. -/// -/// Specializations of this class should provide the following functions: -/// -/// @code{.cpp} -/// -/// static const char* getName(); -/// static Error serialize(ChannelT&, const T&); -/// static Error deserialize(ChannelT&, T&); -/// -/// @endcode -/// -/// The third argument of SerializationTraits is intended to support SFINAE. -/// E.g.: -/// -/// @code{.cpp} -/// -/// class MyVirtualChannel { ... }; -/// -/// template <DerivedChannelT> -/// class SerializationTraits<DerivedChannelT, bool, -/// std::enable_if_t< -/// std::is_base_of<VirtChannel, DerivedChannel>::value -/// >> { -/// public: -/// static const char* getName() { ... }; -/// } -/// -/// @endcode -template <typename ChannelT, typename WireType, - typename ConcreteType = WireType, typename = void> -class SerializationTraits; - -template <typename ChannelT> class SequenceTraits { -public: - static Error emitSeparator(ChannelT &C) { return Error::success(); } - static Error consumeSeparator(ChannelT &C) { return Error::success(); } -}; - -/// Utility class for serializing sequences of values of varying types. -/// Specializations of this class contain 'serialize' and 'deserialize' methods -/// for the given channel. The ArgTs... list will determine the "over-the-wire" -/// types to be serialized. The serialize and deserialize methods take a list -/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., -/// but may be different types from ArgTs, provided that for each CArgT there -/// is a SerializationTraits specialization -/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the -/// caller argument to over-the-wire value. -template <typename ChannelT, typename... ArgTs> class SequenceSerialization; - -template <typename ChannelT> class SequenceSerialization<ChannelT> { -public: - static Error serialize(ChannelT &C) { return Error::success(); } - static Error deserialize(ChannelT &C) { return Error::success(); } -}; - -template <typename ChannelT, typename ArgT> -class SequenceSerialization<ChannelT, ArgT> { -public: - template <typename CArgT> static Error serialize(ChannelT &C, CArgT &&CArg) { - return SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( - C, std::forward<CArgT>(CArg)); - } - - template <typename CArgT> static Error deserialize(ChannelT &C, CArgT &CArg) { - return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); - } -}; - -template <typename ChannelT, typename ArgT, typename... ArgTs> -class SequenceSerialization<ChannelT, ArgT, ArgTs...> { -public: - template <typename CArgT, typename... CArgTs> - static Error serialize(ChannelT &C, CArgT &&CArg, CArgTs &&...CArgs) { - if (auto Err = - SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( - C, std::forward<CArgT>(CArg))) - return Err; - if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) - return Err; - return SequenceSerialization<ChannelT, ArgTs...>::serialize( - C, std::forward<CArgTs>(CArgs)...); - } - - template <typename CArgT, typename... CArgTs> - static Error deserialize(ChannelT &C, CArgT &CArg, CArgTs &...CArgs) { - if (auto Err = - SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) - return Err; - if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) - return Err; - return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); - } -}; - -template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, ArgTs &&...Args) { - return SequenceSerialization<ChannelT, std::decay_t<ArgTs>...>::serialize( - C, std::forward<ArgTs>(Args)...); -} - -template <typename ChannelT, typename... ArgTs> -Error deserializeSeq(ChannelT &C, ArgTs &...Args) { - return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); -} - -template <typename ChannelT> class SerializationTraits<ChannelT, Error> { -public: - using WrappedErrorSerializer = - std::function<Error(ChannelT &C, const ErrorInfoBase &)>; - - using WrappedErrorDeserializer = - std::function<Error(ChannelT &C, Error &Err)>; - - template <typename ErrorInfoT, typename SerializeFtor, - typename DeserializeFtor> - static void registerErrorType(std::string Name, SerializeFtor Serialize, - DeserializeFtor Deserialize) { - assert(!Name.empty() && - "The empty string is reserved for the Success value"); - - const std::string *KeyName = nullptr; - { - // We're abusing the stability of std::map here: We take a reference to - // the key of the deserializers map to save us from duplicating the string - // in the serializer. This should be changed to use a stringpool if we - // switch to a map type that may move keys in memory. - std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); - auto I = Deserializers.insert( - Deserializers.begin(), - std::make_pair(std::move(Name), std::move(Deserialize))); - KeyName = &I->first; - } - - { - assert(KeyName != nullptr && "No keyname pointer"); - std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); - Serializers[ErrorInfoT::classID()] = - [KeyName, Serialize = std::move(Serialize)]( - ChannelT &C, const ErrorInfoBase &EIB) -> Error { - assert(EIB.dynamicClassID() == ErrorInfoT::classID() && - "Serializer called for wrong error type"); - if (auto Err = serializeSeq(C, *KeyName)) - return Err; - return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); - }; - } - } - - static Error serialize(ChannelT &C, Error &&Err) { - std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); - - if (!Err) - return serializeSeq(C, std::string()); - - return handleErrors(std::move(Err), [&C](const ErrorInfoBase &EIB) { - auto SI = Serializers.find(EIB.dynamicClassID()); - if (SI == Serializers.end()) - return serializeAsStringError(C, EIB); - return (SI->second)(C, EIB); - }); - } - - static Error deserialize(ChannelT &C, Error &Err) { - std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); - - std::string Key; - if (auto Err = deserializeSeq(C, Key)) - return Err; - - if (Key.empty()) { - ErrorAsOutParameter EAO(&Err); - Err = Error::success(); - return Error::success(); - } - - auto DI = Deserializers.find(Key); - assert(DI != Deserializers.end() && "No deserializer for error type"); - return (DI->second)(C, Err); - } - -private: - static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { - std::string ErrMsg; - { - raw_string_ostream ErrMsgStream(ErrMsg); - EIB.log(ErrMsgStream); - } - return serialize(C, make_error<StringError>(std::move(ErrMsg), - inconvertibleErrorCode())); - } - - static std::recursive_mutex SerializersMutex; - static std::recursive_mutex DeserializersMutex; - static std::map<const void *, WrappedErrorSerializer> Serializers; - static std::map<std::string, WrappedErrorDeserializer> Deserializers; -}; - -template <typename ChannelT> -std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; - -template <typename ChannelT> -std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; - -template <typename ChannelT> -std::map<const void *, - typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> - SerializationTraits<ChannelT, Error>::Serializers; - -template <typename ChannelT> -std::map<std::string, typename SerializationTraits< - ChannelT, Error>::WrappedErrorDeserializer> - SerializationTraits<ChannelT, Error>::Deserializers; - -/// Registers a serializer and deserializer for the given error type on the -/// given channel type. -template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, - typename DeserializeFtor> -void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, - DeserializeFtor &&Deserialize) { - SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( - std::move(Name), std::forward<SerializeFtor>(Serialize), - std::forward<DeserializeFtor>(Deserialize)); -} - -/// Registers serialization/deserialization for StringError. -template <typename ChannelT> void registerStringError() { - static bool AlreadyRegistered = false; - if (!AlreadyRegistered) { - registerErrorSerialization<ChannelT, StringError>( - "StringError", - [](ChannelT &C, const StringError &SE) { - return serializeSeq(C, SE.getMessage()); - }, - [](ChannelT &C, Error &Err) -> Error { - ErrorAsOutParameter EAO(&Err); - std::string Msg; - if (auto E2 = deserializeSeq(C, Msg)) - return E2; - Err = make_error<StringError>( - std::move(Msg), - orcError(OrcErrorCode::UnknownErrorCodeFromRemote)); - return Error::success(); - }); - AlreadyRegistered = true; - } -} - -/// SerializationTraits for Expected<T1> from an Expected<T2>. -template <typename ChannelT, typename T1, typename T2> -class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { -public: - static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { - if (ValOrErr) { - if (auto Err = serializeSeq(C, true)) - return Err; - return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); - } - if (auto Err = serializeSeq(C, false)) - return Err; - return serializeSeq(C, ValOrErr.takeError()); - } - - static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { - ExpectedAsOutParameter<T2> EAO(&ValOrErr); - bool HasValue; - if (auto Err = deserializeSeq(C, HasValue)) - return Err; - if (HasValue) - return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); - Error Err = Error::success(); - if (auto E2 = deserializeSeq(C, Err)) - return E2; - ValOrErr = std::move(Err); - return Error::success(); - } -}; - -/// SerializationTraits for Expected<T1> from a T2. -template <typename ChannelT, typename T1, typename T2> -class SerializationTraits<ChannelT, Expected<T1>, T2> { -public: - static Error serialize(ChannelT &C, T2 &&Val) { - return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); - } -}; - -/// SerializationTraits for Expected<T1> from an Error. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, Expected<T>, Error> { -public: - static Error serialize(ChannelT &C, Error &&Err) { - return serializeSeq(C, Expected<T>(std::move(Err))); - } -}; - -/// SerializationTraits default specialization for std::pair. -template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> -class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { -public: - static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { - if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) - return Err; - return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); - } - - static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { - if (auto Err = - SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) - return Err; - return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); - } -}; - -/// SerializationTraits default specialization for std::tuple. -template <typename ChannelT, typename... ArgTs> -class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { -public: - /// RPC channel serialization for std::tuple. - static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { - return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); - } - - /// RPC channel deserialization for std::tuple. - static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { - return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); - } - -private: - // Serialization helper for std::tuple. - template <size_t... Is> - static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, - std::index_sequence<Is...> _) { - return serializeSeq(C, std::get<Is>(V)...); - } - - // Serialization helper for std::tuple. - template <size_t... Is> - static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, - std::index_sequence<Is...> _) { - return deserializeSeq(C, std::get<Is>(V)...); - } -}; - -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, Optional<T>> { -public: - /// Serialize an Optional<T>. - static Error serialize(ChannelT &C, const Optional<T> &O) { - if (auto Err = serializeSeq(C, O != None)) - return Err; - if (O) - if (auto Err = serializeSeq(C, *O)) - return Err; - return Error::success(); - } - - /// Deserialize an Optional<T>. - static Error deserialize(ChannelT &C, Optional<T> &O) { - bool HasValue = false; - if (auto Err = deserializeSeq(C, HasValue)) - return Err; - if (HasValue) - if (auto Err = deserializeSeq(C, *O)) - return Err; - return Error::success(); - }; -}; - -/// SerializationTraits default specialization for std::vector. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, std::vector<T>> { -public: - /// Serialize a std::vector<T> from std::vector<T>. - static Error serialize(ChannelT &C, const std::vector<T> &V) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) - return Err; - - for (const auto &E : V) - if (auto Err = serializeSeq(C, E)) - return Err; - - return Error::success(); - } - - /// Deserialize a std::vector<T> to a std::vector<T>. - static Error deserialize(ChannelT &C, std::vector<T> &V) { - assert(V.empty() && - "Expected default-constructed vector to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - V.resize(Count); - for (auto &E : V) - if (auto Err = deserializeSeq(C, E)) - return Err; - - return Error::success(); - } -}; - -/// Enable vector serialization from an ArrayRef. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, std::vector<T>, ArrayRef<T>> { -public: - static Error serialize(ChannelT &C, ArrayRef<T> V) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) - return Err; - - for (const auto &E : V) - if (auto Err = serializeSeq(C, E)) - return Err; - - return Error::success(); - } -}; - -template <typename ChannelT, typename T, typename T2> -class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { -public: - /// Serialize a std::set<T> from std::set<T2>. - static Error serialize(ChannelT &C, const std::set<T2> &S) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) - return Err; - - for (const auto &E : S) - if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) - return Err; - - return Error::success(); - } - - /// Deserialize a std::set<T> to a std::set<T>. - static Error deserialize(ChannelT &C, std::set<T2> &S) { - assert(S.empty() && "Expected default-constructed set to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - T2 Val; - if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) - return Err; - - auto Added = S.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized set", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -template <typename ChannelT, typename K, typename V, typename K2, typename V2> -class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { -public: - /// Serialize a std::map<K, V> from std::map<K2, V2>. - static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) - return Err; - - for (const auto &E : M) { - if (auto Err = - SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) - return Err; - if (auto Err = - SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) - return Err; - } - - return Error::success(); - } - - /// Deserialize a std::map<K, V> to a std::map<K, V>. - static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { - assert(M.empty() && "Expected default-constructed map to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - std::pair<K2, V2> Val; - if (auto Err = - SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) - return Err; - - auto Added = M.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized map", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -template <typename ChannelT, typename K, typename V, typename K2, typename V2> -class SerializationTraits<ChannelT, std::map<K, V>, DenseMap<K2, V2>> { -public: - /// Serialize a std::map<K, V> from DenseMap<K2, V2>. - static Error serialize(ChannelT &C, const DenseMap<K2, V2> &M) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) - return Err; - - for (auto &E : M) { - if (auto Err = - SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) - return Err; - } - - return Error::success(); - } - - /// Serialize a std::map<K, V> from DenseMap<K2, V2>. - static Error deserialize(ChannelT &C, DenseMap<K2, V2> &M) { - assert(M.empty() && "Expected default-constructed map to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - std::pair<K2, V2> Val; - if (auto Err = - SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) - return Err; - - auto Added = M.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized map", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -} // namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_RPC_RPCSERIALIZATION_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- Serialization.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/Support/thread.h" +#include <map> +#include <mutex> +#include <set> +#include <sstream> +#include <string> +#include <vector> + +namespace llvm { +namespace orc { +namespace shared { + +template <typename T> class SerializationTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template <typename... ArgTs> class SerializationTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template <typename OStream> +OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template <typename OStream, typename ArgT> +OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<ArgT> &V) { + OS << SerializationTypeName<ArgT>::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> +OStream & +operator<<(OStream &OS, + const SerializationTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { + OS << SerializationTypeName<ArgT1>::getName() << ", " + << SerializationTypeNameSequence<ArgT2, ArgTs...>(); + return OS; +} + +template <> class SerializationTypeName<void> { +public: + static const char *getName() { return "void"; } +}; + +template <> class SerializationTypeName<int8_t> { +public: + static const char *getName() { return "int8_t"; } +}; + +template <> class SerializationTypeName<uint8_t> { +public: + static const char *getName() { return "uint8_t"; } +}; + +template <> class SerializationTypeName<int16_t> { +public: + static const char *getName() { return "int16_t"; } +}; + +template <> class SerializationTypeName<uint16_t> { +public: + static const char *getName() { return "uint16_t"; } +}; + +template <> class SerializationTypeName<int32_t> { +public: + static const char *getName() { return "int32_t"; } +}; + +template <> class SerializationTypeName<uint32_t> { +public: + static const char *getName() { return "uint32_t"; } +}; + +template <> class SerializationTypeName<int64_t> { +public: + static const char *getName() { return "int64_t"; } +}; + +template <> class SerializationTypeName<uint64_t> { +public: + static const char *getName() { return "uint64_t"; } +}; + +template <> class SerializationTypeName<bool> { +public: + static const char *getName() { return "bool"; } +}; + +template <> class SerializationTypeName<std::string> { +public: + static const char *getName() { return "std::string"; } +}; + +template <> class SerializationTypeName<Error> { +public: + static const char *getName() { return "Error"; } +}; + +template <typename T> class SerializationTypeName<Expected<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "Expected<" << SerializationTypeNameSequence<T>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T1, typename T2> +class SerializationTypeName<std::pair<T1, T2>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::pair<" << SerializationTypeNameSequence<T1, T2>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename... ArgTs> class SerializationTypeName<std::tuple<ArgTs...>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::tuple<" << SerializationTypeNameSequence<ArgTs...>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<Optional<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "Optional<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<std::vector<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::vector<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<std::set<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::set<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename K, typename V> class SerializationTypeName<std::map<K, V>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::map<" << SerializationTypeNameSequence<K, V>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +/// The SerializationTraits<ChannelT, T> class describes how to serialize and +/// deserialize an instance of type T to/from an abstract channel of type +/// ChannelT. It also provides a representation of the type's name via the +/// getName method. +/// +/// Specializations of this class should provide the following functions: +/// +/// @code{.cpp} +/// +/// static const char* getName(); +/// static Error serialize(ChannelT&, const T&); +/// static Error deserialize(ChannelT&, T&); +/// +/// @endcode +/// +/// The third argument of SerializationTraits is intended to support SFINAE. +/// E.g.: +/// +/// @code{.cpp} +/// +/// class MyVirtualChannel { ... }; +/// +/// template <DerivedChannelT> +/// class SerializationTraits<DerivedChannelT, bool, +/// std::enable_if_t< +/// std::is_base_of<VirtChannel, DerivedChannel>::value +/// >> { +/// public: +/// static const char* getName() { ... }; +/// } +/// +/// @endcode +template <typename ChannelT, typename WireType, + typename ConcreteType = WireType, typename = void> +class SerializationTraits; + +template <typename ChannelT> class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; + +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the +/// caller argument to over-the-wire value. +template <typename ChannelT, typename... ArgTs> class SequenceSerialization; + +template <typename ChannelT> class SequenceSerialization<ChannelT> { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; + +template <typename ChannelT, typename ArgT> +class SequenceSerialization<ChannelT, ArgT> { +public: + template <typename CArgT> static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( + C, std::forward<CArgT>(CArg)); + } + + template <typename CArgT> static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); + } +}; + +template <typename ChannelT, typename ArgT, typename... ArgTs> +class SequenceSerialization<ChannelT, ArgT, ArgTs...> { +public: + template <typename CArgT, typename... CArgTs> + static Error serialize(ChannelT &C, CArgT &&CArg, CArgTs &&...CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( + C, std::forward<CArgT>(CArg))) + return Err; + if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::serialize( + C, std::forward<CArgTs>(CArgs)...); + } + + template <typename CArgT, typename... CArgTs> + static Error deserialize(ChannelT &C, CArgT &CArg, CArgTs &...CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); + } +}; + +template <typename ChannelT, typename... ArgTs> +Error serializeSeq(ChannelT &C, ArgTs &&...Args) { + return SequenceSerialization<ChannelT, std::decay_t<ArgTs>...>::serialize( + C, std::forward<ArgTs>(Args)...); +} + +template <typename ChannelT, typename... ArgTs> +Error deserializeSeq(ChannelT &C, ArgTs &...Args) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); +} + +template <typename ChannelT> class SerializationTraits<ChannelT, Error> { +public: + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase &)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to + // the key of the deserializers map to save us from duplicating the string + // in the serializer. This should be changed to use a stringpool if we + // switch to a map type that may move keys in memory. + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + auto I = Deserializers.insert( + Deserializers.begin(), + std::make_pair(std::move(Name), std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize = std::move(Serialize)]( + ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; + static std::map<const void *, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void *, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> + SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, typename SerializationTraits< + ChannelT, Error>::WrappedErrorDeserializer> + SerializationTraits<ChannelT, Error>::Deserializers; + +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( + std::move(Name), std::forward<SerializeFtor>(Serialize), + std::forward<DeserializeFtor>(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. +template <typename ChannelT> void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + registerErrorSerialization<ChannelT, StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = make_error<StringError>( + std::move(Msg), + orcError(OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + +/// SerializationTraits default specialization for std::pair. +template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> +class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { +public: + static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { + if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); + } + + static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { + if (auto Err = + SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); + } +}; + +/// SerializationTraits default specialization for std::tuple. +template <typename ChannelT, typename... ArgTs> +class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { +public: + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + +private: + // Serialization helper for std::tuple. + template <size_t... Is> + static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); + } + + // Serialization helper for std::tuple. + template <size_t... Is> + static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Optional<T>> { +public: + /// Serialize an Optional<T>. + static Error serialize(ChannelT &C, const Optional<T> &O) { + if (auto Err = serializeSeq(C, O != None)) + return Err; + if (O) + if (auto Err = serializeSeq(C, *O)) + return Err; + return Error::success(); + } + + /// Deserialize an Optional<T>. + static Error deserialize(ChannelT &C, Optional<T> &O) { + bool HasValue = false; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + if (auto Err = deserializeSeq(C, *O)) + return Err; + return Error::success(); + }; +}; + +/// SerializationTraits default specialization for std::vector. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>> { +public: + /// Serialize a std::vector<T> from std::vector<T>. + static Error serialize(ChannelT &C, const std::vector<T> &V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::vector<T> to a std::vector<T>. + static Error deserialize(ChannelT &C, std::vector<T> &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = deserializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +/// Enable vector serialization from an ArrayRef. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>, ArrayRef<T>> { +public: + static Error serialize(ChannelT &C, ArrayRef<T> V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +template <typename ChannelT, typename T, typename T2> +class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { +public: + /// Serialize a std::set<T> from std::set<T2>. + static Error serialize(ChannelT &C, const std::set<T2> &S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set<T> to a std::set<T>. + static Error deserialize(ChannelT &C, std::set<T2> &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { +public: + /// Serialize a std::map<K, V> from std::map<K2, V2>. + static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map<K, V> to a std::map<K, V>. + static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, DenseMap<K2, V2>> { +public: + /// Serialize a std::map<K, V> from DenseMap<K2, V2>. + static Error serialize(ChannelT &C, const DenseMap<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Serialize a std::map<K, V> from DenseMap<K2, V2>. + static Error deserialize(ChannelT &C, DenseMap<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +} // namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPC_RPCSERIALIZATION_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h index c3dce579d7..9fc8dfaead 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h @@ -1,176 +1,176 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===--- TargetProcessControlTypes.h -- Shared Core/TPC types ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TargetProcessControl types that are used by both the Orc and -// OrcTargetProcess libraries. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/JITSymbol.h" - -#include <vector> - -namespace llvm { -namespace orc { -namespace tpctypes { - -template <typename T> struct UIntWrite { - UIntWrite() = default; - UIntWrite(JITTargetAddress Address, T Value) - : Address(Address), Value(Value) {} - - JITTargetAddress Address = 0; - T Value = 0; -}; - -/// Describes a write to a uint8_t. -using UInt8Write = UIntWrite<uint8_t>; - -/// Describes a write to a uint16_t. -using UInt16Write = UIntWrite<uint16_t>; - -/// Describes a write to a uint32_t. -using UInt32Write = UIntWrite<uint32_t>; - -/// Describes a write to a uint64_t. -using UInt64Write = UIntWrite<uint64_t>; - -/// Describes a write to a buffer. -/// For use with TargetProcessControl::MemoryAccess objects. -struct BufferWrite { - BufferWrite() = default; - BufferWrite(JITTargetAddress Address, StringRef Buffer) - : Address(Address), Buffer(Buffer) {} - - JITTargetAddress Address = 0; - StringRef Buffer; -}; - -/// A handle used to represent a loaded dylib in the target process. -using DylibHandle = JITTargetAddress; - -using LookupResult = std::vector<JITTargetAddress>; - -/// Either a uint8_t array or a uint8_t*. -union CWrapperFunctionResultData { - uint8_t Value[8]; - uint8_t *ValuePtr; -}; - -/// C ABI compatible wrapper function result. -/// -/// This can be safely returned from extern "C" functions, but should be used -/// to construct a WrapperFunctionResult for safety. -struct CWrapperFunctionResult { - uint64_t Size; - CWrapperFunctionResultData Data; - void (*Destroy)(CWrapperFunctionResultData Data, uint64_t Size); -}; - -/// C++ wrapper function result: Same as CWrapperFunctionResult but -/// auto-releases memory. -class WrapperFunctionResult { -public: - /// Create a default WrapperFunctionResult. - WrapperFunctionResult() { zeroInit(R); } - - /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This - /// instance takes ownership of the result object and will automatically - /// call the Destroy member upon destruction. - WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {} - - WrapperFunctionResult(const WrapperFunctionResult &) = delete; - WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; - - WrapperFunctionResult(WrapperFunctionResult &&Other) { - zeroInit(R); - std::swap(R, Other.R); - } - - WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(Tmp, Other.R); - std::swap(R, Tmp); - return *this; - } - - ~WrapperFunctionResult() { - if (R.Destroy) - R.Destroy(R.Data, R.Size); - } - - /// Relinquish ownership of and return the CWrapperFunctionResult. - CWrapperFunctionResult release() { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(R, Tmp); - return Tmp; - } - - /// Get an ArrayRef covering the data in the result. - ArrayRef<uint8_t> getData() const { - if (R.Size <= 8) - return ArrayRef<uint8_t>(R.Data.Value, R.Size); - return ArrayRef<uint8_t>(R.Data.ValuePtr, R.Size); - } - - /// Create a WrapperFunctionResult from the given integer, provided its - /// size is no greater than 64 bits. - template <typename T, - typename _ = std::enable_if_t<std::is_integral<T>::value && - sizeof(T) <= sizeof(uint64_t)>> - static WrapperFunctionResult from(T Value) { - CWrapperFunctionResult R; - R.Size = sizeof(T); - memcpy(&R.Data.Value, Value, R.Size); - R.Destroy = nullptr; - return R; - } - - /// Create a WrapperFunctionResult from the given string. - static WrapperFunctionResult from(StringRef S); - - /// Always free Data.ValuePtr by calling free on it. - static void destroyWithFree(CWrapperFunctionResultData Data, uint64_t Size); - - /// Always free Data.ValuePtr by calling delete[] on it. - static void destroyWithDeleteArray(CWrapperFunctionResultData Data, - uint64_t Size); - -private: - static void zeroInit(CWrapperFunctionResult &R) { - R.Size = 0; - R.Data.ValuePtr = nullptr; - R.Destroy = nullptr; - } - - CWrapperFunctionResult R; -}; - -} // end namespace tpctypes -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===--- TargetProcessControlTypes.h -- Shared Core/TPC types ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TargetProcessControl types that are used by both the Orc and +// OrcTargetProcess libraries. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" + +#include <vector> + +namespace llvm { +namespace orc { +namespace tpctypes { + +template <typename T> struct UIntWrite { + UIntWrite() = default; + UIntWrite(JITTargetAddress Address, T Value) + : Address(Address), Value(Value) {} + + JITTargetAddress Address = 0; + T Value = 0; +}; + +/// Describes a write to a uint8_t. +using UInt8Write = UIntWrite<uint8_t>; + +/// Describes a write to a uint16_t. +using UInt16Write = UIntWrite<uint16_t>; + +/// Describes a write to a uint32_t. +using UInt32Write = UIntWrite<uint32_t>; + +/// Describes a write to a uint64_t. +using UInt64Write = UIntWrite<uint64_t>; + +/// Describes a write to a buffer. +/// For use with TargetProcessControl::MemoryAccess objects. +struct BufferWrite { + BufferWrite() = default; + BufferWrite(JITTargetAddress Address, StringRef Buffer) + : Address(Address), Buffer(Buffer) {} + + JITTargetAddress Address = 0; + StringRef Buffer; +}; + +/// A handle used to represent a loaded dylib in the target process. +using DylibHandle = JITTargetAddress; + +using LookupResult = std::vector<JITTargetAddress>; + +/// Either a uint8_t array or a uint8_t*. +union CWrapperFunctionResultData { + uint8_t Value[8]; + uint8_t *ValuePtr; +}; + +/// C ABI compatible wrapper function result. +/// +/// This can be safely returned from extern "C" functions, but should be used +/// to construct a WrapperFunctionResult for safety. +struct CWrapperFunctionResult { + uint64_t Size; + CWrapperFunctionResultData Data; + void (*Destroy)(CWrapperFunctionResultData Data, uint64_t Size); +}; + +/// C++ wrapper function result: Same as CWrapperFunctionResult but +/// auto-releases memory. +class WrapperFunctionResult { +public: + /// Create a default WrapperFunctionResult. + WrapperFunctionResult() { zeroInit(R); } + + /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This + /// instance takes ownership of the result object and will automatically + /// call the Destroy member upon destruction. + WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {} + + WrapperFunctionResult(const WrapperFunctionResult &) = delete; + WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; + + WrapperFunctionResult(WrapperFunctionResult &&Other) { + zeroInit(R); + std::swap(R, Other.R); + } + + WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { + CWrapperFunctionResult Tmp; + zeroInit(Tmp); + std::swap(Tmp, Other.R); + std::swap(R, Tmp); + return *this; + } + + ~WrapperFunctionResult() { + if (R.Destroy) + R.Destroy(R.Data, R.Size); + } + + /// Relinquish ownership of and return the CWrapperFunctionResult. + CWrapperFunctionResult release() { + CWrapperFunctionResult Tmp; + zeroInit(Tmp); + std::swap(R, Tmp); + return Tmp; + } + + /// Get an ArrayRef covering the data in the result. + ArrayRef<uint8_t> getData() const { + if (R.Size <= 8) + return ArrayRef<uint8_t>(R.Data.Value, R.Size); + return ArrayRef<uint8_t>(R.Data.ValuePtr, R.Size); + } + + /// Create a WrapperFunctionResult from the given integer, provided its + /// size is no greater than 64 bits. + template <typename T, + typename _ = std::enable_if_t<std::is_integral<T>::value && + sizeof(T) <= sizeof(uint64_t)>> + static WrapperFunctionResult from(T Value) { + CWrapperFunctionResult R; + R.Size = sizeof(T); + memcpy(&R.Data.Value, Value, R.Size); + R.Destroy = nullptr; + return R; + } + + /// Create a WrapperFunctionResult from the given string. + static WrapperFunctionResult from(StringRef S); + + /// Always free Data.ValuePtr by calling free on it. + static void destroyWithFree(CWrapperFunctionResultData Data, uint64_t Size); + + /// Always free Data.ValuePtr by calling delete[] on it. + static void destroyWithDeleteArray(CWrapperFunctionResultData Data, + uint64_t Size); + +private: + static void zeroInit(CWrapperFunctionResult &R) { + R.Size = 0; + R.Data.ValuePtr = nullptr; + R.Destroy = nullptr; + } + + CWrapperFunctionResult R; +}; + +} // end namespace tpctypes +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Speculation.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Speculation.h index f4193ff075..b2b091fb89 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Speculation.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Speculation.h @@ -188,8 +188,8 @@ public: : IRLayer(ES, BaseLayer.getManglingOptions()), NextLayer(BaseLayer), S(Spec), Mangle(Mangle), QueryAnalysis(Interpreter) {} - void emit(std::unique_ptr<MaterializationResponsibility> R, - ThreadSafeModule TSM) override; + void emit(std::unique_ptr<MaterializationResponsibility> R, + ThreadSafeModule TSM) override; private: TargetAndLikelies diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCDynamicLibrarySearchGenerator.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCDynamicLibrarySearchGenerator.h index 8ef5804bf6..95d3c02a54 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCDynamicLibrarySearchGenerator.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCDynamicLibrarySearchGenerator.h @@ -1,77 +1,77 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===------------ TPCDynamicLibrarySearchGenerator.h ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Support loading and searching of dynamic libraries in a target process via -// the TargetProcessControl class. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H -#define LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H - -#include "llvm/ADT/FunctionExtras.h" -#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" - -namespace llvm { -namespace orc { - -class TPCDynamicLibrarySearchGenerator : public DefinitionGenerator { -public: - using SymbolPredicate = unique_function<bool(const SymbolStringPtr &)>; - - /// Create a DynamicLibrarySearchGenerator that searches for symbols in the - /// library with the given handle. - /// - /// If the Allow predicate is given then only symbols matching the predicate - /// will be searched for. If the predicate is not given then all symbols will - /// be searched for. - TPCDynamicLibrarySearchGenerator(TargetProcessControl &TPC, - tpctypes::DylibHandle H, - SymbolPredicate Allow = SymbolPredicate()) - : TPC(TPC), H(H), Allow(std::move(Allow)) {} - - /// Permanently loads the library at the given path and, on success, returns - /// a DynamicLibrarySearchGenerator that will search it for symbol definitions - /// in the library. On failure returns the reason the library failed to load. - static Expected<std::unique_ptr<TPCDynamicLibrarySearchGenerator>> - Load(TargetProcessControl &TPC, const char *LibraryPath, - SymbolPredicate Allow = SymbolPredicate()); - - /// Creates a TPCDynamicLibrarySearchGenerator that searches for symbols in - /// the target process. - static Expected<std::unique_ptr<TPCDynamicLibrarySearchGenerator>> - GetForTargetProcess(TargetProcessControl &TPC, - SymbolPredicate Allow = SymbolPredicate()) { - return Load(TPC, nullptr, std::move(Allow)); - } - - Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, - JITDylibLookupFlags JDLookupFlags, - const SymbolLookupSet &Symbols) override; - -private: - TargetProcessControl &TPC; - tpctypes::DylibHandle H; - SymbolPredicate Allow; -}; - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===------------ TPCDynamicLibrarySearchGenerator.h ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Support loading and searching of dynamic libraries in a target process via +// the TargetProcessControl class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H +#define LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H + +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" + +namespace llvm { +namespace orc { + +class TPCDynamicLibrarySearchGenerator : public DefinitionGenerator { +public: + using SymbolPredicate = unique_function<bool(const SymbolStringPtr &)>; + + /// Create a DynamicLibrarySearchGenerator that searches for symbols in the + /// library with the given handle. + /// + /// If the Allow predicate is given then only symbols matching the predicate + /// will be searched for. If the predicate is not given then all symbols will + /// be searched for. + TPCDynamicLibrarySearchGenerator(TargetProcessControl &TPC, + tpctypes::DylibHandle H, + SymbolPredicate Allow = SymbolPredicate()) + : TPC(TPC), H(H), Allow(std::move(Allow)) {} + + /// Permanently loads the library at the given path and, on success, returns + /// a DynamicLibrarySearchGenerator that will search it for symbol definitions + /// in the library. On failure returns the reason the library failed to load. + static Expected<std::unique_ptr<TPCDynamicLibrarySearchGenerator>> + Load(TargetProcessControl &TPC, const char *LibraryPath, + SymbolPredicate Allow = SymbolPredicate()); + + /// Creates a TPCDynamicLibrarySearchGenerator that searches for symbols in + /// the target process. + static Expected<std::unique_ptr<TPCDynamicLibrarySearchGenerator>> + GetForTargetProcess(TargetProcessControl &TPC, + SymbolPredicate Allow = SymbolPredicate()) { + return Load(TPC, nullptr, std::move(Allow)); + } + + Error tryToGenerate(LookupState &LS, LookupKind K, JITDylib &JD, + JITDylibLookupFlags JDLookupFlags, + const SymbolLookupSet &Symbols) override; + +private: + TargetProcessControl &TPC; + tpctypes::DylibHandle H; + SymbolPredicate Allow; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TPCDYNAMICLIBRARYSEARCHGENERATOR_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCEHFrameRegistrar.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCEHFrameRegistrar.h index 15716d9ff0..9f9eb8420b 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCEHFrameRegistrar.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCEHFrameRegistrar.h @@ -1,65 +1,65 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===-- TPCEHFrameRegistrar.h - TPC based eh-frame registration -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TargetProcessControl based eh-frame registration. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H -#define LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H - -#include "llvm/ExecutionEngine/JITLink/EHFrameSupport.h" -#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" - -namespace llvm { -namespace orc { - -/// Register/Deregisters EH frames in a remote process via a -/// TargetProcessControl instance. -class TPCEHFrameRegistrar : public jitlink::EHFrameRegistrar { -public: - /// Create from a TargetProcessControl instance alone. This will use - /// the TPC's lookupSymbols method to find the registration/deregistration - /// funciton addresses by name. - static Expected<std::unique_ptr<TPCEHFrameRegistrar>> - Create(TargetProcessControl &TPC); - - /// Create a TPCEHFrameRegistrar with the given TargetProcessControl - /// object and registration/deregistration function addresses. - TPCEHFrameRegistrar(TargetProcessControl &TPC, - JITTargetAddress RegisterEHFrameWrapperFnAddr, - JITTargetAddress DeregisterEHFRameWrapperFnAddr) - : TPC(TPC), RegisterEHFrameWrapperFnAddr(RegisterEHFrameWrapperFnAddr), - DeregisterEHFrameWrapperFnAddr(DeregisterEHFRameWrapperFnAddr) {} - - Error registerEHFrames(JITTargetAddress EHFrameSectionAddr, - size_t EHFrameSectionSize) override; - Error deregisterEHFrames(JITTargetAddress EHFrameSectionAddr, - size_t EHFrameSectionSize) override; - -private: - TargetProcessControl &TPC; - JITTargetAddress RegisterEHFrameWrapperFnAddr; - JITTargetAddress DeregisterEHFrameWrapperFnAddr; -}; - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===-- TPCEHFrameRegistrar.h - TPC based eh-frame registration -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TargetProcessControl based eh-frame registration. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H +#define LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H + +#include "llvm/ExecutionEngine/JITLink/EHFrameSupport.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" + +namespace llvm { +namespace orc { + +/// Register/Deregisters EH frames in a remote process via a +/// TargetProcessControl instance. +class TPCEHFrameRegistrar : public jitlink::EHFrameRegistrar { +public: + /// Create from a TargetProcessControl instance alone. This will use + /// the TPC's lookupSymbols method to find the registration/deregistration + /// funciton addresses by name. + static Expected<std::unique_ptr<TPCEHFrameRegistrar>> + Create(TargetProcessControl &TPC); + + /// Create a TPCEHFrameRegistrar with the given TargetProcessControl + /// object and registration/deregistration function addresses. + TPCEHFrameRegistrar(TargetProcessControl &TPC, + JITTargetAddress RegisterEHFrameWrapperFnAddr, + JITTargetAddress DeregisterEHFRameWrapperFnAddr) + : TPC(TPC), RegisterEHFrameWrapperFnAddr(RegisterEHFrameWrapperFnAddr), + DeregisterEHFrameWrapperFnAddr(DeregisterEHFRameWrapperFnAddr) {} + + Error registerEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) override; + Error deregisterEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) override; + +private: + TargetProcessControl &TPC; + JITTargetAddress RegisterEHFrameWrapperFnAddr; + JITTargetAddress DeregisterEHFrameWrapperFnAddr; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TPCEHFRAMEREGISTRAR_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCIndirectionUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCIndirectionUtils.h index 0d4de7bfe4..30643b4eaa 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCIndirectionUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TPCIndirectionUtils.h @@ -1,233 +1,233 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===--- TPCIndirectionUtils.h - TPC based indirection utils ----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Indirection utilities (stubs, trampolines, lazy call-throughs) that use the -// TargetProcessControl API to interact with the target process. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H -#define LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H - -#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" -#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" -#include "llvm/ExecutionEngine/Orc/LazyReexports.h" - -#include <mutex> - -namespace llvm { -namespace orc { - -class TargetProcessControl; - -/// Provides TargetProcessControl based indirect stubs, trampoline pool and -/// lazy call through manager. -class TPCIndirectionUtils { - friend class TPCIndirectionUtilsAccess; - -public: - /// ABI support base class. Used to write resolver, stub, and trampoline - /// blocks. - class ABISupport { - protected: - ABISupport(unsigned PointerSize, unsigned TrampolineSize, unsigned StubSize, - unsigned StubToPointerMaxDisplacement, unsigned ResolverCodeSize) - : PointerSize(PointerSize), TrampolineSize(TrampolineSize), - StubSize(StubSize), - StubToPointerMaxDisplacement(StubToPointerMaxDisplacement), - ResolverCodeSize(ResolverCodeSize) {} - - public: - virtual ~ABISupport(); - - unsigned getPointerSize() const { return PointerSize; } - unsigned getTrampolineSize() const { return TrampolineSize; } - unsigned getStubSize() const { return StubSize; } - unsigned getStubToPointerMaxDisplacement() const { - return StubToPointerMaxDisplacement; - } - unsigned getResolverCodeSize() const { return ResolverCodeSize; } - - virtual void writeResolverCode(char *ResolverWorkingMem, - JITTargetAddress ResolverTargetAddr, - JITTargetAddress ReentryFnAddr, - JITTargetAddress ReentryCtxAddr) const = 0; - - virtual void writeTrampolines(char *TrampolineBlockWorkingMem, - JITTargetAddress TrampolineBlockTragetAddr, - JITTargetAddress ResolverAddr, - unsigned NumTrampolines) const = 0; - - virtual void - writeIndirectStubsBlock(char *StubsBlockWorkingMem, - JITTargetAddress StubsBlockTargetAddress, - JITTargetAddress PointersBlockTargetAddress, - unsigned NumStubs) const = 0; - - private: - unsigned PointerSize = 0; - unsigned TrampolineSize = 0; - unsigned StubSize = 0; - unsigned StubToPointerMaxDisplacement = 0; - unsigned ResolverCodeSize = 0; - }; - - /// Create using the given ABI class. - template <typename ORCABI> - static std::unique_ptr<TPCIndirectionUtils> - CreateWithABI(TargetProcessControl &TPC); - - /// Create based on the TargetProcessControl triple. - static Expected<std::unique_ptr<TPCIndirectionUtils>> - Create(TargetProcessControl &TPC); - - /// Return a reference to the TargetProcessControl object. - TargetProcessControl &getTargetProcessControl() const { return TPC; } - - /// Return a reference to the ABISupport object for this instance. - ABISupport &getABISupport() const { return *ABI; } - - /// Release memory for resources held by this instance. This *must* be called - /// prior to destruction of the class. - Error cleanup(); - - /// Write resolver code to the target process and return its address. - /// This must be called before any call to createTrampolinePool or - /// createLazyCallThroughManager. - Expected<JITTargetAddress> - writeResolverBlock(JITTargetAddress ReentryFnAddr, - JITTargetAddress ReentryCtxAddr); - - /// Returns the address of the Resolver block. Returns zero if the - /// writeResolverBlock method has not previously been called. - JITTargetAddress getResolverBlockAddress() const { return ResolverBlockAddr; } - - /// Create an IndirectStubsManager for the target process. - std::unique_ptr<IndirectStubsManager> createIndirectStubsManager(); - - /// Create a TrampolinePool for the target process. - TrampolinePool &getTrampolinePool(); - - /// Create a LazyCallThroughManager. - /// This function should only be called once. - LazyCallThroughManager & - createLazyCallThroughManager(ExecutionSession &ES, - JITTargetAddress ErrorHandlerAddr); - - /// Create a LazyCallThroughManager for the target process. - LazyCallThroughManager &getLazyCallThroughManager() { - assert(LCTM && "createLazyCallThroughManager must be called first"); - return *LCTM; - } - -private: - using Allocation = jitlink::JITLinkMemoryManager::Allocation; - - struct IndirectStubInfo { - IndirectStubInfo() = default; - IndirectStubInfo(JITTargetAddress StubAddress, - JITTargetAddress PointerAddress) - : StubAddress(StubAddress), PointerAddress(PointerAddress) {} - JITTargetAddress StubAddress = 0; - JITTargetAddress PointerAddress = 0; - }; - - using IndirectStubInfoVector = std::vector<IndirectStubInfo>; - - /// Create a TPCIndirectionUtils instance. - TPCIndirectionUtils(TargetProcessControl &TPC, - std::unique_ptr<ABISupport> ABI); - - Expected<IndirectStubInfoVector> getIndirectStubs(unsigned NumStubs); - - std::mutex TPCUIMutex; - TargetProcessControl &TPC; - std::unique_ptr<ABISupport> ABI; - JITTargetAddress ResolverBlockAddr; - std::unique_ptr<jitlink::JITLinkMemoryManager::Allocation> ResolverBlock; - std::unique_ptr<TrampolinePool> TP; - std::unique_ptr<LazyCallThroughManager> LCTM; - - std::vector<IndirectStubInfo> AvailableIndirectStubs; - std::vector<std::unique_ptr<Allocation>> IndirectStubAllocs; -}; - -/// This will call writeResolver on the given TPCIndirectionUtils instance -/// to set up re-entry via a function that will directly return the trampoline -/// landing address. -/// -/// The TPCIndirectionUtils' LazyCallThroughManager must have been previously -/// created via TPCIndirectionUtils::createLazyCallThroughManager. -/// -/// The TPCIndirectionUtils' writeResolver method must not have been previously -/// called. -/// -/// This function is experimental and likely subject to revision. -Error setUpInProcessLCTMReentryViaTPCIU(TPCIndirectionUtils &TPCIU); - -namespace detail { - -template <typename ORCABI> -class ABISupportImpl : public TPCIndirectionUtils::ABISupport { -public: - ABISupportImpl() - : ABISupport(ORCABI::PointerSize, ORCABI::TrampolineSize, - ORCABI::StubSize, ORCABI::StubToPointerMaxDisplacement, - ORCABI::ResolverCodeSize) {} - - void writeResolverCode(char *ResolverWorkingMem, - JITTargetAddress ResolverTargetAddr, - JITTargetAddress ReentryFnAddr, - JITTargetAddress ReentryCtxAddr) const override { - ORCABI::writeResolverCode(ResolverWorkingMem, ResolverTargetAddr, - ReentryFnAddr, ReentryCtxAddr); - } - - void writeTrampolines(char *TrampolineBlockWorkingMem, - JITTargetAddress TrampolineBlockTargetAddr, - JITTargetAddress ResolverAddr, - unsigned NumTrampolines) const override { - ORCABI::writeTrampolines(TrampolineBlockWorkingMem, - TrampolineBlockTargetAddr, ResolverAddr, - NumTrampolines); - } - - void writeIndirectStubsBlock(char *StubsBlockWorkingMem, - JITTargetAddress StubsBlockTargetAddress, - JITTargetAddress PointersBlockTargetAddress, - unsigned NumStubs) const override { - ORCABI::writeIndirectStubsBlock(StubsBlockWorkingMem, - StubsBlockTargetAddress, - PointersBlockTargetAddress, NumStubs); - } -}; - -} // end namespace detail - -template <typename ORCABI> -std::unique_ptr<TPCIndirectionUtils> -TPCIndirectionUtils::CreateWithABI(TargetProcessControl &TPC) { - return std::unique_ptr<TPCIndirectionUtils>(new TPCIndirectionUtils( - TPC, std::make_unique<detail::ABISupportImpl<ORCABI>>())); -} - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===--- TPCIndirectionUtils.h - TPC based indirection utils ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Indirection utilities (stubs, trampolines, lazy call-throughs) that use the +// TargetProcessControl API to interact with the target process. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H + +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" +#include "llvm/ExecutionEngine/Orc/LazyReexports.h" + +#include <mutex> + +namespace llvm { +namespace orc { + +class TargetProcessControl; + +/// Provides TargetProcessControl based indirect stubs, trampoline pool and +/// lazy call through manager. +class TPCIndirectionUtils { + friend class TPCIndirectionUtilsAccess; + +public: + /// ABI support base class. Used to write resolver, stub, and trampoline + /// blocks. + class ABISupport { + protected: + ABISupport(unsigned PointerSize, unsigned TrampolineSize, unsigned StubSize, + unsigned StubToPointerMaxDisplacement, unsigned ResolverCodeSize) + : PointerSize(PointerSize), TrampolineSize(TrampolineSize), + StubSize(StubSize), + StubToPointerMaxDisplacement(StubToPointerMaxDisplacement), + ResolverCodeSize(ResolverCodeSize) {} + + public: + virtual ~ABISupport(); + + unsigned getPointerSize() const { return PointerSize; } + unsigned getTrampolineSize() const { return TrampolineSize; } + unsigned getStubSize() const { return StubSize; } + unsigned getStubToPointerMaxDisplacement() const { + return StubToPointerMaxDisplacement; + } + unsigned getResolverCodeSize() const { return ResolverCodeSize; } + + virtual void writeResolverCode(char *ResolverWorkingMem, + JITTargetAddress ResolverTargetAddr, + JITTargetAddress ReentryFnAddr, + JITTargetAddress ReentryCtxAddr) const = 0; + + virtual void writeTrampolines(char *TrampolineBlockWorkingMem, + JITTargetAddress TrampolineBlockTragetAddr, + JITTargetAddress ResolverAddr, + unsigned NumTrampolines) const = 0; + + virtual void + writeIndirectStubsBlock(char *StubsBlockWorkingMem, + JITTargetAddress StubsBlockTargetAddress, + JITTargetAddress PointersBlockTargetAddress, + unsigned NumStubs) const = 0; + + private: + unsigned PointerSize = 0; + unsigned TrampolineSize = 0; + unsigned StubSize = 0; + unsigned StubToPointerMaxDisplacement = 0; + unsigned ResolverCodeSize = 0; + }; + + /// Create using the given ABI class. + template <typename ORCABI> + static std::unique_ptr<TPCIndirectionUtils> + CreateWithABI(TargetProcessControl &TPC); + + /// Create based on the TargetProcessControl triple. + static Expected<std::unique_ptr<TPCIndirectionUtils>> + Create(TargetProcessControl &TPC); + + /// Return a reference to the TargetProcessControl object. + TargetProcessControl &getTargetProcessControl() const { return TPC; } + + /// Return a reference to the ABISupport object for this instance. + ABISupport &getABISupport() const { return *ABI; } + + /// Release memory for resources held by this instance. This *must* be called + /// prior to destruction of the class. + Error cleanup(); + + /// Write resolver code to the target process and return its address. + /// This must be called before any call to createTrampolinePool or + /// createLazyCallThroughManager. + Expected<JITTargetAddress> + writeResolverBlock(JITTargetAddress ReentryFnAddr, + JITTargetAddress ReentryCtxAddr); + + /// Returns the address of the Resolver block. Returns zero if the + /// writeResolverBlock method has not previously been called. + JITTargetAddress getResolverBlockAddress() const { return ResolverBlockAddr; } + + /// Create an IndirectStubsManager for the target process. + std::unique_ptr<IndirectStubsManager> createIndirectStubsManager(); + + /// Create a TrampolinePool for the target process. + TrampolinePool &getTrampolinePool(); + + /// Create a LazyCallThroughManager. + /// This function should only be called once. + LazyCallThroughManager & + createLazyCallThroughManager(ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddr); + + /// Create a LazyCallThroughManager for the target process. + LazyCallThroughManager &getLazyCallThroughManager() { + assert(LCTM && "createLazyCallThroughManager must be called first"); + return *LCTM; + } + +private: + using Allocation = jitlink::JITLinkMemoryManager::Allocation; + + struct IndirectStubInfo { + IndirectStubInfo() = default; + IndirectStubInfo(JITTargetAddress StubAddress, + JITTargetAddress PointerAddress) + : StubAddress(StubAddress), PointerAddress(PointerAddress) {} + JITTargetAddress StubAddress = 0; + JITTargetAddress PointerAddress = 0; + }; + + using IndirectStubInfoVector = std::vector<IndirectStubInfo>; + + /// Create a TPCIndirectionUtils instance. + TPCIndirectionUtils(TargetProcessControl &TPC, + std::unique_ptr<ABISupport> ABI); + + Expected<IndirectStubInfoVector> getIndirectStubs(unsigned NumStubs); + + std::mutex TPCUIMutex; + TargetProcessControl &TPC; + std::unique_ptr<ABISupport> ABI; + JITTargetAddress ResolverBlockAddr; + std::unique_ptr<jitlink::JITLinkMemoryManager::Allocation> ResolverBlock; + std::unique_ptr<TrampolinePool> TP; + std::unique_ptr<LazyCallThroughManager> LCTM; + + std::vector<IndirectStubInfo> AvailableIndirectStubs; + std::vector<std::unique_ptr<Allocation>> IndirectStubAllocs; +}; + +/// This will call writeResolver on the given TPCIndirectionUtils instance +/// to set up re-entry via a function that will directly return the trampoline +/// landing address. +/// +/// The TPCIndirectionUtils' LazyCallThroughManager must have been previously +/// created via TPCIndirectionUtils::createLazyCallThroughManager. +/// +/// The TPCIndirectionUtils' writeResolver method must not have been previously +/// called. +/// +/// This function is experimental and likely subject to revision. +Error setUpInProcessLCTMReentryViaTPCIU(TPCIndirectionUtils &TPCIU); + +namespace detail { + +template <typename ORCABI> +class ABISupportImpl : public TPCIndirectionUtils::ABISupport { +public: + ABISupportImpl() + : ABISupport(ORCABI::PointerSize, ORCABI::TrampolineSize, + ORCABI::StubSize, ORCABI::StubToPointerMaxDisplacement, + ORCABI::ResolverCodeSize) {} + + void writeResolverCode(char *ResolverWorkingMem, + JITTargetAddress ResolverTargetAddr, + JITTargetAddress ReentryFnAddr, + JITTargetAddress ReentryCtxAddr) const override { + ORCABI::writeResolverCode(ResolverWorkingMem, ResolverTargetAddr, + ReentryFnAddr, ReentryCtxAddr); + } + + void writeTrampolines(char *TrampolineBlockWorkingMem, + JITTargetAddress TrampolineBlockTargetAddr, + JITTargetAddress ResolverAddr, + unsigned NumTrampolines) const override { + ORCABI::writeTrampolines(TrampolineBlockWorkingMem, + TrampolineBlockTargetAddr, ResolverAddr, + NumTrampolines); + } + + void writeIndirectStubsBlock(char *StubsBlockWorkingMem, + JITTargetAddress StubsBlockTargetAddress, + JITTargetAddress PointersBlockTargetAddress, + unsigned NumStubs) const override { + ORCABI::writeIndirectStubsBlock(StubsBlockWorkingMem, + StubsBlockTargetAddress, + PointersBlockTargetAddress, NumStubs); + } +}; + +} // end namespace detail + +template <typename ORCABI> +std::unique_ptr<TPCIndirectionUtils> +TPCIndirectionUtils::CreateWithABI(TargetProcessControl &TPC) { + return std::unique_ptr<TPCIndirectionUtils>(new TPCIndirectionUtils( + TPC, std::make_unique<detail::ABISupportImpl<ORCABI>>())); +} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TPCINDIRECTIONUTILS_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h index 763b08fbcd..f4cf20aac3 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h @@ -1,631 +1,631 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===-- OrcRPCTPCServer.h -- OrcRPCTargetProcessControl Server --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// OrcRPCTargetProcessControl server class. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H -#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H - -#include "llvm/ADT/BitmaskEnum.h" -#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" -#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" -#include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h" -#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" -#include "llvm/Support/DynamicLibrary.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/Host.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/Memory.h" -#include "llvm/Support/Process.h" - -#include <atomic> - -namespace llvm { -namespace orc { - -namespace orcrpctpc { - -enum WireProtectionFlags : uint8_t { - WPF_None = 0, - WPF_Read = 1U << 0, - WPF_Write = 1U << 1, - WPF_Exec = 1U << 2, - LLVM_MARK_AS_BITMASK_ENUM(WPF_Exec) -}; - -/// Convert from sys::Memory::ProtectionFlags -inline WireProtectionFlags -toWireProtectionFlags(sys::Memory::ProtectionFlags PF) { - WireProtectionFlags WPF = WPF_None; - if (PF & sys::Memory::MF_READ) - WPF |= WPF_Read; - if (PF & sys::Memory::MF_WRITE) - WPF |= WPF_Write; - if (PF & sys::Memory::MF_EXEC) - WPF |= WPF_Exec; - return WPF; -} - -inline sys::Memory::ProtectionFlags -fromWireProtectionFlags(WireProtectionFlags WPF) { - int PF = 0; - if (WPF & WPF_Read) - PF |= sys::Memory::MF_READ; - if (WPF & WPF_Write) - PF |= sys::Memory::MF_WRITE; - if (WPF & WPF_Exec) - PF |= sys::Memory::MF_EXEC; - return static_cast<sys::Memory::ProtectionFlags>(PF); -} - -struct ReserveMemRequestElement { - WireProtectionFlags Prot = WPF_None; - uint64_t Size = 0; - uint64_t Alignment = 0; -}; - -using ReserveMemRequest = std::vector<ReserveMemRequestElement>; - -struct ReserveMemResultElement { - WireProtectionFlags Prot = WPF_None; - JITTargetAddress Address = 0; - uint64_t AllocatedSize = 0; -}; - -using ReserveMemResult = std::vector<ReserveMemResultElement>; - -struct ReleaseOrFinalizeMemRequestElement { - WireProtectionFlags Prot = WPF_None; - JITTargetAddress Address = 0; - uint64_t Size = 0; -}; - -using ReleaseOrFinalizeMemRequest = - std::vector<ReleaseOrFinalizeMemRequestElement>; - -} // end namespace orcrpctpc - -namespace shared { - -template <> class SerializationTypeName<tpctypes::UInt8Write> { -public: - static const char *getName() { return "UInt8Write"; } -}; - -template <> class SerializationTypeName<tpctypes::UInt16Write> { -public: - static const char *getName() { return "UInt16Write"; } -}; - -template <> class SerializationTypeName<tpctypes::UInt32Write> { -public: - static const char *getName() { return "UInt32Write"; } -}; - -template <> class SerializationTypeName<tpctypes::UInt64Write> { -public: - static const char *getName() { return "UInt64Write"; } -}; - -template <> class SerializationTypeName<tpctypes::BufferWrite> { -public: - static const char *getName() { return "BufferWrite"; } -}; - -template <> class SerializationTypeName<orcrpctpc::ReserveMemRequestElement> { -public: - static const char *getName() { return "ReserveMemRequestElement"; } -}; - -template <> class SerializationTypeName<orcrpctpc::ReserveMemResultElement> { -public: - static const char *getName() { return "ReserveMemResultElement"; } -}; - -template <> -class SerializationTypeName<orcrpctpc::ReleaseOrFinalizeMemRequestElement> { -public: - static const char *getName() { return "ReleaseOrFinalizeMemRequestElement"; } -}; - -template <> class SerializationTypeName<tpctypes::WrapperFunctionResult> { -public: - static const char *getName() { return "WrapperFunctionResult"; } -}; - -template <typename ChannelT, typename WriteT> -class SerializationTraits< - ChannelT, WriteT, WriteT, - std::enable_if_t<std::is_same<WriteT, tpctypes::UInt8Write>::value || - std::is_same<WriteT, tpctypes::UInt16Write>::value || - std::is_same<WriteT, tpctypes::UInt32Write>::value || - std::is_same<WriteT, tpctypes::UInt64Write>::value>> { -public: - static Error serialize(ChannelT &C, const WriteT &W) { - return serializeSeq(C, W.Address, W.Value); - } - static Error deserialize(ChannelT &C, WriteT &W) { - return deserializeSeq(C, W.Address, W.Value); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, tpctypes::BufferWrite, tpctypes::BufferWrite, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - static Error serialize(ChannelT &C, const tpctypes::BufferWrite &W) { - uint64_t Size = W.Buffer.size(); - if (auto Err = serializeSeq(C, W.Address, Size)) - return Err; - - return C.appendBytes(W.Buffer.data(), Size); - } - static Error deserialize(ChannelT &C, tpctypes::BufferWrite &W) { - JITTargetAddress Address; - uint64_t Size; - - if (auto Err = deserializeSeq(C, Address, Size)) - return Err; - - char *Buffer = jitTargetAddressToPointer<char *>(Address); - - if (auto Err = C.readBytes(Buffer, Size)) - return Err; - - W = {Address, StringRef(Buffer, Size)}; - return Error::success(); - } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, orcrpctpc::ReserveMemRequestElement> { -public: - static Error serialize(ChannelT &C, - const orcrpctpc::ReserveMemRequestElement &E) { - return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Size, E.Alignment); - } - - static Error deserialize(ChannelT &C, - orcrpctpc::ReserveMemRequestElement &E) { - return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Size, - E.Alignment); - } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, orcrpctpc::ReserveMemResultElement> { -public: - static Error serialize(ChannelT &C, - const orcrpctpc::ReserveMemResultElement &E) { - return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Address, - E.AllocatedSize); - } - - static Error deserialize(ChannelT &C, orcrpctpc::ReserveMemResultElement &E) { - return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Address, - E.AllocatedSize); - } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, - orcrpctpc::ReleaseOrFinalizeMemRequestElement> { -public: - static Error - serialize(ChannelT &C, - const orcrpctpc::ReleaseOrFinalizeMemRequestElement &E) { - return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Address, E.Size); - } - - static Error deserialize(ChannelT &C, - orcrpctpc::ReleaseOrFinalizeMemRequestElement &E) { - return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Address, - E.Size); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, tpctypes::WrapperFunctionResult, tpctypes::WrapperFunctionResult, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - static Error serialize(ChannelT &C, - const tpctypes::WrapperFunctionResult &E) { - auto Data = E.getData(); - if (auto Err = serializeSeq(C, static_cast<uint64_t>(Data.size()))) - return Err; - if (Data.size() == 0) - return Error::success(); - return C.appendBytes(reinterpret_cast<const char *>(Data.data()), - Data.size()); - } - - static Error deserialize(ChannelT &C, tpctypes::WrapperFunctionResult &E) { - tpctypes::CWrapperFunctionResult R; - - R.Size = 0; - R.Data.ValuePtr = nullptr; - R.Destroy = nullptr; - - if (auto Err = deserializeSeq(C, R.Size)) - return Err; - if (R.Size == 0) - return Error::success(); - R.Data.ValuePtr = new uint8_t[R.Size]; - if (auto Err = - C.readBytes(reinterpret_cast<char *>(R.Data.ValuePtr), R.Size)) { - R.Destroy = tpctypes::WrapperFunctionResult::destroyWithDeleteArray; - return Err; - } - - E = tpctypes::WrapperFunctionResult(R); - return Error::success(); - } -}; - -} // end namespace shared - -namespace orcrpctpc { - -using RemoteSymbolLookupSet = std::vector<std::pair<std::string, bool>>; -using RemoteLookupRequest = - std::pair<tpctypes::DylibHandle, RemoteSymbolLookupSet>; - -class GetTargetTriple - : public shared::RPCFunction<GetTargetTriple, std::string()> { -public: - static const char *getName() { return "GetTargetTriple"; } -}; - -class GetPageSize : public shared::RPCFunction<GetPageSize, uint64_t()> { -public: - static const char *getName() { return "GetPageSize"; } -}; - -class ReserveMem - : public shared::RPCFunction<ReserveMem, Expected<ReserveMemResult>( - ReserveMemRequest)> { -public: - static const char *getName() { return "ReserveMem"; } -}; - -class FinalizeMem - : public shared::RPCFunction<FinalizeMem, - Error(ReleaseOrFinalizeMemRequest)> { -public: - static const char *getName() { return "FinalizeMem"; } -}; - -class ReleaseMem - : public shared::RPCFunction<ReleaseMem, - Error(ReleaseOrFinalizeMemRequest)> { -public: - static const char *getName() { return "ReleaseMem"; } -}; - -class WriteUInt8s - : public shared::RPCFunction<WriteUInt8s, - Error(std::vector<tpctypes::UInt8Write>)> { -public: - static const char *getName() { return "WriteUInt8s"; } -}; - -class WriteUInt16s - : public shared::RPCFunction<WriteUInt16s, - Error(std::vector<tpctypes::UInt16Write>)> { -public: - static const char *getName() { return "WriteUInt16s"; } -}; - -class WriteUInt32s - : public shared::RPCFunction<WriteUInt32s, - Error(std::vector<tpctypes::UInt32Write>)> { -public: - static const char *getName() { return "WriteUInt32s"; } -}; - -class WriteUInt64s - : public shared::RPCFunction<WriteUInt64s, - Error(std::vector<tpctypes::UInt64Write>)> { -public: - static const char *getName() { return "WriteUInt64s"; } -}; - -class WriteBuffers - : public shared::RPCFunction<WriteBuffers, - Error(std::vector<tpctypes::BufferWrite>)> { -public: - static const char *getName() { return "WriteBuffers"; } -}; - -class LoadDylib - : public shared::RPCFunction<LoadDylib, Expected<tpctypes::DylibHandle>( - std::string DylibPath)> { -public: - static const char *getName() { return "LoadDylib"; } -}; - -class LookupSymbols - : public shared::RPCFunction<LookupSymbols, - Expected<std::vector<tpctypes::LookupResult>>( - std::vector<RemoteLookupRequest>)> { -public: - static const char *getName() { return "LookupSymbols"; } -}; - -class RunMain - : public shared::RPCFunction<RunMain, - int32_t(JITTargetAddress MainAddr, - std::vector<std::string> Args)> { -public: - static const char *getName() { return "RunMain"; } -}; - -class RunWrapper - : public shared::RPCFunction<RunWrapper, - tpctypes::WrapperFunctionResult( - JITTargetAddress, std::vector<uint8_t>)> { -public: - static const char *getName() { return "RunWrapper"; } -}; - -class CloseConnection : public shared::RPCFunction<CloseConnection, void()> { -public: - static const char *getName() { return "CloseConnection"; } -}; - -} // end namespace orcrpctpc - -/// TargetProcessControl for a process connected via an ORC RPC Endpoint. -template <typename RPCEndpointT> class OrcRPCTPCServer { -public: - /// Create an OrcRPCTPCServer from the given endpoint. - OrcRPCTPCServer(RPCEndpointT &EP) : EP(EP) { - using ThisT = OrcRPCTPCServer<RPCEndpointT>; - - TripleStr = sys::getProcessTriple(); - PageSize = sys::Process::getPageSizeEstimate(); - - EP.template addHandler<orcrpctpc::GetTargetTriple>(*this, - &ThisT::getTargetTriple); - EP.template addHandler<orcrpctpc::GetPageSize>(*this, &ThisT::getPageSize); - - EP.template addHandler<orcrpctpc::ReserveMem>(*this, &ThisT::reserveMemory); - EP.template addHandler<orcrpctpc::FinalizeMem>(*this, - &ThisT::finalizeMemory); - EP.template addHandler<orcrpctpc::ReleaseMem>(*this, &ThisT::releaseMemory); - - EP.template addHandler<orcrpctpc::WriteUInt8s>( - handleWriteUInt<tpctypes::UInt8Write>); - EP.template addHandler<orcrpctpc::WriteUInt16s>( - handleWriteUInt<tpctypes::UInt16Write>); - EP.template addHandler<orcrpctpc::WriteUInt32s>( - handleWriteUInt<tpctypes::UInt32Write>); - EP.template addHandler<orcrpctpc::WriteUInt64s>( - handleWriteUInt<tpctypes::UInt64Write>); - EP.template addHandler<orcrpctpc::WriteBuffers>(handleWriteBuffer); - - EP.template addHandler<orcrpctpc::LoadDylib>(*this, &ThisT::loadDylib); - EP.template addHandler<orcrpctpc::LookupSymbols>(*this, - &ThisT::lookupSymbols); - - EP.template addHandler<orcrpctpc::RunMain>(*this, &ThisT::runMain); - EP.template addHandler<orcrpctpc::RunWrapper>(*this, &ThisT::runWrapper); - - EP.template addHandler<orcrpctpc::CloseConnection>(*this, - &ThisT::closeConnection); - } - - /// Set the ProgramName to be used as the first argv element when running - /// functions via runAsMain. - void setProgramName(Optional<std::string> ProgramName = None) { - this->ProgramName = std::move(ProgramName); - } - - /// Get the RPC endpoint for this server. - RPCEndpointT &getEndpoint() { return EP; } - - /// Run the server loop. - Error run() { - while (!Finished) { - if (auto Err = EP.handleOne()) - return Err; - } - return Error::success(); - } - -private: - std::string getTargetTriple() { return TripleStr; } - uint64_t getPageSize() { return PageSize; } - - template <typename WriteT> - static void handleWriteUInt(const std::vector<WriteT> &Ws) { - using ValueT = decltype(std::declval<WriteT>().Value); - for (auto &W : Ws) - *jitTargetAddressToPointer<ValueT *>(W.Address) = W.Value; - } - - std::string getProtStr(orcrpctpc::WireProtectionFlags WPF) { - std::string Result; - Result += (WPF & orcrpctpc::WPF_Read) ? 'R' : '-'; - Result += (WPF & orcrpctpc::WPF_Write) ? 'W' : '-'; - Result += (WPF & orcrpctpc::WPF_Exec) ? 'X' : '-'; - return Result; - } - - static void handleWriteBuffer(const std::vector<tpctypes::BufferWrite> &Ws) { - for (auto &W : Ws) { - memcpy(jitTargetAddressToPointer<char *>(W.Address), W.Buffer.data(), - W.Buffer.size()); - } - } - - Expected<orcrpctpc::ReserveMemResult> - reserveMemory(const orcrpctpc::ReserveMemRequest &Request) { - orcrpctpc::ReserveMemResult Allocs; - auto PF = sys::Memory::MF_READ | sys::Memory::MF_WRITE; - - uint64_t TotalSize = 0; - - for (const auto &E : Request) { - uint64_t Size = alignTo(E.Size, PageSize); - uint16_t Align = E.Alignment; - - if ((Align > PageSize) || (PageSize % Align)) - return make_error<StringError>( - "Page alignmen does not satisfy requested alignment", - inconvertibleErrorCode()); - - TotalSize += Size; - } - - // Allocate memory slab. - std::error_code EC; - auto MB = sys::Memory::allocateMappedMemory(TotalSize, nullptr, PF, EC); - if (EC) - return make_error<StringError>("Unable to allocate memory: " + - EC.message(), - inconvertibleErrorCode()); - - // Zero-fill the whole thing. - memset(MB.base(), 0, MB.allocatedSize()); - - // Carve up sections to return. - uint64_t SectionBase = 0; - for (const auto &E : Request) { - uint64_t SectionSize = alignTo(E.Size, PageSize); - Allocs.push_back({E.Prot, - pointerToJITTargetAddress(MB.base()) + SectionBase, - SectionSize}); - SectionBase += SectionSize; - } - - return Allocs; - } - - Error finalizeMemory(const orcrpctpc::ReleaseOrFinalizeMemRequest &FMR) { - for (const auto &E : FMR) { - sys::MemoryBlock MB(jitTargetAddressToPointer<void *>(E.Address), E.Size); - - auto PF = orcrpctpc::fromWireProtectionFlags(E.Prot); - if (auto EC = - sys::Memory::protectMappedMemory(MB, static_cast<unsigned>(PF))) - return make_error<StringError>("error protecting memory: " + - EC.message(), - inconvertibleErrorCode()); - } - return Error::success(); - } - - Error releaseMemory(const orcrpctpc::ReleaseOrFinalizeMemRequest &RMR) { - for (const auto &E : RMR) { - sys::MemoryBlock MB(jitTargetAddressToPointer<void *>(E.Address), E.Size); - - if (auto EC = sys::Memory::releaseMappedMemory(MB)) - return make_error<StringError>("error release memory: " + EC.message(), - inconvertibleErrorCode()); - } - return Error::success(); - } - - Expected<tpctypes::DylibHandle> loadDylib(const std::string &Path) { - std::string ErrMsg; - const char *DLPath = !Path.empty() ? Path.c_str() : nullptr; - auto DL = sys::DynamicLibrary::getPermanentLibrary(DLPath, &ErrMsg); - if (!DL.isValid()) - return make_error<StringError>(std::move(ErrMsg), - inconvertibleErrorCode()); - - tpctypes::DylibHandle H = Dylibs.size(); - Dylibs[H] = std::move(DL); - return H; - } - - Expected<std::vector<tpctypes::LookupResult>> - lookupSymbols(const std::vector<orcrpctpc::RemoteLookupRequest> &Request) { - std::vector<tpctypes::LookupResult> Result; - - for (const auto &E : Request) { - auto I = Dylibs.find(E.first); - if (I == Dylibs.end()) - return make_error<StringError>("Unrecognized handle", - inconvertibleErrorCode()); - auto &DL = I->second; - Result.push_back({}); - - for (const auto &KV : E.second) { - auto &SymString = KV.first; - bool WeakReference = KV.second; - - const char *Sym = SymString.c_str(); -#ifdef __APPLE__ - if (*Sym == '_') - ++Sym; -#endif - - void *Addr = DL.getAddressOfSymbol(Sym); - if (!Addr && !WeakReference) - return make_error<StringError>(Twine("Missing definition for ") + Sym, - inconvertibleErrorCode()); - - Result.back().push_back(pointerToJITTargetAddress(Addr)); - } - } - - return Result; - } - - int32_t runMain(JITTargetAddress MainFnAddr, - const std::vector<std::string> &Args) { - Optional<StringRef> ProgramNameOverride; - if (ProgramName) - ProgramNameOverride = *ProgramName; - - return runAsMain( - jitTargetAddressToFunction<int (*)(int, char *[])>(MainFnAddr), Args, - ProgramNameOverride); - } - - tpctypes::WrapperFunctionResult - runWrapper(JITTargetAddress WrapperFnAddr, - const std::vector<uint8_t> &ArgBuffer) { - using WrapperFnTy = tpctypes::CWrapperFunctionResult (*)( - const uint8_t *Data, uint64_t Size); - auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr); - return WrapperFn(ArgBuffer.data(), ArgBuffer.size()); - } - - void closeConnection() { Finished = true; } - - std::string TripleStr; - uint64_t PageSize = 0; - Optional<std::string> ProgramName; - RPCEndpointT &EP; - std::atomic<bool> Finished{false}; - DenseMap<tpctypes::DylibHandle, sys::DynamicLibrary> Dylibs; -}; - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===-- OrcRPCTPCServer.h -- OrcRPCTargetProcessControl Server --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OrcRPCTargetProcessControl server class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H +#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H + +#include "llvm/ADT/BitmaskEnum.h" +#include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" +#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" +#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/Process.h" + +#include <atomic> + +namespace llvm { +namespace orc { + +namespace orcrpctpc { + +enum WireProtectionFlags : uint8_t { + WPF_None = 0, + WPF_Read = 1U << 0, + WPF_Write = 1U << 1, + WPF_Exec = 1U << 2, + LLVM_MARK_AS_BITMASK_ENUM(WPF_Exec) +}; + +/// Convert from sys::Memory::ProtectionFlags +inline WireProtectionFlags +toWireProtectionFlags(sys::Memory::ProtectionFlags PF) { + WireProtectionFlags WPF = WPF_None; + if (PF & sys::Memory::MF_READ) + WPF |= WPF_Read; + if (PF & sys::Memory::MF_WRITE) + WPF |= WPF_Write; + if (PF & sys::Memory::MF_EXEC) + WPF |= WPF_Exec; + return WPF; +} + +inline sys::Memory::ProtectionFlags +fromWireProtectionFlags(WireProtectionFlags WPF) { + int PF = 0; + if (WPF & WPF_Read) + PF |= sys::Memory::MF_READ; + if (WPF & WPF_Write) + PF |= sys::Memory::MF_WRITE; + if (WPF & WPF_Exec) + PF |= sys::Memory::MF_EXEC; + return static_cast<sys::Memory::ProtectionFlags>(PF); +} + +struct ReserveMemRequestElement { + WireProtectionFlags Prot = WPF_None; + uint64_t Size = 0; + uint64_t Alignment = 0; +}; + +using ReserveMemRequest = std::vector<ReserveMemRequestElement>; + +struct ReserveMemResultElement { + WireProtectionFlags Prot = WPF_None; + JITTargetAddress Address = 0; + uint64_t AllocatedSize = 0; +}; + +using ReserveMemResult = std::vector<ReserveMemResultElement>; + +struct ReleaseOrFinalizeMemRequestElement { + WireProtectionFlags Prot = WPF_None; + JITTargetAddress Address = 0; + uint64_t Size = 0; +}; + +using ReleaseOrFinalizeMemRequest = + std::vector<ReleaseOrFinalizeMemRequestElement>; + +} // end namespace orcrpctpc + +namespace shared { + +template <> class SerializationTypeName<tpctypes::UInt8Write> { +public: + static const char *getName() { return "UInt8Write"; } +}; + +template <> class SerializationTypeName<tpctypes::UInt16Write> { +public: + static const char *getName() { return "UInt16Write"; } +}; + +template <> class SerializationTypeName<tpctypes::UInt32Write> { +public: + static const char *getName() { return "UInt32Write"; } +}; + +template <> class SerializationTypeName<tpctypes::UInt64Write> { +public: + static const char *getName() { return "UInt64Write"; } +}; + +template <> class SerializationTypeName<tpctypes::BufferWrite> { +public: + static const char *getName() { return "BufferWrite"; } +}; + +template <> class SerializationTypeName<orcrpctpc::ReserveMemRequestElement> { +public: + static const char *getName() { return "ReserveMemRequestElement"; } +}; + +template <> class SerializationTypeName<orcrpctpc::ReserveMemResultElement> { +public: + static const char *getName() { return "ReserveMemResultElement"; } +}; + +template <> +class SerializationTypeName<orcrpctpc::ReleaseOrFinalizeMemRequestElement> { +public: + static const char *getName() { return "ReleaseOrFinalizeMemRequestElement"; } +}; + +template <> class SerializationTypeName<tpctypes::WrapperFunctionResult> { +public: + static const char *getName() { return "WrapperFunctionResult"; } +}; + +template <typename ChannelT, typename WriteT> +class SerializationTraits< + ChannelT, WriteT, WriteT, + std::enable_if_t<std::is_same<WriteT, tpctypes::UInt8Write>::value || + std::is_same<WriteT, tpctypes::UInt16Write>::value || + std::is_same<WriteT, tpctypes::UInt32Write>::value || + std::is_same<WriteT, tpctypes::UInt64Write>::value>> { +public: + static Error serialize(ChannelT &C, const WriteT &W) { + return serializeSeq(C, W.Address, W.Value); + } + static Error deserialize(ChannelT &C, WriteT &W) { + return deserializeSeq(C, W.Address, W.Value); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, tpctypes::BufferWrite, tpctypes::BufferWrite, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + static Error serialize(ChannelT &C, const tpctypes::BufferWrite &W) { + uint64_t Size = W.Buffer.size(); + if (auto Err = serializeSeq(C, W.Address, Size)) + return Err; + + return C.appendBytes(W.Buffer.data(), Size); + } + static Error deserialize(ChannelT &C, tpctypes::BufferWrite &W) { + JITTargetAddress Address; + uint64_t Size; + + if (auto Err = deserializeSeq(C, Address, Size)) + return Err; + + char *Buffer = jitTargetAddressToPointer<char *>(Address); + + if (auto Err = C.readBytes(Buffer, Size)) + return Err; + + W = {Address, StringRef(Buffer, Size)}; + return Error::success(); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, orcrpctpc::ReserveMemRequestElement> { +public: + static Error serialize(ChannelT &C, + const orcrpctpc::ReserveMemRequestElement &E) { + return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Size, E.Alignment); + } + + static Error deserialize(ChannelT &C, + orcrpctpc::ReserveMemRequestElement &E) { + return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Size, + E.Alignment); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, orcrpctpc::ReserveMemResultElement> { +public: + static Error serialize(ChannelT &C, + const orcrpctpc::ReserveMemResultElement &E) { + return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Address, + E.AllocatedSize); + } + + static Error deserialize(ChannelT &C, orcrpctpc::ReserveMemResultElement &E) { + return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Address, + E.AllocatedSize); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, + orcrpctpc::ReleaseOrFinalizeMemRequestElement> { +public: + static Error + serialize(ChannelT &C, + const orcrpctpc::ReleaseOrFinalizeMemRequestElement &E) { + return serializeSeq(C, static_cast<uint8_t>(E.Prot), E.Address, E.Size); + } + + static Error deserialize(ChannelT &C, + orcrpctpc::ReleaseOrFinalizeMemRequestElement &E) { + return deserializeSeq(C, *reinterpret_cast<uint8_t *>(&E.Prot), E.Address, + E.Size); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, tpctypes::WrapperFunctionResult, tpctypes::WrapperFunctionResult, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + static Error serialize(ChannelT &C, + const tpctypes::WrapperFunctionResult &E) { + auto Data = E.getData(); + if (auto Err = serializeSeq(C, static_cast<uint64_t>(Data.size()))) + return Err; + if (Data.size() == 0) + return Error::success(); + return C.appendBytes(reinterpret_cast<const char *>(Data.data()), + Data.size()); + } + + static Error deserialize(ChannelT &C, tpctypes::WrapperFunctionResult &E) { + tpctypes::CWrapperFunctionResult R; + + R.Size = 0; + R.Data.ValuePtr = nullptr; + R.Destroy = nullptr; + + if (auto Err = deserializeSeq(C, R.Size)) + return Err; + if (R.Size == 0) + return Error::success(); + R.Data.ValuePtr = new uint8_t[R.Size]; + if (auto Err = + C.readBytes(reinterpret_cast<char *>(R.Data.ValuePtr), R.Size)) { + R.Destroy = tpctypes::WrapperFunctionResult::destroyWithDeleteArray; + return Err; + } + + E = tpctypes::WrapperFunctionResult(R); + return Error::success(); + } +}; + +} // end namespace shared + +namespace orcrpctpc { + +using RemoteSymbolLookupSet = std::vector<std::pair<std::string, bool>>; +using RemoteLookupRequest = + std::pair<tpctypes::DylibHandle, RemoteSymbolLookupSet>; + +class GetTargetTriple + : public shared::RPCFunction<GetTargetTriple, std::string()> { +public: + static const char *getName() { return "GetTargetTriple"; } +}; + +class GetPageSize : public shared::RPCFunction<GetPageSize, uint64_t()> { +public: + static const char *getName() { return "GetPageSize"; } +}; + +class ReserveMem + : public shared::RPCFunction<ReserveMem, Expected<ReserveMemResult>( + ReserveMemRequest)> { +public: + static const char *getName() { return "ReserveMem"; } +}; + +class FinalizeMem + : public shared::RPCFunction<FinalizeMem, + Error(ReleaseOrFinalizeMemRequest)> { +public: + static const char *getName() { return "FinalizeMem"; } +}; + +class ReleaseMem + : public shared::RPCFunction<ReleaseMem, + Error(ReleaseOrFinalizeMemRequest)> { +public: + static const char *getName() { return "ReleaseMem"; } +}; + +class WriteUInt8s + : public shared::RPCFunction<WriteUInt8s, + Error(std::vector<tpctypes::UInt8Write>)> { +public: + static const char *getName() { return "WriteUInt8s"; } +}; + +class WriteUInt16s + : public shared::RPCFunction<WriteUInt16s, + Error(std::vector<tpctypes::UInt16Write>)> { +public: + static const char *getName() { return "WriteUInt16s"; } +}; + +class WriteUInt32s + : public shared::RPCFunction<WriteUInt32s, + Error(std::vector<tpctypes::UInt32Write>)> { +public: + static const char *getName() { return "WriteUInt32s"; } +}; + +class WriteUInt64s + : public shared::RPCFunction<WriteUInt64s, + Error(std::vector<tpctypes::UInt64Write>)> { +public: + static const char *getName() { return "WriteUInt64s"; } +}; + +class WriteBuffers + : public shared::RPCFunction<WriteBuffers, + Error(std::vector<tpctypes::BufferWrite>)> { +public: + static const char *getName() { return "WriteBuffers"; } +}; + +class LoadDylib + : public shared::RPCFunction<LoadDylib, Expected<tpctypes::DylibHandle>( + std::string DylibPath)> { +public: + static const char *getName() { return "LoadDylib"; } +}; + +class LookupSymbols + : public shared::RPCFunction<LookupSymbols, + Expected<std::vector<tpctypes::LookupResult>>( + std::vector<RemoteLookupRequest>)> { +public: + static const char *getName() { return "LookupSymbols"; } +}; + +class RunMain + : public shared::RPCFunction<RunMain, + int32_t(JITTargetAddress MainAddr, + std::vector<std::string> Args)> { +public: + static const char *getName() { return "RunMain"; } +}; + +class RunWrapper + : public shared::RPCFunction<RunWrapper, + tpctypes::WrapperFunctionResult( + JITTargetAddress, std::vector<uint8_t>)> { +public: + static const char *getName() { return "RunWrapper"; } +}; + +class CloseConnection : public shared::RPCFunction<CloseConnection, void()> { +public: + static const char *getName() { return "CloseConnection"; } +}; + +} // end namespace orcrpctpc + +/// TargetProcessControl for a process connected via an ORC RPC Endpoint. +template <typename RPCEndpointT> class OrcRPCTPCServer { +public: + /// Create an OrcRPCTPCServer from the given endpoint. + OrcRPCTPCServer(RPCEndpointT &EP) : EP(EP) { + using ThisT = OrcRPCTPCServer<RPCEndpointT>; + + TripleStr = sys::getProcessTriple(); + PageSize = sys::Process::getPageSizeEstimate(); + + EP.template addHandler<orcrpctpc::GetTargetTriple>(*this, + &ThisT::getTargetTriple); + EP.template addHandler<orcrpctpc::GetPageSize>(*this, &ThisT::getPageSize); + + EP.template addHandler<orcrpctpc::ReserveMem>(*this, &ThisT::reserveMemory); + EP.template addHandler<orcrpctpc::FinalizeMem>(*this, + &ThisT::finalizeMemory); + EP.template addHandler<orcrpctpc::ReleaseMem>(*this, &ThisT::releaseMemory); + + EP.template addHandler<orcrpctpc::WriteUInt8s>( + handleWriteUInt<tpctypes::UInt8Write>); + EP.template addHandler<orcrpctpc::WriteUInt16s>( + handleWriteUInt<tpctypes::UInt16Write>); + EP.template addHandler<orcrpctpc::WriteUInt32s>( + handleWriteUInt<tpctypes::UInt32Write>); + EP.template addHandler<orcrpctpc::WriteUInt64s>( + handleWriteUInt<tpctypes::UInt64Write>); + EP.template addHandler<orcrpctpc::WriteBuffers>(handleWriteBuffer); + + EP.template addHandler<orcrpctpc::LoadDylib>(*this, &ThisT::loadDylib); + EP.template addHandler<orcrpctpc::LookupSymbols>(*this, + &ThisT::lookupSymbols); + + EP.template addHandler<orcrpctpc::RunMain>(*this, &ThisT::runMain); + EP.template addHandler<orcrpctpc::RunWrapper>(*this, &ThisT::runWrapper); + + EP.template addHandler<orcrpctpc::CloseConnection>(*this, + &ThisT::closeConnection); + } + + /// Set the ProgramName to be used as the first argv element when running + /// functions via runAsMain. + void setProgramName(Optional<std::string> ProgramName = None) { + this->ProgramName = std::move(ProgramName); + } + + /// Get the RPC endpoint for this server. + RPCEndpointT &getEndpoint() { return EP; } + + /// Run the server loop. + Error run() { + while (!Finished) { + if (auto Err = EP.handleOne()) + return Err; + } + return Error::success(); + } + +private: + std::string getTargetTriple() { return TripleStr; } + uint64_t getPageSize() { return PageSize; } + + template <typename WriteT> + static void handleWriteUInt(const std::vector<WriteT> &Ws) { + using ValueT = decltype(std::declval<WriteT>().Value); + for (auto &W : Ws) + *jitTargetAddressToPointer<ValueT *>(W.Address) = W.Value; + } + + std::string getProtStr(orcrpctpc::WireProtectionFlags WPF) { + std::string Result; + Result += (WPF & orcrpctpc::WPF_Read) ? 'R' : '-'; + Result += (WPF & orcrpctpc::WPF_Write) ? 'W' : '-'; + Result += (WPF & orcrpctpc::WPF_Exec) ? 'X' : '-'; + return Result; + } + + static void handleWriteBuffer(const std::vector<tpctypes::BufferWrite> &Ws) { + for (auto &W : Ws) { + memcpy(jitTargetAddressToPointer<char *>(W.Address), W.Buffer.data(), + W.Buffer.size()); + } + } + + Expected<orcrpctpc::ReserveMemResult> + reserveMemory(const orcrpctpc::ReserveMemRequest &Request) { + orcrpctpc::ReserveMemResult Allocs; + auto PF = sys::Memory::MF_READ | sys::Memory::MF_WRITE; + + uint64_t TotalSize = 0; + + for (const auto &E : Request) { + uint64_t Size = alignTo(E.Size, PageSize); + uint16_t Align = E.Alignment; + + if ((Align > PageSize) || (PageSize % Align)) + return make_error<StringError>( + "Page alignmen does not satisfy requested alignment", + inconvertibleErrorCode()); + + TotalSize += Size; + } + + // Allocate memory slab. + std::error_code EC; + auto MB = sys::Memory::allocateMappedMemory(TotalSize, nullptr, PF, EC); + if (EC) + return make_error<StringError>("Unable to allocate memory: " + + EC.message(), + inconvertibleErrorCode()); + + // Zero-fill the whole thing. + memset(MB.base(), 0, MB.allocatedSize()); + + // Carve up sections to return. + uint64_t SectionBase = 0; + for (const auto &E : Request) { + uint64_t SectionSize = alignTo(E.Size, PageSize); + Allocs.push_back({E.Prot, + pointerToJITTargetAddress(MB.base()) + SectionBase, + SectionSize}); + SectionBase += SectionSize; + } + + return Allocs; + } + + Error finalizeMemory(const orcrpctpc::ReleaseOrFinalizeMemRequest &FMR) { + for (const auto &E : FMR) { + sys::MemoryBlock MB(jitTargetAddressToPointer<void *>(E.Address), E.Size); + + auto PF = orcrpctpc::fromWireProtectionFlags(E.Prot); + if (auto EC = + sys::Memory::protectMappedMemory(MB, static_cast<unsigned>(PF))) + return make_error<StringError>("error protecting memory: " + + EC.message(), + inconvertibleErrorCode()); + } + return Error::success(); + } + + Error releaseMemory(const orcrpctpc::ReleaseOrFinalizeMemRequest &RMR) { + for (const auto &E : RMR) { + sys::MemoryBlock MB(jitTargetAddressToPointer<void *>(E.Address), E.Size); + + if (auto EC = sys::Memory::releaseMappedMemory(MB)) + return make_error<StringError>("error release memory: " + EC.message(), + inconvertibleErrorCode()); + } + return Error::success(); + } + + Expected<tpctypes::DylibHandle> loadDylib(const std::string &Path) { + std::string ErrMsg; + const char *DLPath = !Path.empty() ? Path.c_str() : nullptr; + auto DL = sys::DynamicLibrary::getPermanentLibrary(DLPath, &ErrMsg); + if (!DL.isValid()) + return make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode()); + + tpctypes::DylibHandle H = Dylibs.size(); + Dylibs[H] = std::move(DL); + return H; + } + + Expected<std::vector<tpctypes::LookupResult>> + lookupSymbols(const std::vector<orcrpctpc::RemoteLookupRequest> &Request) { + std::vector<tpctypes::LookupResult> Result; + + for (const auto &E : Request) { + auto I = Dylibs.find(E.first); + if (I == Dylibs.end()) + return make_error<StringError>("Unrecognized handle", + inconvertibleErrorCode()); + auto &DL = I->second; + Result.push_back({}); + + for (const auto &KV : E.second) { + auto &SymString = KV.first; + bool WeakReference = KV.second; + + const char *Sym = SymString.c_str(); +#ifdef __APPLE__ + if (*Sym == '_') + ++Sym; +#endif + + void *Addr = DL.getAddressOfSymbol(Sym); + if (!Addr && !WeakReference) + return make_error<StringError>(Twine("Missing definition for ") + Sym, + inconvertibleErrorCode()); + + Result.back().push_back(pointerToJITTargetAddress(Addr)); + } + } + + return Result; + } + + int32_t runMain(JITTargetAddress MainFnAddr, + const std::vector<std::string> &Args) { + Optional<StringRef> ProgramNameOverride; + if (ProgramName) + ProgramNameOverride = *ProgramName; + + return runAsMain( + jitTargetAddressToFunction<int (*)(int, char *[])>(MainFnAddr), Args, + ProgramNameOverride); + } + + tpctypes::WrapperFunctionResult + runWrapper(JITTargetAddress WrapperFnAddr, + const std::vector<uint8_t> &ArgBuffer) { + using WrapperFnTy = tpctypes::CWrapperFunctionResult (*)( + const uint8_t *Data, uint64_t Size); + auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr); + return WrapperFn(ArgBuffer.data(), ArgBuffer.size()); + } + + void closeConnection() { Finished = true; } + + std::string TripleStr; + uint64_t PageSize = 0; + Optional<std::string> ProgramName; + RPCEndpointT &EP; + std::atomic<bool> Finished{false}; + DenseMap<tpctypes::DylibHandle, sys::DynamicLibrary> Dylibs; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_ORCRPCTPCSERVER_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h index b784287fb8..a178dc7df1 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h @@ -1,52 +1,52 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===----- RegisterEHFrames.h -- Register EH frame sections -----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Support for dynamically registering and deregistering eh-frame sections -// in-process via libunwind. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H -#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H - -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" -#include "llvm/Support/Error.h" -#include <vector> - -namespace llvm { -namespace orc { - -/// Register frames in the given eh-frame section with libunwind. -Error registerEHFrameSection(const void *EHFrameSectionAddr, - size_t EHFrameSectionSize); - -/// Unregister frames in the given eh-frame section with libunwind. -Error deregisterEHFrameSection(const void *EHFrameSectionAddr, - size_t EHFrameSectionSize); - -} // end namespace orc -} // end namespace llvm - -extern "C" llvm::orc::tpctypes::CWrapperFunctionResult -llvm_orc_registerEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); - -extern "C" llvm::orc::tpctypes::CWrapperFunctionResult -llvm_orc_deregisterEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); - -#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===----- RegisterEHFrames.h -- Register EH frame sections -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Support for dynamically registering and deregistering eh-frame sections +// in-process via libunwind. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H +#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H + +#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/Support/Error.h" +#include <vector> + +namespace llvm { +namespace orc { + +/// Register frames in the given eh-frame section with libunwind. +Error registerEHFrameSection(const void *EHFrameSectionAddr, + size_t EHFrameSectionSize); + +/// Unregister frames in the given eh-frame section with libunwind. +Error deregisterEHFrameSection(const void *EHFrameSectionAddr, + size_t EHFrameSectionSize); + +} // end namespace orc +} // end namespace llvm + +extern "C" llvm::orc::tpctypes::CWrapperFunctionResult +llvm_orc_registerEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); + +extern "C" llvm::orc::tpctypes::CWrapperFunctionResult +llvm_orc_deregisterEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); + +#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h index 15a354dd34..334660f4a9 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h @@ -1,49 +1,49 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===-- TargetExecutionUtils.h - Utils for execution in target --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for execution in the target process. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H -#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include <string> - -namespace llvm { -namespace orc { - -/// Run a main function, returning the result. -/// -/// If the optional ProgramName argument is given then it will be inserted -/// before the strings in Args as the first argument to the called function. -/// -/// It is legal to have an empty argument list and no program name, however -/// many main functions will expect a name argument at least, and will fail -/// if none is provided. -int runAsMain(int (*Main)(int, char *[]), ArrayRef<std::string> Args, - Optional<StringRef> ProgramName = None); - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===-- TargetExecutionUtils.h - Utils for execution in target --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for execution in the target process. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include <string> + +namespace llvm { +namespace orc { + +/// Run a main function, returning the result. +/// +/// If the optional ProgramName argument is given then it will be inserted +/// before the strings in Args as the first argument to the called function. +/// +/// It is legal to have an empty argument list and no program name, however +/// many main functions will expect a name argument at least, and will fail +/// if none is provided. +int runAsMain(int (*Main)(int, char *[]), ArrayRef<std::string> Args, + Optional<StringRef> ProgramName = None); + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_TARGETEXECUTIONUTILS_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h index 952bcc2a33..2c2e62ae64 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h @@ -1,229 +1,229 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===--- TargetProcessControl.h - Target process control APIs ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for interacting with target processes. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H -#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H - -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" -#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" -#include "llvm/Support/DynamicLibrary.h" -#include "llvm/Support/MSVCErrorWorkarounds.h" - -#include <future> -#include <vector> - -namespace llvm { -namespace orc { - -/// TargetProcessControl supports interaction with a JIT target process. -class TargetProcessControl { -public: - /// APIs for manipulating memory in the target process. - class MemoryAccess { - public: - /// Callback function for asynchronous writes. - using WriteResultFn = unique_function<void(Error)>; - - virtual ~MemoryAccess(); - - virtual void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, - WriteResultFn OnWriteComplete) = 0; - - virtual void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, - WriteResultFn OnWriteComplete) = 0; - - virtual void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, - WriteResultFn OnWriteComplete) = 0; - - virtual void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, - WriteResultFn OnWriteComplete) = 0; - - virtual void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, - WriteResultFn OnWriteComplete) = 0; - - Error writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws) { - std::promise<MSVCPError> ResultP; - auto ResultF = ResultP.get_future(); - writeUInt8s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); - return ResultF.get(); - } - - Error writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws) { - std::promise<MSVCPError> ResultP; - auto ResultF = ResultP.get_future(); - writeUInt16s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); - return ResultF.get(); - } - - Error writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws) { - std::promise<MSVCPError> ResultP; - auto ResultF = ResultP.get_future(); - writeUInt32s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); - return ResultF.get(); - } - - Error writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws) { - std::promise<MSVCPError> ResultP; - auto ResultF = ResultP.get_future(); - writeUInt64s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); - return ResultF.get(); - } - - Error writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws) { - std::promise<MSVCPError> ResultP; - auto ResultF = ResultP.get_future(); - writeBuffers(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); - return ResultF.get(); - } - }; - - /// A pair of a dylib and a set of symbols to be looked up. - struct LookupRequest { - LookupRequest(tpctypes::DylibHandle Handle, const SymbolLookupSet &Symbols) - : Handle(Handle), Symbols(Symbols) {} - tpctypes::DylibHandle Handle; - const SymbolLookupSet &Symbols; - }; - - virtual ~TargetProcessControl(); - - /// Intern a symbol name in the SymbolStringPool. - SymbolStringPtr intern(StringRef SymName) { return SSP->intern(SymName); } - - /// Return a shared pointer to the SymbolStringPool for this instance. - std::shared_ptr<SymbolStringPool> getSymbolStringPool() const { return SSP; } - - /// Return the Triple for the target process. - const Triple &getTargetTriple() const { return TargetTriple; } - - /// Get the page size for the target process. - unsigned getPageSize() const { return PageSize; } - - /// Return a MemoryAccess object for the target process. - MemoryAccess &getMemoryAccess() const { return *MemAccess; } - - /// Return a JITLinkMemoryManager for the target process. - jitlink::JITLinkMemoryManager &getMemMgr() const { return *MemMgr; } - - /// Load the dynamic library at the given path and return a handle to it. - /// If LibraryPath is null this function will return the global handle for - /// the target process. - virtual Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) = 0; - - /// Search for symbols in the target process. - /// - /// The result of the lookup is a 2-dimentional array of target addresses - /// that correspond to the lookup order. If a required symbol is not - /// found then this method will return an error. If a weakly referenced - /// symbol is not found then it be assigned a '0' value in the result. - /// that correspond to the lookup order. - virtual Expected<std::vector<tpctypes::LookupResult>> - lookupSymbols(ArrayRef<LookupRequest> Request) = 0; - - /// Run function with a main-like signature. - virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, - ArrayRef<std::string> Args) = 0; - - /// Run a wrapper function with signature: - /// - /// \code{.cpp} - /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size); - /// \endcode{.cpp} - /// - virtual Expected<tpctypes::WrapperFunctionResult> - runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<uint8_t> ArgBuffer) = 0; - - /// Disconnect from the target process. - /// - /// This should be called after the JIT session is shut down. - virtual Error disconnect() = 0; - -protected: - TargetProcessControl(std::shared_ptr<SymbolStringPool> SSP) - : SSP(std::move(SSP)) {} - - std::shared_ptr<SymbolStringPool> SSP; - Triple TargetTriple; - unsigned PageSize = 0; - MemoryAccess *MemAccess = nullptr; - jitlink::JITLinkMemoryManager *MemMgr = nullptr; -}; - -/// A TargetProcessControl implementation targeting the current process. -class SelfTargetProcessControl : public TargetProcessControl, - private TargetProcessControl::MemoryAccess { -public: - SelfTargetProcessControl( - std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple, - unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr); - - /// Create a SelfTargetProcessControl with the given memory manager. - /// If no memory manager is given a jitlink::InProcessMemoryManager will - /// be used by default. - static Expected<std::unique_ptr<SelfTargetProcessControl>> - Create(std::shared_ptr<SymbolStringPool> SSP, - std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr = nullptr); - - Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override; - - Expected<std::vector<tpctypes::LookupResult>> - lookupSymbols(ArrayRef<LookupRequest> Request) override; - - Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, - ArrayRef<std::string> Args) override; - - Expected<tpctypes::WrapperFunctionResult> - runWrapper(JITTargetAddress WrapperFnAddr, - ArrayRef<uint8_t> ArgBuffer) override; - - Error disconnect() override; - -private: - void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, - WriteResultFn OnWriteComplete) override; - - void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, - WriteResultFn OnWriteComplete) override; - - std::unique_ptr<jitlink::JITLinkMemoryManager> OwnedMemMgr; - char GlobalManglingPrefix = 0; - std::vector<std::unique_ptr<sys::DynamicLibrary>> DynamicLibraries; -}; - -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===--- TargetProcessControl.h - Target process control APIs ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for interacting with target processes. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H +#define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include <future> +#include <vector> + +namespace llvm { +namespace orc { + +/// TargetProcessControl supports interaction with a JIT target process. +class TargetProcessControl { +public: + /// APIs for manipulating memory in the target process. + class MemoryAccess { + public: + /// Callback function for asynchronous writes. + using WriteResultFn = unique_function<void(Error)>; + + virtual ~MemoryAccess(); + + virtual void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, + WriteResultFn OnWriteComplete) = 0; + + virtual void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, + WriteResultFn OnWriteComplete) = 0; + + virtual void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, + WriteResultFn OnWriteComplete) = 0; + + virtual void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, + WriteResultFn OnWriteComplete) = 0; + + virtual void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, + WriteResultFn OnWriteComplete) = 0; + + Error writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws) { + std::promise<MSVCPError> ResultP; + auto ResultF = ResultP.get_future(); + writeUInt8s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); + return ResultF.get(); + } + + Error writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws) { + std::promise<MSVCPError> ResultP; + auto ResultF = ResultP.get_future(); + writeUInt16s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); + return ResultF.get(); + } + + Error writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws) { + std::promise<MSVCPError> ResultP; + auto ResultF = ResultP.get_future(); + writeUInt32s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); + return ResultF.get(); + } + + Error writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws) { + std::promise<MSVCPError> ResultP; + auto ResultF = ResultP.get_future(); + writeUInt64s(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); + return ResultF.get(); + } + + Error writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws) { + std::promise<MSVCPError> ResultP; + auto ResultF = ResultP.get_future(); + writeBuffers(Ws, [&](Error Err) { ResultP.set_value(std::move(Err)); }); + return ResultF.get(); + } + }; + + /// A pair of a dylib and a set of symbols to be looked up. + struct LookupRequest { + LookupRequest(tpctypes::DylibHandle Handle, const SymbolLookupSet &Symbols) + : Handle(Handle), Symbols(Symbols) {} + tpctypes::DylibHandle Handle; + const SymbolLookupSet &Symbols; + }; + + virtual ~TargetProcessControl(); + + /// Intern a symbol name in the SymbolStringPool. + SymbolStringPtr intern(StringRef SymName) { return SSP->intern(SymName); } + + /// Return a shared pointer to the SymbolStringPool for this instance. + std::shared_ptr<SymbolStringPool> getSymbolStringPool() const { return SSP; } + + /// Return the Triple for the target process. + const Triple &getTargetTriple() const { return TargetTriple; } + + /// Get the page size for the target process. + unsigned getPageSize() const { return PageSize; } + + /// Return a MemoryAccess object for the target process. + MemoryAccess &getMemoryAccess() const { return *MemAccess; } + + /// Return a JITLinkMemoryManager for the target process. + jitlink::JITLinkMemoryManager &getMemMgr() const { return *MemMgr; } + + /// Load the dynamic library at the given path and return a handle to it. + /// If LibraryPath is null this function will return the global handle for + /// the target process. + virtual Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) = 0; + + /// Search for symbols in the target process. + /// + /// The result of the lookup is a 2-dimentional array of target addresses + /// that correspond to the lookup order. If a required symbol is not + /// found then this method will return an error. If a weakly referenced + /// symbol is not found then it be assigned a '0' value in the result. + /// that correspond to the lookup order. + virtual Expected<std::vector<tpctypes::LookupResult>> + lookupSymbols(ArrayRef<LookupRequest> Request) = 0; + + /// Run function with a main-like signature. + virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, + ArrayRef<std::string> Args) = 0; + + /// Run a wrapper function with signature: + /// + /// \code{.cpp} + /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size); + /// \endcode{.cpp} + /// + virtual Expected<tpctypes::WrapperFunctionResult> + runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<uint8_t> ArgBuffer) = 0; + + /// Disconnect from the target process. + /// + /// This should be called after the JIT session is shut down. + virtual Error disconnect() = 0; + +protected: + TargetProcessControl(std::shared_ptr<SymbolStringPool> SSP) + : SSP(std::move(SSP)) {} + + std::shared_ptr<SymbolStringPool> SSP; + Triple TargetTriple; + unsigned PageSize = 0; + MemoryAccess *MemAccess = nullptr; + jitlink::JITLinkMemoryManager *MemMgr = nullptr; +}; + +/// A TargetProcessControl implementation targeting the current process. +class SelfTargetProcessControl : public TargetProcessControl, + private TargetProcessControl::MemoryAccess { +public: + SelfTargetProcessControl( + std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple, + unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr); + + /// Create a SelfTargetProcessControl with the given memory manager. + /// If no memory manager is given a jitlink::InProcessMemoryManager will + /// be used by default. + static Expected<std::unique_ptr<SelfTargetProcessControl>> + Create(std::shared_ptr<SymbolStringPool> SSP, + std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr = nullptr); + + Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override; + + Expected<std::vector<tpctypes::LookupResult>> + lookupSymbols(ArrayRef<LookupRequest> Request) override; + + Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, + ArrayRef<std::string> Args) override; + + Expected<tpctypes::WrapperFunctionResult> + runWrapper(JITTargetAddress WrapperFnAddr, + ArrayRef<uint8_t> ArgBuffer) override; + + Error disconnect() override; + +private: + void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, + WriteResultFn OnWriteComplete) override; + + void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, + WriteResultFn OnWriteComplete) override; + + std::unique_ptr<jitlink::JITLinkMemoryManager> OwnedMemMgr; + char GlobalManglingPrefix = 0; + std::vector<std::unique_ptr<sys::DynamicLibrary>> DynamicLibraries; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESSCONTROL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h index d1b0e81018..3e412f3385 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h @@ -169,7 +169,7 @@ using GVModifier = std::function<void(GlobalValue &)>; /// Clones the given module on to a new context. ThreadSafeModule -cloneToNewContext(const ThreadSafeModule &TSMW, +cloneToNewContext(const ThreadSafeModule &TSMW, GVPredicate ShouldCloneDef = GVPredicate(), GVModifier UpdateClonedDefSource = GVModifier()); |