OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / lu_test.go
1 // Copyright ©2013 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11 )
12
13 func TestLUD(t *testing.T) {
14         for _, n := range []int{1, 5, 10, 11, 50} {
15                 a := NewDense(n, n, nil)
16                 for i := 0; i < n; i++ {
17                         for j := 0; j < n; j++ {
18                                 a.Set(i, j, rand.NormFloat64())
19                         }
20                 }
21                 var want Dense
22                 want.Clone(a)
23
24                 var lu LU
25                 lu.Factorize(a)
26
27                 l := lu.LTo(nil)
28                 u := lu.UTo(nil)
29                 var p Dense
30                 pivot := lu.Pivot(nil)
31                 p.Permutation(n, pivot)
32                 var got Dense
33                 got.Product(&p, l, u)
34                 if !EqualApprox(&got, &want, 1e-12) {
35                         t.Errorf("PLU does not equal original matrix.\nWant: %v\n Got: %v", want, got)
36                 }
37         }
38 }
39
40 func TestLURankOne(t *testing.T) {
41         for _, pivoting := range []bool{true} {
42                 for _, n := range []int{3, 10, 50} {
43                         // Construct a random LU factorization
44                         lu := &LU{}
45                         lu.lu = NewDense(n, n, nil)
46                         for i := 0; i < n; i++ {
47                                 for j := 0; j < n; j++ {
48                                         lu.lu.Set(i, j, rand.Float64())
49                                 }
50                         }
51                         lu.pivot = make([]int, n)
52                         for i := range lu.pivot {
53                                 lu.pivot[i] = i
54                         }
55                         if pivoting {
56                                 // For each row, randomly swap with itself or a row after (like is done)
57                                 // in the actual LU factorization.
58                                 for i := range lu.pivot {
59                                         idx := i + rand.Intn(n-i)
60                                         lu.pivot[i], lu.pivot[idx] = lu.pivot[idx], lu.pivot[i]
61                                 }
62                         }
63                         // Apply a rank one update. Ensure the update magnitude is larger than
64                         // the equal tolerance.
65                         alpha := rand.Float64() + 1
66                         x := NewVecDense(n, nil)
67                         y := NewVecDense(n, nil)
68                         for i := 0; i < n; i++ {
69                                 x.setVec(i, rand.Float64()+1)
70                                 y.setVec(i, rand.Float64()+1)
71                         }
72                         a := luReconstruct(lu)
73                         a.RankOne(a, alpha, x, y)
74
75                         var luNew LU
76                         luNew.RankOne(lu, alpha, x, y)
77                         lu.RankOne(lu, alpha, x, y)
78
79                         aR1New := luReconstruct(&luNew)
80                         aR1 := luReconstruct(lu)
81
82                         if !Equal(aR1, aR1New) {
83                                 t.Error("Different answer when new receiver")
84                         }
85                         if !EqualApprox(aR1, a, 1e-10) {
86                                 t.Errorf("Rank one mismatch, pivot %v.\nWant: %v\nGot:%v\n", pivoting, a, aR1)
87                         }
88                 }
89         }
90 }
91
92 // luReconstruct reconstructs the original A matrix from an LU decomposition.
93 func luReconstruct(lu *LU) *Dense {
94         var L, U TriDense
95         lu.LTo(&L)
96         lu.UTo(&U)
97         var P Dense
98         pivot := lu.Pivot(nil)
99         P.Permutation(len(pivot), pivot)
100
101         var a Dense
102         a.Mul(&L, &U)
103         a.Mul(&P, &a)
104         return &a
105 }
106
107 func TestSolveLU(t *testing.T) {
108         for _, test := range []struct {
109                 n, bc int
110         }{
111                 {5, 5},
112                 {5, 10},
113                 {10, 5},
114         } {
115                 n := test.n
116                 bc := test.bc
117                 a := NewDense(n, n, nil)
118                 for i := 0; i < n; i++ {
119                         for j := 0; j < n; j++ {
120                                 a.Set(i, j, rand.NormFloat64())
121                         }
122                 }
123                 b := NewDense(n, bc, nil)
124                 for i := 0; i < n; i++ {
125                         for j := 0; j < bc; j++ {
126                                 b.Set(i, j, rand.NormFloat64())
127                         }
128                 }
129                 var lu LU
130                 lu.Factorize(a)
131                 var x Dense
132                 if err := lu.Solve(&x, false, b); err != nil {
133                         continue
134                 }
135                 var got Dense
136                 got.Mul(a, &x)
137                 if !EqualApprox(&got, b, 1e-12) {
138                         t.Errorf("Solve mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
139                 }
140         }
141         // TODO(btracey): Add testOneInput test when such a function exists.
142 }
143
144 func TestSolveLUCond(t *testing.T) {
145         for _, test := range []*Dense{
146                 NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
147         } {
148                 m, _ := test.Dims()
149                 var lu LU
150                 lu.Factorize(test)
151                 b := NewDense(m, 2, nil)
152                 var x Dense
153                 if err := lu.Solve(&x, false, b); err == nil {
154                         t.Error("No error for near-singular matrix in matrix solve.")
155                 }
156
157                 bvec := NewVecDense(m, nil)
158                 var xvec VecDense
159                 if err := lu.SolveVec(&xvec, false, bvec); err == nil {
160                         t.Error("No error for near-singular matrix in matrix solve.")
161                 }
162         }
163 }
164
165 func TestSolveLUVec(t *testing.T) {
166         for _, n := range []int{5, 10} {
167                 a := NewDense(n, n, nil)
168                 for i := 0; i < n; i++ {
169                         for j := 0; j < n; j++ {
170                                 a.Set(i, j, rand.NormFloat64())
171                         }
172                 }
173                 b := NewVecDense(n, nil)
174                 for i := 0; i < n; i++ {
175                         b.SetVec(i, rand.NormFloat64())
176                 }
177                 var lu LU
178                 lu.Factorize(a)
179                 var x VecDense
180                 if err := lu.SolveVec(&x, false, b); err != nil {
181                         continue
182                 }
183                 var got VecDense
184                 got.MulVec(a, &x)
185                 if !EqualApprox(&got, b, 1e-12) {
186                         t.Errorf("Solve mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
187                 }
188         }
189         // TODO(btracey): Add testOneInput test when such a function exists.
190 }