// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "fmt" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/lapack" ) type Dgesvder interface { Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool) } func DgesvdTest(t *testing.T, impl Dgesvder) { rnd := rand.New(rand.NewSource(1)) // TODO(btracey): Add tests for all of the cases when the SVD implementation // is finished. // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD // conditions are implemented. Right now mnthr is 5,000,000 which is too // large to create a square matrix of that size. for _, test := range []struct { m, n, lda, ldu, ldvt int }{ {5, 5, 0, 0, 0}, {5, 6, 0, 0, 0}, {6, 5, 0, 0, 0}, {5, 9, 0, 0, 0}, {9, 5, 0, 0, 0}, {5, 5, 10, 11, 12}, {5, 6, 10, 11, 12}, {6, 5, 10, 11, 12}, {5, 5, 10, 11, 12}, {5, 9, 10, 11, 12}, {9, 5, 10, 11, 12}, {300, 300, 0, 0, 0}, {300, 400, 0, 0, 0}, {400, 300, 0, 0, 0}, {300, 600, 0, 0, 0}, {600, 300, 0, 0, 0}, {300, 300, 400, 450, 460}, {300, 400, 500, 550, 560}, {400, 300, 550, 550, 560}, {300, 600, 700, 750, 760}, {600, 300, 700, 750, 760}, } { jobU := lapack.SVDAll jobVT := lapack.SVDAll m := test.m n := test.n lda := test.lda if lda == 0 { lda = n } ldu := test.ldu if ldu == 0 { ldu = m } ldvt := test.ldvt if ldvt == 0 { ldvt = n } a := make([]float64, m*lda) for i := range a { a[i] = rnd.NormFloat64() } u := make([]float64, m*ldu) for i := range u { u[i] = rnd.NormFloat64() } vt := make([]float64, n*ldvt) for i := range vt { vt[i] = rnd.NormFloat64() } uAllOrig := make([]float64, len(u)) copy(uAllOrig, u) vtAllOrig := make([]float64, len(vt)) copy(vtAllOrig, vt) aCopy := make([]float64, len(a)) copy(aCopy, a) s := make([]float64, min(m, n)) work := make([]float64, 1) impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1) if !floats.Equal(a, aCopy) { t.Errorf("a changed during call to get work length") } work = make([]float64, int(work[0])) impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt) svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) // Test InPlace jobU = lapack.SVDInPlace jobVT = lapack.SVDInPlace copy(a, aCopy) copy(u, uAllOrig) copy(vt, vtAllOrig) impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) } } // svdCheckPartial checks that the singular values and vectors are computed when // not all of them are computed. func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) { rnd := rand.New(rand.NewSource(1)) jobU := job jobVT := job // Compare the singular values when computed with {SVDNone, SVDNone.} sCopy := make([]float64, len(s)) copy(sCopy, s) copy(a, aCopy) for i := range s { s[i] = rnd.Float64() } tmp1 := make([]float64, 1) tmp2 := make([]float64, 1) jobU = lapack.SVDNone jobVT = lapack.SVDNone impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1) work = make([]float64, int(work[0])) lwork := len(work) if shortWork { lwork-- } ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork) if !ok { t.Errorf("Dgesvd did not complete successfully") } if !floats.EqualApprox(s, sCopy, 1e-10) { t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr) } // Check that the singular vectors are correctly computed when the other // is none. uAll := make([]float64, len(u)) copy(uAll, u) vtAll := make([]float64, len(vt)) copy(vtAll, vt) // Copy the original vectors so the data outside the matrix bounds is the same. copy(u, uAllOrig) copy(vt, vtAllOrig) jobU = job jobVT = lapack.SVDNone copy(a, aCopy) for i := range s { s[i] = rnd.Float64() } impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1) work = make([]float64, int(work[0])) lwork = len(work) if shortWork { lwork-- } impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work)) if !floats.EqualApprox(uAll, u, 1e-10) { t.Errorf("U mismatch when VT is not computed: %s", errStr) } if !floats.EqualApprox(s, sCopy, 1e-10) { t.Errorf("Singular value mismatch when U computed VT not") } jobU = lapack.SVDNone jobVT = job copy(a, aCopy) for i := range s { s[i] = rnd.Float64() } impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1) work = make([]float64, int(work[0])) lwork = len(work) if shortWork { lwork-- } impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work)) if !floats.EqualApprox(vtAll, vt, 1e-10) { t.Errorf("VT mismatch when U is not computed: %s", errStr) } if !floats.EqualApprox(s, sCopy, 1e-10) { t.Errorf("Singular value mismatch when VT computed U not") } } // svdCheck checks that the singular value decomposition correctly multiplies back // to the original matrix. func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) { sigma := blas64.General{ Rows: m, Cols: n, Stride: n, Data: make([]float64, m*n), } for i := 0; i < min(m, n); i++ { sigma.Data[i*sigma.Stride+i] = s[i] } uMat := blas64.General{ Rows: m, Cols: m, Stride: ldu, Data: u, } vTMat := blas64.General{ Rows: n, Cols: n, Stride: ldvt, Data: vt, } if thin { sigma.Rows = min(m, n) sigma.Cols = min(m, n) uMat.Cols = min(m, n) vTMat.Rows = min(m, n) } tmp := blas64.General{ Rows: m, Cols: n, Stride: n, Data: make([]float64, m*n), } ans := blas64.General{ Rows: m, Cols: n, Stride: lda, Data: make([]float64, m*lda), } copy(ans.Data, a) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans) if !floats.EqualApprox(ans.Data, aCopy, 1e-8) { t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr) } if !thin { // Check that U and V are orthogonal. for i := 0; i < uMat.Rows; i++ { for j := i + 1; j < uMat.Rows; j++ { dot := blas64.Dot(uMat.Cols, blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]}, blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]}, ) if dot > 1e-8 { t.Errorf("U not orthogonal %s", errStr) } } } for i := 0; i < vTMat.Rows; i++ { for j := i + 1; j < vTMat.Rows; j++ { dot := blas64.Dot(vTMat.Cols, blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]}, blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]}, ) if dot > 1e-8 { t.Errorf("V not orthogonal %s", errStr) } } } } }