OSDN Git Service

[mlir][spirv] Allow mixed type cooperative matrix muladd
authorThomas Raoux <thomasraoux@google.com>
Thu, 18 Jun 2020 20:05:09 +0000 (13:05 -0700)
committerThomas Raoux <thomasraoux@google.com>
Thu, 18 Jun 2020 20:05:09 +0000 (13:05 -0700)
muladd can have differenti types for lhs/rhs and acc/destination. Change
verifier and update the test to use supported example.

Differential Revision: https://reviews.llvm.org/D82042

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

index efe6858..87456f0 100644 (file)
@@ -2753,8 +2753,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
       typeR.getScope() != typeB.getScope() ||
       typeR.getScope() != typeC.getScope())
     return op.emitOpError("matrix scope must match");
-  if (typeR.getElementType() != typeA.getElementType() ||
-      typeR.getElementType() != typeB.getElementType() ||
+  if (typeA.getElementType() != typeB.getElementType() ||
       typeR.getElementType() != typeC.getElementType())
     return op.emitOpError("matrix element type must match");
   return success();
index 51c7090..a2dafad 100644 (file)
@@ -37,9 +37,9 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
 }
 
 // CHECK-LABEL: @cooperative_matrix_muladd
-spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
-  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x32xi8, Subgroup>, %b : !spv.coopmatrix<32x8xi8, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   spv.Return
 }