1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
15 "gonum.org/v1/gonum/lapack"
18 type Dggsvd3er interface {
19 Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool)
22 func Dggsvd3Test(t *testing.T, impl Dggsvd3er) {
23 rnd := rand.New(rand.NewSource(1))
24 for cas, test := range []struct {
25 m, p, n, lda, ldb, ldu, ldv, ldq int
29 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
30 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
31 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
32 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
33 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
34 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
35 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
36 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
37 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
38 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
39 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
40 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
41 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
42 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
43 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
44 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
45 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
46 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
47 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
73 a := randomGeneral(m, n, lda, rnd)
74 aCopy := cloneGeneral(a)
75 b := randomGeneral(p, n, ldb, rnd)
76 bCopy := cloneGeneral(b)
78 alpha := make([]float64, n)
79 beta := make([]float64, n)
81 u := nanGeneral(m, m, ldu)
82 v := nanGeneral(p, p, ldv)
83 q := nanGeneral(n, n, ldq)
85 iwork := make([]int, n)
88 impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
99 work = make([]float64, lwork)
101 k, l, ok := impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
113 t.Errorf("test %d unexpectedly did not converge", cas)
118 // Check orthogonality of U, V and Q.
119 if !isOrthonormal(u) {
120 t.Errorf("test %d: U is not orthogonal\n%+v", cas, u)
122 if !isOrthonormal(v) {
123 t.Errorf("test %d: V is not orthogonal\n%+v", cas, v)
125 if !isOrthonormal(q) {
126 t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q)
129 // Check C^2 + S^2 = I.
130 var elements []float64
132 elements = alpha[k : k+l]
134 elements = alpha[k:m]
136 for i := range elements {
138 d := alpha[i]*alpha[i] + beta[i]*beta[i]
139 if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
140 t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d)
144 zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta)
146 // Check U^T*A*Q = D1*[ 0 R ].
147 uTmp := nanGeneral(m, n, n)
148 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
149 uAns := nanGeneral(m, n, n)
150 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
152 d10r := nanGeneral(m, n, n)
153 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)
155 if !equalApproxGeneral(uAns, d10r, 1e-14) {
156 t.Errorf("test %d: U^T*A*Q != D1*[ 0 R ]\nU^T*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
160 // Check V^T*B*Q = D2*[ 0 R ].
161 vTmp := nanGeneral(p, n, n)
162 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
163 vAns := nanGeneral(p, n, n)
164 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
166 d20r := nanGeneral(p, n, n)
167 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)
169 if !equalApproxGeneral(vAns, d20r, 1e-13) {
170 t.Errorf("test %d: V^T*B*Q != D2*[ 0 R ]\nV^T*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",