1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
13 func TestSolve(t *testing.T) {
15 for _, test := range []struct {
24 ans: [][]float64{{0.5}},
47 {0.8147, 0.9134, 0.5528},
48 {0.9058, 0.6324, 0.8723},
49 {0.1270, 0.0975, 0.7612},
65 {0.8147, 0.9134, 0.5528},
66 {0.9058, 0.6324, 0.8723},
73 {0.25919787248965376},
74 {-0.25560256266441034},
81 {0.8147, 0.9134, 0.9},
82 {0.9058, 0.6324, 0.9},
83 {0.1270, 0.0975, 0.1},
101 {0.8147, 0.9134, 0.231, -1.65},
102 {0.9058, 0.6324, 0.9, 0.72},
103 {0.1270, 0.0975, 0.1, 1.723},
104 {1.6, 2.8, -3.5, 0.987},
105 {7.231, 9.154, 1.823, 0.9},
115 {1.863006789511373, 44.467887791812750},
116 {-1.127270935407224, -34.073794226035126},
117 {-0.527926457947330, -8.032133759788573},
118 {-0.248621916204897, -2.366366415805275},
161 a := NewDense(flatten(test.a))
162 b := NewDense(flatten(test.b))
166 ans = NewDense(flatten(test.ans))
173 t.Errorf("Unexpected solve error: %s", err)
177 if err == nil && test.shouldErr {
178 t.Errorf("Did not error during solve.")
181 if !EqualApprox(&x, ans, 1e-12) {
182 t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x)
187 for _, test := range []struct {
203 a := NewDense(m, n, nil)
204 for i := 0; i < m; i++ {
205 for j := 0; j < n; j++ {
206 a.Set(i, j, rand.Float64())
210 b := NewDense(br, bc, nil)
211 for i := 0; i < br; i++ {
212 for j := 0; j < bc; j++ {
213 b.Set(i, j, rand.Float64())
219 // Test that the normal equations hold.
220 // A^T * A * x = A^T * b
221 var tmp, lhs, rhs Dense
225 if !EqualApprox(&lhs, &rhs, 1e-10) {
226 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
231 method := func(receiver, a, b Matrix) {
232 type Solver interface {
233 Solve(a, b Matrix) error
235 rd := receiver.(Solver)
238 denseComparison := func(receiver, a, b *Dense) {
241 testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7)
244 func TestSolveVec(t *testing.T) {
245 for _, test := range []struct {
260 a := NewDense(m, n, nil)
261 for i := 0; i < m; i++ {
262 for j := 0; j < n; j++ {
263 a.Set(i, j, rand.Float64())
267 b := NewVecDense(br, nil)
268 for i := 0; i < br; i++ {
269 b.SetVec(i, rand.Float64())
274 // Test that the normal equations hold.
275 // A^T * A * x = A^T * b
276 var tmp, lhs, rhs Dense
280 if !EqualApprox(&lhs, &rhs, 1e-10) {
281 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
286 method := func(receiver, a, b Matrix) {
287 type SolveVecer interface {
288 SolveVec(a Matrix, b Vector) error
290 rd := receiver.(SolveVecer)
291 rd.SolveVec(a, b.(Vector))
293 denseComparison := func(receiver, a, b *Dense) {
296 testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)