OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlatrd.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "math"
10         "testing"
11
12         "golang.org/x/exp/rand"
13
14         "gonum.org/v1/gonum/blas"
15         "gonum.org/v1/gonum/blas/blas64"
16 )
17
18 type Dlatrder interface {
19         Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int)
20 }
21
22 func DlatrdTest(t *testing.T, impl Dlatrder) {
23         rnd := rand.New(rand.NewSource(1))
24         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
25                 for _, test := range []struct {
26                         n, nb, lda, ldw int
27                 }{
28                         {5, 2, 0, 0},
29                         {5, 5, 0, 0},
30
31                         {5, 3, 10, 11},
32                         {5, 5, 10, 11},
33                 } {
34                         n := test.n
35                         nb := test.nb
36                         lda := test.lda
37                         if lda == 0 {
38                                 lda = n
39                         }
40                         ldw := test.ldw
41                         if ldw == 0 {
42                                 ldw = nb
43                         }
44
45                         a := make([]float64, n*lda)
46                         for i := range a {
47                                 a[i] = rnd.NormFloat64()
48                         }
49
50                         e := make([]float64, n-1)
51                         for i := range e {
52                                 e[i] = math.NaN()
53                         }
54                         tau := make([]float64, n-1)
55                         for i := range tau {
56                                 tau[i] = math.NaN()
57                         }
58                         w := make([]float64, n*ldw)
59                         for i := range w {
60                                 w[i] = math.NaN()
61                         }
62
63                         aCopy := make([]float64, len(a))
64                         copy(aCopy, a)
65
66                         impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw)
67
68                         // Construct Q.
69                         ldq := n
70                         q := blas64.General{
71                                 Rows:   n,
72                                 Cols:   n,
73                                 Stride: ldq,
74                                 Data:   make([]float64, n*ldq),
75                         }
76                         for i := 0; i < n; i++ {
77                                 q.Data[i*ldq+i] = 1
78                         }
79                         if uplo == blas.Upper {
80                                 for i := n - 1; i >= n-nb; i-- {
81                                         if i == 0 {
82                                                 continue
83                                         }
84                                         h := blas64.General{
85                                                 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
86                                         }
87                                         for j := 0; j < n; j++ {
88                                                 h.Data[j*n+j] = 1
89                                         }
90                                         v := blas64.Vector{
91                                                 Inc:  1,
92                                                 Data: make([]float64, n),
93                                         }
94                                         for j := 0; j < i-1; j++ {
95                                                 v.Data[j] = a[j*lda+i]
96                                         }
97                                         v.Data[i-1] = 1
98
99                                         blas64.Ger(-tau[i-1], v, v, h)
100
101                                         qTmp := blas64.General{
102                                                 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
103                                         }
104                                         copy(qTmp.Data, q.Data)
105                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
106                                 }
107                         } else {
108                                 for i := 0; i < nb; i++ {
109                                         if i == n-1 {
110                                                 continue
111                                         }
112                                         h := blas64.General{
113                                                 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
114                                         }
115                                         for j := 0; j < n; j++ {
116                                                 h.Data[j*n+j] = 1
117                                         }
118                                         v := blas64.Vector{
119                                                 Inc:  1,
120                                                 Data: make([]float64, n),
121                                         }
122                                         v.Data[i+1] = 1
123                                         for j := i + 2; j < n; j++ {
124                                                 v.Data[j] = a[j*lda+i]
125                                         }
126                                         blas64.Ger(-tau[i], v, v, h)
127
128                                         qTmp := blas64.General{
129                                                 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
130                                         }
131                                         copy(qTmp.Data, q.Data)
132                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
133                                 }
134                         }
135                         errStr := fmt.Sprintf("isUpper = %v, n = %v, nb = %v", uplo == blas.Upper, n, nb)
136                         if !isOrthonormal(q) {
137                                 t.Errorf("Q not orthonormal. %s", errStr)
138                         }
139                         aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy})
140                         if !dlatrdCheckDecomposition(t, uplo, n, nb, e, tau, a, lda, aGen, q) {
141                                 t.Errorf("Decomposition mismatch. %s", errStr)
142                         }
143                 }
144         }
145 }
146
147 // dlatrdCheckDecomposition checks that the first nb rows have been successfully
148 // reduced.
149 func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
150         // Compute Q^T * A * Q.
151         tmp := blas64.General{
152                 Rows:   n,
153                 Cols:   n,
154                 Stride: n,
155                 Data:   make([]float64, n*n),
156         }
157
158         ans := blas64.General{
159                 Rows:   n,
160                 Cols:   n,
161                 Stride: n,
162                 Data:   make([]float64, n*n),
163         }
164
165         blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
166         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)
167
168         // Compare with T.
169         if uplo == blas.Upper {
170                 for i := n - 1; i >= n-nb; i-- {
171                         for j := 0; j < n; j++ {
172                                 v := ans.Data[i*ans.Stride+j]
173                                 switch {
174                                 case i == j:
175                                         if math.Abs(v-a[i*lda+j]) > 1e-10 {
176                                                 return false
177                                         }
178                                 case i == j-1:
179                                         if math.Abs(a[i*lda+j]-1) > 1e-10 {
180                                                 return false
181                                         }
182                                         if math.Abs(v-e[i]) > 1e-10 {
183                                                 return false
184                                         }
185                                 case i == j+1:
186                                 default:
187                                         if math.Abs(v) > 1e-10 {
188                                                 return false
189                                         }
190                                 }
191                         }
192                 }
193         } else {
194                 for i := 0; i < nb; i++ {
195                         for j := 0; j < n; j++ {
196                                 v := ans.Data[i*ans.Stride+j]
197                                 switch {
198                                 case i == j:
199                                         if math.Abs(v-a[i*lda+j]) > 1e-10 {
200                                                 return false
201                                         }
202                                 case i == j-1:
203                                 case i == j+1:
204                                         if math.Abs(a[i*lda+j]-1) > 1e-10 {
205                                                 return false
206                                         }
207                                         if math.Abs(v-e[i-1]) > 1e-10 {
208                                                 return false
209                                         }
210                                 default:
211                                         if math.Abs(v) > 1e-10 {
212                                                 return false
213                                         }
214                                 }
215                         }
216                 }
217         }
218         return true
219 }
220
221 // genFromSym constructs a (symmetric) general matrix from the data in the
222 // symmetric.
223 // TODO(btracey): Replace other constructions of this with a call to this function.
224 func genFromSym(a blas64.Symmetric) blas64.General {
225         n := a.N
226         lda := a.Stride
227         uplo := a.Uplo
228         b := blas64.General{
229                 Rows:   n,
230                 Cols:   n,
231                 Stride: n,
232                 Data:   make([]float64, n*n),
233         }
234
235         for i := 0; i < n; i++ {
236                 for j := i; j < n; j++ {
237                         v := a.Data[i*lda+j]
238                         if uplo == blas.Lower {
239                                 v = a.Data[j*lda+i]
240                         }
241                         b.Data[i*n+j] = v
242                         b.Data[j*n+i] = v
243                 }
244         }
245         return b
246 }