OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgesvd.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15         "gonum.org/v1/gonum/floats"
16         "gonum.org/v1/gonum/lapack"
17 )
18
19 type Dgesvder interface {
20         Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
21 }
22
23 func DgesvdTest(t *testing.T, impl Dgesvder) {
24         rnd := rand.New(rand.NewSource(1))
25         // TODO(btracey): Add tests for all of the cases when the SVD implementation
26         // is finished.
27         // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
28         // conditions are implemented. Right now mnthr is 5,000,000 which is too
29         // large to create a square matrix of that size.
30         for _, test := range []struct {
31                 m, n, lda, ldu, ldvt int
32         }{
33                 {5, 5, 0, 0, 0},
34                 {5, 6, 0, 0, 0},
35                 {6, 5, 0, 0, 0},
36                 {5, 9, 0, 0, 0},
37                 {9, 5, 0, 0, 0},
38
39                 {5, 5, 10, 11, 12},
40                 {5, 6, 10, 11, 12},
41                 {6, 5, 10, 11, 12},
42                 {5, 5, 10, 11, 12},
43                 {5, 9, 10, 11, 12},
44                 {9, 5, 10, 11, 12},
45
46                 {300, 300, 0, 0, 0},
47                 {300, 400, 0, 0, 0},
48                 {400, 300, 0, 0, 0},
49                 {300, 600, 0, 0, 0},
50                 {600, 300, 0, 0, 0},
51
52                 {300, 300, 400, 450, 460},
53                 {300, 400, 500, 550, 560},
54                 {400, 300, 550, 550, 560},
55                 {300, 600, 700, 750, 760},
56                 {600, 300, 700, 750, 760},
57         } {
58                 jobU := lapack.SVDAll
59                 jobVT := lapack.SVDAll
60
61                 m := test.m
62                 n := test.n
63                 lda := test.lda
64                 if lda == 0 {
65                         lda = n
66                 }
67                 ldu := test.ldu
68                 if ldu == 0 {
69                         ldu = m
70                 }
71                 ldvt := test.ldvt
72                 if ldvt == 0 {
73                         ldvt = n
74                 }
75
76                 a := make([]float64, m*lda)
77                 for i := range a {
78                         a[i] = rnd.NormFloat64()
79                 }
80
81                 u := make([]float64, m*ldu)
82                 for i := range u {
83                         u[i] = rnd.NormFloat64()
84                 }
85
86                 vt := make([]float64, n*ldvt)
87                 for i := range vt {
88                         vt[i] = rnd.NormFloat64()
89                 }
90
91                 uAllOrig := make([]float64, len(u))
92                 copy(uAllOrig, u)
93                 vtAllOrig := make([]float64, len(vt))
94                 copy(vtAllOrig, vt)
95                 aCopy := make([]float64, len(a))
96                 copy(aCopy, a)
97
98                 s := make([]float64, min(m, n))
99
100                 work := make([]float64, 1)
101                 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
102
103                 if !floats.Equal(a, aCopy) {
104                         t.Errorf("a changed during call to get work length")
105                 }
106
107                 work = make([]float64, int(work[0]))
108                 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
109
110                 errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
111                 svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
112                 svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
113
114                 // Test InPlace
115                 jobU = lapack.SVDInPlace
116                 jobVT = lapack.SVDInPlace
117                 copy(a, aCopy)
118                 copy(u, uAllOrig)
119                 copy(vt, vtAllOrig)
120
121                 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
122                 svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
123                 svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
124         }
125 }
126
127 // svdCheckPartial checks that the singular values and vectors are computed when
128 // not all of them are computed.
129 func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
130         rnd := rand.New(rand.NewSource(1))
131         jobU := job
132         jobVT := job
133         // Compare the singular values when computed with {SVDNone, SVDNone.}
134         sCopy := make([]float64, len(s))
135         copy(sCopy, s)
136         copy(a, aCopy)
137         for i := range s {
138                 s[i] = rnd.Float64()
139         }
140         tmp1 := make([]float64, 1)
141         tmp2 := make([]float64, 1)
142         jobU = lapack.SVDNone
143         jobVT = lapack.SVDNone
144
145         impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
146         work = make([]float64, int(work[0]))
147         lwork := len(work)
148         if shortWork {
149                 lwork--
150         }
151         ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
152         if !ok {
153                 t.Errorf("Dgesvd did not complete successfully")
154         }
155         if !floats.EqualApprox(s, sCopy, 1e-10) {
156                 t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
157         }
158         // Check that the singular vectors are correctly computed when the other
159         // is none.
160         uAll := make([]float64, len(u))
161         copy(uAll, u)
162         vtAll := make([]float64, len(vt))
163         copy(vtAll, vt)
164
165         // Copy the original vectors so the data outside the matrix bounds is the same.
166         copy(u, uAllOrig)
167         copy(vt, vtAllOrig)
168
169         jobU = job
170         jobVT = lapack.SVDNone
171         copy(a, aCopy)
172         for i := range s {
173                 s[i] = rnd.Float64()
174         }
175         impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
176         work = make([]float64, int(work[0]))
177         lwork = len(work)
178         if shortWork {
179                 lwork--
180         }
181         impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
182         if !floats.EqualApprox(uAll, u, 1e-10) {
183                 t.Errorf("U mismatch when VT is not computed: %s", errStr)
184         }
185         if !floats.EqualApprox(s, sCopy, 1e-10) {
186                 t.Errorf("Singular value mismatch when U computed VT not")
187         }
188         jobU = lapack.SVDNone
189         jobVT = job
190         copy(a, aCopy)
191         for i := range s {
192                 s[i] = rnd.Float64()
193         }
194         impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
195         work = make([]float64, int(work[0]))
196         lwork = len(work)
197         if shortWork {
198                 lwork--
199         }
200         impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
201         if !floats.EqualApprox(vtAll, vt, 1e-10) {
202                 t.Errorf("VT mismatch when U is not computed: %s", errStr)
203         }
204         if !floats.EqualApprox(s, sCopy, 1e-10) {
205                 t.Errorf("Singular value mismatch when VT computed U not")
206         }
207 }
208
209 // svdCheck checks that the singular value decomposition correctly multiplies back
210 // to the original matrix.
211 func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
212         sigma := blas64.General{
213                 Rows:   m,
214                 Cols:   n,
215                 Stride: n,
216                 Data:   make([]float64, m*n),
217         }
218         for i := 0; i < min(m, n); i++ {
219                 sigma.Data[i*sigma.Stride+i] = s[i]
220         }
221
222         uMat := blas64.General{
223                 Rows:   m,
224                 Cols:   m,
225                 Stride: ldu,
226                 Data:   u,
227         }
228         vTMat := blas64.General{
229                 Rows:   n,
230                 Cols:   n,
231                 Stride: ldvt,
232                 Data:   vt,
233         }
234         if thin {
235                 sigma.Rows = min(m, n)
236                 sigma.Cols = min(m, n)
237                 uMat.Cols = min(m, n)
238                 vTMat.Rows = min(m, n)
239         }
240
241         tmp := blas64.General{
242                 Rows:   m,
243                 Cols:   n,
244                 Stride: n,
245                 Data:   make([]float64, m*n),
246         }
247         ans := blas64.General{
248                 Rows:   m,
249                 Cols:   n,
250                 Stride: lda,
251                 Data:   make([]float64, m*lda),
252         }
253         copy(ans.Data, a)
254
255         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
256         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
257
258         if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
259                 t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
260         }
261
262         if !thin {
263                 // Check that U and V are orthogonal.
264                 for i := 0; i < uMat.Rows; i++ {
265                         for j := i + 1; j < uMat.Rows; j++ {
266                                 dot := blas64.Dot(uMat.Cols,
267                                         blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
268                                         blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
269                                 )
270                                 if dot > 1e-8 {
271                                         t.Errorf("U not orthogonal %s", errStr)
272                                 }
273                         }
274                 }
275                 for i := 0; i < vTMat.Rows; i++ {
276                         for j := i + 1; j < vTMat.Rows; j++ {
277                                 dot := blas64.Dot(vTMat.Cols,
278                                         blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
279                                         blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
280                                 )
281                                 if dot > 1e-8 {
282                                         t.Errorf("V not orthogonal %s", errStr)
283                                 }
284                         }
285                 }
286         }
287 }