From b57c915405a13b5321756af7b82212652f2fa378 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Mon, 12 Sep 2016 20:34:41 +0000 Subject: [PATCH] [ORC] Replace the serialize/deserialize function pair with a SerializationTraits class. SerializationTraits provides serialize and deserialize methods corresponding to the earlier functions, but also provides a name for the type. In future, this name will be used to render function signatures as strings, which will in turn be used to negotiate and verify API support between RPC clients and servers. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@281254 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h | 47 ++-- include/llvm/ExecutionEngine/Orc/RPCByteChannel.h | 269 ++++++++++----------- .../llvm/ExecutionEngine/Orc/RPCSerialization.h | 205 ++++++++++++++++ include/llvm/ExecutionEngine/Orc/RPCUtils.h | 6 +- 4 files changed, 360 insertions(+), 167 deletions(-) create mode 100644 include/llvm/ExecutionEngine/Orc/RPCSerialization.h diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 2b3caf06067..33d6b604c61 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -40,27 +40,34 @@ private: uint64_t Size; }; -inline Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) { - if (auto EC = serialize(C, DBW.getDst())) - return EC; - if (auto EC = serialize(C, DBW.getSize())) - return EC; - return C.appendBytes(DBW.getSrc(), DBW.getSize()); -} - -inline Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) { - JITTargetAddress Dst; - if (auto EC = deserialize(C, Dst)) - return EC; - uint64_t Size; - if (auto EC = deserialize(C, Size)) - return EC; - char *Addr = reinterpret_cast(static_cast(Dst)); - - DBW = DirectBufferWriter(0, Dst, Size); +template <> +class SerializationTraits { +public: - return C.readBytes(Addr, Size); -} + static const char* getName() { return "DirectBufferWriter"; } + + static Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) { + if (auto EC = serializeSeq(C, DBW.getDst())) + return EC; + if (auto EC = serializeSeq(C, DBW.getSize())) + return EC; + return C.appendBytes(DBW.getSrc(), DBW.getSize()); + } + + static Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) { + JITTargetAddress Dst; + if (auto EC = deserializeSeq(C, Dst)) + return EC; + uint64_t Size; + if (auto EC = deserializeSeq(C, Size)) + return EC; + char *Addr = reinterpret_cast(static_cast(Dst)); + + DBW = DirectBufferWriter(0, Dst, Size); + + return C.readBytes(Addr, Size); + } +}; class OrcRemoteTargetRPCAPI : public RPC { protected: diff --git a/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h b/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h index 1069cb91d36..c8cb42d5374 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h +++ b/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h @@ -11,6 +11,7 @@ #define LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H #include "OrcError.h" +#include "RPCSerialization.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" @@ -21,6 +22,7 @@ #include #include #include +#include #include namespace llvm { @@ -79,169 +81,148 @@ inline Error endReceiveMessage(RPCByteChannel &C) { return Error::success(); } -/// RPC channel serialization for a variadic list of arguments. -template -Error serializeSeq(RPCByteChannel &C, const T &Arg, const Ts &... Args) { - if (auto Err = serialize(C, Arg)) - return Err; - return serializeSeq(C, Args...); -} - -/// RPC channel serialization for an (empty) variadic list of arguments. -inline Error serializeSeq(RPCByteChannel &C) { return Error::success(); } +template ::value>:: + type> +class RPCByteChannelPrimitiveSerialization { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap(V); + return C.appendBytes(reinterpret_cast(&V), sizeof(T)); + }; -/// RPC channel deserialization for a variadic list of arguments. -template -Error deserializeSeq(RPCByteChannel &C, T &Arg, Ts &... Args) { - if (auto Err = deserialize(C, Arg)) - return Err; - return deserializeSeq(C, Args...); -} + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast(&V), sizeof(T))) + return Err; + support::endian::byte_swap(V); + return Error::success(); + }; +}; -/// RPC channel serialization for an (empty) variadic list of arguments. -inline Error deserializeSeq(RPCByteChannel &C) { return Error::success(); } - -/// RPC channel serialization for integer primitives. -template -typename std::enable_if< - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value, - Error>::type -serialize(RPCByteChannel &C, T V) { - support::endian::byte_swap(V); - return C.appendBytes(reinterpret_cast(&V), sizeof(T)); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "uint64_t"; } +}; -/// RPC channel deserialization for integer primitives. -template -typename std::enable_if< - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value, - Error>::type -deserialize(RPCByteChannel &C, T &V) { - if (auto Err = C.readBytes(reinterpret_cast(&V), sizeof(T))) - return Err; - support::endian::byte_swap(V); - return Error::success(); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "int64_t"; } +}; -/// RPC channel serialization for enums. -template -typename std::enable_if::value, Error>::type -serialize(RPCByteChannel &C, T V) { - return serialize(C, static_cast::type>(V)); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "uint32_t"; } +}; -/// RPC channel deserialization for enums. -template -typename std::enable_if::value, Error>::type -deserialize(RPCByteChannel &C, T &V) { - typename std::underlying_type::type Tmp; - Error Err = deserialize(C, Tmp); - V = static_cast(Tmp); - return Err; -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "int32_t"; } +}; -/// RPC channel serialization for bools. -inline Error serialize(RPCByteChannel &C, bool V) { - uint8_t VN = V ? 1 : 0; - return C.appendBytes(reinterpret_cast(&VN), 1); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "uint16_t"; } +}; -/// RPC channel deserialization for bools. -inline Error deserialize(RPCByteChannel &C, bool &V) { - uint8_t VN = 0; - if (auto Err = C.readBytes(reinterpret_cast(&VN), 1)) - return Err; +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "int16_t"; } +}; - V = (VN != 0); - return Error::success(); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "uint8_t"; } +}; -/// RPC channel serialization for StringRefs. -/// Note: There is no corresponding deseralization for this, as StringRef -/// doesn't own its memory and so can't hold the deserialized data. -inline Error serialize(RPCByteChannel &C, StringRef S) { - if (auto Err = serialize(C, static_cast(S.size()))) - return Err; - return C.appendBytes((const char *)S.bytes_begin(), S.size()); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "int8_t"; } +}; -/// RPC channel serialization for std::strings. -inline Error serialize(RPCByteChannel &C, const std::string &S) { - return serialize(C, StringRef(S)); -} +template +class SerializationTraits + : public RPCByteChannelPrimitiveSerialization { +public: + static const char* getName() { return "char"; } -/// RPC channel deserialization for std::strings. -inline Error deserialize(RPCByteChannel &C, std::string &S) { - uint64_t Count; - if (auto Err = deserialize(C, Count)) - return Err; - S.resize(Count); - return C.readBytes(&S[0], Count); -} + static Error serialize(RPCByteChannel &C, char V) { + return serializeSeq(C, static_cast(V)); + }; -// Serialization helper for std::tuple. -template -inline Error serializeTupleHelper(RPCByteChannel &C, const TupleT &V, - llvm::index_sequence _) { - return serializeSeq(C, std::get(V)...); -} + static Error deserialize(RPCByteChannel &C, char &V) { + uint8_t VV; + if (auto Err = deserializeSeq(C, VV)) + return Err; + V = static_cast(V); + return Error::success(); + }; +}; -/// RPC channel serialization for std::tuple. -template -inline Error serialize(RPCByteChannel &C, const std::tuple &V) { - return serializeTupleHelper(C, V, llvm::index_sequence_for()); -} +template +class SerializationTraits::value>:: + type> { +public: + static const char* getName() { return "bool"; } -// Serialization helper for std::tuple. -template -inline Error deserializeTupleHelper(RPCByteChannel &C, TupleT &V, - llvm::index_sequence _) { - return deserializeSeq(C, std::get(V)...); -} + static Error serialize(ChannelT &C, bool V) { + return C.appendBytes(reinterpret_cast(&V), 1); + } -/// RPC channel deserialization for std::tuple. -template -inline Error deserialize(RPCByteChannel &C, std::tuple &V) { - return deserializeTupleHelper(C, V, llvm::index_sequence_for()); -} + static Error deserialize(ChannelT &C, bool &V) { + return C.readBytes(reinterpret_cast(&V), 1); + } +}; -/// RPC channel serialization for ArrayRef. -template Error serialize(RPCByteChannel &C, const ArrayRef &A) { - if (auto Err = serialize(C, static_cast(A.size()))) - return Err; +template +class SerializationTraits::value>:: + type> { +public: + static const char* getName() { return "std::string"; } - for (const auto &E : A) - if (auto Err = serialize(C, E)) + static Error serialize(RPCByteChannel &C, StringRef S) { + if (auto Err = SerializationTraits:: + serialize(C, static_cast(S.size()))) return Err; - - return Error::success(); -} - -/// RPC channel serialization for std::array. -template Error serialize(RPCByteChannel &C, - const std::vector &V) { - return serialize(C, ArrayRef(V)); -} - -/// RPC channel deserialization for std::array. -template Error deserialize(RPCByteChannel &C, std::vector &V) { - uint64_t Count = 0; - if (auto Err = deserialize(C, Count)) - return Err; - - V.resize(Count); - for (auto &E : V) - if (auto Err = deserialize(C, E)) + return C.appendBytes((const char *)S.bytes_begin(), S.size()); + } + + /// RPC channel serialization for std::strings. + static Error serialize(RPCByteChannel &C, const std::string &S) { + return serialize(C, StringRef(S)); + } + + /// RPC channel deserialization for std::strings. + static Error deserialize(RPCByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = SerializationTraits:: + deserialize(C, Count)) return Err; - - return Error::success(); -} + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; } // end namespace remote } // end namespace orc diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h new file mode 100644 index 00000000000..5ed6e45ae20 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -0,0 +1,205 @@ +//===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H + +#include "OrcError.h" +#include + +namespace llvm { +namespace orc { +namespace remote { + +template +class SerializationTraits {}; + +/// RPC channel serialization for a variadic list of arguments. +template +Error serializeSeq(ChannelT &C, const T &Arg, const Ts &... Args) { + if (auto Err = SerializationTraits::serialize(C, Arg)) + return Err; + return serializeSeq(C, Args...); +} + +/// RPC channel serialization for an (empty) variadic list of arguments. +template +Error serializeSeq(ChannelT &C) { return Error::success(); } + +/// RPC channel deserialization for a variadic list of arguments. +template +Error deserializeSeq(ChannelT &C, T &Arg, Ts &... Args) { + if (auto Err = SerializationTraits::deserialize(C, Arg)) + return Err; + return deserializeSeq(C, Args...); +} + +/// RPC channel serialization for an (empty) variadic list of arguments. +template +Error deserializeSeq(ChannelT &C) { return Error::success(); } + +template +class TypeNameSequence {}; + +template +OStream& operator<<(OStream &OS, const TypeNameSequence &V) { + OS << SerializationTraits::getName(); + return OS; +} + +template +OStream& +operator<<(OStream &OS, + const TypeNameSequence &V) { + OS << SerializationTraits::getName() << ", " + << TypeNameSequence(); + return OS; +} + +/// Serialization for pairs. +template +class SerializationTraits> { +public: + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + Name = (std::ostringstream() + << "std::pair<" + << TypeNameSequence() + << ">").str(); + + return Name.data(); + } + + static Error serialize(ChannelT &C, const std::pair &V) { + return serializeSeq(C, V.first, V.second); + } + + static Error deserialize(ChannelT &C, std::pair &V) { + return deserializeSeq(C, V.first, V.second); + } +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex SerializationTraits>::NameMutex; + +template +std::string SerializationTraits>::Name; + +/// Serialization for tuples. +template +class SerializationTraits> { +public: + + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + Name = (std::ostringstream() + << "std::tuple<" + << TypeNameSequence() + << ">").str(); + + return Name.data(); + } + + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple &V) { + return serializeTupleHelper(C, V, llvm::index_sequence_for()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple &V) { + return deserializeTupleHelper(C, V, llvm::index_sequence_for()); + } + +private: + + // Serialization helper for std::tuple. + template + static Error serializeTupleHelper(ChannelT &C, const std::tuple &V, + llvm::index_sequence _) { + return serializeSeq(C, std::get(V)...); + } + + // Serialization helper for std::tuple. + template + static Error deserializeTupleHelper(ChannelT &C, std::tuple &V, + llvm::index_sequence _) { + return deserializeSeq(C, std::get(V)...); + } + + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex SerializationTraits>::NameMutex; + +template +std::string SerializationTraits>::Name; + +template +class SerializationTraits> { +public: + + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + Name = (std::ostringstream() + << "std::vector<" << TypeNameSequence() + << ">").str(); + return Name.data(); + } + + static Error serialize(ChannelT &C, const std::vector &V) { + if (auto Err = SerializationTraits:: + serialize(C, static_cast(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = SerializationTraits::serialize(C, E)) + return Err; + + return Error::success(); + } + + static Error deserialize(ChannelT &C, std::vector &V) { + uint64_t Count = 0; + if (auto Err = SerializationTraits:: + deserialize(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = SerializationTraits::deserialize(C, E)) + return Err; + + return Error::success(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex SerializationTraits>::NameMutex; + +template +std::string SerializationTraits>::Name; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index f061fff405b..e766a7f47f2 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -166,7 +166,7 @@ protected: template static Error readResult(ChannelT &C, std::promise &P) { RetT Val; - auto Err = deserialize(C, Val); + auto Err = deserializeSeq(C, Val); auto Err2 = endReceiveMessage(C); Err = joinErrors(std::move(Err), std::move(Err2)); if (Err) @@ -581,7 +581,7 @@ public: if (auto Err = startReceiveMessage(C)) return Err; - return deserialize(C, Id); + return deserializeSeq(C, Id); } /// Deserialize args for Func from C and call Handler. The signature of @@ -645,7 +645,7 @@ public: /// This should be called from the receive loop to retrieve results. Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) { SequenceNumberT SeqNo; - if (auto Err = deserialize(C, SeqNo)) { + if (auto Err = deserializeSeq(C, SeqNo)) { abandonOutstandingResults(); return Err; } -- 2.11.0