OSDN Git Service

[RenderScript] fixes for L3 BLAS APIs
authorMiao Wang <miaowang@google.com>
Fri, 24 Apr 2015 18:19:53 +0000 (11:19 -0700)
committerMiao Wang <miaowang@google.com>
Thu, 7 May 2015 22:41:37 +0000 (15:41 -0700)
  - Typo for validateUplo
  - Typo in ZHEMM, element should be FLOAT64_2.
  - For GEMM and SYMM, SYRK, 'CONJ_TRANSPOSE' should also be handled in the
    validation process.
  - For SYMM, check matrix A is symmetric.
  - For HERK, the dimension validation was switched for Transpose case.
    Also, only Conj Trans is allowed in this case.
  - FOR SYR2K, fix the dimension check for Matrix C.
  - For TRMM & TRSM, fix the validation part for dimension check.

Change-Id: I559b5c5695aa82604de2955ae2327b694236d3ed

rs/java/android/renderscript/ScriptIntrinsicBLAS.java

index 7af61ac..65818b1 100644 (file)
@@ -242,7 +242,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
     }
 
     static void validateUplo(@Uplo int Uplo) {
-        if (Uplo != LEFT && Uplo != RIGHT) {
+        if (Uplo != UPPER && Uplo != LOWER) {
             throw new RSRuntimeException("Invalid uplo passed to BLAS");
         }
     }
@@ -986,56 +986,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
      */
 
     static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
-        int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
+        int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
         if ((A != null && !A.getType().getElement().isCompatible(e)) ||
             (B != null && !B.getType().getElement().isCompatible(e)) ||
             (C != null && !C.getType().getElement().isCompatible(e))) {
             throw new RSRuntimeException("Called BLAS with wrong Element type");
         }
-        if (C != null) {
-            cX = C.getType().getY();
-            cY = C.getType().getX();
+        if (C == null) {
+            //since matrix C is used to store the result, it cannot be null.
+            throw new RSRuntimeException("Allocation C cannot be null");
         }
+        cM = C.getType().getY();
+        cN = C.getType().getX();
+
         if (Side == RIGHT) {
+            if ((A == null && B != null) || (A != null && B == null)) {
+                throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
+            }
             if (B != null) {
-                bX = A.getType().getY();
-                bY = A.getType().getX();
+                bM = A.getType().getY();
+                bN = A.getType().getX();
             }
             if (A != null) {
-                aX = B.getType().getY();
-                aY = B.getType().getX();
+                aM = B.getType().getY();
+                aN = B.getType().getX();
             }
         } else {
             if (A != null) {
-                if (TransA == TRANSPOSE) {
-                    aY = A.getType().getY();
-                    aX = A.getType().getX();
+                if (TransA != NO_TRANSPOSE) {
+                    aN = A.getType().getY();
+                    aM = A.getType().getX();
                 } else {
-                    aX = A.getType().getY();
-                    aY = A.getType().getX();
+                    aM = A.getType().getY();
+                    aN = A.getType().getX();
                 }
             }
             if (B != null) {
-                if (TransB == TRANSPOSE) {
-                    bY = B.getType().getY();
-                    bX = B.getType().getX();
+                if (TransB != NO_TRANSPOSE) {
+                    bN = B.getType().getY();
+                    bM = B.getType().getX();
                 } else {
-                    bX = B.getType().getY();
-                    bY = B.getType().getX();
+                    bM = B.getType().getY();
+                    bN = B.getType().getX();
                 }
             }
         }
         if (A != null && B != null && C != null) {
-            if (aY != bX || aX != cX || bY != cY) {
+            if (aN != bM || aM != cM || bN != cN) {
                 throw new RSRuntimeException("Called BLAS with invalid dimensions");
             }
         } else if (A != null && C != null) {
-            // A and C only
-            if (aX != cY || aY != cX) {
-                throw new RSRuntimeException("Called BLAS with invalid dimensions");
+            // A and C only, for SYRK
+            if (cM != cN) {
+                throw new RSRuntimeException("Matrix C is not symmetric");
+            }
+            if (TransA != NO_TRANSPOSE) {
+                if (aN != cM) {
+                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
+                }
+            } else {
+                if (aM != cM) {
+                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
+                }
             }
         } else if (A != null && B != null) {
             // A and B only
+            if (aN != bM) {
+                throw new RSRuntimeException("Called BLAS with invalid dimensions");
+            }
         }
 
     }
@@ -1047,14 +1065,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
 
         int M = -1, N = -1, K = -1;
-        if (TransA == TRANSPOSE) {
+        if (TransA != NO_TRANSPOSE) {
             M = A.getType().getX();
             K = A.getType().getY();
         } else {
             M = A.getType().getY();
             K = A.getType().getX();
         }
-        if (TransB == TRANSPOSE) {
+        if (TransB != NO_TRANSPOSE) {
             N = B.getType().getY();
         } else {
             N = B.getType().getX();
@@ -1068,14 +1086,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateTranspose(TransB);
         validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
         int M = -1, N = -1, K = -1;
-        if (TransA == TRANSPOSE) {
+        if (TransA != NO_TRANSPOSE) {
             M = A.getType().getX();
             K = A.getType().getY();
         } else {
             M = A.getType().getY();
             K = A.getType().getX();
         }
-        if (TransB == TRANSPOSE) {
+        if (TransB != NO_TRANSPOSE) {
             N = B.getType().getY();
         } else {
             N = B.getType().getX();
@@ -1089,14 +1107,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateTranspose(TransB);
         validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
         int M = -1, N = -1, K = -1;
-        if (TransA == TRANSPOSE) {
+        if (TransA != NO_TRANSPOSE) {
             M = A.getType().getX();
             K = A.getType().getY();
         } else {
             M = A.getType().getY();
             K = A.getType().getX();
         }
-        if (TransB == TRANSPOSE) {
+        if (TransB != NO_TRANSPOSE) {
             N = B.getType().getY();
         } else {
             N = B.getType().getX();
@@ -1111,14 +1129,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateTranspose(TransB);
         validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
         int M = -1, N = -1, K = -1;
-        if (TransA == TRANSPOSE) {
+        if (TransA != NO_TRANSPOSE) {
             M = A.getType().getX();
             K = A.getType().getY();
         } else {
             M = A.getType().getY();
             K = A.getType().getX();
         }
-        if (TransB == TRANSPOSE) {
+        if (TransB != NO_TRANSPOSE) {
             N = B.getType().getY();
         } else {
             N = B.getType().getX();
@@ -1131,6 +1149,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                       Allocation B, float beta, Allocation C) {
         validateSide(Side);
         validateUplo(Uplo);
+        //For SYMM, Matrix A should be symmetric
+        if (A.getType().getX() != A.getType().getY()) {
+            throw new RSRuntimeException("Matrix A is not symmetric");
+        }
         validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
         mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
                                         beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1139,6 +1161,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                       Allocation B, double beta, Allocation C) {
         validateSide(Side);
         validateUplo(Uplo);
+        if (A.getType().getX() != A.getType().getY()) {
+            throw new RSRuntimeException("Matrix A is not symmetric");
+        }
         validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
         mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
                                         beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1147,6 +1172,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                       Allocation B, Float2 beta, Allocation C) {
         validateSide(Side);
         validateUplo(Uplo);
+        if (A.getType().getX() != A.getType().getY()) {
+            throw new RSRuntimeException("Matrix A is not symmetric");
+        }
         validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
         mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
                                          beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1155,6 +1183,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                       Allocation B, Double2 beta, Allocation C) {
         validateSide(Side);
         validateUplo(Uplo);
+        if (A.getType().getX() != A.getType().getY()) {
+            throw new RSRuntimeException("Matrix A is not symmetric");
+        }
         validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
         mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
                                    beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1165,7 +1196,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
         int K = -1;
-        if (Trans == TRANSPOSE) {
+        if (Trans != NO_TRANSPOSE) {
             K = A.getType().getY();
         } else {
             K = A.getType().getX();
@@ -1179,7 +1210,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
         int K = -1;
-        if (Trans == TRANSPOSE) {
+        if (Trans != NO_TRANSPOSE) {
             K = A.getType().getY();
         } else {
             K = A.getType().getX();
@@ -1191,7 +1222,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
         int K = -1;
-        if (Trans == TRANSPOSE) {
+        if (Trans != NO_TRANSPOSE) {
             K = A.getType().getY();
         } else {
             K = A.getType().getX();
@@ -1204,7 +1235,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
         int K = -1;
-        if (Trans == TRANSPOSE) {
+        if (Trans != NO_TRANSPOSE) {
             K = A.getType().getY();
         } else {
             K = A.getType().getX();
@@ -1230,7 +1261,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
             // check rows versus C
             Cdim = A.getType().getY();
         }
-        if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
+        if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
             throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
         }
         // A dims == B dims
@@ -1286,26 +1317,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
     static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
         validateSide(Side);
         validateTranspose(TransA);
-        int aX = -1, aY = -1, bX = -1, bY = -1;
+        int aM = -1, aN = -1, bM = -1, bN = -1;
         if (!A.getType().getElement().isCompatible(e) ||
             !B.getType().getElement().isCompatible(e)) {
             throw new RSRuntimeException("Called BLAS with wrong Element type");
         }
-        if (TransA == TRANSPOSE) {
-            aY = A.getType().getY();
-            aX = A.getType().getX();
-        } else {
-            aY = A.getType().getX();
-            aX = A.getType().getY();
+
+        aM = A.getType().getY();
+        aN = A.getType().getX();
+        if (aM != aN) {
+            throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
         }
-        bX = B.getType().getY();
-        bY = B.getType().getX();
+
+        bM = B.getType().getY();
+        bN = B.getType().getX();
         if (Side == LEFT) {
-            if (aX == 0 || aY != bX) {
+            if (aN != bM) {
                 throw new RSRuntimeException("Called TRMM with invalid matrices");
             }
         } else {
-            if (bY != aX || aY == 0) {
+            if (bN != aM) {
                 throw new RSRuntimeException("Called TRMM with invalid matrices");
             }
         }
@@ -1340,7 +1371,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
     }
 
     static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
-        int adim = -1, bX = -1, bY = -1;
+        int adim = -1, bM = -1, bN = -1;
         validateSide(Side);
         validateTranspose(TransA);
         if (!A.getType().getElement().isCompatible(e) ||
@@ -1354,16 +1385,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
             // for now we assume adapters are sufficient, will reevaluate in the future
             throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
         }
-        bX = B.getType().getY();
-        bY = B.getType().getX();
+        bM = B.getType().getY();
+        bN = B.getType().getX();
         if (Side == LEFT) {
             // A is M*M
-            if (adim != bY) {
+            if (adim != bM) {
                 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
             }
         } else {
             // A is N*N
-            if (adim != bX) {
+            if (adim != bN) {
                 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
             }
         }
@@ -1428,7 +1459,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
     }
     public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
         validateUplo(Uplo);
-        validateHEMM(Element.F32_2(mRS), Side, A, B, C);
+        validateHEMM(Element.F64_2(mRS), Side, A, B, C);
         mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
                                    alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
     }
@@ -1444,11 +1475,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
             throw new RSRuntimeException("Called HERK with non-square C");
         }
         if (Trans == NO_TRANSPOSE) {
-            if (cdim != A.getType().getX()) {
+            if (cdim != A.getType().getY()) {
                 throw new RSRuntimeException("Called HERK with invalid A");
             }
         } else {
-            if (cdim != A.getType().getY()) {
+            if (cdim != A.getType().getX()) {
                 throw new RSRuntimeException("Called HERK with invalid A");
             }
         }
@@ -1457,7 +1488,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateHERK(Element.F32_2(mRS), Trans, A, C);
         int k = 0;
-        if (Trans == TRANSPOSE) {
+        if (Trans == CONJ_TRANSPOSE) {
             k = A.getType().getY();
         } else {
             k = A.getType().getX();
@@ -1469,7 +1500,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
         validateUplo(Uplo);
         validateHERK(Element.F64_2(mRS), Trans, A, C);
         int k = 0;
-        if (Trans == TRANSPOSE) {
+        if (Trans == CONJ_TRANSPOSE) {
             k = A.getType().getY();
         } else {
             k = A.getType().getX();