LogicalResult deserialize();
/// Collects the final SPIR-V ModuleOp.
- Optional<spirv::ModuleOp> collect();
+ spirv::OwningSPIRVModuleRef collect();
private:
//===--------------------------------------------------------------------===//
//===--------------------------------------------------------------------===//
/// Initializes the `module` ModuleOp in this deserializer instance.
- spirv::ModuleOp createModuleOp();
+ spirv::OwningSPIRVModuleRef createModuleOp();
/// Processes SPIR-V module header in `binary`.
LogicalResult processHeader();
Location unknownLoc;
/// The SPIR-V ModuleOp.
- Optional<spirv::ModuleOp> module;
+ spirv::OwningSPIRVModuleRef module;
/// The current function under construction.
Optional<spirv::FuncOp> curFunction;
return success();
}
-Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
+spirv::OwningSPIRVModuleRef Deserializer::collect() {
+ return std::move(module);
+}
//===----------------------------------------------------------------------===//
// Module structure
//===----------------------------------------------------------------------===//
-spirv::ModuleOp Deserializer::createModuleOp() {
+spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
OpBuilder builder(context);
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
spirv::ModuleOp::build(builder, state);
// Go through all ops and remap the operands.
auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands())
- if (auto mappedOp = mapper.lookupOrNull(operand.get()))
+ if (Value mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp);
for (auto &succOp : op->getBlockOperands())
- if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
+ if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp);
};
for (auto &block : body) {
return emitError(unknownLoc,
"missing Execution Model specification in OpEntryPoint");
}
- auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+ auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
if (wordIndex >= words.size()) {
return emitError(unknownLoc, "missing <id> in OpEntryPoint");
}
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++;
}
- opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
opBuilder.getSymbolRefAttr(fnName),
opBuilder.getArrayAttr(interface));
return success();
if (failed(deserializer.deserialize()))
return nullptr;
- return deserializer.collect().getValueOr(nullptr);
+ return deserializer.collect();
}
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
}
Type getFloatStructType() {
- OpBuilder opBuilder(module.body());
+ OpBuilder opBuilder(module->body());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
}
void addGlobalVar(Type type, llvm::StringRef name) {
- OpBuilder opBuilder(module.body());
+ OpBuilder opBuilder(module->body());
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
opBuilder.create<spirv::GlobalVariableOp>(
UnknownLoc::get(&context), TypeAttr::get(ptrType),
protected:
MLIRContext context;
- spirv::ModuleOp module;
+ spirv::OwningSPIRVModuleRef module;
SmallVector<uint32_t, 0> binary;
};
TEST_F(SerializationTest, BlockDecorationTest) {
auto structType = getFloatStructType();
addGlobalVar(structType, "var0");
- ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
auto hasBlockDecoration = [](spirv::Opcode opcode,
ArrayRef<uint32_t> operands) -> bool {
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)