OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / svd_test.go
1 // Copyright ©2013 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/floats"
13 )
14
15 func TestSVD(t *testing.T) {
16         // Hand coded tests
17         for _, test := range []struct {
18                 a *Dense
19                 u *Dense
20                 v *Dense
21                 s []float64
22         }{
23                 {
24                         a: NewDense(4, 2, []float64{2, 4, 1, 3, 0, 0, 0, 0}),
25                         u: NewDense(4, 2, []float64{
26                                 -0.8174155604703632, -0.5760484367663209,
27                                 -0.5760484367663209, 0.8174155604703633,
28                                 0, 0,
29                                 0, 0,
30                         }),
31                         v: NewDense(2, 2, []float64{
32                                 -0.4045535848337571, -0.9145142956773044,
33                                 -0.9145142956773044, 0.4045535848337571,
34                         }),
35                         s: []float64{5.464985704219041, 0.365966190626258},
36                 },
37                 {
38                         // Issue #5.
39                         a: NewDense(3, 11, []float64{
40                                 1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1,
41                                 1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2,
42                                 1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3,
43                         }),
44                         u: NewDense(3, 3, []float64{
45                                 -0.5224167862273765, 0.7864430360363114, 0.3295270133658976,
46                                 -0.5739526766688285, -0.03852203026050301, -0.8179818935216693,
47                                 -0.6306021141833781, -0.6164603833618163, 0.4715056408282468,
48                         }),
49                         v: NewDense(11, 3, []float64{
50                                 -0.08123293141915189, 0.08528085505260324, -0.013165501690885152,
51                                 -0.05423546426886932, 0.1102707844980355, 0.622210623111631,
52                                 0, 0, 0,
53                                 -0.0245733326078166, 0.510179651760153, 0.25596360803140994,
54                                 0, 0, 0,
55                                 0, 0, 0,
56                                 -0.026997467150282436, -0.024989929445430496, -0.6353761248025164,
57                                 0, 0, 0,
58                                 -0.029662131661052707, -0.3999088672621176, 0.3662470150802212,
59                                 -0.9798839760830571, 0.11328174160898856, -0.047702613241813366,
60                                 -0.16755466189153964, -0.7395268089170608, 0.08395240366704032,
61                         }),
62                         s: []float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628},
63                 },
64         } {
65                 var svd SVD
66                 ok := svd.Factorize(test.a, SVDThin)
67                 if !ok {
68                         t.Errorf("SVD failed")
69                 }
70                 s, u, v := extractSVD(&svd)
71                 if !floats.EqualApprox(s, test.s, 1e-10) {
72                         t.Errorf("Singular value mismatch. Got %v, want %v.", s, test.s)
73                 }
74                 if !EqualApprox(u, test.u, 1e-10) {
75                         t.Errorf("U mismatch.\nGot:\n%v\nWant:\n%v", Formatted(u), Formatted(test.u))
76                 }
77                 if !EqualApprox(v, test.v, 1e-10) {
78                         t.Errorf("V mismatch.\nGot:\n%v\nWant:\n%v", Formatted(v), Formatted(test.v))
79                 }
80                 m, n := test.a.Dims()
81                 sigma := NewDense(min(m, n), min(m, n), nil)
82                 for i := 0; i < min(m, n); i++ {
83                         sigma.Set(i, i, s[i])
84                 }
85
86                 var ans Dense
87                 ans.Product(u, sigma, v.T())
88                 if !EqualApprox(test.a, &ans, 1e-10) {
89                         t.Errorf("A reconstruction mismatch.\nGot:\n%v\nWant:\n%v\n", Formatted(&ans), Formatted(test.a))
90                 }
91         }
92
93         for _, test := range []struct {
94                 m, n int
95         }{
96                 {5, 5},
97                 {5, 3},
98                 {3, 5},
99                 {150, 150},
100                 {200, 150},
101                 {150, 200},
102         } {
103                 m := test.m
104                 n := test.n
105                 for trial := 0; trial < 10; trial++ {
106                         a := NewDense(m, n, nil)
107                         for i := range a.mat.Data {
108                                 a.mat.Data[i] = rand.NormFloat64()
109                         }
110                         aCopy := DenseCopyOf(a)
111
112                         // Test Full decomposition.
113                         var svd SVD
114                         ok := svd.Factorize(a, SVDFull)
115                         if !ok {
116                                 t.Errorf("SVD factorization failed")
117                         }
118                         if !Equal(a, aCopy) {
119                                 t.Errorf("A changed during call to SVD with full")
120                         }
121                         s, u, v := extractSVD(&svd)
122                         sigma := NewDense(m, n, nil)
123                         for i := 0; i < min(m, n); i++ {
124                                 sigma.Set(i, i, s[i])
125                         }
126                         var ansFull Dense
127                         ansFull.Product(u, sigma, v.T())
128                         if !EqualApprox(&ansFull, a, 1e-8) {
129                                 t.Errorf("Answer mismatch when SVDFull")
130                         }
131
132                         // Test Thin decomposition.
133                         ok = svd.Factorize(a, SVDThin)
134                         if !ok {
135                                 t.Errorf("SVD factorization failed")
136                         }
137                         if !Equal(a, aCopy) {
138                                 t.Errorf("A changed during call to SVD with Thin")
139                         }
140                         sThin, u, v := extractSVD(&svd)
141                         if !floats.EqualApprox(s, sThin, 1e-8) {
142                                 t.Errorf("Singular value mismatch between Full and Thin decomposition")
143                         }
144                         sigma = NewDense(min(m, n), min(m, n), nil)
145                         for i := 0; i < min(m, n); i++ {
146                                 sigma.Set(i, i, sThin[i])
147                         }
148                         ansFull.Reset()
149                         ansFull.Product(u, sigma, v.T())
150                         if !EqualApprox(&ansFull, a, 1e-8) {
151                                 t.Errorf("Answer mismatch when SVDFull")
152                         }
153
154                         // Test None decomposition.
155                         ok = svd.Factorize(a, SVDNone)
156                         if !ok {
157                                 t.Errorf("SVD factorization failed")
158                         }
159                         if !Equal(a, aCopy) {
160                                 t.Errorf("A changed during call to SVD with none")
161                         }
162                         sNone := make([]float64, min(m, n))
163                         svd.Values(sNone)
164                         if !floats.EqualApprox(s, sNone, 1e-8) {
165                                 t.Errorf("Singular value mismatch between Full and None decomposition")
166                         }
167                 }
168         }
169 }
170
171 func extractSVD(svd *SVD) (s []float64, u, v *Dense) {
172         return svd.Values(nil), svd.UTo(nil), svd.VTo(nil)
173 }