// a function handler. See addHandlerImpl.
using LaunchPolicy = std::function<Error(std::function<Error()>)>;
+ FunctionIdT getInvalidFunctionId() const {
+ return FnIdAllocator.getInvalidId();
+ }
+
/// Add the given handler to the handler map and make it available for
/// autonegotiation and execution.
template <typename Func, typename HandlerT>
FunctionIdT handleNegotiate(const std::string &Name) {
auto I = LocalFunctionIds.find(Name);
if (I == LocalFunctionIds.end())
- return FnIdAllocator.getInvalidId();
+ return getInvalidFunctionId();
return I->second;
}
// If autonegotiation indicates that the remote end doesn't support this
// function, return an unknown function error.
- if (RemoteId == FnIdAllocator.getInvalidId())
+ if (RemoteId == getInvalidFunctionId())
return orcError(OrcErrorCode::UnknownRPCFunction);
// Autonegotiation succeeded and returned a valid id. Update the map and
}
/// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction() {
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
+ // Check if we already have a function id...
+ auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+ if (I != this->RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != this->getInvalidFunctionId())
+ return Error::success();
+ // If it's invalid and we can't re-attempt negotiation, throw an error.
+ if (!Retry)
+ return orcError(OrcErrorCode::UnknownRPCFunction);
+ }
+
+ // We don't have a function id for Func yet, call the remote to try to
+ // negotiate one.
if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == this->getInvalidFunctionId())
+ return orcError(OrcErrorCode::UnknownRPCFunction);
return Error::success();
} else
return RemoteIdOrErr.takeError();
}
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func> Error negotiateFunctions() {
- return negotiateFunction<Func>();
- }
-
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func1, typename Func2, typename... Funcs>
- Error negotiateFunctions() {
- if (auto Err = negotiateFunction<Func1>())
- return Err;
- return negotiateFunctions<Func2, Funcs...>();
- }
-
/// Return type for non-blocking call primitives.
template <typename Func>
using NonBlockingCallResult = typename detail::ResultTraits<
}
/// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction() {
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
+ // Check if we already have a function id...
+ auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+ if (I != this->RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != this->getInvalidFunctionId())
+ return Error::success();
+ // If it's invalid and we can't re-attempt negotiation, throw an error.
+ if (!Retry)
+ return orcError(OrcErrorCode::UnknownRPCFunction);
+ }
+
+ // We don't have a function id for Func yet, call the remote to try to
+ // negotiate one.
if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == this->getInvalidFunctionId())
+ return orcError(OrcErrorCode::UnknownRPCFunction);
return Error::success();
} else
return RemoteIdOrErr.takeError();
}
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func> Error negotiateFunctions() {
- return negotiateFunction<Func>();
- }
-
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func1, typename Func2, typename... Funcs>
- Error negotiateFunctions() {
- if (auto Err = negotiateFunction<Func1>())
- return Err;
- return negotiateFunctions<Func2, Funcs...>();
- }
-
template <typename Func, typename... ArgTs,
typename AltRetT = typename Func::ReturnType>
typename detail::ResultTraits<AltRetT>::ErrorReturnType
uint32_t NumOutstandingCalls;
};
+/// @brief Convenience class for grouping RPC Functions into APIs that can be
+/// negotiated as a block.
+///
+template <typename... Funcs>
+class APICalls {
+public:
+
+ /// @brief Test whether this API contains Function F.
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value = false;
+ };
+
+ /// @brief Negotiate all functions in this API.
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ return Error::success();
+ }
+};
+
+template <typename Func, typename... Funcs>
+class APICalls<Func, Funcs...> {
+public:
+
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value = std::is_same<F, Func>::value |
+ APICalls<Funcs...>::template Contains<F>::value;
+ };
+
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ if (auto Err = R.template negotiateFunction<Func>())
+ return Err;
+ return APICalls<Funcs...>::negotiate(R);
+ }
+
+};
+
+template <typename... InnerFuncs, typename... Funcs>
+class APICalls<APICalls<InnerFuncs...>, Funcs...> {
+public:
+
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value =
+ APICalls<InnerFuncs...>::template Contains<F>::value |
+ APICalls<Funcs...>::template Contains<F>::value;
+ };
+
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
+ return Err;
+ return APICalls<Funcs...>::negotiate(R);
+ }
+
+};
+
} // end namespace rpc
} // end namespace orc
} // end namespace llvm
} // end namespace orc
} // end namespace llvm
-class DummyRPCAPI {
-public:
+namespace DummyRPCAPI {
class VoidBool : public Function<VoidBool, void(bool)> {
public:
ServerThread.join();
}
+
+TEST(DummyRPC, TestAPICalls) {
+
+ using DummyCalls1 = APICalls<DummyRPCAPI::VoidBool, DummyRPCAPI::IntInt>;
+ using DummyCalls2 = APICalls<DummyRPCAPI::AllTheTypes>;
+ using DummyCalls3 = APICalls<DummyCalls1, DummyRPCAPI::CustomType>;
+ using DummyCallsAll = APICalls<DummyCalls1, DummyCalls2, DummyRPCAPI::CustomType>;
+
+ static_assert(DummyCalls1::Contains<DummyRPCAPI::VoidBool>::value,
+ "Contains<Func> template should return true here");
+ static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value,
+ "Contains<Func> template should return false here");
+
+ Queue Q1, Q2;
+ DummyRPCEndpoint Client(Q1, Q2);
+ DummyRPCEndpoint Server(Q2, Q1);
+
+ std::thread ServerThread(
+ [&]() {
+ Server.addHandler<DummyRPCAPI::VoidBool>([](bool b) { });
+ Server.addHandler<DummyRPCAPI::IntInt>([](int x) { return x; });
+ Server.addHandler<DummyRPCAPI::CustomType>([](RPCFoo F) {});
+
+ for (unsigned I = 0; I < 4; ++I) {
+ auto Err = Server.handleOne();
+ (void)!!Err;
+ }
+ });
+
+ {
+ auto Err = DummyCalls1::negotiate(Client);
+ EXPECT_FALSE(!!Err) << "DummyCalls1::negotiate failed";
+ }
+
+ {
+ auto Err = DummyCalls3::negotiate(Client);
+ EXPECT_FALSE(!!Err) << "DummyCalls3::negotiate failed";
+ }
+
+ {
+ auto Err = DummyCallsAll::negotiate(Client);
+ EXPECT_EQ(errorToErrorCode(std::move(Err)).value(),
+ static_cast<int>(OrcErrorCode::UnknownRPCFunction))
+ << "Uxpected 'UnknownRPCFunction' error for attempted negotiate of "
+ "unsupported function";
+ }
+
+ ServerThread.join();
+}