OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dggsvp3.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/lapack"
15 )
16
17 type Dggsvp3er interface {
18         Dlanger
19         Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int)
20 }
21
22 func Dggsvp3Test(t *testing.T, impl Dggsvp3er) {
23         rnd := rand.New(rand.NewSource(1))
24         for cas, test := range []struct {
25                 m, p, n, lda, ldb, ldu, ldv, ldq int
26         }{
27                 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
28                 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
29                 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
30                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
31                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
32                 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
33                 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
34                 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
35                 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
36                 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
37                 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
38                 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
39                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
40                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
41                 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
42                 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
43                 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
44                 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
45                 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
46         } {
47                 m := test.m
48                 p := test.p
49                 n := test.n
50                 lda := test.lda
51                 if lda == 0 {
52                         lda = n
53                 }
54                 ldb := test.ldb
55                 if ldb == 0 {
56                         ldb = n
57                 }
58                 ldu := test.ldu
59                 if ldu == 0 {
60                         ldu = m
61                 }
62                 ldv := test.ldv
63                 if ldv == 0 {
64                         ldv = p
65                 }
66                 ldq := test.ldq
67                 if ldq == 0 {
68                         ldq = n
69                 }
70
71                 a := randomGeneral(m, n, lda, rnd)
72                 aCopy := cloneGeneral(a)
73                 b := randomGeneral(p, n, ldb, rnd)
74                 bCopy := cloneGeneral(b)
75
76                 tola := float64(max(m, n)) * impl.Dlange(lapack.NormFrob, m, n, a.Data, a.Stride, nil) * dlamchE
77                 tolb := float64(max(p, n)) * impl.Dlange(lapack.NormFrob, p, n, b.Data, b.Stride, nil) * dlamchE
78
79                 u := nanGeneral(m, m, ldu)
80                 v := nanGeneral(p, p, ldv)
81                 q := nanGeneral(n, n, ldq)
82
83                 iwork := make([]int, n)
84                 tau := make([]float64, n)
85
86                 work := []float64{0}
87                 impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
88                         m, p, n,
89                         a.Data, a.Stride,
90                         b.Data, b.Stride,
91                         tola, tolb,
92                         u.Data, u.Stride,
93                         v.Data, v.Stride,
94                         q.Data, q.Stride,
95                         iwork, tau,
96                         work, -1)
97
98                 lwork := int(work[0])
99                 work = make([]float64, lwork)
100
101                 k, l := impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
102                         m, p, n,
103                         a.Data, a.Stride,
104                         b.Data, b.Stride,
105                         tola, tolb,
106                         u.Data, u.Stride,
107                         v.Data, v.Stride,
108                         q.Data, q.Stride,
109                         iwork, tau,
110                         work, lwork)
111
112                 // Check orthogonality of U, V and Q.
113                 if !isOrthonormal(u) {
114                         t.Errorf("test %d: U is not orthogonal\n%+v", cas, u)
115                 }
116                 if !isOrthonormal(v) {
117                         t.Errorf("test %d: V is not orthogonal\n%+v", cas, v)
118                 }
119                 if !isOrthonormal(q) {
120                         t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q)
121                 }
122
123                 zeroA, zeroB := constructGSVPresults(n, p, m, k, l, a, b)
124
125                 // Check U^T*A*Q = [ 0 RA ].
126                 uTmp := nanGeneral(m, n, n)
127                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
128                 uAns := nanGeneral(m, n, n)
129                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
130
131                 if !equalApproxGeneral(uAns, zeroA, 1e-14) {
132                         t.Errorf("test %d: U^T*A*Q != [ 0 RA ]\nU^T*A*Q:\n%+v\n[ 0 RA ]:\n%+v",
133                                 cas, uAns, zeroA)
134                 }
135
136                 // Check V^T*B*Q = [ 0 RB ].
137                 vTmp := nanGeneral(p, n, n)
138                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
139                 vAns := nanGeneral(p, n, n)
140                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
141
142                 if !equalApproxGeneral(vAns, zeroB, 1e-14) {
143                         t.Errorf("test %d: V^T*B*Q != [ 0 RB ]\nV^T*B*Q:\n%+v\n[ 0 RB ]:\n%+v",
144                                 cas, vAns, zeroB)
145                 }
146         }
147 }