+++ /dev/null
-// Copyright ©2014 The Gonum Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package gonum
-
-import (
- "testing"
-
- "golang.org/x/exp/rand"
-
- "gonum.org/v1/gonum/blas"
-)
-
-func TestDgemmParallel(t *testing.T) {
- for i, test := range []struct {
- m int
- n int
- k int
- alpha float64
- tA blas.Transpose
- tB blas.Transpose
- }{
- {
- m: 3,
- n: 4,
- k: 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize*2 + 5,
- n: 3,
- k: 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: 3,
- n: blockSize * 2,
- k: 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: 2,
- n: 3,
- k: blockSize*3 - 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize * minParBlock,
- n: 3,
- k: 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: 3,
- n: blockSize * minParBlock,
- k: 2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: 2,
- n: 3,
- k: blockSize * minParBlock,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize*minParBlock + 1,
- n: blockSize * minParBlock,
- k: 3,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: 3,
- n: blockSize*minParBlock + 2,
- k: blockSize * 3,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize * minParBlock,
- n: 3,
- k: blockSize * minParBlock,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize * minParBlock,
- n: blockSize * minParBlock,
- k: blockSize * 3,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- {
- m: blockSize + blockSize/2,
- n: blockSize + blockSize/2,
- k: blockSize + blockSize/2,
- alpha: 2.5,
- tA: blas.NoTrans,
- tB: blas.NoTrans,
- },
- } {
- testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
- testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
- testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
- testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
- }
-}
-
-func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
- var (
- rowA, colA int
- rowB, colB int
- )
- if tA == blas.NoTrans {
- rowA = m
- colA = k
- } else {
- rowA = k
- colA = m
- }
- if tB == blas.NoTrans {
- rowB = k
- colB = n
- } else {
- rowB = n
- colB = k
- }
- a := randmat(rowA, colA, colA)
- b := randmat(rowB, colB, colB)
- c := randmat(m, n, n)
-
- aClone := a.clone()
- bClone := b.clone()
- cClone := c.clone()
-
- lda := colA
- ldb := colB
- ldc := n
- dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
- dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
- if !a.equal(aClone) {
- t.Errorf("Case %v: a changed during call to dgemmParallel", i)
- }
- if !b.equal(bClone) {
- t.Errorf("Case %v: b changed during call to dgemmParallel", i)
- }
- if !c.equalWithinAbs(cClone, 1e-12) {
- t.Errorf("Case %v: answer not equal parallel and serial", i)
- }
-}
-
-func randmat(r, c, stride int) general64 {
- data := make([]float64, r*stride+c)
- for i := range data {
- data[i] = rand.Float64()
- }
- return general64{
- data: data,
- rows: r,
- cols: c,
- stride: stride,
- }
-}