OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dggsvd3.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dggsvd3er interface {
19         Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool)
20 }
21
22 func Dggsvd3Test(t *testing.T, impl Dggsvd3er) {
23         rnd := rand.New(rand.NewSource(1))
24         for cas, test := range []struct {
25                 m, p, n, lda, ldb, ldu, ldv, ldq int
26
27                 ok bool
28         }{
29                 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
30                 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
31                 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
32                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
33                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
34                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
35                 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
36                 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
37                 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
38                 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
39                 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
40                 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
41                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
42                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
43                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
44                 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
45                 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
46                 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
47                 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
48         } {
49                 m := test.m
50                 p := test.p
51                 n := test.n
52                 lda := test.lda
53                 if lda == 0 {
54                         lda = n
55                 }
56                 ldb := test.ldb
57                 if ldb == 0 {
58                         ldb = n
59                 }
60                 ldu := test.ldu
61                 if ldu == 0 {
62                         ldu = m
63                 }
64                 ldv := test.ldv
65                 if ldv == 0 {
66                         ldv = p
67                 }
68                 ldq := test.ldq
69                 if ldq == 0 {
70                         ldq = n
71                 }
72
73                 a := randomGeneral(m, n, lda, rnd)
74                 aCopy := cloneGeneral(a)
75                 b := randomGeneral(p, n, ldb, rnd)
76                 bCopy := cloneGeneral(b)
77
78                 alpha := make([]float64, n)
79                 beta := make([]float64, n)
80
81                 u := nanGeneral(m, m, ldu)
82                 v := nanGeneral(p, p, ldv)
83                 q := nanGeneral(n, n, ldq)
84
85                 iwork := make([]int, n)
86
87                 work := []float64{0}
88                 impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
89                         m, n, p,
90                         a.Data, a.Stride,
91                         b.Data, b.Stride,
92                         alpha, beta,
93                         u.Data, u.Stride,
94                         v.Data, v.Stride,
95                         q.Data, q.Stride,
96                         work, -1, iwork)
97
98                 lwork := int(work[0])
99                 work = make([]float64, lwork)
100
101                 k, l, ok := impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
102                         m, n, p,
103                         a.Data, a.Stride,
104                         b.Data, b.Stride,
105                         alpha, beta,
106                         u.Data, u.Stride,
107                         v.Data, v.Stride,
108                         q.Data, q.Stride,
109                         work, lwork, iwork)
110
111                 if !ok {
112                         if test.ok {
113                                 t.Errorf("test %d unexpectedly did not converge", cas)
114                         }
115                         continue
116                 }
117
118                 // Check orthogonality of U, V and Q.
119                 if !isOrthonormal(u) {
120                         t.Errorf("test %d: U is not orthogonal\n%+v", cas, u)
121                 }
122                 if !isOrthonormal(v) {
123                         t.Errorf("test %d: V is not orthogonal\n%+v", cas, v)
124                 }
125                 if !isOrthonormal(q) {
126                         t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q)
127                 }
128
129                 // Check C^2 + S^2 = I.
130                 var elements []float64
131                 if m-k-l >= 0 {
132                         elements = alpha[k : k+l]
133                 } else {
134                         elements = alpha[k:m]
135                 }
136                 for i := range elements {
137                         i += k
138                         d := alpha[i]*alpha[i] + beta[i]*beta[i]
139                         if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
140                                 t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d)
141                         }
142                 }
143
144                 zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta)
145
146                 // Check U^T*A*Q = D1*[ 0 R ].
147                 uTmp := nanGeneral(m, n, n)
148                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
149                 uAns := nanGeneral(m, n, n)
150                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
151
152                 d10r := nanGeneral(m, n, n)
153                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)
154
155                 if !equalApproxGeneral(uAns, d10r, 1e-14) {
156                         t.Errorf("test %d: U^T*A*Q != D1*[ 0 R ]\nU^T*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
157                                 cas, uAns, d10r)
158                 }
159
160                 // Check V^T*B*Q = D2*[ 0 R ].
161                 vTmp := nanGeneral(p, n, n)
162                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
163                 vAns := nanGeneral(p, n, n)
164                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
165
166                 d20r := nanGeneral(p, n, n)
167                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)
168
169                 if !equalApproxGeneral(vAns, d20r, 1e-13) {
170                         t.Errorf("test %d: V^T*B*Q != D2*[ 0 R ]\nV^T*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",
171                                 cas, vAns, d20r)
172                 }
173         }
174 }