1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/blas/blas64"
15 "gonum.org/v1/gonum/floats"
16 "gonum.org/v1/gonum/lapack"
19 type Dgesvder interface {
20 Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
23 func DgesvdTest(t *testing.T, impl Dgesvder) {
24 rnd := rand.New(rand.NewSource(1))
25 // TODO(btracey): Add tests for all of the cases when the SVD implementation
27 // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
28 // conditions are implemented. Right now mnthr is 5,000,000 which is too
29 // large to create a square matrix of that size.
30 for _, test := range []struct {
31 m, n, lda, ldu, ldvt int
52 {300, 300, 400, 450, 460},
53 {300, 400, 500, 550, 560},
54 {400, 300, 550, 550, 560},
55 {300, 600, 700, 750, 760},
56 {600, 300, 700, 750, 760},
59 jobVT := lapack.SVDAll
76 a := make([]float64, m*lda)
78 a[i] = rnd.NormFloat64()
81 u := make([]float64, m*ldu)
83 u[i] = rnd.NormFloat64()
86 vt := make([]float64, n*ldvt)
88 vt[i] = rnd.NormFloat64()
91 uAllOrig := make([]float64, len(u))
93 vtAllOrig := make([]float64, len(vt))
95 aCopy := make([]float64, len(a))
98 s := make([]float64, min(m, n))
100 work := make([]float64, 1)
101 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
103 if !floats.Equal(a, aCopy) {
104 t.Errorf("a changed during call to get work length")
107 work = make([]float64, int(work[0]))
108 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
110 errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
111 svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
112 svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
115 jobU = lapack.SVDInPlace
116 jobVT = lapack.SVDInPlace
121 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
122 svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
123 svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
127 // svdCheckPartial checks that the singular values and vectors are computed when
128 // not all of them are computed.
129 func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
130 rnd := rand.New(rand.NewSource(1))
133 // Compare the singular values when computed with {SVDNone, SVDNone.}
134 sCopy := make([]float64, len(s))
140 tmp1 := make([]float64, 1)
141 tmp2 := make([]float64, 1)
142 jobU = lapack.SVDNone
143 jobVT = lapack.SVDNone
145 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
146 work = make([]float64, int(work[0]))
151 ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
153 t.Errorf("Dgesvd did not complete successfully")
155 if !floats.EqualApprox(s, sCopy, 1e-10) {
156 t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
158 // Check that the singular vectors are correctly computed when the other
160 uAll := make([]float64, len(u))
162 vtAll := make([]float64, len(vt))
165 // Copy the original vectors so the data outside the matrix bounds is the same.
170 jobVT = lapack.SVDNone
175 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
176 work = make([]float64, int(work[0]))
181 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
182 if !floats.EqualApprox(uAll, u, 1e-10) {
183 t.Errorf("U mismatch when VT is not computed: %s", errStr)
185 if !floats.EqualApprox(s, sCopy, 1e-10) {
186 t.Errorf("Singular value mismatch when U computed VT not")
188 jobU = lapack.SVDNone
194 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
195 work = make([]float64, int(work[0]))
200 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
201 if !floats.EqualApprox(vtAll, vt, 1e-10) {
202 t.Errorf("VT mismatch when U is not computed: %s", errStr)
204 if !floats.EqualApprox(s, sCopy, 1e-10) {
205 t.Errorf("Singular value mismatch when VT computed U not")
209 // svdCheck checks that the singular value decomposition correctly multiplies back
210 // to the original matrix.
211 func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
212 sigma := blas64.General{
216 Data: make([]float64, m*n),
218 for i := 0; i < min(m, n); i++ {
219 sigma.Data[i*sigma.Stride+i] = s[i]
222 uMat := blas64.General{
228 vTMat := blas64.General{
235 sigma.Rows = min(m, n)
236 sigma.Cols = min(m, n)
237 uMat.Cols = min(m, n)
238 vTMat.Rows = min(m, n)
241 tmp := blas64.General{
245 Data: make([]float64, m*n),
247 ans := blas64.General{
251 Data: make([]float64, m*lda),
255 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
256 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
258 if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
259 t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
263 // Check that U and V are orthogonal.
264 for i := 0; i < uMat.Rows; i++ {
265 for j := i + 1; j < uMat.Rows; j++ {
266 dot := blas64.Dot(uMat.Cols,
267 blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
268 blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
271 t.Errorf("U not orthogonal %s", errStr)
275 for i := 0; i < vTMat.Rows; i++ {
276 for j := i + 1; j < vTMat.Rows; j++ {
277 dot := blas64.Dot(vTMat.Cols,
278 blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
279 blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
282 t.Errorf("V not orthogonal %s", errStr)