OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dbdsqr.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "sort"
10         "testing"
11
12         "golang.org/x/exp/rand"
13
14         "gonum.org/v1/gonum/blas"
15         "gonum.org/v1/gonum/blas/blas64"
16         "gonum.org/v1/gonum/floats"
17 )
18
19 type Dbdsqrer interface {
20         Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool)
21 }
22
23 func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
24         rnd := rand.New(rand.NewSource(1))
25         bi := blas64.Implementation()
26         _ = bi
27         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
28                 for _, test := range []struct {
29                         n, ncvt, nru, ncc, ldvt, ldu, ldc int
30                 }{
31                         {5, 5, 5, 5, 0, 0, 0},
32                         {10, 10, 10, 10, 0, 0, 0},
33                         {10, 11, 12, 13, 0, 0, 0},
34                         {20, 13, 12, 11, 0, 0, 0},
35
36                         {5, 5, 5, 5, 6, 7, 8},
37                         {10, 10, 10, 10, 30, 40, 50},
38                         {10, 12, 11, 13, 30, 40, 50},
39                         {20, 12, 13, 11, 30, 40, 50},
40
41                         {130, 130, 130, 500, 900, 900, 500},
42                 } {
43                         for cas := 0; cas < 10; cas++ {
44                                 n := test.n
45                                 ncvt := test.ncvt
46                                 nru := test.nru
47                                 ncc := test.ncc
48                                 ldvt := test.ldvt
49                                 ldu := test.ldu
50                                 ldc := test.ldc
51                                 if ldvt == 0 {
52                                         ldvt = ncvt
53                                 }
54                                 if ldu == 0 {
55                                         ldu = n
56                                 }
57                                 if ldc == 0 {
58                                         ldc = ncc
59                                 }
60
61                                 d := make([]float64, n)
62                                 for i := range d {
63                                         d[i] = rnd.NormFloat64()
64                                 }
65                                 e := make([]float64, n-1)
66                                 for i := range e {
67                                         e[i] = rnd.NormFloat64()
68                                 }
69                                 dCopy := make([]float64, len(d))
70                                 copy(dCopy, d)
71                                 eCopy := make([]float64, len(e))
72                                 copy(eCopy, e)
73                                 work := make([]float64, 4*n)
74                                 for i := range work {
75                                         work[i] = rnd.NormFloat64()
76                                 }
77
78                                 // First test the decomposition of the bidiagonal matrix. Set
79                                 // pt and u equal to I with the correct size. At the result
80                                 // of Dbdsqr, p and u  will contain the data of P^T and Q, which
81                                 // will be used in the next step to test the multiplication
82                                 // with Q and VT.
83
84                                 q := make([]float64, n*n)
85                                 ldq := n
86                                 pt := make([]float64, n*n)
87                                 ldpt := n
88                                 for i := 0; i < n; i++ {
89                                         q[i*ldq+i] = 1
90                                 }
91                                 for i := 0; i < n; i++ {
92                                         pt[i*ldpt+i] = 1
93                                 }
94
95                                 ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work)
96
97                                 isUpper := uplo == blas.Upper
98                                 errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
99                                 if !ok {
100                                         t.Errorf("Unexpected Dbdsqr failure: %s", errStr)
101                                 }
102
103                                 bMat := constructBidiagonal(uplo, n, dCopy, eCopy)
104                                 sMat := constructBidiagonal(uplo, n, d, e)
105
106                                 tmp := blas64.General{
107                                         Rows:   n,
108                                         Cols:   n,
109                                         Stride: n,
110                                         Data:   make([]float64, n*n),
111                                 }
112                                 ansMat := blas64.General{
113                                         Rows:   n,
114                                         Cols:   n,
115                                         Stride: n,
116                                         Data:   make([]float64, n*n),
117                                 }
118
119                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride)
120                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride)
121
122                                 same := true
123                                 for i := 0; i < n; i++ {
124                                         for j := 0; j < n; j++ {
125                                                 if !floats.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) {
126                                                         same = false
127                                                 }
128                                         }
129                                 }
130                                 if !same {
131                                         t.Errorf("Bidiagonal mismatch. %s", errStr)
132                                 }
133                                 if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) {
134                                         t.Errorf("D is not sorted. %s", errStr)
135                                 }
136
137                                 // The above computed the real P and Q. Now input data for V^T,
138                                 // U, and C to check that the multiplications happen properly.
139                                 dAns := make([]float64, len(d))
140                                 copy(dAns, d)
141                                 eAns := make([]float64, len(e))
142                                 copy(eAns, e)
143
144                                 u := make([]float64, nru*ldu)
145                                 for i := range u {
146                                         u[i] = rnd.NormFloat64()
147                                 }
148                                 uCopy := make([]float64, len(u))
149                                 copy(uCopy, u)
150                                 vt := make([]float64, n*ldvt)
151                                 for i := range vt {
152                                         vt[i] = rnd.NormFloat64()
153                                 }
154                                 vtCopy := make([]float64, len(vt))
155                                 copy(vtCopy, vt)
156                                 c := make([]float64, n*ldc)
157                                 for i := range c {
158                                         c[i] = rnd.NormFloat64()
159                                 }
160                                 cCopy := make([]float64, len(c))
161                                 copy(cCopy, c)
162
163                                 // Reset input data
164                                 copy(d, dCopy)
165                                 copy(e, eCopy)
166                                 impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
167
168                                 // Check result.
169                                 if !floats.EqualApprox(d, dAns, 1e-14) {
170                                         t.Errorf("D mismatch second time. %s", errStr)
171                                 }
172                                 if !floats.EqualApprox(e, eAns, 1e-14) {
173                                         t.Errorf("E mismatch second time. %s", errStr)
174                                 }
175                                 ans := make([]float64, len(vtCopy))
176                                 copy(ans, vtCopy)
177                                 ldans := ldvt
178                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans)
179                                 if !floats.EqualApprox(ans, vt, 1e-10) {
180                                         t.Errorf("Vt result mismatch. %s", errStr)
181                                 }
182                                 ans = make([]float64, len(uCopy))
183                                 copy(ans, uCopy)
184                                 ldans = ldu
185                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans)
186                                 if !floats.EqualApprox(ans, u, 1e-10) {
187                                         t.Errorf("U result mismatch. %s", errStr)
188                                 }
189                                 ans = make([]float64, len(cCopy))
190                                 copy(ans, cCopy)
191                                 ldans = ldc
192                                 bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans)
193                                 if !floats.EqualApprox(ans, c, 1e-10) {
194                                         t.Errorf("C result mismatch. %s", errStr)
195                                 }
196                         }
197                 }
198         }
199 }