OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dsteqr.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dsteqrer interface {
19         Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool)
20         Dorgtrer
21 }
22
23 func DsteqrTest(t *testing.T, impl Dsteqrer) {
24         rnd := rand.New(rand.NewSource(1))
25         for _, compz := range []lapack.EVComp{lapack.OriginalEV, lapack.TridiagEV} {
26                 for _, test := range []struct {
27                         n, lda int
28                 }{
29                         {1, 0},
30                         {4, 0},
31                         {8, 0},
32                         {10, 0},
33
34                         {2, 10},
35                         {8, 10},
36                         {10, 20},
37                 } {
38                         for cas := 0; cas < 100; cas++ {
39                                 n := test.n
40                                 lda := test.lda
41                                 if lda == 0 {
42                                         lda = n
43                                 }
44                                 d := make([]float64, n)
45                                 for i := range d {
46                                         d[i] = rnd.Float64()
47                                 }
48                                 e := make([]float64, n-1)
49                                 for i := range e {
50                                         e[i] = rnd.Float64()
51                                 }
52                                 a := make([]float64, n*lda)
53                                 for i := range a {
54                                         a[i] = rnd.Float64()
55                                 }
56                                 dCopy := make([]float64, len(d))
57                                 copy(dCopy, d)
58                                 eCopy := make([]float64, len(e))
59                                 copy(eCopy, e)
60                                 aCopy := make([]float64, len(a))
61                                 copy(aCopy, a)
62                                 if compz == lapack.OriginalEV {
63                                         // Compute triangular decomposition and orthonormal matrix.
64                                         uplo := blas.Upper
65                                         tau := make([]float64, n)
66                                         work := make([]float64, 1)
67                                         impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1)
68                                         work = make([]float64, int(work[0]))
69                                         impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
70                                         impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
71                                 } else {
72                                         for i := 0; i < n; i++ {
73                                                 for j := 0; j < n; j++ {
74                                                         a[i*lda+j] = 0
75                                                         if i == j {
76                                                                 a[i*lda+j] = 1
77                                                         }
78                                                 }
79                                         }
80                                 }
81                                 work := make([]float64, 2*n)
82
83                                 aDecomp := make([]float64, len(a))
84                                 copy(aDecomp, a)
85                                 dDecomp := make([]float64, len(d))
86                                 copy(dDecomp, d)
87                                 eDecomp := make([]float64, len(e))
88                                 copy(eDecomp, e)
89                                 impl.Dsteqr(compz, n, d, e, a, lda, work)
90                                 dAns := make([]float64, len(d))
91                                 copy(dAns, d)
92
93                                 var truth blas64.General
94                                 if compz == lapack.OriginalEV {
95                                         truth = blas64.General{
96                                                 Rows:   n,
97                                                 Cols:   n,
98                                                 Stride: n,
99                                                 Data:   make([]float64, n*n),
100                                         }
101                                         for i := 0; i < n; i++ {
102                                                 for j := i; j < n; j++ {
103                                                         v := aCopy[i*lda+j]
104                                                         truth.Data[i*truth.Stride+j] = v
105                                                         truth.Data[j*truth.Stride+i] = v
106                                                 }
107                                         }
108                                 } else {
109                                         truth = blas64.General{
110                                                 Rows:   n,
111                                                 Cols:   n,
112                                                 Stride: n,
113                                                 Data:   make([]float64, n*n),
114                                         }
115                                         for i := 0; i < n; i++ {
116                                                 truth.Data[i*truth.Stride+i] = dCopy[i]
117                                                 if i != n-1 {
118                                                         truth.Data[(i+1)*truth.Stride+i] = eCopy[i]
119                                                         truth.Data[i*truth.Stride+i+1] = eCopy[i]
120                                                 }
121                                         }
122                                 }
123
124                                 V := blas64.General{
125                                         Rows:   n,
126                                         Cols:   n,
127                                         Stride: lda,
128                                         Data:   a,
129                                 }
130                                 if !eigenDecompCorrect(d, truth, V) {
131                                         t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v",
132                                                 compz == lapack.OriginalEV, n)
133                                 }
134
135                                 // Compare eigenvalues when not computing eigenvectors.
136                                 for i := range work {
137                                         work[i] = rnd.Float64()
138                                 }
139                                 impl.Dsteqr(lapack.None, n, dDecomp, eDecomp, aDecomp, lda, work)
140                                 if !floats.EqualApprox(d, dAns, 1e-8) {
141                                         t.Errorf("Eigenvalue mismatch when eigenvectors not computed")
142                                 }
143                         }
144                 }
145         }
146 }
147
148 // eigenDecompCorrect returns whether the eigen decomposition is correct.
149 // It checks if
150 //  A * v ≈ λ * v
151 // where the eigenvalues λ are stored in values, and the eigenvectors are stored
152 // in the columns of v.
153 func eigenDecompCorrect(values []float64, A, V blas64.General) bool {
154         n := A.Rows
155         for i := 0; i < n; i++ {
156                 lambda := values[i]
157                 vector := make([]float64, n)
158                 ans2 := make([]float64, n)
159                 for j := range vector {
160                         v := V.Data[j*V.Stride+i]
161                         vector[j] = v
162                         ans2[j] = lambda * v
163                 }
164                 v := blas64.Vector{Inc: 1, Data: vector}
165                 ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)}
166                 blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1)
167                 if !floats.EqualApprox(ans1.Data, ans2, 1e-8) {
168                         return false
169                 }
170         }
171         return true
172 }