1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
12 "golang.org/x/exp/rand"
14 "gonum.org/v1/gonum/blas"
15 "gonum.org/v1/gonum/blas/blas64"
16 "gonum.org/v1/gonum/floats"
19 type Dbdsqrer interface {
20 Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool)
23 func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
24 rnd := rand.New(rand.NewSource(1))
25 bi := blas64.Implementation()
27 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
28 for _, test := range []struct {
29 n, ncvt, nru, ncc, ldvt, ldu, ldc int
31 {5, 5, 5, 5, 0, 0, 0},
32 {10, 10, 10, 10, 0, 0, 0},
33 {10, 11, 12, 13, 0, 0, 0},
34 {20, 13, 12, 11, 0, 0, 0},
36 {5, 5, 5, 5, 6, 7, 8},
37 {10, 10, 10, 10, 30, 40, 50},
38 {10, 12, 11, 13, 30, 40, 50},
39 {20, 12, 13, 11, 30, 40, 50},
41 {130, 130, 130, 500, 900, 900, 500},
43 for cas := 0; cas < 10; cas++ {
61 d := make([]float64, n)
63 d[i] = rnd.NormFloat64()
65 e := make([]float64, n-1)
67 e[i] = rnd.NormFloat64()
69 dCopy := make([]float64, len(d))
71 eCopy := make([]float64, len(e))
73 work := make([]float64, 4*n)
75 work[i] = rnd.NormFloat64()
78 // First test the decomposition of the bidiagonal matrix. Set
79 // pt and u equal to I with the correct size. At the result
80 // of Dbdsqr, p and u will contain the data of P^T and Q, which
81 // will be used in the next step to test the multiplication
84 q := make([]float64, n*n)
86 pt := make([]float64, n*n)
88 for i := 0; i < n; i++ {
91 for i := 0; i < n; i++ {
95 ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work)
97 isUpper := uplo == blas.Upper
98 errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
100 t.Errorf("Unexpected Dbdsqr failure: %s", errStr)
103 bMat := constructBidiagonal(uplo, n, dCopy, eCopy)
104 sMat := constructBidiagonal(uplo, n, d, e)
106 tmp := blas64.General{
110 Data: make([]float64, n*n),
112 ansMat := blas64.General{
116 Data: make([]float64, n*n),
119 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride)
120 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride)
123 for i := 0; i < n; i++ {
124 for j := 0; j < n; j++ {
125 if !floats.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) {
131 t.Errorf("Bidiagonal mismatch. %s", errStr)
133 if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) {
134 t.Errorf("D is not sorted. %s", errStr)
137 // The above computed the real P and Q. Now input data for V^T,
138 // U, and C to check that the multiplications happen properly.
139 dAns := make([]float64, len(d))
141 eAns := make([]float64, len(e))
144 u := make([]float64, nru*ldu)
146 u[i] = rnd.NormFloat64()
148 uCopy := make([]float64, len(u))
150 vt := make([]float64, n*ldvt)
152 vt[i] = rnd.NormFloat64()
154 vtCopy := make([]float64, len(vt))
156 c := make([]float64, n*ldc)
158 c[i] = rnd.NormFloat64()
160 cCopy := make([]float64, len(c))
166 impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
169 if !floats.EqualApprox(d, dAns, 1e-14) {
170 t.Errorf("D mismatch second time. %s", errStr)
172 if !floats.EqualApprox(e, eAns, 1e-14) {
173 t.Errorf("E mismatch second time. %s", errStr)
175 ans := make([]float64, len(vtCopy))
178 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans)
179 if !floats.EqualApprox(ans, vt, 1e-10) {
180 t.Errorf("Vt result mismatch. %s", errStr)
182 ans = make([]float64, len(uCopy))
185 bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans)
186 if !floats.EqualApprox(ans, u, 1e-10) {
187 t.Errorf("U result mismatch. %s", errStr)
189 ans = make([]float64, len(cCopy))
192 bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans)
193 if !floats.EqualApprox(ans, c, 1e-10) {
194 t.Errorf("C result mismatch. %s", errStr)