diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
commit | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch) | |
tree | 64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared | |
parent | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff) | |
download | ydb-e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared')
6 files changed, 2993 insertions, 2993 deletions
diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h index d4aa712442..e1a376bc6b 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/FDRawByteChannel.h @@ -1,90 +1,90 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- FDRawByteChannel.h - File descriptor based byte-channel -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// File descriptor based RawByteChannel. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H - -#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" - -#if !defined(_MSC_VER) && !defined(__MINGW32__) -#include <unistd.h> -#else -#include <io.h> -#endif - -namespace llvm { -namespace orc { -namespace shared { - -/// Serialization channel that reads from and writes from file descriptors. -class FDRawByteChannel final : public RawByteChannel { -public: - FDRawByteChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} - - llvm::Error readBytes(char *Dst, unsigned Size) override { - assert(Dst && "Attempt to read into null."); - ssize_t Completed = 0; - while (Completed < static_cast<ssize_t>(Size)) { - ssize_t Read = ::read(InFD, Dst + Completed, Size - Completed); - if (Read <= 0) { - auto ErrNo = errno; - if (ErrNo == EAGAIN || ErrNo == EINTR) - continue; - else - return llvm::errorCodeToError( - std::error_code(errno, std::generic_category())); - } - Completed += Read; - } - return llvm::Error::success(); - } - - llvm::Error appendBytes(const char *Src, unsigned Size) override { - assert(Src && "Attempt to append from null."); - ssize_t Completed = 0; - while (Completed < static_cast<ssize_t>(Size)) { - ssize_t Written = ::write(OutFD, Src + Completed, Size - Completed); - if (Written < 0) { - auto ErrNo = errno; - if (ErrNo == EAGAIN || ErrNo == EINTR) - continue; - else - return llvm::errorCodeToError( - std::error_code(errno, std::generic_category())); - } - Completed += Written; - } - return llvm::Error::success(); - } - - llvm::Error send() override { return llvm::Error::success(); } - -private: - int InFD, OutFD; -}; - -} // namespace shared -} // namespace orc -} // namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- FDRawByteChannel.h - File descriptor based byte-channel -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// File descriptor based RawByteChannel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H + +#include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" + +#if !defined(_MSC_VER) && !defined(__MINGW32__) +#include <unistd.h> +#else +#include <io.h> +#endif + +namespace llvm { +namespace orc { +namespace shared { + +/// Serialization channel that reads from and writes from file descriptors. +class FDRawByteChannel final : public RawByteChannel { +public: + FDRawByteChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + + llvm::Error readBytes(char *Dst, unsigned Size) override { + assert(Dst && "Attempt to read into null."); + ssize_t Completed = 0; + while (Completed < static_cast<ssize_t>(Size)) { + ssize_t Read = ::read(InFD, Dst + Completed, Size - Completed); + if (Read <= 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Read; + } + return llvm::Error::success(); + } + + llvm::Error appendBytes(const char *Src, unsigned Size) override { + assert(Src && "Attempt to append from null."); + ssize_t Completed = 0; + while (Completed < static_cast<ssize_t>(Size)) { + ssize_t Written = ::write(OutFD, Src + Completed, Size - Completed); + if (Written < 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Written; + } + return llvm::Error::success(); + } + + llvm::Error send() override { return llvm::Error::success(); } + +private: + int InFD, OutFD; +}; + +} // namespace shared +} // namespace orc +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_FDRAWBYTECHANNEL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h index 172c35a221..2dde3afdce 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/OrcError.h @@ -1,85 +1,85 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===------ OrcError.h - Reject symbol lookup requests ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Define an error category, error codes, and helper utilities for Orc. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H -#define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H - -#include "llvm/Support/Error.h" -#include "llvm/Support/raw_ostream.h" -#include <string> -#include <system_error> - -namespace llvm { -namespace orc { - -enum class OrcErrorCode : int { - // RPC Errors - UnknownORCError = 1, - DuplicateDefinition, - JITSymbolNotFound, - RemoteAllocatorDoesNotExist, - RemoteAllocatorIdAlreadyInUse, - RemoteMProtectAddrUnrecognized, - RemoteIndirectStubsOwnerDoesNotExist, - RemoteIndirectStubsOwnerIdAlreadyInUse, - RPCConnectionClosed, - RPCCouldNotNegotiateFunction, - RPCResponseAbandoned, - UnexpectedRPCCall, - UnexpectedRPCResponse, - UnknownErrorCodeFromRemote, - UnknownResourceHandle, - MissingSymbolDefinitions, - UnexpectedSymbolDefinitions, -}; - -std::error_code orcError(OrcErrorCode ErrCode); - -class DuplicateDefinition : public ErrorInfo<DuplicateDefinition> { -public: - static char ID; - - DuplicateDefinition(std::string SymbolName); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSymbolName() const; -private: - std::string SymbolName; -}; - -class JITSymbolNotFound : public ErrorInfo<JITSymbolNotFound> { -public: - static char ID; - - JITSymbolNotFound(std::string SymbolName); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSymbolName() const; -private: - std::string SymbolName; -}; - -} // End namespace orc. -} // End namespace llvm. - -#endif // LLVM_EXECUTIONENGINE_ORC_ORCERROR_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===------ OrcError.h - Reject symbol lookup requests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Define an error category, error codes, and helper utilities for Orc. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H +#define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H + +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <string> +#include <system_error> + +namespace llvm { +namespace orc { + +enum class OrcErrorCode : int { + // RPC Errors + UnknownORCError = 1, + DuplicateDefinition, + JITSymbolNotFound, + RemoteAllocatorDoesNotExist, + RemoteAllocatorIdAlreadyInUse, + RemoteMProtectAddrUnrecognized, + RemoteIndirectStubsOwnerDoesNotExist, + RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCConnectionClosed, + RPCCouldNotNegotiateFunction, + RPCResponseAbandoned, + UnexpectedRPCCall, + UnexpectedRPCResponse, + UnknownErrorCodeFromRemote, + UnknownResourceHandle, + MissingSymbolDefinitions, + UnexpectedSymbolDefinitions, +}; + +std::error_code orcError(OrcErrorCode ErrCode); + +class DuplicateDefinition : public ErrorInfo<DuplicateDefinition> { +public: + static char ID; + + DuplicateDefinition(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +class JITSymbolNotFound : public ErrorInfo<JITSymbolNotFound> { +public: + static char ID; + + JITSymbolNotFound(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +} // End namespace orc. +} // End namespace llvm. + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCERROR_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h index 26b64ee2db..4bc6d3577b 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h @@ -1,1668 +1,1668 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities to support construction of simple RPC APIs. -// -// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ -// programmers, high performance, low memory overhead, and efficient use of the -// communications channel. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H - -#include <map> -#include <thread> -#include <vector> - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" -#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" -#include "llvm/Support/MSVCErrorWorkarounds.h" - -#include <future> - -namespace llvm { -namespace orc { -namespace shared { - -/// Base class of all fatal RPC errors (those that necessarily result in the -/// termination of the RPC session). -class RPCFatalError : public ErrorInfo<RPCFatalError> { -public: - static char ID; -}; - -/// RPCConnectionClosed is returned from RPC operations if the RPC connection -/// has already been closed due to either an error or graceful disconnection. -class ConnectionClosed : public ErrorInfo<ConnectionClosed> { -public: - static char ID; - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; -}; - -/// BadFunctionCall is returned from handleOne when the remote makes a call with -/// an unrecognized function id. -/// -/// This error is fatal because Orc RPC needs to know how to parse a function -/// call to know where the next call starts, and if it doesn't recognize the -/// function id it cannot parse the call. -template <typename FnIdT, typename SeqNoT> -class BadFunctionCall - : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { -public: - static char ID; - - BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) - : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} - - std::error_code convertToErrorCode() const override { - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - - void log(raw_ostream &OS) const override { - OS << "Call to invalid RPC function id '" << FnId - << "' with " - "sequence number " - << SeqNo; - } - -private: - FnIdT FnId; - SeqNoT SeqNo; -}; - -template <typename FnIdT, typename SeqNoT> -char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; - -/// InvalidSequenceNumberForResponse is returned from handleOne when a response -/// call arrives with a sequence number that doesn't correspond to any in-flight -/// function call. -/// -/// This error is fatal because Orc RPC needs to know how to parse the rest of -/// the response call to know where the next call starts, and if it doesn't have -/// a result parser for this sequence number it can't do that. -template <typename SeqNoT> -class InvalidSequenceNumberForResponse - : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, - RPCFatalError> { -public: - static char ID; - - InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} - - std::error_code convertToErrorCode() const override { - return orcError(OrcErrorCode::UnexpectedRPCCall); - }; - - void log(raw_ostream &OS) const override { - OS << "Response has unknown sequence number " << SeqNo; - } - -private: - SeqNoT SeqNo; -}; - -template <typename SeqNoT> -char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; - -/// This non-fatal error will be passed to asynchronous result handlers in place -/// of a result if the connection goes down before a result returns, or if the -/// function to be called cannot be negotiated with the remote. -class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { -public: - static char ID; - - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; -}; - -/// This error is returned if the remote does not have a handler installed for -/// the given RPC function. -class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { -public: - static char ID; - - CouldNotNegotiate(std::string Signature); - std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; - const std::string &getSignature() const { return Signature; } - -private: - std::string Signature; -}; - -template <typename DerivedFunc, typename FnT> class RPCFunction; - -// RPC Function class. -// DerivedFunc should be a user defined class with a static 'getName()' method -// returning a const char* representing the function's name. -template <typename DerivedFunc, typename RetT, typename... ArgTs> -class RPCFunction<DerivedFunc, RetT(ArgTs...)> { -public: - /// User defined function type. - using Type = RetT(ArgTs...); - - /// Return type. - using ReturnType = RetT; - - /// Returns the full function prototype as a string. - static const char *getPrototype() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << SerializationTypeName<RetT>::getName() << " " - << DerivedFunc::getName() << "(" - << SerializationTypeNameSequence<ArgTs...>() << ")"; - return Name; - }(); - return Name.data(); - } -}; - -/// Allocates RPC function ids during autonegotiation. -/// Specializations of this class must provide four members: -/// -/// static T getInvalidId(): -/// Should return a reserved id that will be used to represent missing -/// functions during autonegotiation. -/// -/// static T getResponseId(): -/// Should return a reserved id that will be used to send function responses -/// (return values). -/// -/// static T getNegotiateId(): -/// Should return a reserved id for the negotiate function, which will be used -/// to negotiate ids for user defined functions. -/// -/// template <typename Func> T allocate(): -/// Allocate a unique id for function Func. -template <typename T, typename = void> class RPCFunctionIdAllocator; - -/// This specialization of RPCFunctionIdAllocator provides a default -/// implementation for integral types. -template <typename T> -class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> { -public: - static T getInvalidId() { return T(0); } - static T getResponseId() { return T(1); } - static T getNegotiateId() { return T(2); } - - template <typename Func> T allocate() { return NextId++; } - -private: - T NextId = 3; -}; - -namespace detail { - -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class RPCFunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class RPCFunctionArgsTuple<RetT(ArgTs...)> { -public: - using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>; -}; - -// ResultTraits provides typedefs and utilities specific to the return type -// of functions. -template <typename RetT> class ResultTraits { -public: - // The return type wrapped in llvm::Expected. - using ErrorReturnType = Expected<RetT>; - -#ifdef _MSC_VER - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<MSVCPExpected<RetT>>; -#else - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<ErrorReturnType>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<ErrorReturnType>; -#endif - - // Create a 'blank' value of the ErrorReturnType, ready and safe to - // overwrite. - static ErrorReturnType createBlankErrorReturnValue() { - return ErrorReturnType(RetT()); - } - - // Consume an abandoned ErrorReturnType. - static void consumeAbandoned(ErrorReturnType RetOrErr) { - consumeError(RetOrErr.takeError()); - } -}; - -// ResultTraits specialization for void functions. -template <> class ResultTraits<void> { -public: - // For void functions, ErrorReturnType is llvm::Error. - using ErrorReturnType = Error; - -#ifdef _MSC_VER - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<MSVCPError>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<MSVCPError>; -#else - // The ErrorReturnType wrapped in a std::promise. - using ReturnPromiseType = std::promise<ErrorReturnType>; - - // The ErrorReturnType wrapped in a std::future. - using ReturnFutureType = std::future<ErrorReturnType>; -#endif - - // Create a 'blank' value of the ErrorReturnType, ready and safe to - // overwrite. - static ErrorReturnType createBlankErrorReturnValue() { - return ErrorReturnType::success(); - } - - // Consume an abandoned ErrorReturnType. - static void consumeAbandoned(ErrorReturnType Err) { - consumeError(std::move(Err)); - } -}; - -// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows -// handlers for void RPC functions to return either void (in which case they -// implicitly succeed) or Error (in which case their error return is -// propagated). See usage in HandlerTraits::runHandlerHelper. -template <> class ResultTraits<Error> : public ResultTraits<void> {}; - -// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows -// handlers for RPC functions returning a T to return either a T (in which -// case they implicitly succeed) or Expected<T> (in which case their error -// return is propagated). See usage in HandlerTraits::runHandlerHelper. -template <typename RetT> -class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; - -// Determines whether an RPC function's defined error return type supports -// error return value. -template <typename T> class SupportsErrorReturn { -public: - static const bool value = false; -}; - -template <> class SupportsErrorReturn<Error> { -public: - static const bool value = true; -}; - -template <typename T> class SupportsErrorReturn<Expected<T>> { -public: - static const bool value = true; -}; - -// RespondHelper packages return values based on whether or not the declared -// RPC function return type supports error returns. -template <bool FuncSupportsErrorReturn> class RespondHelper; - -// RespondHelper specialization for functions that support error returns. -template <> class RespondHelper<true> { -public: - // Send Expected<T>. - template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) - return ResultOrErr.takeError(); - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>:: - serialize(C, std::move(ResultOrErr))) - return Err; - - // Close the response message. - if (auto Err = C.endSendMessage()) - return Err; - return C.send(); - } - - template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err && Err.isA<RPCFatalError>()) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - if (auto Err2 = serializeSeq(C, std::move(Err))) - return Err2; - if (auto Err2 = C.endSendMessage()) - return Err2; - return C.send(); - } -}; - -// RespondHelper specialization for functions that do not support error returns. -template <> class RespondHelper<false> { -public: - template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - if (auto Err = ResultOrErr.takeError()) - return Err; - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( - C, *ResultOrErr)) - return Err; - - // End the response message. - if (auto Err = C.endSendMessage()) - return Err; - - return C.send(); - } - - template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> - static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - if (auto Err2 = C.endSendMessage()) - return Err2; - return C.send(); - } -}; - -// Send a response of the given wire return type (WireRetT) over the -// channel, with the given sequence number. -template <typename WireRetT, typename HandlerRetT, typename ChannelT, - typename FunctionIdT, typename SequenceNumberT> -Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, - Expected<HandlerRetT> ResultOrErr) { - return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: - template sendResult<WireRetT>(C, ResponseId, SeqNo, - std::move(ResultOrErr)); -} - -// Send an empty response message on the given channel to indicate that -// the handler ran. -template <typename WireRetT, typename ChannelT, typename FunctionIdT, - typename SequenceNumberT> -Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, - Error Err) { - return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult( - C, ResponseId, SeqNo, std::move(Err)); -} - -// Converts a given type to the equivalent error return type. -template <typename T> class WrappedHandlerReturn { -public: - using Type = Expected<T>; -}; - -template <typename T> class WrappedHandlerReturn<Expected<T>> { -public: - using Type = Expected<T>; -}; - -template <> class WrappedHandlerReturn<void> { -public: - using Type = Error; -}; - -template <> class WrappedHandlerReturn<Error> { -public: - using Type = Error; -}; - -template <> class WrappedHandlerReturn<ErrorSuccess> { -public: - using Type = Error; -}; - -// Traits class that strips the response function from the list of handler -// arguments. -template <typename FnT> class AsyncHandlerTraits; - -template <typename ResultT, typename... ArgTs> -class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, - ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Expected<ResultT>; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename... ArgTs> -class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { -public: - using Type = Error(ArgTs...); - using ResultType = Error; -}; - -template <typename ResponseHandlerT, typename... ArgTs> -class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> - : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>, - ArgTs...)> {}; - -// This template class provides utilities related to RPC function handlers. -// The base case applies to non-function types (the template class is -// specialized for function types) and inherits from the appropriate -// speciilization for the given non-function type's call operator. -template <typename HandlerT> -class HandlerTraits - : public HandlerTraits< - decltype(&std::remove_reference<HandlerT>::type::operator())> {}; - -// Traits for handlers with a given function type. -template <typename RetT, typename... ArgTs> -class HandlerTraits<RetT(ArgTs...)> { -public: - // Function type of the handler. - using Type = RetT(ArgTs...); - - // Return type of the handler. - using ReturnType = RetT; - - // Call the given handler with the given arguments. - template <typename HandlerT, typename... TArgTs> - static typename WrappedHandlerReturn<RetT>::Type - unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { - return unpackAndRunHelper(Handler, Args, - std::index_sequence_for<TArgTs...>()); - } - - // Call the given handler with the given arguments. - template <typename HandlerT, typename ResponderT, typename... TArgTs> - static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, - std::tuple<TArgTs...> &Args) { - return unpackAndRunAsyncHelper(Handler, Responder, Args, - std::index_sequence_for<TArgTs...>()); - } - - // Call the given handler with the given arguments. - template <typename HandlerT> - static std::enable_if_t< - std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error> - run(HandlerT &Handler, ArgTs &&...Args) { - Handler(std::move(Args)...); - return Error::success(); - } - - template <typename HandlerT, typename... TArgTs> - static std::enable_if_t< - !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, - typename HandlerTraits<HandlerT>::ReturnType> - run(HandlerT &Handler, TArgTs... Args) { - return Handler(std::move(Args)...); - } - - // Serialize arguments to the channel. - template <typename ChannelT, typename... CArgTs> - static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); - } - - // Deserialize arguments from the channel. - template <typename ChannelT, typename... CArgTs> - static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { - return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); - } - -private: - template <typename ChannelT, typename... CArgTs, size_t... Indexes> - static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, - std::index_sequence<Indexes...> _) { - return SequenceSerialization<ChannelT, ArgTs...>::deserialize( - C, std::get<Indexes>(Args)...); - } - - template <typename HandlerT, typename ArgTuple, size_t... Indexes> - static typename WrappedHandlerReturn< - typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, - std::index_sequence<Indexes...>) { - return run(Handler, std::move(std::get<Indexes>(Args))...); - } - - template <typename HandlerT, typename ResponderT, typename ArgTuple, - size_t... Indexes> - static typename WrappedHandlerReturn< - typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, - ArgTuple &Args, std::index_sequence<Indexes...>) { - return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); - } -}; - -// Handler traits for free functions. -template <typename RetT, typename... ArgTs> -class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> { -}; - -// Handler traits for class methods (especially call operators for lambdas). -template <typename Class, typename RetT, typename... ArgTs> -class HandlerTraits<RetT (Class::*)(ArgTs...)> - : public HandlerTraits<RetT(ArgTs...)> {}; - -// Handler traits for const class methods (especially call operators for -// lambdas). -template <typename Class, typename RetT, typename... ArgTs> -class HandlerTraits<RetT (Class::*)(ArgTs...) const> - : public HandlerTraits<RetT(ArgTs...)> {}; - -// Utility to peel the Expected wrapper off a response handler error type. -template <typename HandlerT> class ResponseHandlerArg; - -template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { -public: - using ArgType = Expected<ArgT>; - using UnwrappedArgType = ArgT; -}; - -template <typename ArgT> -class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { -public: - using ArgType = Expected<ArgT>; - using UnwrappedArgType = ArgT; -}; - -template <> class ResponseHandlerArg<Error(Error)> { -public: - using ArgType = Error; -}; - -template <> class ResponseHandlerArg<ErrorSuccess(Error)> { -public: - using ArgType = Error; -}; - -// ResponseHandler represents a handler for a not-yet-received function call -// result. -template <typename ChannelT> class ResponseHandler { -public: - virtual ~ResponseHandler() {} - - // Reads the function result off the wire and acts on it. The meaning of - // "act" will depend on how this method is implemented in any given - // ResponseHandler subclass but could, for example, mean running a - // user-specified handler or setting a promise value. - virtual Error handleResponse(ChannelT &C) = 0; - - // Abandons this outstanding result. - virtual void abandon() = 0; - - // Create an error instance representing an abandoned response. - static Error createAbandonedResponseError() { - return make_error<ResponseAbandoned>(); - } -}; - -// ResponseHandler subclass for RPC functions with non-void returns. -template <typename ChannelT, typename FuncRetT, typename HandlerT> -class ResponseHandlerImpl : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - using UnwrappedArgType = typename ResponseHandlerArg< - typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; - UnwrappedArgType Result; - if (auto Err = - SerializationTraits<ChannelT, FuncRetT, - UnwrappedArgType>::deserialize(C, Result)) - return Err; - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -// ResponseHandler subclass for RPC functions with void returns. -template <typename ChannelT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, void, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result (no actual value, just a notification that the function - // has completed on the remote end) by calling the user-defined handler with - // Error::success(). - Error handleResponse(ChannelT &C) override { - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(Error::success()); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -template <typename ChannelT, typename FuncRetT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - using HandlerArgType = typename ResponseHandlerArg< - typename HandlerTraits<HandlerT>::Type>::ArgType; - HandlerArgType Result((typename HandlerArgType::value_type())); - - if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>, - HandlerArgType>::deserialize(C, Result)) - return Err; - if (auto Err = C.endReceiveMessage()) - return Err; - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -template <typename ChannelT, typename HandlerT> -class ResponseHandlerImpl<ChannelT, Error, HandlerT> - : public ResponseHandler<ChannelT> { -public: - ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} - - // Handle the result by deserializing it from the channel then passing it - // to the user defined handler. - Error handleResponse(ChannelT &C) override { - Error Result = Error::success(); - if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( - C, Result)) { - consumeError(std::move(Result)); - return Err; - } - if (auto Err = C.endReceiveMessage()) { - consumeError(std::move(Result)); - return Err; - } - return Handler(std::move(Result)); - } - - // Abandon this response by calling the handler with an 'abandoned response' - // error. - void abandon() override { - if (auto Err = Handler(this->createAbandonedResponseError())) { - // Handlers should not fail when passed an abandoned response error. - report_fatal_error(std::move(Err)); - } - } - -private: - HandlerT Handler; -}; - -// Create a ResponseHandler from a given user handler. -template <typename ChannelT, typename FuncRetT, typename HandlerT> -std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { - return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( - std::move(H)); -} - -// Helper for wrapping member functions up as functors. This is useful for -// installing methods as result handlers. -template <typename ClassT, typename RetT, typename... ArgTs> -class MemberFnWrapper { -public: - using MethodT = RetT (ClassT::*)(ArgTs...); - MemberFnWrapper(ClassT &Instance, MethodT Method) - : Instance(Instance), Method(Method) {} - RetT operator()(ArgTs &&...Args) { - return (Instance.*Method)(std::move(Args)...); - } - -private: - ClassT &Instance; - MethodT Method; -}; - -// Helper that provides a Functor for deserializing arguments. -template <typename... ArgTs> class ReadArgs { -public: - Error operator()() { return Error::success(); } -}; - -template <typename ArgT, typename... ArgTs> -class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { -public: - ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} - - Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { - this->Arg = std::move(ArgVal); - return ReadArgs<ArgTs...>::operator()(ArgVals...); - } - -private: - ArgT &Arg; -}; - -// Manage sequence numbers. -template <typename SequenceNumberT> class SequenceNumberManager { -public: - // Reset, making all sequence numbers available. - void reset() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - NextSequenceNumber = 0; - FreeSequenceNumbers.clear(); - } - - // Get the next available sequence number. Will re-use numbers that have - // been released. - SequenceNumberT getSequenceNumber() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - if (FreeSequenceNumbers.empty()) - return NextSequenceNumber++; - auto SequenceNumber = FreeSequenceNumbers.back(); - FreeSequenceNumbers.pop_back(); - return SequenceNumber; - } - - // Release a sequence number, making it available for re-use. - void releaseSequenceNumber(SequenceNumberT SequenceNumber) { - std::lock_guard<std::mutex> Lock(SeqNoLock); - FreeSequenceNumbers.push_back(SequenceNumber); - } - -private: - std::mutex SeqNoLock; - SequenceNumberT NextSequenceNumber = 0; - std::vector<SequenceNumberT> FreeSequenceNumbers; -}; - -// Checks that predicate P holds for each corresponding pair of type arguments -// from T1 and T2 tuple. -template <template <class, class> class P, typename T1Tuple, typename T2Tuple> -class RPCArgTypeCheckHelper; - -template <template <class, class> class P> -class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { -public: - static const bool value = true; -}; - -template <template <class, class> class P, typename T, typename... Ts, - typename U, typename... Us> -class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { -public: - static const bool value = - P<T, U>::value && - RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; -}; - -template <template <class, class> class P, typename T1Sig, typename T2Sig> -class RPCArgTypeCheck { -public: - using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type; - using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type; - - static_assert(std::tuple_size<T1Tuple>::value >= - std::tuple_size<T2Tuple>::value, - "Too many arguments to RPC call"); - static_assert(std::tuple_size<T1Tuple>::value <= - std::tuple_size<T2Tuple>::value, - "Too few arguments to RPC call"); - - static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; -}; - -template <typename ChannelT, typename WireT, typename ConcreteT> -class CanSerialize { -private: - using S = SerializationTraits<ChannelT, WireT, ConcreteT>; - - template <typename T> - static std::true_type check( - std::enable_if_t<std::is_same<decltype(T::serialize( - std::declval<ChannelT &>(), - std::declval<const ConcreteT &>())), - Error>::value, - void *>); - - template <typename> static std::false_type check(...); - -public: - static const bool value = decltype(check<S>(0))::value; -}; - -template <typename ChannelT, typename WireT, typename ConcreteT> -class CanDeserialize { -private: - using S = SerializationTraits<ChannelT, WireT, ConcreteT>; - - template <typename T> - static std::true_type - check(std::enable_if_t< - std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), - std::declval<ConcreteT &>())), - Error>::value, - void *>); - - template <typename> static std::false_type check(...); - -public: - static const bool value = decltype(check<S>(0))::value; -}; - -/// Contains primitive utilities for defining, calling and handling calls to -/// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure -/// identifier type that must be serializable on ChannelT, and SequenceNumberT -/// is an integral type that will be used to number in-flight function calls. -/// -/// These utilities support the construction of very primitive RPC utilities. -/// Their intent is to ensure correct serialization and deserialization of -/// procedure arguments, and to keep the client and server's view of the API in -/// sync. -template <typename ImplT, typename ChannelT, typename FunctionIdT, - typename SequenceNumberT> -class RPCEndpointBase { -protected: - class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> { - public: - static const char *getName() { return "__orc_rpc$invalid"; } - }; - - class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> { - public: - static const char *getName() { return "__orc_rpc$response"; } - }; - - class OrcRPCNegotiate - : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> { - public: - static const char *getName() { return "__orc_rpc$negotiate"; } - }; - - // Helper predicate for testing for the presence of SerializeTraits - // serializers. - template <typename WireT, typename ConcreteT> - class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { - public: - using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; - - static_assert(value, "Missing serializer for argument (Can't serialize the " - "first template type argument of CanSerializeCheck " - "from the second)"); - }; - - // Helper predicate for testing for the presence of SerializeTraits - // deserializers. - template <typename WireT, typename ConcreteT> - class CanDeserializeCheck - : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { - public: - using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; - - static_assert(value, "Missing deserializer for argument (Can't deserialize " - "the second template type argument of " - "CanDeserializeCheck from the first)"); - }; - -public: - /// Construct an RPC instance on a channel. - RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) - : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { - // Hold ResponseId in a special variable, since we expect Response to be - // called relatively frequently, and want to avoid the map lookup. - ResponseId = FnIdAllocator.getResponseId(); - RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; - - // Register the negotiate function id and handler. - auto NegotiateId = FnIdAllocator.getNegotiateId(); - RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; - Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( - [this](const std::string &Name) { return handleNegotiate(Name); }); - } - - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - return getRemoteFunctionId<Func>(true, Retry).takeError(); - } - - /// Append a call Func, does not call send on the channel. - /// The first argument specifies a user-defined handler to be run when the - /// function returns. The handler should take an Expected<Func::ReturnType>, - /// or an Error (if Func::ReturnType is void). The handler will be called - /// with an error if the return value is abandoned due to a channel error. - template <typename Func, typename HandlerT, typename... ArgTs> - Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) { - - static_assert( - detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, - void(ArgTs...)>::value, - ""); - - // Look up the function ID. - FunctionIdT FnId; - if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) - FnId = *FnIdOrErr; - else { - // Negotiation failed. Notify the handler then return the negotiate-failed - // error. - cantFail(Handler(make_error<ResponseAbandoned>())); - return FnIdOrErr.takeError(); - } - - SequenceNumberT SeqNo; // initialized in locked scope below. - { - // Lock the pending responses map and sequence number manager. - std::lock_guard<std::mutex> Lock(ResponsesMutex); - - // Allocate a sequence number. - SeqNo = SequenceNumberMgr.getSequenceNumber(); - assert(!PendingResponses.count(SeqNo) && - "Sequence number already allocated"); - - // Install the user handler. - PendingResponses[SeqNo] = - detail::createResponseHandler<ChannelT, typename Func::ReturnType>( - std::move(Handler)); - } - - // Open the function call message. - if (auto Err = C.startSendMessage(FnId, SeqNo)) { - abandonPendingResponses(); - return Err; - } - - // Serialize the call arguments. - if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( - C, Args...)) { - abandonPendingResponses(); - return Err; - } - - // Close the function call messagee. - if (auto Err = C.endSendMessage()) { - abandonPendingResponses(); - return Err; - } - - return Error::success(); - } - - Error sendAppendedCalls() { return C.send(); }; - - template <typename Func, typename HandlerT, typename... ArgTs> - Error callAsync(HandlerT Handler, const ArgTs &...Args) { - if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) - return Err; - return C.send(); - } - - /// Handle one incoming call. - Error handleOne() { - FunctionIdT FnId; - SequenceNumberT SeqNo; - if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { - abandonPendingResponses(); - return Err; - } - if (FnId == ResponseId) - return handleResponse(SeqNo); - auto I = Handlers.find(FnId); - if (I != Handlers.end()) - return I->second(C, SeqNo); - - // else: No handler found. Report error to client? - return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, - SeqNo); - } - - /// Helper for handling setter procedures - this method returns a functor that - /// sets the variables referred to by Args... to values deserialized from the - /// channel. - /// E.g. - /// - /// typedef Function<0, bool, int> Func1; - /// - /// ... - /// bool B; - /// int I; - /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) - /// /* Handle Args */ ; - /// - template <typename... ArgTs> - static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) { - return detail::ReadArgs<ArgTs...>(Args...); - } - - /// Abandon all outstanding result handlers. - /// - /// This will call all currently registered result handlers to receive an - /// "abandoned" error as their argument. This is used internally by the RPC - /// in error situations, but can also be called directly by clients who are - /// disconnecting from the remote and don't or can't expect responses to their - /// outstanding calls. (Especially for outstanding blocking calls, calling - /// this function may be necessary to avoid dead threads). - void abandonPendingResponses() { - // Lock the pending responses map and sequence number manager. - std::lock_guard<std::mutex> Lock(ResponsesMutex); - - for (auto &KV : PendingResponses) - KV.second->abandon(); - PendingResponses.clear(); - SequenceNumberMgr.reset(); - } - - /// Remove the handler for the given function. - /// A handler must currently be registered for this function. - template <typename Func> void removeHandler() { - auto IdItr = LocalFunctionIds.find(Func::getPrototype()); - assert(IdItr != LocalFunctionIds.end() && - "Function does not have a registered handler"); - auto HandlerItr = Handlers.find(IdItr->second); - assert(HandlerItr != Handlers.end() && - "Function does not have a registered handler"); - Handlers.erase(HandlerItr); - } - - /// Clear all handlers. - void clearHandlers() { Handlers.clear(); } - -protected: - FunctionIdT getInvalidFunctionId() const { - return FnIdAllocator.getInvalidId(); - } - - /// Add the given handler to the handler map and make it available for - /// autonegotiation and execution. - template <typename Func, typename HandlerT> - void addHandlerImpl(HandlerT Handler) { - - static_assert(detail::RPCArgTypeCheck< - CanDeserializeCheck, typename Func::Type, - typename detail::HandlerTraits<HandlerT>::Type>::value, - ""); - - FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); - LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandlerImpl(HandlerT Handler) { - - static_assert( - detail::RPCArgTypeCheck< - CanDeserializeCheck, typename Func::Type, - typename detail::AsyncHandlerTraits< - typename detail::HandlerTraits<HandlerT>::Type>::Type>::value, - ""); - - FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); - LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); - } - - Error handleResponse(SequenceNumberT SeqNo) { - using Handler = typename decltype(PendingResponses)::mapped_type; - Handler PRHandler; - - { - // Lock the pending responses map and sequence number manager. - std::unique_lock<std::mutex> Lock(ResponsesMutex); - auto I = PendingResponses.find(SeqNo); - - if (I != PendingResponses.end()) { - PRHandler = std::move(I->second); - PendingResponses.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); - } else { - // Unlock the pending results map to prevent recursive lock. - Lock.unlock(); - abandonPendingResponses(); - return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>( - SeqNo); - } - } - - assert(PRHandler && - "If we didn't find a response handler we should have bailed out"); - - if (auto Err = PRHandler->handleResponse(C)) { - abandonPendingResponses(); - return Err; - } - - return Error::success(); - } - - FunctionIdT handleNegotiate(const std::string &Name) { - auto I = LocalFunctionIds.find(Name); - if (I == LocalFunctionIds.end()) - return getInvalidFunctionId(); - return I->second; - } - - // Find the remote FunctionId for the given function. - template <typename Func> - Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, - bool NegotiateIfInvalid) { - bool DoNegotiate; - - // Check if we already have a function id... - auto I = RemoteFunctionIds.find(Func::getPrototype()); - if (I != RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != getInvalidFunctionId()) - return I->second; - DoNegotiate = NegotiateIfInvalid; - } else - DoNegotiate = NegotiateIfNotInMap; - - // We don't have a function id for Func yet, but we're allowed to try to - // negotiate one. - if (DoNegotiate) { - auto &Impl = static_cast<ImplT &>(*this); - if (auto RemoteIdOrErr = - Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { - RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == getInvalidFunctionId()) - return make_error<CouldNotNegotiate>(Func::getPrototype()); - return *RemoteIdOrErr; - } else - return RemoteIdOrErr.takeError(); - } - - // No key was available in the map and we weren't allowed to try to - // negotiate one, so return an unknown function error. - return make_error<CouldNotNegotiate>(Func::getPrototype()); - } - - using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; - - // Wrap the given user handler in the necessary argument-deserialization code, - // result-serialization code, and call to the launch policy (if present). - template <typename Func, typename HandlerT> - WrappedHandlerFn wrapHandler(HandlerT Handler) { - return [this, Handler](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { - // Start by deserializing the arguments. - using ArgsTuple = typename detail::RPCFunctionArgsTuple< - typename detail::HandlerTraits<HandlerT>::Type>::Type; - auto Args = std::make_shared<ArgsTuple>(); - - if (auto Err = - detail::HandlerTraits<typename Func::Type>::deserializeArgs( - Channel, *Args)) - return Err; - - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)Args; - - // End receieve message, unlocking the channel for reading. - if (auto Err = Channel.endReceiveMessage()) - return Err; - - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, - HTraits::unpackAndRun(Handler, *Args)); - }; - } - - // Wrap the given user handler in the necessary argument-deserialization code, - // result-serialization code, and call to the launch policy (if present). - template <typename Func, typename HandlerT> - WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { - return [this, Handler](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { - // Start by deserializing the arguments. - using AHTraits = detail::AsyncHandlerTraits< - typename detail::HandlerTraits<HandlerT>::Type>; - using ArgsTuple = - typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type; - auto Args = std::make_shared<ArgsTuple>(); - - if (auto Err = - detail::HandlerTraits<typename Func::Type>::deserializeArgs( - Channel, *Args)) - return Err; - - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)Args; - - // End receieve message, unlocking the channel for reading. - if (auto Err = Channel.endReceiveMessage()) - return Err; - - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - auto Responder = [this, - SeqNo](typename AHTraits::ResultType RetVal) -> Error { - return detail::respond<FuncReturn>(C, ResponseId, SeqNo, - std::move(RetVal)); - }; - - return HTraits::unpackAndRunAsync(Handler, Responder, *Args); - }; - } - - ChannelT &C; - - bool LazyAutoNegotiation; - - RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; - - FunctionIdT ResponseId; - std::map<std::string, FunctionIdT> LocalFunctionIds; - std::map<const char *, FunctionIdT> RemoteFunctionIds; - - std::map<FunctionIdT, WrappedHandlerFn> Handlers; - - std::mutex ResponsesMutex; - detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; - std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> - PendingResponses; -}; - -} // end namespace detail - -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint32_t> -class MultiThreadedRPCEndpoint - : public detail::RPCEndpointBase< - MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT> { -private: - using BaseClass = detail::RPCEndpointBase< - MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; - -public: - MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) - : BaseClass(C, LazyAutoNegotiation) {} - - /// Add a handler for the given RPC function. - /// This installs the given handler functor for the given RPCFunction, and - /// makes the RPC function available for negotiation/calling from the remote. - template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandler(HandlerT Handler) { - return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addAsyncHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - /// Return type for non-blocking call primitives. - template <typename Func> - using NonBlockingCallResult = typename detail::ResultTraits< - typename Func::ReturnType>::ReturnFutureType; - - /// Call Func on Channel C. Does not block, does not call send. Returns a pair - /// of a future result and the sequence number assigned to the result. - /// - /// This utility function is primarily used for single-threaded mode support, - /// where the sequence number can be used to wait for the corresponding - /// result. In multi-threaded mode the appendCallNB method, which does not - /// return the sequence numeber, should be preferred. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) { - using RTraits = detail::ResultTraits<typename Func::ReturnType>; - using ErrorReturn = typename RTraits::ErrorReturnType; - using ErrorReturnPromise = typename RTraits::ReturnPromiseType; - - ErrorReturnPromise Promise; - auto FutureResult = Promise.get_future(); - - if (auto Err = this->template appendCallAsync<Func>( - [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { - Promise.set_value(std::move(RetOrErr)); - return Error::success(); - }, - Args...)) { - RTraits::consumeAbandoned(FutureResult.get()); - return std::move(Err); - } - return std::move(FutureResult); - } - - /// The same as appendCallNBWithSeq, except that it calls C.send() to - /// flush the channel after serializing the call. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) { - auto Result = appendCallNB<Func>(Args...); - if (!Result) - return Result; - if (auto Err = this->C.send()) { - this->abandonPendingResponses(); - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result->get())); - return std::move(Err); - } - return Result; - } - - /// Call Func on Channel C. Blocks waiting for a result. Returns an Error - /// for void functions or an Expected<T> for functions returning a T. - /// - /// This function is for use in threaded code where another thread is - /// handling responses and incoming calls. - template <typename Func, typename... ArgTs, - typename AltRetT = typename Func::ReturnType> - typename detail::ResultTraits<AltRetT>::ErrorReturnType - callB(const ArgTs &...Args) { - if (auto FutureResOrErr = callNB<Func>(Args...)) - return FutureResOrErr->get(); - else - return FutureResOrErr.takeError(); - } - - /// Handle incoming RPC calls. - Error handlerLoop() { - while (true) - if (auto Err = this->handleOne()) - return Err; - return Error::success(); - } -}; - -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint32_t> -class SingleThreadedRPCEndpoint - : public detail::RPCEndpointBase< - SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT> { -private: - using BaseClass = detail::RPCEndpointBase< - SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; - -public: - SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) - : BaseClass(C, LazyAutoNegotiation) {} - - template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler)); - } - - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename HandlerT> - void addAsyncHandler(HandlerT Handler) { - return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); - } - - /// Add a class-method as a handler. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { - addAsyncHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); - } - - template <typename Func, typename... ArgTs, - typename AltRetT = typename Func::ReturnType> - typename detail::ResultTraits<AltRetT>::ErrorReturnType - callB(const ArgTs &...Args) { - bool ReceivedResponse = false; - using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; - auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); - - // We have to 'Check' result (which we know is in a success state at this - // point) so that it can be overwritten in the async handler. - (void)!!Result; - - if (auto Err = this->template appendCallAsync<Func>( - [&](ResultType R) { - Result = std::move(R); - ReceivedResponse = true; - return Error::success(); - }, - Args...)) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - - if (auto Err = this->C.send()) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - - while (!ReceivedResponse) { - if (auto Err = this->handleOne()) { - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(Result)); - return std::move(Err); - } - } - - return Result; - } -}; - -/// Asynchronous dispatch for a function on an RPC endpoint. -template <typename RPCClass, typename Func> class RPCAsyncDispatch { -public: - RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} - - template <typename HandlerT, typename... ArgTs> - Error operator()(HandlerT Handler, const ArgTs &...Args) const { - return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); - } - -private: - RPCClass &Endpoint; -}; - -/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. -template <typename Func, typename RPCEndpointT> -RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { - return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); -} - -/// Allows a set of asynchrounous calls to be dispatched, and then -/// waited on as a group. -class ParallelCallGroup { -public: - ParallelCallGroup() = default; - ParallelCallGroup(const ParallelCallGroup &) = delete; - ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; - - /// Make as asynchronous call. - template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> - Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, - const ArgTs &...Args) { - // Increment the count of outstanding calls. This has to happen before - // we invoke the call, as the handler may (depending on scheduling) - // be run immediately on another thread, and we don't want the decrement - // in the wrapped handler below to run before the increment. - { - std::unique_lock<std::mutex> Lock(M); - ++NumOutstandingCalls; - } - - // Wrap the user handler in a lambda that will decrement the - // outstanding calls count, then poke the condition variable. - using ArgType = typename detail::ResponseHandlerArg< - typename detail::HandlerTraits<HandlerT>::Type>::ArgType; - auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { - auto Err = Handler(std::move(Arg)); - std::unique_lock<std::mutex> Lock(M); - --NumOutstandingCalls; - CV.notify_all(); - return Err; - }; - - return AsyncDispatch(std::move(WrappedHandler), Args...); - } - - /// Blocks until all calls have been completed and their return value - /// handlers run. - void wait() { - std::unique_lock<std::mutex> Lock(M); - while (NumOutstandingCalls > 0) - CV.wait(Lock); - } - -private: - std::mutex M; - std::condition_variable CV; - uint32_t NumOutstandingCalls = 0; -}; - -/// Convenience class for grouping RPCFunctions into APIs that can be -/// negotiated as a block. -/// -template <typename... Funcs> class APICalls { -public: - /// Test whether this API contains Function F. - template <typename F> class Contains { - public: - static const bool value = false; - }; - - /// Negotiate all functions in this API. - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - return Error::success(); - } -}; - -template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> { -public: - template <typename F> class Contains { - public: - static const bool value = std::is_same<F, Func>::value | - APICalls<Funcs...>::template Contains<F>::value; - }; - - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - if (auto Err = R.template negotiateFunction<Func>()) - return Err; - return APICalls<Funcs...>::negotiate(R); - } -}; - -template <typename... InnerFuncs, typename... Funcs> -class APICalls<APICalls<InnerFuncs...>, Funcs...> { -public: - template <typename F> class Contains { - public: - static const bool value = - APICalls<InnerFuncs...>::template Contains<F>::value | - APICalls<Funcs...>::template Contains<F>::value; - }; - - template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { - if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) - return Err; - return APICalls<Funcs...>::negotiate(R); - } -}; - -} // end namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H + +#include <map> +#include <thread> +#include <vector> + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include <future> + +namespace llvm { +namespace orc { +namespace shared { + +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall + : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId + << "' with " + "sequence number " + << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse + : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, + RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } + +private: + SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } + +private: + std::string Signature; +}; + +template <typename DerivedFunc, typename FnT> class RPCFunction; + +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template <typename DerivedFunc, typename RetT, typename... ArgTs> +class RPCFunction<DerivedFunc, RetT(ArgTs...)> { +public: + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << SerializationTypeName<RetT>::getName() << " " + << DerivedFunc::getName() << "(" + << SerializationTypeNameSequence<ArgTs...>() << ")"; + return Name; + }(); + return Name.data(); + } +}; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template <typename Func> T allocate(): +/// Allocate a unique id for function Func. +template <typename T, typename = void> class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template <typename T> +class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> { +public: + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template <typename Func> T allocate() { return NextId++; } + +private: + T NextId = 3; +}; + +namespace detail { + +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class RPCFunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class RPCFunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>; +}; + +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template <typename RetT> class ResultTraits { +public: + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected<RetT>; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPExpected<RetT>>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> class ResultTraits<void> { +public: + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPError>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPError>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; + +// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> class ResultTraits<Error> : public ResultTraits<void> {}; + +// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected<T> (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template <typename RetT> +class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; + +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> class SupportsErrorReturn<Error> { +public: + static const bool value = true; +}; + +template <typename T> class SupportsErrorReturn<Expected<T>> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> class RespondHelper<true> { +public: + // Send Expected<T>. + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>:: + serialize(C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + if (auto Err = C.endSendMessage()) + return Err; + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA<RPCFatalError>()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> class RespondHelper<false> { +public: + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( + C, *ResultOrErr)) + return Err; + + // End the response message. + if (auto Err = C.endSendMessage()) + return Err; + + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } +}; + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + template sendResult<WireRetT>(C, ResponseId, SeqNo, + std::move(ResultOrErr)); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template <typename WireRetT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult( + C, ResponseId, SeqNo, std::move(Err)); +} + +// Converts a given type to the equivalent error return type. +template <typename T> class WrappedHandlerReturn { +public: + using Type = Expected<T>; +}; + +template <typename T> class WrappedHandlerReturn<Expected<T>> { +public: + using Type = Expected<T>; +}; + +template <> class WrappedHandlerReturn<void> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<Error> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<ErrorSuccess> { +public: + using Type = Error; +}; + +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, + ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> + : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>, + ArgTs...)> {}; + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template <typename HandlerT> +class HandlerTraits + : public HandlerTraits< + decltype(&std::remove_reference<HandlerT>::type::operator())> {}; + +// Traits for handlers with a given function type. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(ArgTs...)> { +public: + // Function type of the handler. + using Type = RetT(ArgTs...); + + // Return type of the handler. + using ReturnType = RetT; + + // Call the given handler with the given arguments. + template <typename HandlerT, typename... TArgTs> + static typename WrappedHandlerReturn<RetT>::Type + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { + return unpackAndRunHelper(Handler, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT> + static std::enable_if_t< + std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error> + run(HandlerT &Handler, ArgTs &&...Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template <typename HandlerT, typename... TArgTs> + static std::enable_if_t< + !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + typename HandlerTraits<HandlerT>::ReturnType> + run(HandlerT &Handler, TArgTs... Args) { + return Handler(std::move(Args)...); + } + + // Serialize arguments to the channel. + template <typename ChannelT, typename... CArgTs> + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + } + + // Deserialize arguments from the channel. + template <typename ChannelT, typename... CArgTs> + static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { + return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); + } + +private: + template <typename ChannelT, typename... CArgTs, size_t... Indexes> + static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, + std::index_sequence<Indexes...> _) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize( + C, std::get<Indexes>(Args)...); + } + + template <typename HandlerT, typename ArgTuple, size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, + std::index_sequence<Indexes...>) { + return run(Handler, std::move(std::get<Indexes>(Args))...); + } + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, std::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } +}; + +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> { +}; + +// Handler traits for class methods (especially call operators for lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Handler traits for const class methods (especially call operators for +// lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...) const> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Utility to peel the Expected wrapper off a response handler error type. +template <typename HandlerT> class ResponseHandlerArg; + +template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <typename ArgT> +class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg<Error(Error)> { +public: + using ArgType = Error; +}; + +template <> class ResponseHandlerArg<ErrorSuccess(Error)> { +public: + using ArgType = Error; +}; + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template <typename ChannelT> class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error<ResponseAbandoned>(); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; + UnwrappedArgType Result; + if (auto Err = + SerializationTraits<ChannelT, FuncRetT, + UnwrappedArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// ResponseHandler subclass for RPC functions with void returns. +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, void, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( + C, Result)) { + consumeError(std::move(Result)); + return Err; + } + if (auto Err = C.endReceiveMessage()) { + consumeError(std::move(Result)); + return Err; + } + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// Create a ResponseHandler from a given user handler. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { + return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( + std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template <typename ClassT, typename RetT, typename... ArgTs> +class MemberFnWrapper { +public: + using MethodT = RetT (ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&...Args) { + return (Instance.*Method)(std::move(Args)...); + } + +private: + ClassT &Instance; + MethodT Method; +}; + +// Helper that provides a Functor for deserializing arguments. +template <typename... ArgTs> class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; + +template <typename ArgT, typename... ArgTs> +class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { +public: + ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} + + Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs<ArgTs...>::operator()(ArgVals...); + } + +private: + ArgT &Arg; +}; + +// Manage sequence numbers. +template <typename SequenceNumberT> class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; +}; + +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template <template <class, class> class P, typename T1Tuple, typename T2Tuple> +class RPCArgTypeCheckHelper; + +template <template <class, class> class P> +class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { +public: + static const bool value = true; +}; + +template <template <class, class> class P, typename T, typename... Ts, + typename U, typename... Us> +class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { +public: + static const bool value = + P<T, U>::value && + RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; +}; + +template <template <class, class> class P, typename T1Sig, typename T2Sig> +class RPCArgTypeCheck { +public: + using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type; + using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type; + + static_assert(std::tuple_size<T1Tuple>::value >= + std::tuple_size<T2Tuple>::value, + "Too many arguments to RPC call"); + static_assert(std::tuple_size<T1Tuple>::value <= + std::tuple_size<T2Tuple>::value, + "Too few arguments to RPC call"); + + static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanSerialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type check( + std::enable_if_t<std::is_same<decltype(T::serialize( + std::declval<ChannelT &>(), + std::declval<const ConcreteT &>())), + Error>::value, + void *>); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanDeserialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(std::enable_if_t< + std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), + std::declval<ConcreteT &>())), + Error>::value, + void *>); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +/// Contains primitive utilities for defining, calling and handling calls to +/// remote procedures. ChannelT is a bidirectional stream conforming to the +/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure +/// identifier type that must be serializable on ChannelT, and SequenceNumberT +/// is an integral type that will be used to number in-flight function calls. +/// +/// These utilities support the construction of very primitive RPC utilities. +/// Their intent is to ensure correct serialization and deserialization of +/// procedure arguments, and to keep the client and server's view of the API in +/// sync. +template <typename ImplT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +class RPCEndpointBase { +protected: + class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> { + public: + static const char *getName() { return "__orc_rpc$invalid"; } + }; + + class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> { + public: + static const char *getName() { return "__orc_rpc$response"; } + }; + + class OrcRPCNegotiate + : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> { + public: + static const char *getName() { return "__orc_rpc$negotiate"; } + }; + + // Helper predicate for testing for the presence of SerializeTraits + // serializers. + template <typename WireT, typename ConcreteT> + class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing serializer for argument (Can't serialize the " + "first template type argument of CanSerializeCheck " + "from the second)"); + }; + + // Helper predicate for testing for the presence of SerializeTraits + // deserializers. + template <typename WireT, typename ConcreteT> + class CanDeserializeCheck + : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing deserializer for argument (Can't deserialize " + "the second template type argument of " + "CanDeserializeCheck from the first)"); + }; + +public: + /// Construct an RPC instance on a channel. + RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) + : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { + // Hold ResponseId in a special variable, since we expect Response to be + // called relatively frequently, and want to avoid the map lookup. + ResponseId = FnIdAllocator.getResponseId(); + RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; + + // Register the negotiate function id and handler. + auto NegotiateId = FnIdAllocator.getNegotiateId(); + RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; + Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( + [this](const std::string &Name) { return handleNegotiate(Name); }); + } + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> Error negotiateFunction(bool Retry = false) { + return getRemoteFunctionId<Func>(true, Retry).takeError(); + } + + /// Append a call Func, does not call send on the channel. + /// The first argument specifies a user-defined handler to be run when the + /// function returns. The handler should take an Expected<Func::ReturnType>, + /// or an Error (if Func::ReturnType is void). The handler will be called + /// with an error if the return value is abandoned due to a channel error. + template <typename Func, typename HandlerT, typename... ArgTs> + Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) { + + static_assert( + detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, + void(ArgTs...)>::value, + ""); + + // Look up the function ID. + FunctionIdT FnId; + if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) + FnId = *FnIdOrErr; + else { + // Negotiation failed. Notify the handler then return the negotiate-failed + // error. + cantFail(Handler(make_error<ResponseAbandoned>())); + return FnIdOrErr.takeError(); + } + + SequenceNumberT SeqNo; // initialized in locked scope below. + { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + // Allocate a sequence number. + SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); + + // Install the user handler. + PendingResponses[SeqNo] = + detail::createResponseHandler<ChannelT, typename Func::ReturnType>( + std::move(Handler)); + } + + // Open the function call message. + if (auto Err = C.startSendMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + + // Serialize the call arguments. + if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( + C, Args...)) { + abandonPendingResponses(); + return Err; + } + + // Close the function call messagee. + if (auto Err = C.endSendMessage()) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + Error sendAppendedCalls() { return C.send(); }; + + template <typename Func, typename HandlerT, typename... ArgTs> + Error callAsync(HandlerT Handler, const ArgTs &...Args) { + if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) + return Err; + return C.send(); + } + + /// Handle one incoming call. + Error handleOne() { + FunctionIdT FnId; + SequenceNumberT SeqNo; + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + if (FnId == ResponseId) + return handleResponse(SeqNo); + auto I = Handlers.find(FnId); + if (I != Handlers.end()) + return I->second(C, SeqNo); + + // else: No handler found. Report error to client? + return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, + SeqNo); + } + + /// Helper for handling setter procedures - this method returns a functor that + /// sets the variables referred to by Args... to values deserialized from the + /// channel. + /// E.g. + /// + /// typedef Function<0, bool, int> Func1; + /// + /// ... + /// bool B; + /// int I; + /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) + /// /* Handle Args */ ; + /// + template <typename... ArgTs> + static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) { + return detail::ReadArgs<ArgTs...>(Args...); + } + + /// Abandon all outstanding result handlers. + /// + /// This will call all currently registered result handlers to receive an + /// "abandoned" error as their argument. This is used internally by the RPC + /// in error situations, but can also be called directly by clients who are + /// disconnecting from the remote and don't or can't expect responses to their + /// outstanding calls. (Especially for outstanding blocking calls, calling + /// this function may be necessary to avoid dead threads). + void abandonPendingResponses() { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } + + /// Remove the handler for the given function. + /// A handler must currently be registered for this function. + template <typename Func> void removeHandler() { + auto IdItr = LocalFunctionIds.find(Func::getPrototype()); + assert(IdItr != LocalFunctionIds.end() && + "Function does not have a registered handler"); + auto HandlerItr = Handlers.find(IdItr->second); + assert(HandlerItr != Handlers.end() && + "Function does not have a registered handler"); + Handlers.erase(HandlerItr); + } + + /// Clear all handlers. + void clearHandlers() { Handlers.clear(); } + +protected: + FunctionIdT getInvalidFunctionId() const { + return FnIdAllocator.getInvalidId(); + } + + /// Add the given handler to the handler map and make it available for + /// autonegotiation and execution. + template <typename Func, typename HandlerT> + void addHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::HandlerTraits<HandlerT>::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert( + detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); + } + + Error handleResponse(SequenceNumberT SeqNo) { + using Handler = typename decltype(PendingResponses)::mapped_type; + Handler PRHandler; + + { + // Lock the pending responses map and sequence number manager. + std::unique_lock<std::mutex> Lock(ResponsesMutex); + auto I = PendingResponses.find(SeqNo); + + if (I != PendingResponses.end()) { + PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + } else { + // Unlock the pending results map to prevent recursive lock. + Lock.unlock(); + abandonPendingResponses(); + return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>( + SeqNo); + } + } + + assert(PRHandler && + "If we didn't find a response handler we should have bailed out"); + + if (auto Err = PRHandler->handleResponse(C)) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + FunctionIdT handleNegotiate(const std::string &Name) { + auto I = LocalFunctionIds.find(Name); + if (I == LocalFunctionIds.end()) + return getInvalidFunctionId(); + return I->second; + } + + // Find the remote FunctionId for the given function. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, + bool NegotiateIfInvalid) { + bool DoNegotiate; + + // Check if we already have a function id... + auto I = RemoteFunctionIds.find(Func::getPrototype()); + if (I != RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != getInvalidFunctionId()) + return I->second; + DoNegotiate = NegotiateIfInvalid; + } else + DoNegotiate = NegotiateIfNotInMap; + + // We don't have a function id for Func yet, but we're allowed to try to + // negotiate one. + if (DoNegotiate) { + auto &Impl = static_cast<ImplT &>(*this); + if (auto RemoteIdOrErr = + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == getInvalidFunctionId()) + return make_error<CouldNotNegotiate>(Func::getPrototype()); + return *RemoteIdOrErr; + } else + return RemoteIdOrErr.takeError(); + } + + // No key was available in the map and we weren't allowed to try to + // negotiate one, so return an unknown function error. + return make_error<CouldNotNegotiate>(Func::getPrototype()); + } + + using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using ArgsTuple = typename detail::RPCFunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = [this, + SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); + }; + } + + ChannelT &C; + + bool LazyAutoNegotiation; + + RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; + + FunctionIdT ResponseId; + std::map<std::string, FunctionIdT> LocalFunctionIds; + std::map<const char *, FunctionIdT> RemoteFunctionIds; + + std::map<FunctionIdT, WrappedHandlerFn> Handlers; + + std::mutex ResponsesMutex; + detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> + PendingResponses; +}; + +} // end namespace detail + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class MultiThreadedRPCEndpoint + : public detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + /// Add a handler for the given RPC function. + /// This installs the given handler functor for the given RPCFunction, and + /// makes the RPC function available for negotiation/calling from the remote. + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + /// Return type for non-blocking call primitives. + template <typename Func> + using NonBlockingCallResult = typename detail::ResultTraits< + typename Func::ReturnType>::ReturnFutureType; + + /// Call Func on Channel C. Does not block, does not call send. Returns a pair + /// of a future result and the sequence number assigned to the result. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallNB method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) { + using RTraits = detail::ResultTraits<typename Func::ReturnType>; + using ErrorReturn = typename RTraits::ErrorReturnType; + using ErrorReturnPromise = typename RTraits::ReturnPromiseType; + + ErrorReturnPromise Promise; + auto FutureResult = Promise.get_future(); + + if (auto Err = this->template appendCallAsync<Func>( + [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { + Promise.set_value(std::move(RetOrErr)); + return Error::success(); + }, + Args...)) { + RTraits::consumeAbandoned(FutureResult.get()); + return std::move(Err); + } + return std::move(FutureResult); + } + + /// The same as appendCallNBWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) { + auto Result = appendCallNB<Func>(Args...); + if (!Result) + return Result; + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result->get())); + return std::move(Err); + } + return Result; + } + + /// Call Func on Channel C. Blocks waiting for a result. Returns an Error + /// for void functions or an Expected<T> for functions returning a T. + /// + /// This function is for use in threaded code where another thread is + /// handling responses and incoming calls. + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &...Args) { + if (auto FutureResOrErr = callNB<Func>(Args...)) + return FutureResOrErr->get(); + else + return FutureResOrErr.takeError(); + } + + /// Handle incoming RPC calls. + Error handlerLoop() { + while (true) + if (auto Err = this->handleOne()) + return Err; + return Error::success(); + } +}; + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class SingleThreadedRPCEndpoint + : public detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &...Args) { + bool ReceivedResponse = false; + using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; + auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); + + // We have to 'Check' result (which we know is in a success state at this + // point) so that it can be overwritten in the async handler. + (void)!!Result; + + if (auto Err = this->template appendCallAsync<Func>( + [&](ResultType R) { + Result = std::move(R); + ReceivedResponse = true; + return Error::success(); + }, + Args...)) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + if (auto Err = this->C.send()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + while (!ReceivedResponse) { + if (auto Err = this->handleOne()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + } + + return Result; + } +}; + +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> class RPCAsyncDispatch { +public: + RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + + template <typename HandlerT, typename... ArgTs> + Error operator()(HandlerT Handler, const ArgTs &...Args) const { + return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); + } + +private: + RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { + return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} + +/// Allows a set of asynchrounous calls to be dispatched, and then +/// waited on as a group. +class ParallelCallGroup { +public: + ParallelCallGroup() = default; + ParallelCallGroup(const ParallelCallGroup &) = delete; + ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; + + /// Make as asynchronous call. + template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> + Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, + const ArgTs &...Args) { + // Increment the count of outstanding calls. This has to happen before + // we invoke the call, as the handler may (depending on scheduling) + // be run immediately on another thread, and we don't want the decrement + // in the wrapped handler below to run before the increment. + { + std::unique_lock<std::mutex> Lock(M); + ++NumOutstandingCalls; + } + + // Wrap the user handler in a lambda that will decrement the + // outstanding calls count, then poke the condition variable. + using ArgType = typename detail::ResponseHandlerArg< + typename detail::HandlerTraits<HandlerT>::Type>::ArgType; + auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { + auto Err = Handler(std::move(Arg)); + std::unique_lock<std::mutex> Lock(M); + --NumOutstandingCalls; + CV.notify_all(); + return Err; + }; + + return AsyncDispatch(std::move(WrappedHandler), Args...); + } + + /// Blocks until all calls have been completed and their return value + /// handlers run. + void wait() { + std::unique_lock<std::mutex> Lock(M); + while (NumOutstandingCalls > 0) + CV.wait(Lock); + } + +private: + std::mutex M; + std::condition_variable CV; + uint32_t NumOutstandingCalls = 0; +}; + +/// Convenience class for grouping RPCFunctions into APIs that can be +/// negotiated as a block. +/// +template <typename... Funcs> class APICalls { +public: + /// Test whether this API contains Function F. + template <typename F> class Contains { + public: + static const bool value = false; + }; + + /// Negotiate all functions in this API. + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + return Error::success(); + } +}; + +template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> { +public: + template <typename F> class Contains { + public: + static const bool value = std::is_same<F, Func>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + if (auto Err = R.template negotiateFunction<Func>()) + return Err; + return APICalls<Funcs...>::negotiate(R); + } +}; + +template <typename... InnerFuncs, typename... Funcs> +class APICalls<APICalls<InnerFuncs...>, Funcs...> { +public: + template <typename F> class Contains { + public: + static const bool value = + APICalls<InnerFuncs...>::template Contains<F>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { + if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) + return Err; + return APICalls<Funcs...>::negotiate(R); + } +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h index 4f6175af33..94bb6c7739 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h @@ -1,194 +1,194 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- RawByteChannel.h -----------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H - -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" -#include "llvm/Support/Endian.h" -#include "llvm/Support/Error.h" -#include <cstdint> -#include <mutex> -#include <string> -#include <type_traits> - -namespace llvm { -namespace orc { -namespace shared { - -/// Interface for byte-streams to be used with ORC Serialization. -class RawByteChannel { -public: - virtual ~RawByteChannel() = default; - - /// Read Size bytes from the stream into *Dst. - virtual Error readBytes(char *Dst, unsigned Size) = 0; - - /// Read size bytes from *Src and append them to the stream. - virtual Error appendBytes(const char *Src, unsigned Size) = 0; - - /// Flush the stream if possible. - virtual Error send() = 0; - - /// Notify the channel that we're starting a message send. - /// Locks the channel for writing. - template <typename FunctionIdT, typename SequenceIdT> - Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { - writeLock.lock(); - if (auto Err = serializeSeq(*this, FnId, SeqNo)) { - writeLock.unlock(); - return Err; - } - return Error::success(); - } - - /// Notify the channel that we're ending a message send. - /// Unlocks the channel for writing. - Error endSendMessage() { - writeLock.unlock(); - return Error::success(); - } - - /// Notify the channel that we're starting a message receive. - /// Locks the channel for reading. - template <typename FunctionIdT, typename SequenceNumberT> - Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { - readLock.lock(); - if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { - readLock.unlock(); - return Err; - } - return Error::success(); - } - - /// Notify the channel that we're ending a message receive. - /// Unlocks the channel for reading. - Error endReceiveMessage() { - readLock.unlock(); - return Error::success(); - } - - /// Get the lock for stream reading. - std::mutex &getReadLock() { return readLock; } - - /// Get the lock for stream writing. - std::mutex &getWriteLock() { return writeLock; } - -private: - std::mutex readLock, writeLock; -}; - -template <typename ChannelT, typename T> -class SerializationTraits< - ChannelT, T, T, - std::enable_if_t< - std::is_base_of<RawByteChannel, ChannelT>::value && - (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || - std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || - std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || - std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || - std::is_same<T, char>::value)>> { -public: - static Error serialize(ChannelT &C, T V) { - support::endian::byte_swap<T, support::big>(V); - return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); - }; - - static Error deserialize(ChannelT &C, T &V) { - if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) - return Err; - support::endian::byte_swap<T, support::big>(V); - return Error::success(); - }; -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, bool, bool, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - static Error serialize(ChannelT &C, bool V) { - uint8_t Tmp = V ? 1 : 0; - if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) - return Err; - return Error::success(); - } - - static Error deserialize(ChannelT &C, bool &V) { - uint8_t Tmp = 0; - if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) - return Err; - V = Tmp != 0; - return Error::success(); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, std::string, StringRef, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - /// Serialization channel serialization for std::strings. - static Error serialize(RawByteChannel &C, StringRef S) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) - return Err; - return C.appendBytes((const char *)S.data(), S.size()); - } -}; - -template <typename ChannelT, typename T> -class SerializationTraits< - ChannelT, std::string, T, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && - (std::is_same<T, const char *>::value || - std::is_same<T, char *>::value)>> { -public: - static Error serialize(RawByteChannel &C, const char *S) { - return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, - S); - } -}; - -template <typename ChannelT> -class SerializationTraits< - ChannelT, std::string, std::string, - std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { -public: - /// Serialization channel serialization for std::strings. - static Error serialize(RawByteChannel &C, const std::string &S) { - return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, - S); - } - - /// Serialization channel deserialization for std::strings. - static Error deserialize(RawByteChannel &C, std::string &S) { - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - S.resize(Count); - return C.readBytes(&S[0], Count); - } -}; - -} // end namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- RawByteChannel.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include <cstdint> +#include <mutex> +#include <string> +#include <type_traits> + +namespace llvm { +namespace orc { +namespace shared { + +/// Interface for byte-streams to be used with ORC Serialization. +class RawByteChannel { +public: + virtual ~RawByteChannel() = default; + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template <typename FunctionIdT, typename SequenceIdT> + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + writeLock.lock(); + if (auto Err = serializeSeq(*this, FnId, SeqNo)) { + writeLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template <typename FunctionIdT, typename SequenceNumberT> + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { + readLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, T, T, + std::enable_if_t< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || + std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || + std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || + std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || + std::is_same<T, char>::value)>> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap<T, support::big>(V); + return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; + support::endian::byte_swap<T, support::big>(V); + return Error::success(); + }; +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, bool, bool, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + static Error serialize(ChannelT &C, bool V) { + uint8_t Tmp = V ? 1 : 0; + if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) + return Err; + return Error::success(); + } + + static Error deserialize(ChannelT &C, bool &V) { + uint8_t Tmp = 0; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) + return Err; + V = Tmp != 0; + return Error::success(); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, std::string, StringRef, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + /// Serialization channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, std::string, T, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, const char *>::value || + std::is_same<T, char *>::value)>> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, std::string, std::string, + std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { +public: + /// Serialization channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } + + /// Serialization channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h index fa48a7af43..5f4e2767f0 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/Serialization.h @@ -1,780 +1,780 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===- Serialization.h ------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" -#include "llvm/Support/thread.h" -#include <map> -#include <mutex> -#include <set> -#include <sstream> -#include <string> -#include <vector> - -namespace llvm { -namespace orc { -namespace shared { - -template <typename T> class SerializationTypeName; - -/// TypeNameSequence is a utility for rendering sequences of types to a string -/// by rendering each type, separated by ", ". -template <typename... ArgTs> class SerializationTypeNameSequence {}; - -/// Render an empty TypeNameSequence to an ostream. -template <typename OStream> -OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<> &V) { - return OS; -} - -/// Render a TypeNameSequence of a single type to an ostream. -template <typename OStream, typename ArgT> -OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<ArgT> &V) { - OS << SerializationTypeName<ArgT>::getName(); - return OS; -} - -/// Render a TypeNameSequence of more than one type to an ostream. -template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> -OStream & -operator<<(OStream &OS, - const SerializationTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { - OS << SerializationTypeName<ArgT1>::getName() << ", " - << SerializationTypeNameSequence<ArgT2, ArgTs...>(); - return OS; -} - -template <> class SerializationTypeName<void> { -public: - static const char *getName() { return "void"; } -}; - -template <> class SerializationTypeName<int8_t> { -public: - static const char *getName() { return "int8_t"; } -}; - -template <> class SerializationTypeName<uint8_t> { -public: - static const char *getName() { return "uint8_t"; } -}; - -template <> class SerializationTypeName<int16_t> { -public: - static const char *getName() { return "int16_t"; } -}; - -template <> class SerializationTypeName<uint16_t> { -public: - static const char *getName() { return "uint16_t"; } -}; - -template <> class SerializationTypeName<int32_t> { -public: - static const char *getName() { return "int32_t"; } -}; - -template <> class SerializationTypeName<uint32_t> { -public: - static const char *getName() { return "uint32_t"; } -}; - -template <> class SerializationTypeName<int64_t> { -public: - static const char *getName() { return "int64_t"; } -}; - -template <> class SerializationTypeName<uint64_t> { -public: - static const char *getName() { return "uint64_t"; } -}; - -template <> class SerializationTypeName<bool> { -public: - static const char *getName() { return "bool"; } -}; - -template <> class SerializationTypeName<std::string> { -public: - static const char *getName() { return "std::string"; } -}; - -template <> class SerializationTypeName<Error> { -public: - static const char *getName() { return "Error"; } -}; - -template <typename T> class SerializationTypeName<Expected<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "Expected<" << SerializationTypeNameSequence<T>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T1, typename T2> -class SerializationTypeName<std::pair<T1, T2>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::pair<" << SerializationTypeNameSequence<T1, T2>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename... ArgTs> class SerializationTypeName<std::tuple<ArgTs...>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::tuple<" << SerializationTypeNameSequence<ArgTs...>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<Optional<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "Optional<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<std::vector<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::vector<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename T> class SerializationTypeName<std::set<T>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::set<" << SerializationTypeName<T>::getName() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -template <typename K, typename V> class SerializationTypeName<std::map<K, V>> { -public: - static const char *getName() { - static std::string Name = [] { - std::string Name; - raw_string_ostream(Name) - << "std::map<" << SerializationTypeNameSequence<K, V>() << ">"; - return Name; - }(); - return Name.data(); - } -}; - -/// The SerializationTraits<ChannelT, T> class describes how to serialize and -/// deserialize an instance of type T to/from an abstract channel of type -/// ChannelT. It also provides a representation of the type's name via the -/// getName method. -/// -/// Specializations of this class should provide the following functions: -/// -/// @code{.cpp} -/// -/// static const char* getName(); -/// static Error serialize(ChannelT&, const T&); -/// static Error deserialize(ChannelT&, T&); -/// -/// @endcode -/// -/// The third argument of SerializationTraits is intended to support SFINAE. -/// E.g.: -/// -/// @code{.cpp} -/// -/// class MyVirtualChannel { ... }; -/// -/// template <DerivedChannelT> -/// class SerializationTraits<DerivedChannelT, bool, -/// std::enable_if_t< -/// std::is_base_of<VirtChannel, DerivedChannel>::value -/// >> { -/// public: -/// static const char* getName() { ... }; -/// } -/// -/// @endcode -template <typename ChannelT, typename WireType, - typename ConcreteType = WireType, typename = void> -class SerializationTraits; - -template <typename ChannelT> class SequenceTraits { -public: - static Error emitSeparator(ChannelT &C) { return Error::success(); } - static Error consumeSeparator(ChannelT &C) { return Error::success(); } -}; - -/// Utility class for serializing sequences of values of varying types. -/// Specializations of this class contain 'serialize' and 'deserialize' methods -/// for the given channel. The ArgTs... list will determine the "over-the-wire" -/// types to be serialized. The serialize and deserialize methods take a list -/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., -/// but may be different types from ArgTs, provided that for each CArgT there -/// is a SerializationTraits specialization -/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the -/// caller argument to over-the-wire value. -template <typename ChannelT, typename... ArgTs> class SequenceSerialization; - -template <typename ChannelT> class SequenceSerialization<ChannelT> { -public: - static Error serialize(ChannelT &C) { return Error::success(); } - static Error deserialize(ChannelT &C) { return Error::success(); } -}; - -template <typename ChannelT, typename ArgT> -class SequenceSerialization<ChannelT, ArgT> { -public: - template <typename CArgT> static Error serialize(ChannelT &C, CArgT &&CArg) { - return SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( - C, std::forward<CArgT>(CArg)); - } - - template <typename CArgT> static Error deserialize(ChannelT &C, CArgT &CArg) { - return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); - } -}; - -template <typename ChannelT, typename ArgT, typename... ArgTs> -class SequenceSerialization<ChannelT, ArgT, ArgTs...> { -public: - template <typename CArgT, typename... CArgTs> - static Error serialize(ChannelT &C, CArgT &&CArg, CArgTs &&...CArgs) { - if (auto Err = - SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( - C, std::forward<CArgT>(CArg))) - return Err; - if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) - return Err; - return SequenceSerialization<ChannelT, ArgTs...>::serialize( - C, std::forward<CArgTs>(CArgs)...); - } - - template <typename CArgT, typename... CArgTs> - static Error deserialize(ChannelT &C, CArgT &CArg, CArgTs &...CArgs) { - if (auto Err = - SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) - return Err; - if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) - return Err; - return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); - } -}; - -template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, ArgTs &&...Args) { - return SequenceSerialization<ChannelT, std::decay_t<ArgTs>...>::serialize( - C, std::forward<ArgTs>(Args)...); -} - -template <typename ChannelT, typename... ArgTs> -Error deserializeSeq(ChannelT &C, ArgTs &...Args) { - return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); -} - -template <typename ChannelT> class SerializationTraits<ChannelT, Error> { -public: - using WrappedErrorSerializer = - std::function<Error(ChannelT &C, const ErrorInfoBase &)>; - - using WrappedErrorDeserializer = - std::function<Error(ChannelT &C, Error &Err)>; - - template <typename ErrorInfoT, typename SerializeFtor, - typename DeserializeFtor> - static void registerErrorType(std::string Name, SerializeFtor Serialize, - DeserializeFtor Deserialize) { - assert(!Name.empty() && - "The empty string is reserved for the Success value"); - - const std::string *KeyName = nullptr; - { - // We're abusing the stability of std::map here: We take a reference to - // the key of the deserializers map to save us from duplicating the string - // in the serializer. This should be changed to use a stringpool if we - // switch to a map type that may move keys in memory. - std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); - auto I = Deserializers.insert( - Deserializers.begin(), - std::make_pair(std::move(Name), std::move(Deserialize))); - KeyName = &I->first; - } - - { - assert(KeyName != nullptr && "No keyname pointer"); - std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); - Serializers[ErrorInfoT::classID()] = - [KeyName, Serialize = std::move(Serialize)]( - ChannelT &C, const ErrorInfoBase &EIB) -> Error { - assert(EIB.dynamicClassID() == ErrorInfoT::classID() && - "Serializer called for wrong error type"); - if (auto Err = serializeSeq(C, *KeyName)) - return Err; - return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); - }; - } - } - - static Error serialize(ChannelT &C, Error &&Err) { - std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); - - if (!Err) - return serializeSeq(C, std::string()); - - return handleErrors(std::move(Err), [&C](const ErrorInfoBase &EIB) { - auto SI = Serializers.find(EIB.dynamicClassID()); - if (SI == Serializers.end()) - return serializeAsStringError(C, EIB); - return (SI->second)(C, EIB); - }); - } - - static Error deserialize(ChannelT &C, Error &Err) { - std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); - - std::string Key; - if (auto Err = deserializeSeq(C, Key)) - return Err; - - if (Key.empty()) { - ErrorAsOutParameter EAO(&Err); - Err = Error::success(); - return Error::success(); - } - - auto DI = Deserializers.find(Key); - assert(DI != Deserializers.end() && "No deserializer for error type"); - return (DI->second)(C, Err); - } - -private: - static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { - std::string ErrMsg; - { - raw_string_ostream ErrMsgStream(ErrMsg); - EIB.log(ErrMsgStream); - } - return serialize(C, make_error<StringError>(std::move(ErrMsg), - inconvertibleErrorCode())); - } - - static std::recursive_mutex SerializersMutex; - static std::recursive_mutex DeserializersMutex; - static std::map<const void *, WrappedErrorSerializer> Serializers; - static std::map<std::string, WrappedErrorDeserializer> Deserializers; -}; - -template <typename ChannelT> -std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; - -template <typename ChannelT> -std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; - -template <typename ChannelT> -std::map<const void *, - typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> - SerializationTraits<ChannelT, Error>::Serializers; - -template <typename ChannelT> -std::map<std::string, typename SerializationTraits< - ChannelT, Error>::WrappedErrorDeserializer> - SerializationTraits<ChannelT, Error>::Deserializers; - -/// Registers a serializer and deserializer for the given error type on the -/// given channel type. -template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, - typename DeserializeFtor> -void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, - DeserializeFtor &&Deserialize) { - SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( - std::move(Name), std::forward<SerializeFtor>(Serialize), - std::forward<DeserializeFtor>(Deserialize)); -} - -/// Registers serialization/deserialization for StringError. -template <typename ChannelT> void registerStringError() { - static bool AlreadyRegistered = false; - if (!AlreadyRegistered) { - registerErrorSerialization<ChannelT, StringError>( - "StringError", - [](ChannelT &C, const StringError &SE) { - return serializeSeq(C, SE.getMessage()); - }, - [](ChannelT &C, Error &Err) -> Error { - ErrorAsOutParameter EAO(&Err); - std::string Msg; - if (auto E2 = deserializeSeq(C, Msg)) - return E2; - Err = make_error<StringError>( - std::move(Msg), - orcError(OrcErrorCode::UnknownErrorCodeFromRemote)); - return Error::success(); - }); - AlreadyRegistered = true; - } -} - -/// SerializationTraits for Expected<T1> from an Expected<T2>. -template <typename ChannelT, typename T1, typename T2> -class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { -public: - static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { - if (ValOrErr) { - if (auto Err = serializeSeq(C, true)) - return Err; - return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); - } - if (auto Err = serializeSeq(C, false)) - return Err; - return serializeSeq(C, ValOrErr.takeError()); - } - - static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { - ExpectedAsOutParameter<T2> EAO(&ValOrErr); - bool HasValue; - if (auto Err = deserializeSeq(C, HasValue)) - return Err; - if (HasValue) - return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); - Error Err = Error::success(); - if (auto E2 = deserializeSeq(C, Err)) - return E2; - ValOrErr = std::move(Err); - return Error::success(); - } -}; - -/// SerializationTraits for Expected<T1> from a T2. -template <typename ChannelT, typename T1, typename T2> -class SerializationTraits<ChannelT, Expected<T1>, T2> { -public: - static Error serialize(ChannelT &C, T2 &&Val) { - return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); - } -}; - -/// SerializationTraits for Expected<T1> from an Error. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, Expected<T>, Error> { -public: - static Error serialize(ChannelT &C, Error &&Err) { - return serializeSeq(C, Expected<T>(std::move(Err))); - } -}; - -/// SerializationTraits default specialization for std::pair. -template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> -class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { -public: - static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { - if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) - return Err; - return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); - } - - static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { - if (auto Err = - SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) - return Err; - return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); - } -}; - -/// SerializationTraits default specialization for std::tuple. -template <typename ChannelT, typename... ArgTs> -class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { -public: - /// RPC channel serialization for std::tuple. - static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { - return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); - } - - /// RPC channel deserialization for std::tuple. - static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { - return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); - } - -private: - // Serialization helper for std::tuple. - template <size_t... Is> - static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, - std::index_sequence<Is...> _) { - return serializeSeq(C, std::get<Is>(V)...); - } - - // Serialization helper for std::tuple. - template <size_t... Is> - static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, - std::index_sequence<Is...> _) { - return deserializeSeq(C, std::get<Is>(V)...); - } -}; - -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, Optional<T>> { -public: - /// Serialize an Optional<T>. - static Error serialize(ChannelT &C, const Optional<T> &O) { - if (auto Err = serializeSeq(C, O != None)) - return Err; - if (O) - if (auto Err = serializeSeq(C, *O)) - return Err; - return Error::success(); - } - - /// Deserialize an Optional<T>. - static Error deserialize(ChannelT &C, Optional<T> &O) { - bool HasValue = false; - if (auto Err = deserializeSeq(C, HasValue)) - return Err; - if (HasValue) - if (auto Err = deserializeSeq(C, *O)) - return Err; - return Error::success(); - }; -}; - -/// SerializationTraits default specialization for std::vector. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, std::vector<T>> { -public: - /// Serialize a std::vector<T> from std::vector<T>. - static Error serialize(ChannelT &C, const std::vector<T> &V) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) - return Err; - - for (const auto &E : V) - if (auto Err = serializeSeq(C, E)) - return Err; - - return Error::success(); - } - - /// Deserialize a std::vector<T> to a std::vector<T>. - static Error deserialize(ChannelT &C, std::vector<T> &V) { - assert(V.empty() && - "Expected default-constructed vector to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - V.resize(Count); - for (auto &E : V) - if (auto Err = deserializeSeq(C, E)) - return Err; - - return Error::success(); - } -}; - -/// Enable vector serialization from an ArrayRef. -template <typename ChannelT, typename T> -class SerializationTraits<ChannelT, std::vector<T>, ArrayRef<T>> { -public: - static Error serialize(ChannelT &C, ArrayRef<T> V) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) - return Err; - - for (const auto &E : V) - if (auto Err = serializeSeq(C, E)) - return Err; - - return Error::success(); - } -}; - -template <typename ChannelT, typename T, typename T2> -class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { -public: - /// Serialize a std::set<T> from std::set<T2>. - static Error serialize(ChannelT &C, const std::set<T2> &S) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) - return Err; - - for (const auto &E : S) - if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) - return Err; - - return Error::success(); - } - - /// Deserialize a std::set<T> to a std::set<T>. - static Error deserialize(ChannelT &C, std::set<T2> &S) { - assert(S.empty() && "Expected default-constructed set to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - T2 Val; - if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) - return Err; - - auto Added = S.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized set", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -template <typename ChannelT, typename K, typename V, typename K2, typename V2> -class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { -public: - /// Serialize a std::map<K, V> from std::map<K2, V2>. - static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) - return Err; - - for (const auto &E : M) { - if (auto Err = - SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) - return Err; - if (auto Err = - SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) - return Err; - } - - return Error::success(); - } - - /// Deserialize a std::map<K, V> to a std::map<K, V>. - static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { - assert(M.empty() && "Expected default-constructed map to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - std::pair<K2, V2> Val; - if (auto Err = - SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) - return Err; - - auto Added = M.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized map", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -template <typename ChannelT, typename K, typename V, typename K2, typename V2> -class SerializationTraits<ChannelT, std::map<K, V>, DenseMap<K2, V2>> { -public: - /// Serialize a std::map<K, V> from DenseMap<K2, V2>. - static Error serialize(ChannelT &C, const DenseMap<K2, V2> &M) { - if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) - return Err; - - for (auto &E : M) { - if (auto Err = - SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) - return Err; - } - - return Error::success(); - } - - /// Serialize a std::map<K, V> from DenseMap<K2, V2>. - static Error deserialize(ChannelT &C, DenseMap<K2, V2> &M) { - assert(M.empty() && "Expected default-constructed map to deserialize into"); - - uint64_t Count = 0; - if (auto Err = deserializeSeq(C, Count)) - return Err; - - while (Count-- != 0) { - std::pair<K2, V2> Val; - if (auto Err = - SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) - return Err; - - if (auto Err = - SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) - return Err; - - auto Added = M.insert(Val).second; - if (!Added) - return make_error<StringError>("Duplicate element in deserialized map", - orcError(OrcErrorCode::UnknownORCError)); - } - - return Error::success(); - } -}; - -} // namespace shared -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_RPC_RPCSERIALIZATION_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===- Serialization.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_SERIALIZATION_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" +#include "llvm/Support/thread.h" +#include <map> +#include <mutex> +#include <set> +#include <sstream> +#include <string> +#include <vector> + +namespace llvm { +namespace orc { +namespace shared { + +template <typename T> class SerializationTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template <typename... ArgTs> class SerializationTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template <typename OStream> +OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template <typename OStream, typename ArgT> +OStream &operator<<(OStream &OS, const SerializationTypeNameSequence<ArgT> &V) { + OS << SerializationTypeName<ArgT>::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> +OStream & +operator<<(OStream &OS, + const SerializationTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { + OS << SerializationTypeName<ArgT1>::getName() << ", " + << SerializationTypeNameSequence<ArgT2, ArgTs...>(); + return OS; +} + +template <> class SerializationTypeName<void> { +public: + static const char *getName() { return "void"; } +}; + +template <> class SerializationTypeName<int8_t> { +public: + static const char *getName() { return "int8_t"; } +}; + +template <> class SerializationTypeName<uint8_t> { +public: + static const char *getName() { return "uint8_t"; } +}; + +template <> class SerializationTypeName<int16_t> { +public: + static const char *getName() { return "int16_t"; } +}; + +template <> class SerializationTypeName<uint16_t> { +public: + static const char *getName() { return "uint16_t"; } +}; + +template <> class SerializationTypeName<int32_t> { +public: + static const char *getName() { return "int32_t"; } +}; + +template <> class SerializationTypeName<uint32_t> { +public: + static const char *getName() { return "uint32_t"; } +}; + +template <> class SerializationTypeName<int64_t> { +public: + static const char *getName() { return "int64_t"; } +}; + +template <> class SerializationTypeName<uint64_t> { +public: + static const char *getName() { return "uint64_t"; } +}; + +template <> class SerializationTypeName<bool> { +public: + static const char *getName() { return "bool"; } +}; + +template <> class SerializationTypeName<std::string> { +public: + static const char *getName() { return "std::string"; } +}; + +template <> class SerializationTypeName<Error> { +public: + static const char *getName() { return "Error"; } +}; + +template <typename T> class SerializationTypeName<Expected<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "Expected<" << SerializationTypeNameSequence<T>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T1, typename T2> +class SerializationTypeName<std::pair<T1, T2>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::pair<" << SerializationTypeNameSequence<T1, T2>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename... ArgTs> class SerializationTypeName<std::tuple<ArgTs...>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::tuple<" << SerializationTypeNameSequence<ArgTs...>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<Optional<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "Optional<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<std::vector<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::vector<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class SerializationTypeName<std::set<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::set<" << SerializationTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename K, typename V> class SerializationTypeName<std::map<K, V>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::map<" << SerializationTypeNameSequence<K, V>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +/// The SerializationTraits<ChannelT, T> class describes how to serialize and +/// deserialize an instance of type T to/from an abstract channel of type +/// ChannelT. It also provides a representation of the type's name via the +/// getName method. +/// +/// Specializations of this class should provide the following functions: +/// +/// @code{.cpp} +/// +/// static const char* getName(); +/// static Error serialize(ChannelT&, const T&); +/// static Error deserialize(ChannelT&, T&); +/// +/// @endcode +/// +/// The third argument of SerializationTraits is intended to support SFINAE. +/// E.g.: +/// +/// @code{.cpp} +/// +/// class MyVirtualChannel { ... }; +/// +/// template <DerivedChannelT> +/// class SerializationTraits<DerivedChannelT, bool, +/// std::enable_if_t< +/// std::is_base_of<VirtChannel, DerivedChannel>::value +/// >> { +/// public: +/// static const char* getName() { ... }; +/// } +/// +/// @endcode +template <typename ChannelT, typename WireType, + typename ConcreteType = WireType, typename = void> +class SerializationTraits; + +template <typename ChannelT> class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; + +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the +/// caller argument to over-the-wire value. +template <typename ChannelT, typename... ArgTs> class SequenceSerialization; + +template <typename ChannelT> class SequenceSerialization<ChannelT> { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; + +template <typename ChannelT, typename ArgT> +class SequenceSerialization<ChannelT, ArgT> { +public: + template <typename CArgT> static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( + C, std::forward<CArgT>(CArg)); + } + + template <typename CArgT> static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); + } +}; + +template <typename ChannelT, typename ArgT, typename... ArgTs> +class SequenceSerialization<ChannelT, ArgT, ArgTs...> { +public: + template <typename CArgT, typename... CArgTs> + static Error serialize(ChannelT &C, CArgT &&CArg, CArgTs &&...CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, std::decay_t<CArgT>>::serialize( + C, std::forward<CArgT>(CArg))) + return Err; + if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::serialize( + C, std::forward<CArgTs>(CArgs)...); + } + + template <typename CArgT, typename... CArgTs> + static Error deserialize(ChannelT &C, CArgT &CArg, CArgTs &...CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); + } +}; + +template <typename ChannelT, typename... ArgTs> +Error serializeSeq(ChannelT &C, ArgTs &&...Args) { + return SequenceSerialization<ChannelT, std::decay_t<ArgTs>...>::serialize( + C, std::forward<ArgTs>(Args)...); +} + +template <typename ChannelT, typename... ArgTs> +Error deserializeSeq(ChannelT &C, ArgTs &...Args) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); +} + +template <typename ChannelT> class SerializationTraits<ChannelT, Error> { +public: + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase &)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to + // the key of the deserializers map to save us from duplicating the string + // in the serializer. This should be changed to use a stringpool if we + // switch to a map type that may move keys in memory. + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + auto I = Deserializers.insert( + Deserializers.begin(), + std::make_pair(std::move(Name), std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize = std::move(Serialize)]( + ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; + static std::map<const void *, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void *, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> + SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, typename SerializationTraits< + ChannelT, Error>::WrappedErrorDeserializer> + SerializationTraits<ChannelT, Error>::Deserializers; + +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( + std::move(Name), std::forward<SerializeFtor>(Serialize), + std::forward<DeserializeFtor>(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. +template <typename ChannelT> void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + registerErrorSerialization<ChannelT, StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = make_error<StringError>( + std::move(Msg), + orcError(OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + +/// SerializationTraits default specialization for std::pair. +template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> +class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { +public: + static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { + if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); + } + + static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { + if (auto Err = + SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); + } +}; + +/// SerializationTraits default specialization for std::tuple. +template <typename ChannelT, typename... ArgTs> +class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { +public: + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + +private: + // Serialization helper for std::tuple. + template <size_t... Is> + static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); + } + + // Serialization helper for std::tuple. + template <size_t... Is> + static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Optional<T>> { +public: + /// Serialize an Optional<T>. + static Error serialize(ChannelT &C, const Optional<T> &O) { + if (auto Err = serializeSeq(C, O != None)) + return Err; + if (O) + if (auto Err = serializeSeq(C, *O)) + return Err; + return Error::success(); + } + + /// Deserialize an Optional<T>. + static Error deserialize(ChannelT &C, Optional<T> &O) { + bool HasValue = false; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + if (auto Err = deserializeSeq(C, *O)) + return Err; + return Error::success(); + }; +}; + +/// SerializationTraits default specialization for std::vector. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>> { +public: + /// Serialize a std::vector<T> from std::vector<T>. + static Error serialize(ChannelT &C, const std::vector<T> &V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::vector<T> to a std::vector<T>. + static Error deserialize(ChannelT &C, std::vector<T> &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = deserializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +/// Enable vector serialization from an ArrayRef. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>, ArrayRef<T>> { +public: + static Error serialize(ChannelT &C, ArrayRef<T> V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +template <typename ChannelT, typename T, typename T2> +class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { +public: + /// Serialize a std::set<T> from std::set<T2>. + static Error serialize(ChannelT &C, const std::set<T2> &S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set<T> to a std::set<T>. + static Error deserialize(ChannelT &C, std::set<T2> &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { +public: + /// Serialize a std::map<K, V> from std::map<K2, V2>. + static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map<K, V> to a std::map<K, V>. + static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, DenseMap<K2, V2>> { +public: + /// Serialize a std::map<K, V> from DenseMap<K2, V2>. + static Error serialize(ChannelT &C, const DenseMap<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Serialize a std::map<K, V> from DenseMap<K2, V2>. + static Error deserialize(ChannelT &C, DenseMap<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +} // namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPC_RPCSERIALIZATION_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h index 9fc8dfaead..c3dce579d7 100644 --- a/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h +++ b/contrib/libs/llvm12/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h @@ -1,176 +1,176 @@ -#pragma once - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -//===--- TargetProcessControlTypes.h -- Shared Core/TPC types ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TargetProcessControl types that are used by both the Orc and -// OrcTargetProcess libraries. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H -#define LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/JITSymbol.h" - -#include <vector> - -namespace llvm { -namespace orc { -namespace tpctypes { - -template <typename T> struct UIntWrite { - UIntWrite() = default; - UIntWrite(JITTargetAddress Address, T Value) - : Address(Address), Value(Value) {} - - JITTargetAddress Address = 0; - T Value = 0; -}; - -/// Describes a write to a uint8_t. -using UInt8Write = UIntWrite<uint8_t>; - -/// Describes a write to a uint16_t. -using UInt16Write = UIntWrite<uint16_t>; - -/// Describes a write to a uint32_t. -using UInt32Write = UIntWrite<uint32_t>; - -/// Describes a write to a uint64_t. -using UInt64Write = UIntWrite<uint64_t>; - -/// Describes a write to a buffer. -/// For use with TargetProcessControl::MemoryAccess objects. -struct BufferWrite { - BufferWrite() = default; - BufferWrite(JITTargetAddress Address, StringRef Buffer) - : Address(Address), Buffer(Buffer) {} - - JITTargetAddress Address = 0; - StringRef Buffer; -}; - -/// A handle used to represent a loaded dylib in the target process. -using DylibHandle = JITTargetAddress; - -using LookupResult = std::vector<JITTargetAddress>; - -/// Either a uint8_t array or a uint8_t*. -union CWrapperFunctionResultData { - uint8_t Value[8]; - uint8_t *ValuePtr; -}; - -/// C ABI compatible wrapper function result. -/// -/// This can be safely returned from extern "C" functions, but should be used -/// to construct a WrapperFunctionResult for safety. -struct CWrapperFunctionResult { - uint64_t Size; - CWrapperFunctionResultData Data; - void (*Destroy)(CWrapperFunctionResultData Data, uint64_t Size); -}; - -/// C++ wrapper function result: Same as CWrapperFunctionResult but -/// auto-releases memory. -class WrapperFunctionResult { -public: - /// Create a default WrapperFunctionResult. - WrapperFunctionResult() { zeroInit(R); } - - /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This - /// instance takes ownership of the result object and will automatically - /// call the Destroy member upon destruction. - WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {} - - WrapperFunctionResult(const WrapperFunctionResult &) = delete; - WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; - - WrapperFunctionResult(WrapperFunctionResult &&Other) { - zeroInit(R); - std::swap(R, Other.R); - } - - WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(Tmp, Other.R); - std::swap(R, Tmp); - return *this; - } - - ~WrapperFunctionResult() { - if (R.Destroy) - R.Destroy(R.Data, R.Size); - } - - /// Relinquish ownership of and return the CWrapperFunctionResult. - CWrapperFunctionResult release() { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(R, Tmp); - return Tmp; - } - - /// Get an ArrayRef covering the data in the result. - ArrayRef<uint8_t> getData() const { - if (R.Size <= 8) - return ArrayRef<uint8_t>(R.Data.Value, R.Size); - return ArrayRef<uint8_t>(R.Data.ValuePtr, R.Size); - } - - /// Create a WrapperFunctionResult from the given integer, provided its - /// size is no greater than 64 bits. - template <typename T, - typename _ = std::enable_if_t<std::is_integral<T>::value && - sizeof(T) <= sizeof(uint64_t)>> - static WrapperFunctionResult from(T Value) { - CWrapperFunctionResult R; - R.Size = sizeof(T); - memcpy(&R.Data.Value, Value, R.Size); - R.Destroy = nullptr; - return R; - } - - /// Create a WrapperFunctionResult from the given string. - static WrapperFunctionResult from(StringRef S); - - /// Always free Data.ValuePtr by calling free on it. - static void destroyWithFree(CWrapperFunctionResultData Data, uint64_t Size); - - /// Always free Data.ValuePtr by calling delete[] on it. - static void destroyWithDeleteArray(CWrapperFunctionResultData Data, - uint64_t Size); - -private: - static void zeroInit(CWrapperFunctionResult &R) { - R.Size = 0; - R.Data.ValuePtr = nullptr; - R.Destroy = nullptr; - } - - CWrapperFunctionResult R; -}; - -} // end namespace tpctypes -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +//===--- TargetProcessControlTypes.h -- Shared Core/TPC types ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TargetProcessControl types that are used by both the Orc and +// OrcTargetProcess libraries. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" + +#include <vector> + +namespace llvm { +namespace orc { +namespace tpctypes { + +template <typename T> struct UIntWrite { + UIntWrite() = default; + UIntWrite(JITTargetAddress Address, T Value) + : Address(Address), Value(Value) {} + + JITTargetAddress Address = 0; + T Value = 0; +}; + +/// Describes a write to a uint8_t. +using UInt8Write = UIntWrite<uint8_t>; + +/// Describes a write to a uint16_t. +using UInt16Write = UIntWrite<uint16_t>; + +/// Describes a write to a uint32_t. +using UInt32Write = UIntWrite<uint32_t>; + +/// Describes a write to a uint64_t. +using UInt64Write = UIntWrite<uint64_t>; + +/// Describes a write to a buffer. +/// For use with TargetProcessControl::MemoryAccess objects. +struct BufferWrite { + BufferWrite() = default; + BufferWrite(JITTargetAddress Address, StringRef Buffer) + : Address(Address), Buffer(Buffer) {} + + JITTargetAddress Address = 0; + StringRef Buffer; +}; + +/// A handle used to represent a loaded dylib in the target process. +using DylibHandle = JITTargetAddress; + +using LookupResult = std::vector<JITTargetAddress>; + +/// Either a uint8_t array or a uint8_t*. +union CWrapperFunctionResultData { + uint8_t Value[8]; + uint8_t *ValuePtr; +}; + +/// C ABI compatible wrapper function result. +/// +/// This can be safely returned from extern "C" functions, but should be used +/// to construct a WrapperFunctionResult for safety. +struct CWrapperFunctionResult { + uint64_t Size; + CWrapperFunctionResultData Data; + void (*Destroy)(CWrapperFunctionResultData Data, uint64_t Size); +}; + +/// C++ wrapper function result: Same as CWrapperFunctionResult but +/// auto-releases memory. +class WrapperFunctionResult { +public: + /// Create a default WrapperFunctionResult. + WrapperFunctionResult() { zeroInit(R); } + + /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This + /// instance takes ownership of the result object and will automatically + /// call the Destroy member upon destruction. + WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {} + + WrapperFunctionResult(const WrapperFunctionResult &) = delete; + WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; + + WrapperFunctionResult(WrapperFunctionResult &&Other) { + zeroInit(R); + std::swap(R, Other.R); + } + + WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { + CWrapperFunctionResult Tmp; + zeroInit(Tmp); + std::swap(Tmp, Other.R); + std::swap(R, Tmp); + return *this; + } + + ~WrapperFunctionResult() { + if (R.Destroy) + R.Destroy(R.Data, R.Size); + } + + /// Relinquish ownership of and return the CWrapperFunctionResult. + CWrapperFunctionResult release() { + CWrapperFunctionResult Tmp; + zeroInit(Tmp); + std::swap(R, Tmp); + return Tmp; + } + + /// Get an ArrayRef covering the data in the result. + ArrayRef<uint8_t> getData() const { + if (R.Size <= 8) + return ArrayRef<uint8_t>(R.Data.Value, R.Size); + return ArrayRef<uint8_t>(R.Data.ValuePtr, R.Size); + } + + /// Create a WrapperFunctionResult from the given integer, provided its + /// size is no greater than 64 bits. + template <typename T, + typename _ = std::enable_if_t<std::is_integral<T>::value && + sizeof(T) <= sizeof(uint64_t)>> + static WrapperFunctionResult from(T Value) { + CWrapperFunctionResult R; + R.Size = sizeof(T); + memcpy(&R.Data.Value, Value, R.Size); + R.Destroy = nullptr; + return R; + } + + /// Create a WrapperFunctionResult from the given string. + static WrapperFunctionResult from(StringRef S); + + /// Always free Data.ValuePtr by calling free on it. + static void destroyWithFree(CWrapperFunctionResultData Data, uint64_t Size); + + /// Always free Data.ValuePtr by calling delete[] on it. + static void destroyWithDeleteArray(CWrapperFunctionResultData Data, + uint64_t Size); + +private: + static void zeroInit(CWrapperFunctionResult &R) { + R.Size = 0; + R.Data.ValuePtr = nullptr; + R.Destroy = nullptr; + } + + CWrapperFunctionResult R; +}; + +} // end namespace tpctypes +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_TARGETPROCESSCONTROLTYPES_H + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif |