--- /dev/null
+// Copyright ©2015 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testlapack
+
+import (
+ "fmt"
+ "testing"
+
+ "golang.org/x/exp/rand"
+
+ "gonum.org/v1/gonum/blas"
+ "gonum.org/v1/gonum/blas/blas64"
+ "gonum.org/v1/gonum/floats"
+ "gonum.org/v1/gonum/lapack"
+)
+
+type Dgesvder interface {
+ Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
+}
+
+func DgesvdTest(t *testing.T, impl Dgesvder) {
+ rnd := rand.New(rand.NewSource(1))
+ // TODO(btracey): Add tests for all of the cases when the SVD implementation
+ // is finished.
+ // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
+ // conditions are implemented. Right now mnthr is 5,000,000 which is too
+ // large to create a square matrix of that size.
+ for _, test := range []struct {
+ m, n, lda, ldu, ldvt int
+ }{
+ {5, 5, 0, 0, 0},
+ {5, 6, 0, 0, 0},
+ {6, 5, 0, 0, 0},
+ {5, 9, 0, 0, 0},
+ {9, 5, 0, 0, 0},
+
+ {5, 5, 10, 11, 12},
+ {5, 6, 10, 11, 12},
+ {6, 5, 10, 11, 12},
+ {5, 5, 10, 11, 12},
+ {5, 9, 10, 11, 12},
+ {9, 5, 10, 11, 12},
+
+ {300, 300, 0, 0, 0},
+ {300, 400, 0, 0, 0},
+ {400, 300, 0, 0, 0},
+ {300, 600, 0, 0, 0},
+ {600, 300, 0, 0, 0},
+
+ {300, 300, 400, 450, 460},
+ {300, 400, 500, 550, 560},
+ {400, 300, 550, 550, 560},
+ {300, 600, 700, 750, 760},
+ {600, 300, 700, 750, 760},
+ } {
+ jobU := lapack.SVDAll
+ jobVT := lapack.SVDAll
+
+ m := test.m
+ n := test.n
+ lda := test.lda
+ if lda == 0 {
+ lda = n
+ }
+ ldu := test.ldu
+ if ldu == 0 {
+ ldu = m
+ }
+ ldvt := test.ldvt
+ if ldvt == 0 {
+ ldvt = n
+ }
+
+ a := make([]float64, m*lda)
+ for i := range a {
+ a[i] = rnd.NormFloat64()
+ }
+
+ u := make([]float64, m*ldu)
+ for i := range u {
+ u[i] = rnd.NormFloat64()
+ }
+
+ vt := make([]float64, n*ldvt)
+ for i := range vt {
+ vt[i] = rnd.NormFloat64()
+ }
+
+ uAllOrig := make([]float64, len(u))
+ copy(uAllOrig, u)
+ vtAllOrig := make([]float64, len(vt))
+ copy(vtAllOrig, vt)
+ aCopy := make([]float64, len(a))
+ copy(aCopy, a)
+
+ s := make([]float64, min(m, n))
+
+ work := make([]float64, 1)
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
+
+ if !floats.Equal(a, aCopy) {
+ t.Errorf("a changed during call to get work length")
+ }
+
+ work = make([]float64, int(work[0]))
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
+
+ errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
+ svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
+ svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
+
+ // Test InPlace
+ jobU = lapack.SVDInPlace
+ jobVT = lapack.SVDInPlace
+ copy(a, aCopy)
+ copy(u, uAllOrig)
+ copy(vt, vtAllOrig)
+
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
+ svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
+ svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
+ }
+}
+
+// svdCheckPartial checks that the singular values and vectors are computed when
+// not all of them are computed.
+func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
+ rnd := rand.New(rand.NewSource(1))
+ jobU := job
+ jobVT := job
+ // Compare the singular values when computed with {SVDNone, SVDNone.}
+ sCopy := make([]float64, len(s))
+ copy(sCopy, s)
+ copy(a, aCopy)
+ for i := range s {
+ s[i] = rnd.Float64()
+ }
+ tmp1 := make([]float64, 1)
+ tmp2 := make([]float64, 1)
+ jobU = lapack.SVDNone
+ jobVT = lapack.SVDNone
+
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
+ work = make([]float64, int(work[0]))
+ lwork := len(work)
+ if shortWork {
+ lwork--
+ }
+ ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
+ if !ok {
+ t.Errorf("Dgesvd did not complete successfully")
+ }
+ if !floats.EqualApprox(s, sCopy, 1e-10) {
+ t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
+ }
+ // Check that the singular vectors are correctly computed when the other
+ // is none.
+ uAll := make([]float64, len(u))
+ copy(uAll, u)
+ vtAll := make([]float64, len(vt))
+ copy(vtAll, vt)
+
+ // Copy the original vectors so the data outside the matrix bounds is the same.
+ copy(u, uAllOrig)
+ copy(vt, vtAllOrig)
+
+ jobU = job
+ jobVT = lapack.SVDNone
+ copy(a, aCopy)
+ for i := range s {
+ s[i] = rnd.Float64()
+ }
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
+ work = make([]float64, int(work[0]))
+ lwork = len(work)
+ if shortWork {
+ lwork--
+ }
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
+ if !floats.EqualApprox(uAll, u, 1e-10) {
+ t.Errorf("U mismatch when VT is not computed: %s", errStr)
+ }
+ if !floats.EqualApprox(s, sCopy, 1e-10) {
+ t.Errorf("Singular value mismatch when U computed VT not")
+ }
+ jobU = lapack.SVDNone
+ jobVT = job
+ copy(a, aCopy)
+ for i := range s {
+ s[i] = rnd.Float64()
+ }
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
+ work = make([]float64, int(work[0]))
+ lwork = len(work)
+ if shortWork {
+ lwork--
+ }
+ impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
+ if !floats.EqualApprox(vtAll, vt, 1e-10) {
+ t.Errorf("VT mismatch when U is not computed: %s", errStr)
+ }
+ if !floats.EqualApprox(s, sCopy, 1e-10) {
+ t.Errorf("Singular value mismatch when VT computed U not")
+ }
+}
+
+// svdCheck checks that the singular value decomposition correctly multiplies back
+// to the original matrix.
+func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
+ sigma := blas64.General{
+ Rows: m,
+ Cols: n,
+ Stride: n,
+ Data: make([]float64, m*n),
+ }
+ for i := 0; i < min(m, n); i++ {
+ sigma.Data[i*sigma.Stride+i] = s[i]
+ }
+
+ uMat := blas64.General{
+ Rows: m,
+ Cols: m,
+ Stride: ldu,
+ Data: u,
+ }
+ vTMat := blas64.General{
+ Rows: n,
+ Cols: n,
+ Stride: ldvt,
+ Data: vt,
+ }
+ if thin {
+ sigma.Rows = min(m, n)
+ sigma.Cols = min(m, n)
+ uMat.Cols = min(m, n)
+ vTMat.Rows = min(m, n)
+ }
+
+ tmp := blas64.General{
+ Rows: m,
+ Cols: n,
+ Stride: n,
+ Data: make([]float64, m*n),
+ }
+ ans := blas64.General{
+ Rows: m,
+ Cols: n,
+ Stride: lda,
+ Data: make([]float64, m*lda),
+ }
+ copy(ans.Data, a)
+
+ blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
+ blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
+
+ if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
+ t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
+ }
+
+ if !thin {
+ // Check that U and V are orthogonal.
+ for i := 0; i < uMat.Rows; i++ {
+ for j := i + 1; j < uMat.Rows; j++ {
+ dot := blas64.Dot(uMat.Cols,
+ blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
+ blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
+ )
+ if dot > 1e-8 {
+ t.Errorf("U not orthogonal %s", errStr)
+ }
+ }
+ }
+ for i := 0; i < vTMat.Rows; i++ {
+ for j := i + 1; j < vTMat.Rows; j++ {
+ dot := blas64.Dot(vTMat.Cols,
+ blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
+ blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
+ )
+ if dot > 1e-8 {
+ t.Errorf("V not orthogonal %s", errStr)
+ }
+ }
+ }
+ }
+}