OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / gonum / pardgemm_test.go
diff --git a/vendor/gonum.org/v1/gonum/blas/gonum/pardgemm_test.go b/vendor/gonum.org/v1/gonum/blas/gonum/pardgemm_test.go
new file mode 100644 (file)
index 0000000..cc03bb3
--- /dev/null
@@ -0,0 +1,182 @@
+// Copyright ©2014 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gonum
+
+import (
+       "testing"
+
+       "golang.org/x/exp/rand"
+
+       "gonum.org/v1/gonum/blas"
+)
+
+func TestDgemmParallel(t *testing.T) {
+       for i, test := range []struct {
+               m     int
+               n     int
+               k     int
+               alpha float64
+               tA    blas.Transpose
+               tB    blas.Transpose
+       }{
+               {
+                       m:     3,
+                       n:     4,
+                       k:     2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize*2 + 5,
+                       n:     3,
+                       k:     2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     3,
+                       n:     blockSize * 2,
+                       k:     2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     2,
+                       n:     3,
+                       k:     blockSize*3 - 2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize * minParBlock,
+                       n:     3,
+                       k:     2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     3,
+                       n:     blockSize * minParBlock,
+                       k:     2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     2,
+                       n:     3,
+                       k:     blockSize * minParBlock,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize*minParBlock + 1,
+                       n:     blockSize * minParBlock,
+                       k:     3,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     3,
+                       n:     blockSize*minParBlock + 2,
+                       k:     blockSize * 3,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize * minParBlock,
+                       n:     3,
+                       k:     blockSize * minParBlock,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize * minParBlock,
+                       n:     blockSize * minParBlock,
+                       k:     blockSize * 3,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+               {
+                       m:     blockSize + blockSize/2,
+                       n:     blockSize + blockSize/2,
+                       k:     blockSize + blockSize/2,
+                       alpha: 2.5,
+                       tA:    blas.NoTrans,
+                       tB:    blas.NoTrans,
+               },
+       } {
+               testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
+               testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
+               testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
+               testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
+       }
+}
+
+func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
+       var (
+               rowA, colA int
+               rowB, colB int
+       )
+       if tA == blas.NoTrans {
+               rowA = m
+               colA = k
+       } else {
+               rowA = k
+               colA = m
+       }
+       if tB == blas.NoTrans {
+               rowB = k
+               colB = n
+       } else {
+               rowB = n
+               colB = k
+       }
+       a := randmat(rowA, colA, colA)
+       b := randmat(rowB, colB, colB)
+       c := randmat(m, n, n)
+
+       aClone := a.clone()
+       bClone := b.clone()
+       cClone := c.clone()
+
+       lda := colA
+       ldb := colB
+       ldc := n
+       dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
+       dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
+       if !a.equal(aClone) {
+               t.Errorf("Case %v: a changed during call to dgemmParallel", i)
+       }
+       if !b.equal(bClone) {
+               t.Errorf("Case %v: b changed during call to dgemmParallel", i)
+       }
+       if !c.equalWithinAbs(cClone, 1e-12) {
+               t.Errorf("Case %v: answer not equal parallel and serial", i)
+       }
+}
+
+func randmat(r, c, stride int) general64 {
+       data := make([]float64, r*stride+c)
+       for i := range data {
+               data[i] = rand.Float64()
+       }
+       return general64{
+               data:   data,
+               rows:   r,
+               cols:   c,
+               stride: stride,
+       }
+}