OSDN Git Service

Harden MLIR detection of misconfiguration when missing dialect registration
authorMehdi Amini <joker.eph@gmail.com>
Thu, 28 May 2020 08:08:20 +0000 (08:08 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 28 May 2020 08:14:49 +0000 (08:14 +0000)
This changes will catch error where C++ op are used without being
registered, either through creation with the OpBuilder or when trying to
cast to the C++ op.

Differential Revision: https://reviews.llvm.org/D80651

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/MLIRContext.cpp

index 424eb98..0dcf4da 100644 (file)
@@ -374,6 +374,10 @@ public:
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args &&... args) {
     OperationState state(location, OpTy::getOperationName());
+    if (!state.name.getAbstractOperation())
+      llvm::report_fatal_error("Building op `" +
+                               state.name.getStringRef().str() +
+                               "` but it isn't registered in this MLIRContext");
     OpTy::build(*this, state, std::forward<Args>(args)...);
     auto *op = createOperation(state);
     auto result = dyn_cast<OpTy>(op);
@@ -390,6 +394,10 @@ public:
     // Create the operation without using 'createOperation' as we don't want to
     // insert it yet.
     OperationState state(location, OpTy::getOperationName());
+    if (!state.name.getAbstractOperation())
+      llvm::report_fatal_error("Building op `" +
+                               state.name.getStringRef().str() +
+                               "` but it isn't registered in this MLIRContext");
     OpTy::build(*this, state, std::forward<Args>(args)...);
     Operation *op = Operation::create(state);
 
index da0b0bd..8e75bb6 100644 (file)
@@ -85,6 +85,9 @@ public:
   /// directly.
   std::vector<AbstractOperation *> getRegisteredOperations();
 
+  /// Return true if this operation name is registered in this context.
+  bool isOperationRegistered(StringRef name);
+
   // This is effectively private given that only MLIRContext.cpp can see the
   // MLIRContextImpl type.
   MLIRContextImpl &getImpl() { return *impl; }
index bf5bd70..e92d54e 100644 (file)
@@ -1235,7 +1235,10 @@ public:
   static bool classof(Operation *op) {
     if (auto *abstractOp = op->getAbstractOperation())
       return TypeID::get<ConcreteType>() == abstractOp->typeID;
-    return op->getName().getStringRef() == ConcreteType::getOperationName();
+    assert(op->getContext()->isOperationRegistered(
+               ConcreteType::getOperationName()) &&
+           "Casting attempt to an unregistered operation");
+    return false;
   }
 
   /// This is the hook used by the AsmParser to parse the custom form of this
index 0728f29..da607a2 100644 (file)
@@ -543,6 +543,13 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
   return result;
 }
 
+bool MLIRContext::isOperationRegistered(StringRef name) {
+  // Lock access to the context registry.
+  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
+
+  return impl->registeredOperations.count(name);
+}
+
 void Dialect::addOperation(AbstractOperation opInfo) {
   assert((getNamespace().empty() ||
           opInfo.name.split('.').first == getNamespace()) &&
@@ -621,8 +628,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
 static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) {
   auto &impl = ctx->getImpl();
   auto it = impl.registeredDialectSymbols.find(typeID);
-  assert(it != impl.registeredDialectSymbols.end() &&
-         "symbol is not registered.");
+  if (it == impl.registeredDialectSymbols.end())
+    llvm::report_fatal_error(
+        "Trying to create a type that was not registered in this MLIRContext.");
   return *it->second;
 }