OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgesvd.go
diff --git a/vendor/gonum.org/v1/gonum/lapack/testlapack/dgesvd.go b/vendor/gonum.org/v1/gonum/lapack/testlapack/dgesvd.go
deleted file mode 100644 (file)
index 4042e15..0000000
+++ /dev/null
@@ -1,287 +0,0 @@
-// Copyright ©2015 The Gonum Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package testlapack
-
-import (
-       "fmt"
-       "testing"
-
-       "golang.org/x/exp/rand"
-
-       "gonum.org/v1/gonum/blas"
-       "gonum.org/v1/gonum/blas/blas64"
-       "gonum.org/v1/gonum/floats"
-       "gonum.org/v1/gonum/lapack"
-)
-
-type Dgesvder interface {
-       Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
-}
-
-func DgesvdTest(t *testing.T, impl Dgesvder) {
-       rnd := rand.New(rand.NewSource(1))
-       // TODO(btracey): Add tests for all of the cases when the SVD implementation
-       // is finished.
-       // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
-       // conditions are implemented. Right now mnthr is 5,000,000 which is too
-       // large to create a square matrix of that size.
-       for _, test := range []struct {
-               m, n, lda, ldu, ldvt int
-       }{
-               {5, 5, 0, 0, 0},
-               {5, 6, 0, 0, 0},
-               {6, 5, 0, 0, 0},
-               {5, 9, 0, 0, 0},
-               {9, 5, 0, 0, 0},
-
-               {5, 5, 10, 11, 12},
-               {5, 6, 10, 11, 12},
-               {6, 5, 10, 11, 12},
-               {5, 5, 10, 11, 12},
-               {5, 9, 10, 11, 12},
-               {9, 5, 10, 11, 12},
-
-               {300, 300, 0, 0, 0},
-               {300, 400, 0, 0, 0},
-               {400, 300, 0, 0, 0},
-               {300, 600, 0, 0, 0},
-               {600, 300, 0, 0, 0},
-
-               {300, 300, 400, 450, 460},
-               {300, 400, 500, 550, 560},
-               {400, 300, 550, 550, 560},
-               {300, 600, 700, 750, 760},
-               {600, 300, 700, 750, 760},
-       } {
-               jobU := lapack.SVDAll
-               jobVT := lapack.SVDAll
-
-               m := test.m
-               n := test.n
-               lda := test.lda
-               if lda == 0 {
-                       lda = n
-               }
-               ldu := test.ldu
-               if ldu == 0 {
-                       ldu = m
-               }
-               ldvt := test.ldvt
-               if ldvt == 0 {
-                       ldvt = n
-               }
-
-               a := make([]float64, m*lda)
-               for i := range a {
-                       a[i] = rnd.NormFloat64()
-               }
-
-               u := make([]float64, m*ldu)
-               for i := range u {
-                       u[i] = rnd.NormFloat64()
-               }
-
-               vt := make([]float64, n*ldvt)
-               for i := range vt {
-                       vt[i] = rnd.NormFloat64()
-               }
-
-               uAllOrig := make([]float64, len(u))
-               copy(uAllOrig, u)
-               vtAllOrig := make([]float64, len(vt))
-               copy(vtAllOrig, vt)
-               aCopy := make([]float64, len(a))
-               copy(aCopy, a)
-
-               s := make([]float64, min(m, n))
-
-               work := make([]float64, 1)
-               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
-
-               if !floats.Equal(a, aCopy) {
-                       t.Errorf("a changed during call to get work length")
-               }
-
-               work = make([]float64, int(work[0]))
-               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
-
-               errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
-               svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
-               svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
-
-               // Test InPlace
-               jobU = lapack.SVDInPlace
-               jobVT = lapack.SVDInPlace
-               copy(a, aCopy)
-               copy(u, uAllOrig)
-               copy(vt, vtAllOrig)
-
-               impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
-               svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
-               svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
-       }
-}
-
-// svdCheckPartial checks that the singular values and vectors are computed when
-// not all of them are computed.
-func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
-       rnd := rand.New(rand.NewSource(1))
-       jobU := job
-       jobVT := job
-       // Compare the singular values when computed with {SVDNone, SVDNone.}
-       sCopy := make([]float64, len(s))
-       copy(sCopy, s)
-       copy(a, aCopy)
-       for i := range s {
-               s[i] = rnd.Float64()
-       }
-       tmp1 := make([]float64, 1)
-       tmp2 := make([]float64, 1)
-       jobU = lapack.SVDNone
-       jobVT = lapack.SVDNone
-
-       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
-       work = make([]float64, int(work[0]))
-       lwork := len(work)
-       if shortWork {
-               lwork--
-       }
-       ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
-       if !ok {
-               t.Errorf("Dgesvd did not complete successfully")
-       }
-       if !floats.EqualApprox(s, sCopy, 1e-10) {
-               t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
-       }
-       // Check that the singular vectors are correctly computed when the other
-       // is none.
-       uAll := make([]float64, len(u))
-       copy(uAll, u)
-       vtAll := make([]float64, len(vt))
-       copy(vtAll, vt)
-
-       // Copy the original vectors so the data outside the matrix bounds is the same.
-       copy(u, uAllOrig)
-       copy(vt, vtAllOrig)
-
-       jobU = job
-       jobVT = lapack.SVDNone
-       copy(a, aCopy)
-       for i := range s {
-               s[i] = rnd.Float64()
-       }
-       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
-       work = make([]float64, int(work[0]))
-       lwork = len(work)
-       if shortWork {
-               lwork--
-       }
-       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
-       if !floats.EqualApprox(uAll, u, 1e-10) {
-               t.Errorf("U mismatch when VT is not computed: %s", errStr)
-       }
-       if !floats.EqualApprox(s, sCopy, 1e-10) {
-               t.Errorf("Singular value mismatch when U computed VT not")
-       }
-       jobU = lapack.SVDNone
-       jobVT = job
-       copy(a, aCopy)
-       for i := range s {
-               s[i] = rnd.Float64()
-       }
-       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
-       work = make([]float64, int(work[0]))
-       lwork = len(work)
-       if shortWork {
-               lwork--
-       }
-       impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
-       if !floats.EqualApprox(vtAll, vt, 1e-10) {
-               t.Errorf("VT mismatch when U is not computed: %s", errStr)
-       }
-       if !floats.EqualApprox(s, sCopy, 1e-10) {
-               t.Errorf("Singular value mismatch when VT computed U not")
-       }
-}
-
-// svdCheck checks that the singular value decomposition correctly multiplies back
-// to the original matrix.
-func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
-       sigma := blas64.General{
-               Rows:   m,
-               Cols:   n,
-               Stride: n,
-               Data:   make([]float64, m*n),
-       }
-       for i := 0; i < min(m, n); i++ {
-               sigma.Data[i*sigma.Stride+i] = s[i]
-       }
-
-       uMat := blas64.General{
-               Rows:   m,
-               Cols:   m,
-               Stride: ldu,
-               Data:   u,
-       }
-       vTMat := blas64.General{
-               Rows:   n,
-               Cols:   n,
-               Stride: ldvt,
-               Data:   vt,
-       }
-       if thin {
-               sigma.Rows = min(m, n)
-               sigma.Cols = min(m, n)
-               uMat.Cols = min(m, n)
-               vTMat.Rows = min(m, n)
-       }
-
-       tmp := blas64.General{
-               Rows:   m,
-               Cols:   n,
-               Stride: n,
-               Data:   make([]float64, m*n),
-       }
-       ans := blas64.General{
-               Rows:   m,
-               Cols:   n,
-               Stride: lda,
-               Data:   make([]float64, m*lda),
-       }
-       copy(ans.Data, a)
-
-       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
-       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
-
-       if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
-               t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
-       }
-
-       if !thin {
-               // Check that U and V are orthogonal.
-               for i := 0; i < uMat.Rows; i++ {
-                       for j := i + 1; j < uMat.Rows; j++ {
-                               dot := blas64.Dot(uMat.Cols,
-                                       blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
-                                       blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
-                               )
-                               if dot > 1e-8 {
-                                       t.Errorf("U not orthogonal %s", errStr)
-                               }
-                       }
-               }
-               for i := 0; i < vTMat.Rows; i++ {
-                       for j := i + 1; j < vTMat.Rows; j++ {
-                               dot := blas64.Dot(vTMat.Cols,
-                                       blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
-                                       blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
-                               )
-                               if dot > 1e-8 {
-                                       t.Errorf("V not orthogonal %s", errStr)
-                               }
-                       }
-               }
-       }
-}