1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
15 func TestDgemmParallel(t *testing.T) {
16 for i, test := range []struct {
57 m: blockSize * minParBlock,
66 n: blockSize * minParBlock,
75 k: blockSize * minParBlock,
81 m: blockSize*minParBlock + 1,
82 n: blockSize * minParBlock,
90 n: blockSize*minParBlock + 2,
97 m: blockSize * minParBlock,
99 k: blockSize * minParBlock,
105 m: blockSize * minParBlock,
106 n: blockSize * minParBlock,
113 m: blockSize + blockSize/2,
114 n: blockSize + blockSize/2,
115 k: blockSize + blockSize/2,
121 testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
122 testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
123 testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
124 testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
128 func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
133 if tA == blas.NoTrans {
140 if tB == blas.NoTrans {
147 a := randmat(rowA, colA, colA)
148 b := randmat(rowB, colB, colB)
149 c := randmat(m, n, n)
158 dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
159 dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
160 if !a.equal(aClone) {
161 t.Errorf("Case %v: a changed during call to dgemmParallel", i)
163 if !b.equal(bClone) {
164 t.Errorf("Case %v: b changed during call to dgemmParallel", i)
166 if !c.equalWithinAbs(cClone, 1e-12) {
167 t.Errorf("Case %v: answer not equal parallel and serial", i)
171 func randmat(r, c, stride int) general64 {
172 data := make([]float64, r*stride+c)
173 for i := range data {
174 data[i] = rand.Float64()