OSDN Git Service

[ORC] Replace the serialize/deserialize function pair with a SerializationTraits
authorLang Hames <lhames@gmail.com>
Mon, 12 Sep 2016 20:34:41 +0000 (20:34 +0000)
committerLang Hames <lhames@gmail.com>
Mon, 12 Sep 2016 20:34:41 +0000 (20:34 +0000)
class.

SerializationTraits provides serialize and deserialize methods corresponding to
the earlier functions, but also provides a name for the type. In future, this
name will be used to render function signatures as strings, which will in turn
be used to negotiate and verify API support between RPC clients and servers.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@281254 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
include/llvm/ExecutionEngine/Orc/RPCByteChannel.h
include/llvm/ExecutionEngine/Orc/RPCSerialization.h [new file with mode: 0644]
include/llvm/ExecutionEngine/Orc/RPCUtils.h

index 2b3caf0..33d6b60 100644 (file)
@@ -40,27 +40,34 @@ private:
   uint64_t Size;
 };
 
-inline Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) {
-  if (auto EC = serialize(C, DBW.getDst()))
-    return EC;
-  if (auto EC = serialize(C, DBW.getSize()))
-    return EC;
-  return C.appendBytes(DBW.getSrc(), DBW.getSize());
-}
-
-inline Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) {
-  JITTargetAddress Dst;
-  if (auto EC = deserialize(C, Dst))
-    return EC;
-  uint64_t Size;
-  if (auto EC = deserialize(C, Size))
-    return EC;
-  char *Addr = reinterpret_cast<char *>(static_cast<uintptr_t>(Dst));
-
-  DBW = DirectBufferWriter(0, Dst, Size);
+template <>
+class SerializationTraits<RPCByteChannel, DirectBufferWriter> {
+public:
 
-  return C.readBytes(Addr, Size);
-}
+  static const char* getName() { return "DirectBufferWriter"; }
+
+  static Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) {
+    if (auto EC = serializeSeq(C, DBW.getDst()))
+      return EC;
+    if (auto EC = serializeSeq(C, DBW.getSize()))
+      return EC;
+    return C.appendBytes(DBW.getSrc(), DBW.getSize());
+  }
+
+  static Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) {
+    JITTargetAddress Dst;
+    if (auto EC = deserializeSeq(C, Dst))
+      return EC;
+    uint64_t Size;
+    if (auto EC = deserializeSeq(C, Size))
+      return EC;
+    char *Addr = reinterpret_cast<char *>(static_cast<uintptr_t>(Dst));
+
+    DBW = DirectBufferWriter(0, Dst, Size);
+
+    return C.readBytes(Addr, Size);
+  }
+};
 
 class OrcRemoteTargetRPCAPI : public RPC<RPCByteChannel> {
 protected:
index 1069cb9..c8cb42d 100644 (file)
@@ -11,6 +11,7 @@
 #define LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H
 
 #include "OrcError.h"
+#include "RPCSerialization.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
@@ -21,6 +22,7 @@
 #include <mutex>
 #include <string>
 #include <tuple>
+#include <type_traits>
 #include <vector>
 
 namespace llvm {
@@ -79,169 +81,148 @@ inline Error endReceiveMessage(RPCByteChannel &C) {
   return Error::success();
 }
 
-/// RPC channel serialization for a variadic list of arguments.
-template <typename T, typename... Ts>
-Error serializeSeq(RPCByteChannel &C, const T &Arg, const Ts &... Args) {
-  if (auto Err = serialize(C, Arg))
-    return Err;
-  return serializeSeq(C, Args...);
-}
-
-/// RPC channel serialization for an (empty) variadic list of arguments.
-inline Error serializeSeq(RPCByteChannel &C) { return Error::success(); }
+template <typename ChannelT, typename T,
+          typename =
+            typename std::enable_if<
+                       std::is_base_of<RPCByteChannel, ChannelT>::value>::
+                         type>
+class RPCByteChannelPrimitiveSerialization {
+public:
+  static Error serialize(ChannelT &C, T V) {
+    support::endian::byte_swap<T, support::big>(V);
+    return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
+  };
 
-/// RPC channel deserialization for a variadic list of arguments.
-template <typename T, typename... Ts>
-Error deserializeSeq(RPCByteChannel &C, T &Arg, Ts &... Args) {
-  if (auto Err = deserialize(C, Arg))
-    return Err;
-  return deserializeSeq(C, Args...);
-}
+  static Error deserialize(ChannelT &C, T &V) {
+    if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
+      return Err;
+    support::endian::byte_swap<T, support::big>(V);
+    return Error::success();
+  };
+};
 
-/// RPC channel serialization for an (empty) variadic list of arguments.
-inline Error deserializeSeq(RPCByteChannel &C) { return Error::success(); }
-
-/// RPC channel serialization for integer primitives.
-template <typename T>
-typename std::enable_if<
-    std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
-        std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
-        std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
-        std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value,
-    Error>::type
-serialize(RPCByteChannel &C, T V) {
-  support::endian::byte_swap<T, support::big>(V);
-  return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, uint64_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, uint64_t> {
+public:
+  static const char* getName() { return "uint64_t"; }
+};
 
-/// RPC channel deserialization for integer primitives.
-template <typename T>
-typename std::enable_if<
-    std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
-        std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
-        std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
-        std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value,
-    Error>::type
-deserialize(RPCByteChannel &C, T &V) {
-  if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
-    return Err;
-  support::endian::byte_swap<T, support::big>(V);
-  return Error::success();
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, int64_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, int64_t> {
+public:
+  static const char* getName() { return "int64_t"; }
+};
 
-/// RPC channel serialization for enums.
-template <typename T>
-typename std::enable_if<std::is_enum<T>::value, Error>::type
-serialize(RPCByteChannel &C, T V) {
-  return serialize(C, static_cast<typename std::underlying_type<T>::type>(V));
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, uint32_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, uint32_t> {
+public:
+  static const char* getName() { return "uint32_t"; }
+};
 
-/// RPC channel deserialization for enums.
-template <typename T>
-typename std::enable_if<std::is_enum<T>::value, Error>::type
-deserialize(RPCByteChannel &C, T &V) {
-  typename std::underlying_type<T>::type Tmp;
-  Error Err = deserialize(C, Tmp);
-  V = static_cast<T>(Tmp);
-  return Err;
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, int32_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, int32_t> {
+public:
+  static const char* getName() { return "int32_t"; }
+};
 
-/// RPC channel serialization for bools.
-inline Error serialize(RPCByteChannel &C, bool V) {
-  uint8_t VN = V ? 1 : 0;
-  return C.appendBytes(reinterpret_cast<const char *>(&VN), 1);
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, uint16_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, uint16_t> {
+public:
+  static const char* getName() { return "uint16_t"; }
+};
 
-/// RPC channel deserialization for bools.
-inline Error deserialize(RPCByteChannel &C, bool &V) {
-  uint8_t VN = 0;
-  if (auto Err = C.readBytes(reinterpret_cast<char *>(&VN), 1))
-    return Err;
+template <typename ChannelT>
+class SerializationTraits<ChannelT, int16_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, int16_t> {
+public:
+  static const char* getName() { return "int16_t"; }
+};
 
-  V = (VN != 0);
-  return Error::success();
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, uint8_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, uint8_t> {
+public:
+  static const char* getName() { return "uint8_t"; }
+};
 
-/// RPC channel serialization for StringRefs.
-/// Note: There is no corresponding deseralization for this, as StringRef
-/// doesn't own its memory and so can't hold the deserialized data.
-inline Error serialize(RPCByteChannel &C, StringRef S) {
-  if (auto Err = serialize(C, static_cast<uint64_t>(S.size())))
-    return Err;
-  return C.appendBytes((const char *)S.bytes_begin(), S.size());
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, int8_t>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, int8_t> {
+public:
+  static const char* getName() { return "int8_t"; }
+};
 
-/// RPC channel serialization for std::strings.
-inline Error serialize(RPCByteChannel &C, const std::string &S) {
-  return serialize(C, StringRef(S));
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, char>
+  : public RPCByteChannelPrimitiveSerialization<ChannelT, uint8_t> {
+public:
+  static const char* getName() { return "char"; }
 
-/// RPC channel deserialization for std::strings.
-inline Error deserialize(RPCByteChannel &C, std::string &S) {
-  uint64_t Count;
-  if (auto Err = deserialize(C, Count))
-    return Err;
-  S.resize(Count);
-  return C.readBytes(&S[0], Count);
-}
+  static Error serialize(RPCByteChannel &C, char V) {
+    return serializeSeq(C, static_cast<uint8_t>(V));
+  };
 
-// Serialization helper for std::tuple.
-template <typename TupleT, size_t... Is>
-inline Error serializeTupleHelper(RPCByteChannel &C, const TupleT &V,
-                                  llvm::index_sequence<Is...> _) {
-  return serializeSeq(C, std::get<Is>(V)...);
-}
+  static Error deserialize(RPCByteChannel &C, char &V) {
+    uint8_t VV;
+    if (auto Err = deserializeSeq(C, VV))
+      return Err;
+    V = static_cast<char>(V);
+    return Error::success();
+  };
+};
 
-/// RPC channel serialization for std::tuple.
-template <typename... ArgTs>
-inline Error serialize(RPCByteChannel &C, const std::tuple<ArgTs...> &V) {
-  return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
-}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, bool,
+                          typename std::enable_if<
+                            std::is_base_of<RPCByteChannel, ChannelT>::value>::
+                              type> {
+public:
+  static const char* getName() { return "bool"; }
 
-// Serialization helper for std::tuple.
-template <typename TupleT, size_t... Is>
-inline Error deserializeTupleHelper(RPCByteChannel &C, TupleT &V,
-                                    llvm::index_sequence<Is...> _) {
-  return deserializeSeq(C, std::get<Is>(V)...);
-}
+  static Error serialize(ChannelT &C, bool V) {
+    return C.appendBytes(reinterpret_cast<const char *>(&V), 1);
+  }
 
-/// RPC channel deserialization for std::tuple.
-template <typename... ArgTs>
-inline Error deserialize(RPCByteChannel &C, std::tuple<ArgTs...> &V) {
-  return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
-}
+  static Error deserialize(ChannelT &C, bool &V) {
+    return C.readBytes(reinterpret_cast<char *>(&V), 1);
+  }
+};
 
-/// RPC channel serialization for ArrayRef<T>.
-template <typename T> Error serialize(RPCByteChannel &C, const ArrayRef<T> &A) {
-  if (auto Err = serialize(C, static_cast<uint64_t>(A.size())))
-    return Err;
+template <typename ChannelT>
+class SerializationTraits<ChannelT, std::string,
+                          typename std::enable_if<
+                            std::is_base_of<RPCByteChannel, ChannelT>::value>::
+                              type> {
+public:
+  static const char* getName() { return "std::string"; }
 
-  for (const auto &E : A)
-    if (auto Err = serialize(C, E))
+  static Error serialize(RPCByteChannel &C, StringRef S) {
+    if (auto Err = SerializationTraits<RPCByteChannel, uint64_t>::
+                     serialize(C, static_cast<uint64_t>(S.size())))
       return Err;
-
-  return Error::success();
-}
-
-/// RPC channel serialization for std::array<T>.
-template <typename T> Error serialize(RPCByteChannel &C,
-                                      const std::vector<T> &V) {
-  return serialize(C, ArrayRef<T>(V));
-}
-
-/// RPC channel deserialization for std::array<T>.
-template <typename T> Error deserialize(RPCByteChannel &C, std::vector<T> &V) {
-  uint64_t Count = 0;
-  if (auto Err = deserialize(C, Count))
-    return Err;
-
-  V.resize(Count);
-  for (auto &E : V)
-    if (auto Err = deserialize(C, E))
+    return C.appendBytes((const char *)S.bytes_begin(), S.size());
+  }
+
+  /// RPC channel serialization for std::strings.
+  static Error serialize(RPCByteChannel &C, const std::string &S) {
+    return serialize(C, StringRef(S));
+  }
+
+  /// RPC channel deserialization for std::strings.
+  static Error deserialize(RPCByteChannel &C, std::string &S) {
+    uint64_t Count = 0;
+    if (auto Err = SerializationTraits<RPCByteChannel, uint64_t>::
+                     deserialize(C, Count))
       return Err;
-
-  return Error::success();
-}
+    S.resize(Count);
+    return C.readBytes(&S[0], Count);
+  }
+};
 
 } // end namespace remote
 } // end namespace orc
diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
new file mode 100644 (file)
index 0000000..5ed6e45
--- /dev/null
@@ -0,0 +1,205 @@
+//===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
+#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
+
+#include "OrcError.h"
+#include <sstream>
+
+namespace llvm {
+namespace orc {
+namespace remote {
+
+template <typename ChannelT, typename T, typename = void>
+class SerializationTraits {};
+
+/// RPC channel serialization for a variadic list of arguments.
+template <typename ChannelT, typename T, typename... Ts>
+Error serializeSeq(ChannelT &C, const T &Arg, const Ts &... Args) {
+  if (auto Err = SerializationTraits<ChannelT, T>::serialize(C, Arg))
+    return Err;
+  return serializeSeq(C, Args...);
+}
+
+/// RPC channel serialization for an (empty) variadic list of arguments.
+template <typename ChannelT>
+Error serializeSeq(ChannelT &C) { return Error::success(); }
+
+/// RPC channel deserialization for a variadic list of arguments.
+template <typename ChannelT, typename T, typename... Ts>
+Error deserializeSeq(ChannelT &C, T &Arg, Ts &... Args) {
+  if (auto Err = SerializationTraits<ChannelT, T>::deserialize(C, Arg))
+    return Err;
+  return deserializeSeq(C, Args...);
+}
+
+/// RPC channel serialization for an (empty) variadic list of arguments.
+template <typename ChannelT>
+Error deserializeSeq(ChannelT &C) { return Error::success(); }
+
+template <typename ChannelT, typename... ArgTs>
+class TypeNameSequence {};
+
+template <typename OStream, typename ChannelT, typename ArgT>
+OStream& operator<<(OStream &OS, const TypeNameSequence<ChannelT, ArgT> &V) {
+  OS << SerializationTraits<ChannelT, ArgT>::getName();
+  return OS;
+}
+
+template <typename OStream, typename ChannelT, typename ArgT1,
+          typename ArgT2, typename... ArgTs>
+OStream&
+operator<<(OStream &OS,
+           const TypeNameSequence<ChannelT, ArgT1, ArgT2, ArgTs...> &V) {
+  OS << SerializationTraits<ChannelT, ArgT1>::getName() << ", "
+     << TypeNameSequence<ChannelT, ArgT2, ArgTs...>();
+  return OS;
+}
+
+/// Serialization for pairs.
+template <typename ChannelT, typename T1, typename T2>
+class SerializationTraits<ChannelT, std::pair<T1, T2>> {
+public:
+  static const char* getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      Name = (std::ostringstream()
+               << "std::pair<"
+               << TypeNameSequence<ChannelT, T1, T2>()
+               << ">").str();
+
+    return Name.data();
+  }
+
+  static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) {
+    return serializeSeq(C, V.first, V.second);
+  }
+
+  static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) {
+    return deserializeSeq(C, V.first, V.second);
+  }
+private:
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename ChannelT, typename T1, typename T2>
+std::mutex SerializationTraits<ChannelT, std::pair<T1, T2>>::NameMutex;
+
+template <typename ChannelT, typename T1, typename T2>
+std::string SerializationTraits<ChannelT, std::pair<T1, T2>>::Name;
+
+/// Serialization for tuples.
+template <typename ChannelT, typename... ArgTs>
+class SerializationTraits<ChannelT, std::tuple<ArgTs...>> {
+public:
+
+  static const char* getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      Name = (std::ostringstream()
+               << "std::tuple<"
+               << TypeNameSequence<ChannelT, ArgTs...>()
+               << ">").str();
+
+    return Name.data();
+  }
+
+  /// RPC channel serialization for std::tuple.
+  static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) {
+    return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
+  }
+
+  /// RPC channel deserialization for std::tuple.
+  static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) {
+    return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
+  }
+
+private:
+
+  // Serialization helper for std::tuple.
+  template <size_t... Is>
+  static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V,
+                                    llvm::index_sequence<Is...> _) {
+    return serializeSeq(C, std::get<Is>(V)...);
+  }
+
+  // Serialization helper for std::tuple.
+  template <size_t... Is>
+  static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V,
+                                      llvm::index_sequence<Is...> _) {
+    return deserializeSeq(C, std::get<Is>(V)...);
+  }
+
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename ChannelT, typename... ArgTs>
+std::mutex SerializationTraits<ChannelT, std::tuple<ArgTs...>>::NameMutex;
+
+template <typename ChannelT, typename... ArgTs>
+std::string SerializationTraits<ChannelT, std::tuple<ArgTs...>>::Name;
+
+template <typename ChannelT, typename T>
+class SerializationTraits<ChannelT, std::vector<T>> {
+public:
+
+  static const char* getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      Name = (std::ostringstream()
+                << "std::vector<" << TypeNameSequence<ChannelT, T>()
+                << ">").str();
+    return Name.data();
+  }
+
+  static Error serialize(ChannelT &C, const std::vector<T> &V) {
+    if (auto Err = SerializationTraits<ChannelT, uint64_t>::
+                     serialize(C, static_cast<uint64_t>(V.size())))
+      return Err;
+
+    for (const auto &E : V)
+      if (auto Err = SerializationTraits<ChannelT, T>::serialize(C, E))
+        return Err;
+
+    return Error::success();
+  }
+
+  static Error deserialize(ChannelT &C, std::vector<T> &V) {
+    uint64_t Count = 0;
+    if (auto Err = SerializationTraits<ChannelT, uint64_t>::
+                     deserialize(C, Count))
+      return Err;
+
+    V.resize(Count);
+    for (auto &E : V)
+      if (auto Err = SerializationTraits<ChannelT, T>::deserialize(C, E))
+        return Err;
+
+    return Error::success();
+  }
+
+private:
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename ChannelT, typename T>
+std::mutex SerializationTraits<ChannelT, std::vector<T>>::NameMutex;
+
+template <typename ChannelT, typename T>
+std::string SerializationTraits<ChannelT, std::vector<T>>::Name;
+
+} // end namespace remote
+} // end namespace orc
+} // end namespace llvm
+
+#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
index f061fff..e766a7f 100644 (file)
@@ -166,7 +166,7 @@ protected:
     template <typename ChannelT>
     static Error readResult(ChannelT &C, std::promise<PErrorReturn> &P) {
       RetT Val;
-      auto Err = deserialize(C, Val);
+      auto Err = deserializeSeq(C, Val);
       auto Err2 = endReceiveMessage(C);
       Err = joinErrors(std::move(Err), std::move(Err2));
       if (Err)
@@ -581,7 +581,7 @@ public:
     if (auto Err = startReceiveMessage(C))
       return Err;
 
-    return deserialize(C, Id);
+    return deserializeSeq(C, Id);
   }
 
   /// Deserialize args for Func from C and call Handler. The signature of
@@ -645,7 +645,7 @@ public:
   /// This should be called from the receive loop to retrieve results.
   Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) {
     SequenceNumberT SeqNo;
-    if (auto Err = deserialize(C, SeqNo)) {
+    if (auto Err = deserializeSeq(C, SeqNo)) {
       abandonOutstandingResults();
       return Err;
     }