OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / solve_test.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11 )
12
13 func TestSolve(t *testing.T) {
14         // Hand-coded cases.
15         for _, test := range []struct {
16                 a         [][]float64
17                 b         [][]float64
18                 ans       [][]float64
19                 shouldErr bool
20         }{
21                 {
22                         a:         [][]float64{{6}},
23                         b:         [][]float64{{3}},
24                         ans:       [][]float64{{0.5}},
25                         shouldErr: false,
26                 },
27                 {
28                         a: [][]float64{
29                                 {1, 0, 0},
30                                 {0, 1, 0},
31                                 {0, 0, 1},
32                         },
33                         b: [][]float64{
34                                 {3},
35                                 {2},
36                                 {1},
37                         },
38                         ans: [][]float64{
39                                 {3},
40                                 {2},
41                                 {1},
42                         },
43                         shouldErr: false,
44                 },
45                 {
46                         a: [][]float64{
47                                 {0.8147, 0.9134, 0.5528},
48                                 {0.9058, 0.6324, 0.8723},
49                                 {0.1270, 0.0975, 0.7612},
50                         },
51                         b: [][]float64{
52                                 {0.278},
53                                 {0.547},
54                                 {0.958},
55                         },
56                         ans: [][]float64{
57                                 {-0.932687281002860},
58                                 {0.303963920182067},
59                                 {1.375216503507109},
60                         },
61                         shouldErr: false,
62                 },
63                 {
64                         a: [][]float64{
65                                 {0.8147, 0.9134, 0.5528},
66                                 {0.9058, 0.6324, 0.8723},
67                         },
68                         b: [][]float64{
69                                 {0.278},
70                                 {0.547},
71                         },
72                         ans: [][]float64{
73                                 {0.25919787248965376},
74                                 {-0.25560256266441034},
75                                 {0.5432324059702451},
76                         },
77                         shouldErr: false,
78                 },
79                 {
80                         a: [][]float64{
81                                 {0.8147, 0.9134, 0.9},
82                                 {0.9058, 0.6324, 0.9},
83                                 {0.1270, 0.0975, 0.1},
84                                 {1.6, 2.8, -3.5},
85                         },
86                         b: [][]float64{
87                                 {0.278},
88                                 {0.547},
89                                 {-0.958},
90                                 {1.452},
91                         },
92                         ans: [][]float64{
93                                 {0.820970340787782},
94                                 {-0.218604626527306},
95                                 {-0.212938815234215},
96                         },
97                         shouldErr: false,
98                 },
99                 {
100                         a: [][]float64{
101                                 {0.8147, 0.9134, 0.231, -1.65},
102                                 {0.9058, 0.6324, 0.9, 0.72},
103                                 {0.1270, 0.0975, 0.1, 1.723},
104                                 {1.6, 2.8, -3.5, 0.987},
105                                 {7.231, 9.154, 1.823, 0.9},
106                         },
107                         b: [][]float64{
108                                 {0.278, 8.635},
109                                 {0.547, 9.125},
110                                 {-0.958, -0.762},
111                                 {1.452, 1.444},
112                                 {1.999, -7.234},
113                         },
114                         ans: [][]float64{
115                                 {1.863006789511373, 44.467887791812750},
116                                 {-1.127270935407224, -34.073794226035126},
117                                 {-0.527926457947330, -8.032133759788573},
118                                 {-0.248621916204897, -2.366366415805275},
119                         },
120                         shouldErr: false,
121                 },
122                 {
123                         a: [][]float64{
124                                 {0, 0},
125                                 {0, 0},
126                         },
127                         b: [][]float64{
128                                 {3},
129                                 {2},
130                         },
131                         ans:       nil,
132                         shouldErr: true,
133                 },
134                 {
135                         a: [][]float64{
136                                 {0, 0},
137                                 {0, 0},
138                                 {0, 0},
139                         },
140                         b: [][]float64{
141                                 {3},
142                                 {2},
143                                 {1},
144                         },
145                         ans:       nil,
146                         shouldErr: true,
147                 },
148                 {
149                         a: [][]float64{
150                                 {0, 0, 0},
151                                 {0, 0, 0},
152                         },
153                         b: [][]float64{
154                                 {3},
155                                 {2},
156                         },
157                         ans:       nil,
158                         shouldErr: true,
159                 },
160         } {
161                 a := NewDense(flatten(test.a))
162                 b := NewDense(flatten(test.b))
163
164                 var ans *Dense
165                 if test.ans != nil {
166                         ans = NewDense(flatten(test.ans))
167                 }
168
169                 var x Dense
170                 err := x.Solve(a, b)
171                 if err != nil {
172                         if !test.shouldErr {
173                                 t.Errorf("Unexpected solve error: %s", err)
174                         }
175                         continue
176                 }
177                 if err == nil && test.shouldErr {
178                         t.Errorf("Did not error during solve.")
179                         continue
180                 }
181                 if !EqualApprox(&x, ans, 1e-12) {
182                         t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x)
183                 }
184         }
185
186         // Random Cases.
187         for _, test := range []struct {
188                 m, n, bc int
189         }{
190                 {5, 5, 1},
191                 {5, 10, 1},
192                 {10, 5, 1},
193                 {5, 5, 7},
194                 {5, 10, 7},
195                 {10, 5, 7},
196                 {5, 5, 12},
197                 {5, 10, 12},
198                 {10, 5, 12},
199         } {
200                 m := test.m
201                 n := test.n
202                 bc := test.bc
203                 a := NewDense(m, n, nil)
204                 for i := 0; i < m; i++ {
205                         for j := 0; j < n; j++ {
206                                 a.Set(i, j, rand.Float64())
207                         }
208                 }
209                 br := m
210                 b := NewDense(br, bc, nil)
211                 for i := 0; i < br; i++ {
212                         for j := 0; j < bc; j++ {
213                                 b.Set(i, j, rand.Float64())
214                         }
215                 }
216                 var x Dense
217                 x.Solve(a, b)
218
219                 // Test that the normal equations hold.
220                 // A^T * A * x = A^T * b
221                 var tmp, lhs, rhs Dense
222                 tmp.Mul(a.T(), a)
223                 lhs.Mul(&tmp, &x)
224                 rhs.Mul(a.T(), b)
225                 if !EqualApprox(&lhs, &rhs, 1e-10) {
226                         t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
227                 }
228         }
229
230         // Use testTwoInput.
231         method := func(receiver, a, b Matrix) {
232                 type Solver interface {
233                         Solve(a, b Matrix) error
234                 }
235                 rd := receiver.(Solver)
236                 rd.Solve(a, b)
237         }
238         denseComparison := func(receiver, a, b *Dense) {
239                 receiver.Solve(a, b)
240         }
241         testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7)
242 }
243
244 func TestSolveVec(t *testing.T) {
245         for _, test := range []struct {
246                 m, n int
247         }{
248                 {5, 5},
249                 {5, 10},
250                 {10, 5},
251                 {5, 5},
252                 {5, 10},
253                 {10, 5},
254                 {5, 5},
255                 {5, 10},
256                 {10, 5},
257         } {
258                 m := test.m
259                 n := test.n
260                 a := NewDense(m, n, nil)
261                 for i := 0; i < m; i++ {
262                         for j := 0; j < n; j++ {
263                                 a.Set(i, j, rand.Float64())
264                         }
265                 }
266                 br := m
267                 b := NewVecDense(br, nil)
268                 for i := 0; i < br; i++ {
269                         b.SetVec(i, rand.Float64())
270                 }
271                 var x VecDense
272                 x.SolveVec(a, b)
273
274                 // Test that the normal equations hold.
275                 // A^T * A * x = A^T * b
276                 var tmp, lhs, rhs Dense
277                 tmp.Mul(a.T(), a)
278                 lhs.Mul(&tmp, &x)
279                 rhs.Mul(a.T(), b)
280                 if !EqualApprox(&lhs, &rhs, 1e-10) {
281                         t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
282                 }
283         }
284
285         // Use testTwoInput
286         method := func(receiver, a, b Matrix) {
287                 type SolveVecer interface {
288                         SolveVec(a Matrix, b Vector) error
289                 }
290                 rd := receiver.(SolveVecer)
291                 rd.SolveVec(a, b.(Vector))
292         }
293         denseComparison := func(receiver, a, b *Dense) {
294                 receiver.Solve(a, b)
295         }
296         testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
297 }