OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dsytrd.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dsytrder interface {
18         Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
19
20         Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
21         Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
22 }
23
24 func DsytrdTest(t *testing.T, impl Dsytrder) {
25         const tol = 1e-13
26         rnd := rand.New(rand.NewSource(1))
27         for tc, test := range []struct {
28                 n, lda int
29         }{
30                 {1, 0},
31                 {2, 0},
32                 {3, 0},
33                 {4, 0},
34                 {10, 0},
35                 {50, 0},
36                 {100, 0},
37                 {150, 0},
38                 {300, 0},
39
40                 {1, 3},
41                 {2, 3},
42                 {3, 7},
43                 {4, 9},
44                 {10, 20},
45                 {50, 70},
46                 {100, 120},
47                 {150, 170},
48                 {300, 320},
49         } {
50                 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
51                         for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
52                                 n := test.n
53                                 lda := test.lda
54                                 if lda == 0 {
55                                         lda = n
56                                 }
57                                 a := randomGeneral(n, n, lda, rnd)
58                                 for i := 1; i < n; i++ {
59                                         for j := 0; j < i; j++ {
60                                                 a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
61                                         }
62                                 }
63                                 aCopy := cloneGeneral(a)
64
65                                 d := nanSlice(n)
66                                 e := nanSlice(n - 1)
67                                 tau := nanSlice(n - 1)
68
69                                 var lwork int
70                                 switch wl {
71                                 case minimumWork:
72                                         lwork = 1
73                                 case mediumWork:
74                                         work := make([]float64, 1)
75                                         impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
76                                         lwork = (int(work[0]) + 1) / 2
77                                         lwork = max(1, lwork)
78                                 case optimumWork:
79                                         work := make([]float64, 1)
80                                         impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
81                                         lwork = int(work[0])
82                                 }
83                                 work := make([]float64, lwork)
84
85                                 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
86
87                                 prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
88                                         tc, uplo, n, lda, wl)
89
90                                 if !generalOutsideAllNaN(a) {
91                                         t.Errorf("%v: out-of-range write to A", prefix)
92                                 }
93
94                                 // Extract Q by doing what Dorgtr does.
95                                 q := cloneGeneral(a)
96                                 if uplo == blas.Upper {
97                                         for j := 0; j < n-1; j++ {
98                                                 for i := 0; i < j; i++ {
99                                                         q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
100                                                 }
101                                                 q.Data[(n-1)*q.Stride+j] = 0
102                                         }
103                                         for i := 0; i < n-1; i++ {
104                                                 q.Data[i*q.Stride+n-1] = 0
105                                         }
106                                         q.Data[(n-1)*q.Stride+n-1] = 1
107                                         if n > 1 {
108                                                 work = make([]float64, n-1)
109                                                 impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
110                                         }
111                                 } else {
112                                         for j := n - 1; j > 0; j-- {
113                                                 q.Data[j] = 0
114                                                 for i := j + 1; i < n; i++ {
115                                                         q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
116                                                 }
117                                         }
118                                         q.Data[0] = 1
119                                         for i := 1; i < n; i++ {
120                                                 q.Data[i*q.Stride] = 0
121                                         }
122                                         if n > 1 {
123                                                 work = make([]float64, n-1)
124                                                 impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
125                                         }
126                                 }
127                                 if !isOrthonormal(q) {
128                                         t.Errorf("%v: Q not orthogonal", prefix)
129                                 }
130
131                                 // Contruct symmetric tridiagonal T from d and e.
132                                 tMat := zeros(n, n, n)
133                                 for i := 0; i < n; i++ {
134                                         tMat.Data[i*tMat.Stride+i] = d[i]
135                                 }
136                                 if uplo == blas.Upper {
137                                         for j := 1; j < n; j++ {
138                                                 tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
139                                                 tMat.Data[j*tMat.Stride+j-1] = e[j-1]
140                                         }
141                                 } else {
142                                         for j := 0; j < n-1; j++ {
143                                                 tMat.Data[(j+1)*tMat.Stride+j] = e[j]
144                                                 tMat.Data[j*tMat.Stride+j+1] = e[j]
145                                         }
146                                 }
147
148                                 // Compute Q^T * A * Q.
149                                 tmp := zeros(n, n, n)
150                                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
151                                 got := zeros(n, n, n)
152                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
153
154                                 // Compare with T.
155                                 if !equalApproxGeneral(got, tMat, tol) {
156                                         t.Errorf("%v: Q^T*A*Q != T", prefix)
157                                 }
158                         }
159                 }
160         }
161 }