OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / testblas / dgemm.go
1 package testblas
2
3 import (
4         "testing"
5
6         "gonum.org/v1/gonum/blas"
7 )
8
9 type Dgemmer interface {
10         Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
11 }
12
13 type DgemmCase struct {
14         m, n, k     int
15         alpha, beta float64
16         a           [][]float64
17         b           [][]float64
18         c           [][]float64
19         ans         [][]float64
20 }
21
22 var DgemmCases = []DgemmCase{
23
24         {
25                 m:     4,
26                 n:     3,
27                 k:     2,
28                 alpha: 2,
29                 beta:  0.5,
30                 a: [][]float64{
31                         {1, 2},
32                         {4, 5},
33                         {7, 8},
34                         {10, 11},
35                 },
36                 b: [][]float64{
37                         {1, 5, 6},
38                         {5, -8, 8},
39                 },
40                 c: [][]float64{
41                         {4, 8, -9},
42                         {12, 16, -8},
43                         {1, 5, 15},
44                         {-3, -4, 7},
45                 },
46                 ans: [][]float64{
47                         {24, -18, 39.5},
48                         {64, -32, 124},
49                         {94.5, -55.5, 219.5},
50                         {128.5, -78, 299.5},
51                 },
52         },
53         {
54                 m:     4,
55                 n:     2,
56                 k:     3,
57                 alpha: 2,
58                 beta:  0.5,
59                 a: [][]float64{
60                         {1, 2, 3},
61                         {4, 5, 6},
62                         {7, 8, 9},
63                         {10, 11, 12},
64                 },
65                 b: [][]float64{
66                         {1, 5},
67                         {5, -8},
68                         {6, 2},
69                 },
70                 c: [][]float64{
71                         {4, 8},
72                         {12, 16},
73                         {1, 5},
74                         {-3, -4},
75                 },
76                 ans: [][]float64{
77                         {60, -6},
78                         {136, -8},
79                         {202.5, -19.5},
80                         {272.5, -30},
81                 },
82         },
83         {
84                 m:     3,
85                 n:     2,
86                 k:     4,
87                 alpha: 2,
88                 beta:  0.5,
89                 a: [][]float64{
90                         {1, 2, 3, 4},
91                         {4, 5, 6, 7},
92                         {8, 9, 10, 11},
93                 },
94                 b: [][]float64{
95                         {1, 5},
96                         {5, -8},
97                         {6, 2},
98                         {8, 10},
99                 },
100                 c: [][]float64{
101                         {4, 8},
102                         {12, 16},
103                         {9, -10},
104                 },
105                 ans: [][]float64{
106                         {124, 74},
107                         {248, 132},
108                         {406.5, 191},
109                 },
110         },
111         {
112                 m:     3,
113                 n:     4,
114                 k:     2,
115                 alpha: 2,
116                 beta:  0.5,
117                 a: [][]float64{
118                         {1, 2},
119                         {4, 5},
120                         {8, 9},
121                 },
122                 b: [][]float64{
123                         {1, 5, 2, 1},
124                         {5, -8, 2, 1},
125                 },
126                 c: [][]float64{
127                         {4, 8, 2, 2},
128                         {12, 16, 8, 9},
129                         {9, -10, 10, 10},
130                 },
131                 ans: [][]float64{
132                         {24, -18, 13, 7},
133                         {64, -32, 40, 22.5},
134                         {110.5, -69, 73, 39},
135                 },
136         },
137         {
138                 m:     2,
139                 n:     4,
140                 k:     3,
141                 alpha: 2,
142                 beta:  0.5,
143                 a: [][]float64{
144                         {1, 2, 3},
145                         {4, 5, 6},
146                 },
147                 b: [][]float64{
148                         {1, 5, 8, 8},
149                         {5, -8, 9, 10},
150                         {6, 2, -3, 2},
151                 },
152                 c: [][]float64{
153                         {4, 8, 7, 8},
154                         {12, 16, -2, 6},
155                 },
156                 ans: [][]float64{
157                         {60, -6, 37.5, 72},
158                         {136, -8, 117, 191},
159                 },
160         },
161         {
162                 m:     2,
163                 n:     3,
164                 k:     4,
165                 alpha: 2,
166                 beta:  0.5,
167                 a: [][]float64{
168                         {1, 2, 3, 4},
169                         {4, 5, 6, 7},
170                 },
171                 b: [][]float64{
172                         {1, 5, 8},
173                         {5, -8, 9},
174                         {6, 2, -3},
175                         {8, 10, 2},
176                 },
177                 c: [][]float64{
178                         {4, 8, 1},
179                         {12, 16, 6},
180                 },
181                 ans: [][]float64{
182                         {124, 74, 50.5},
183                         {248, 132, 149},
184                 },
185         },
186 }
187
188 // assumes [][]float64 is actually a matrix
189 func transpose(a [][]float64) [][]float64 {
190         b := make([][]float64, len(a[0]))
191         for i := range b {
192                 b[i] = make([]float64, len(a))
193                 for j := range b[i] {
194                         b[i][j] = a[j][i]
195                 }
196         }
197         return b
198 }
199
200 func TestDgemm(t *testing.T, blasser Dgemmer) {
201         for i, test := range DgemmCases {
202                 // Test that it passes row major
203                 dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
204                         test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
205                 // Try with A transposed
206                 dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
207                         test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
208                 // Try with B transposed
209                 dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
210                         test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
211                 // Try with both transposed
212                 dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
213                         test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
214         }
215 }
216
217 func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
218         alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
219
220         aFlat := flatten(a)
221         aCopy := flatten(a)
222         bFlat := flatten(b)
223         bCopy := flatten(b)
224         cFlat := flatten(c)
225         ansFlat := flatten(ans)
226         lda := len(a[0])
227         ldb := len(b[0])
228         ldc := len(c[0])
229
230         // Compute the matrix multiplication
231         blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
232
233         if !dSliceEqual(aFlat, aCopy) {
234                 t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
235         }
236         if !dSliceEqual(bFlat, bCopy) {
237                 t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
238         }
239
240         if !dSliceTolEqual(ansFlat, cFlat) {
241                 t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
242         }
243         // TODO: Need to add a sub-slice test where don't use up full matrix
244 }