OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / gonum / pardgemm_test.go
1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package gonum
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13 )
14
15 func TestDgemmParallel(t *testing.T) {
16         for i, test := range []struct {
17                 m     int
18                 n     int
19                 k     int
20                 alpha float64
21                 tA    blas.Transpose
22                 tB    blas.Transpose
23         }{
24                 {
25                         m:     3,
26                         n:     4,
27                         k:     2,
28                         alpha: 2.5,
29                         tA:    blas.NoTrans,
30                         tB:    blas.NoTrans,
31                 },
32                 {
33                         m:     blockSize*2 + 5,
34                         n:     3,
35                         k:     2,
36                         alpha: 2.5,
37                         tA:    blas.NoTrans,
38                         tB:    blas.NoTrans,
39                 },
40                 {
41                         m:     3,
42                         n:     blockSize * 2,
43                         k:     2,
44                         alpha: 2.5,
45                         tA:    blas.NoTrans,
46                         tB:    blas.NoTrans,
47                 },
48                 {
49                         m:     2,
50                         n:     3,
51                         k:     blockSize*3 - 2,
52                         alpha: 2.5,
53                         tA:    blas.NoTrans,
54                         tB:    blas.NoTrans,
55                 },
56                 {
57                         m:     blockSize * minParBlock,
58                         n:     3,
59                         k:     2,
60                         alpha: 2.5,
61                         tA:    blas.NoTrans,
62                         tB:    blas.NoTrans,
63                 },
64                 {
65                         m:     3,
66                         n:     blockSize * minParBlock,
67                         k:     2,
68                         alpha: 2.5,
69                         tA:    blas.NoTrans,
70                         tB:    blas.NoTrans,
71                 },
72                 {
73                         m:     2,
74                         n:     3,
75                         k:     blockSize * minParBlock,
76                         alpha: 2.5,
77                         tA:    blas.NoTrans,
78                         tB:    blas.NoTrans,
79                 },
80                 {
81                         m:     blockSize*minParBlock + 1,
82                         n:     blockSize * minParBlock,
83                         k:     3,
84                         alpha: 2.5,
85                         tA:    blas.NoTrans,
86                         tB:    blas.NoTrans,
87                 },
88                 {
89                         m:     3,
90                         n:     blockSize*minParBlock + 2,
91                         k:     blockSize * 3,
92                         alpha: 2.5,
93                         tA:    blas.NoTrans,
94                         tB:    blas.NoTrans,
95                 },
96                 {
97                         m:     blockSize * minParBlock,
98                         n:     3,
99                         k:     blockSize * minParBlock,
100                         alpha: 2.5,
101                         tA:    blas.NoTrans,
102                         tB:    blas.NoTrans,
103                 },
104                 {
105                         m:     blockSize * minParBlock,
106                         n:     blockSize * minParBlock,
107                         k:     blockSize * 3,
108                         alpha: 2.5,
109                         tA:    blas.NoTrans,
110                         tB:    blas.NoTrans,
111                 },
112                 {
113                         m:     blockSize + blockSize/2,
114                         n:     blockSize + blockSize/2,
115                         k:     blockSize + blockSize/2,
116                         alpha: 2.5,
117                         tA:    blas.NoTrans,
118                         tB:    blas.NoTrans,
119                 },
120         } {
121                 testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
122                 testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
123                 testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
124                 testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
125         }
126 }
127
128 func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
129         var (
130                 rowA, colA int
131                 rowB, colB int
132         )
133         if tA == blas.NoTrans {
134                 rowA = m
135                 colA = k
136         } else {
137                 rowA = k
138                 colA = m
139         }
140         if tB == blas.NoTrans {
141                 rowB = k
142                 colB = n
143         } else {
144                 rowB = n
145                 colB = k
146         }
147         a := randmat(rowA, colA, colA)
148         b := randmat(rowB, colB, colB)
149         c := randmat(m, n, n)
150
151         aClone := a.clone()
152         bClone := b.clone()
153         cClone := c.clone()
154
155         lda := colA
156         ldb := colB
157         ldc := n
158         dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
159         dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
160         if !a.equal(aClone) {
161                 t.Errorf("Case %v: a changed during call to dgemmParallel", i)
162         }
163         if !b.equal(bClone) {
164                 t.Errorf("Case %v: b changed during call to dgemmParallel", i)
165         }
166         if !c.equalWithinAbs(cClone, 1e-12) {
167                 t.Errorf("Case %v: answer not equal parallel and serial", i)
168         }
169 }
170
171 func randmat(r, c, stride int) general64 {
172         data := make([]float64, r*stride+c)
173         for i := range data {
174                 data[i] = rand.Float64()
175         }
176         return general64{
177                 data:   data,
178                 rows:   r,
179                 cols:   c,
180                 stride: stride,
181         }
182 }