OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgesvd.go
diff --git a/vendor/gonum.org/v1/gonum/lapack/testlapack/dgesvd.go b/vendor/gonum.org/v1/gonum/lapack/testlapack/dgesvd.go
new file mode 100644 (file)
index 0000000..4042e15
--- /dev/null
@@ -0,0 +1,287 @@
+// Copyright ©2015 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testlapack
+
+import (
+       "fmt"
+       "testing"
+
+       "golang.org/x/exp/rand"
+
+       "gonum.org/v1/gonum/blas"
+       "gonum.org/v1/gonum/blas/blas64"
+       "gonum.org/v1/gonum/floats"
+       "gonum.org/v1/gonum/lapack"
+)
+
+type Dgesvder interface {
+       Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
+}
+
+func DgesvdTest(t *testing.T, impl Dgesvder) {
+       rnd := rand.New(rand.NewSource(1))
+       // TODO(btracey): Add tests for all of the cases when the SVD implementation
+       // is finished.
+       // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
+       // conditions are implemented. Right now mnthr is 5,000,000 which is too
+       // large to create a square matrix of that size.
+       for _, test := range []struct {
+               m, n, lda, ldu, ldvt int
+       }{
+               {5, 5, 0, 0, 0},
+               {5, 6, 0, 0, 0},
+               {6, 5, 0, 0, 0},
+               {5, 9, 0, 0, 0},
+               {9, 5, 0, 0, 0},
+
+               {5, 5, 10, 11, 12},
+               {5, 6, 10, 11, 12},
+               {6, 5, 10, 11, 12},
+               {5, 5, 10, 11, 12},
+               {5, 9, 10, 11, 12},
+               {9, 5, 10, 11, 12},
+
+               {300, 300, 0, 0, 0},
+               {300, 400, 0, 0, 0},
+               {400, 300, 0, 0, 0},
+               {300, 600, 0, 0, 0},
+               {600, 300, 0, 0, 0},
+
+               {300, 300, 400, 450, 460},
+               {300, 400, 500, 550, 560},
+               {400, 300, 550, 550, 560},
+               {300, 600, 700, 750, 760},
+               {600, 300, 700, 750, 760},
+       } {
+               jobU := lapack.SVDAll
+               jobVT := lapack.SVDAll
+
+               m := test.m
+               n := test.n
+               lda := test.lda
+               if lda == 0 {
+                       lda = n
+               }
+               ldu := test.ldu
+               if ldu == 0 {
+                       ldu = m
+               }
+               ldvt := test.ldvt
+               if ldvt == 0 {
+                       ldvt = n
+               }
+
+               a := make([]float64, m*lda)
+               for i := range a {
+                       a[i] = rnd.NormFloat64()
+               }
+
+               u := make([]float64, m*ldu)
+               for i := range u {
+                       u[i] = rnd.NormFloat64()
+               }
+
+               vt := make([]float64, n*ldvt)
+               for i := range vt {
+                       vt[i] = rnd.NormFloat64()
+               }
+
+               uAllOrig := make([]float64, len(u))
+               copy(uAllOrig, u)
+               vtAllOrig := make([]float64, len(vt))
+               copy(vtAllOrig, vt)
+               aCopy := make([]float64, len(a))
+               copy(aCopy, a)
+
+               s := make([]float64, min(m, n))
+
+               work := make([]float64, 1)
+               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
+
+               if !floats.Equal(a, aCopy) {
+                       t.Errorf("a changed during call to get work length")
+               }
+
+               work = make([]float64, int(work[0]))
+               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
+
+               errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
+               svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
+               svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
+
+               // Test InPlace
+               jobU = lapack.SVDInPlace
+               jobVT = lapack.SVDInPlace
+               copy(a, aCopy)
+               copy(u, uAllOrig)
+               copy(vt, vtAllOrig)
+
+               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
+               svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
+               svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
+       }
+}
+
+// svdCheckPartial checks that the singular values and vectors are computed when
+// not all of them are computed.
+func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
+       rnd := rand.New(rand.NewSource(1))
+       jobU := job
+       jobVT := job
+       // Compare the singular values when computed with {SVDNone, SVDNone.}
+       sCopy := make([]float64, len(s))
+       copy(sCopy, s)
+       copy(a, aCopy)
+       for i := range s {
+               s[i] = rnd.Float64()
+       }
+       tmp1 := make([]float64, 1)
+       tmp2 := make([]float64, 1)
+       jobU = lapack.SVDNone
+       jobVT = lapack.SVDNone
+
+       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
+       work = make([]float64, int(work[0]))
+       lwork := len(work)
+       if shortWork {
+               lwork--
+       }
+       ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
+       if !ok {
+               t.Errorf("Dgesvd did not complete successfully")
+       }
+       if !floats.EqualApprox(s, sCopy, 1e-10) {
+               t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
+       }
+       // Check that the singular vectors are correctly computed when the other
+       // is none.
+       uAll := make([]float64, len(u))
+       copy(uAll, u)
+       vtAll := make([]float64, len(vt))
+       copy(vtAll, vt)
+
+       // Copy the original vectors so the data outside the matrix bounds is the same.
+       copy(u, uAllOrig)
+       copy(vt, vtAllOrig)
+
+       jobU = job
+       jobVT = lapack.SVDNone
+       copy(a, aCopy)
+       for i := range s {
+               s[i] = rnd.Float64()
+       }
+       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
+       work = make([]float64, int(work[0]))
+       lwork = len(work)
+       if shortWork {
+               lwork--
+       }
+       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
+       if !floats.EqualApprox(uAll, u, 1e-10) {
+               t.Errorf("U mismatch when VT is not computed: %s", errStr)
+       }
+       if !floats.EqualApprox(s, sCopy, 1e-10) {
+               t.Errorf("Singular value mismatch when U computed VT not")
+       }
+       jobU = lapack.SVDNone
+       jobVT = job
+       copy(a, aCopy)
+       for i := range s {
+               s[i] = rnd.Float64()
+       }
+       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
+       work = make([]float64, int(work[0]))
+       lwork = len(work)
+       if shortWork {
+               lwork--
+       }
+       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
+       if !floats.EqualApprox(vtAll, vt, 1e-10) {
+               t.Errorf("VT mismatch when U is not computed: %s", errStr)
+       }
+       if !floats.EqualApprox(s, sCopy, 1e-10) {
+               t.Errorf("Singular value mismatch when VT computed U not")
+       }
+}
+
+// svdCheck checks that the singular value decomposition correctly multiplies back
+// to the original matrix.
+func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
+       sigma := blas64.General{
+               Rows:   m,
+               Cols:   n,
+               Stride: n,
+               Data:   make([]float64, m*n),
+       }
+       for i := 0; i < min(m, n); i++ {
+               sigma.Data[i*sigma.Stride+i] = s[i]
+       }
+
+       uMat := blas64.General{
+               Rows:   m,
+               Cols:   m,
+               Stride: ldu,
+               Data:   u,
+       }
+       vTMat := blas64.General{
+               Rows:   n,
+               Cols:   n,
+               Stride: ldvt,
+               Data:   vt,
+       }
+       if thin {
+               sigma.Rows = min(m, n)
+               sigma.Cols = min(m, n)
+               uMat.Cols = min(m, n)
+               vTMat.Rows = min(m, n)
+       }
+
+       tmp := blas64.General{
+               Rows:   m,
+               Cols:   n,
+               Stride: n,
+               Data:   make([]float64, m*n),
+       }
+       ans := blas64.General{
+               Rows:   m,
+               Cols:   n,
+               Stride: lda,
+               Data:   make([]float64, m*lda),
+       }
+       copy(ans.Data, a)
+
+       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
+       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
+
+       if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
+               t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
+       }
+
+       if !thin {
+               // Check that U and V are orthogonal.
+               for i := 0; i < uMat.Rows; i++ {
+                       for j := i + 1; j < uMat.Rows; j++ {
+                               dot := blas64.Dot(uMat.Cols,
+                                       blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
+                                       blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
+                               )
+                               if dot > 1e-8 {
+                                       t.Errorf("U not orthogonal %s", errStr)
+                               }
+                       }
+               }
+               for i := 0; i < vTMat.Rows; i++ {
+                       for j := i + 1; j < vTMat.Rows; j++ {
+                               dot := blas64.Dot(vTMat.Cols,
+                                       blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
+                                       blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
+                               )
+                               if dot > 1e-8 {
+                                       t.Errorf("V not orthogonal %s", errStr)
+                               }
+                       }
+               }
+       }
+}