OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / mul_test.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15 )
16
17 // TODO: Need to add tests where one is overwritten.
18 func TestMulTypes(t *testing.T) {
19         for _, test := range []struct {
20                 ar     int
21                 ac     int
22                 br     int
23                 bc     int
24                 Panics bool
25         }{
26                 {
27                         ar:     5,
28                         ac:     5,
29                         br:     5,
30                         bc:     5,
31                         Panics: false,
32                 },
33                 {
34                         ar:     10,
35                         ac:     5,
36                         br:     5,
37                         bc:     3,
38                         Panics: false,
39                 },
40                 {
41                         ar:     10,
42                         ac:     5,
43                         br:     5,
44                         bc:     8,
45                         Panics: false,
46                 },
47                 {
48                         ar:     8,
49                         ac:     10,
50                         br:     10,
51                         bc:     3,
52                         Panics: false,
53                 },
54                 {
55                         ar:     8,
56                         ac:     3,
57                         br:     3,
58                         bc:     10,
59                         Panics: false,
60                 },
61                 {
62                         ar:     5,
63                         ac:     8,
64                         br:     8,
65                         bc:     10,
66                         Panics: false,
67                 },
68                 {
69                         ar:     5,
70                         ac:     12,
71                         br:     12,
72                         bc:     8,
73                         Panics: false,
74                 },
75                 {
76                         ar:     5,
77                         ac:     7,
78                         br:     8,
79                         bc:     10,
80                         Panics: true,
81                 },
82         } {
83                 ar := test.ar
84                 ac := test.ac
85                 br := test.br
86                 bc := test.bc
87
88                 // Generate random matrices
89                 avec := make([]float64, ar*ac)
90                 randomSlice(avec)
91                 a := NewDense(ar, ac, avec)
92
93                 bvec := make([]float64, br*bc)
94                 randomSlice(bvec)
95
96                 b := NewDense(br, bc, bvec)
97
98                 // Check that it panics if it is supposed to
99                 if test.Panics {
100                         c := NewDense(0, 0, nil)
101                         fn := func() {
102                                 c.Mul(a, b)
103                         }
104                         pan, _ := panics(fn)
105                         if !pan {
106                                 t.Errorf("Mul did not panic with dimension mismatch")
107                         }
108                         continue
109                 }
110
111                 cvec := make([]float64, ar*bc)
112
113                 // Get correct matrix multiply answer from blas64.Gemm
114                 blas64.Gemm(blas.NoTrans, blas.NoTrans,
115                         1, a.mat, b.mat,
116                         0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec},
117                 )
118
119                 avecCopy := append([]float64{}, avec...)
120                 bvecCopy := append([]float64{}, bvec...)
121                 cvecCopy := append([]float64{}, cvec...)
122
123                 acomp := matComp{r: ar, c: ac, data: avecCopy}
124                 bcomp := matComp{r: br, c: bc, data: bvecCopy}
125                 ccomp := matComp{r: ar, c: bc, data: cvecCopy}
126
127                 // Do normal multiply with empty dense
128                 d := NewDense(0, 0, nil)
129
130                 testMul(t, a, b, d, acomp, bcomp, ccomp, false, "zero receiver")
131
132                 // Normal multiply with existing receiver
133                 c := NewDense(ar, bc, cvec)
134                 randomSlice(cvec)
135                 testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver")
136
137                 // Cast a as a basic matrix
138                 am := (*basicMatrix)(a)
139                 bm := (*basicMatrix)(b)
140                 d.Reset()
141                 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is zero")
142                 d.Reset()
143                 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is zero")
144                 d.Reset()
145                 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is zero")
146                 randomSlice(cvec)
147                 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full")
148                 randomSlice(cvec)
149                 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full")
150                 randomSlice(cvec)
151                 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full")
152         }
153 }
154
155 func randomSlice(s []float64) {
156         for i := range s {
157                 s[i] = rand.NormFloat64()
158         }
159 }
160
161 type matComp struct {
162         r, c int
163         data []float64
164 }
165
166 func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, cvecApprox bool, name string) {
167         c.Mul(a, b)
168         var aDense *Dense
169         switch t := a.(type) {
170         case *Dense:
171                 aDense = t
172         case *basicMatrix:
173                 aDense = (*Dense)(t)
174         }
175
176         var bDense *Dense
177         switch t := b.(type) {
178         case *Dense:
179                 bDense = t
180         case *basicMatrix:
181                 bDense = (*Dense)(t)
182         }
183
184         if !denseEqual(aDense, acomp) {
185                 t.Errorf("a changed unexpectedly for %v", name)
186         }
187         if !denseEqual(bDense, bcomp) {
188                 t.Errorf("b changed unexpectedly for %v", name)
189         }
190         if cvecApprox {
191                 if !denseEqualApprox(c, ccomp, 1e-14) {
192                         t.Errorf("mul answer not within tol for %v", name)
193                 }
194                 return
195         }
196
197         if !denseEqual(c, ccomp) {
198                 t.Errorf("mul answer not equal for %v", name)
199         }
200 }
201
202 type basicMatrix Dense
203
204 func (m *basicMatrix) At(r, c int) float64 {
205         return (*Dense)(m).At(r, c)
206 }
207
208 func (m *basicMatrix) Dims() (r, c int) {
209         return (*Dense)(m).Dims()
210 }
211
212 func (m *basicMatrix) T() Matrix {
213         return Transpose{m}
214 }
215
216 type basicSymmetric SymDense
217
218 var _ Symmetric = &basicSymmetric{}
219
220 func (m *basicSymmetric) At(r, c int) float64 {
221         return (*SymDense)(m).At(r, c)
222 }
223
224 func (m *basicSymmetric) Dims() (r, c int) {
225         return (*SymDense)(m).Dims()
226 }
227
228 func (m *basicSymmetric) T() Matrix {
229         return m
230 }
231
232 func (m *basicSymmetric) Symmetric() int {
233         return (*SymDense)(m).Symmetric()
234 }
235
236 type basicTriangular TriDense
237
238 func (m *basicTriangular) At(r, c int) float64 {
239         return (*TriDense)(m).At(r, c)
240 }
241
242 func (m *basicTriangular) Dims() (r, c int) {
243         return (*TriDense)(m).Dims()
244 }
245
246 func (m *basicTriangular) T() Matrix {
247         return Transpose{m}
248 }
249
250 func (m *basicTriangular) Triangle() (int, TriKind) {
251         return (*TriDense)(m).Triangle()
252 }
253
254 func (m *basicTriangular) TTri() Triangular {
255         return TransposeTri{m}
256 }
257
258 func denseEqual(a *Dense, acomp matComp) bool {
259         ar2, ac2 := a.Dims()
260         if ar2 != acomp.r {
261                 return false
262         }
263         if ac2 != acomp.c {
264                 return false
265         }
266         if !floats.Equal(a.mat.Data, acomp.data) {
267                 return false
268         }
269         return true
270 }
271
272 func denseEqualApprox(a *Dense, acomp matComp, tol float64) bool {
273         ar2, ac2 := a.Dims()
274         if ar2 != acomp.r {
275                 return false
276         }
277         if ac2 != acomp.c {
278                 return false
279         }
280         if !floats.EqualApprox(a.mat.Data, acomp.data, tol) {
281                 return false
282         }
283         return true
284 }