1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
15 "gonum.org/v1/gonum/lapack"
18 type Dsteqrer interface {
19 Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool)
23 func DsteqrTest(t *testing.T, impl Dsteqrer) {
24 rnd := rand.New(rand.NewSource(1))
25 for _, compz := range []lapack.EVComp{lapack.OriginalEV, lapack.TridiagEV} {
26 for _, test := range []struct {
38 for cas := 0; cas < 100; cas++ {
44 d := make([]float64, n)
48 e := make([]float64, n-1)
52 a := make([]float64, n*lda)
56 dCopy := make([]float64, len(d))
58 eCopy := make([]float64, len(e))
60 aCopy := make([]float64, len(a))
62 if compz == lapack.OriginalEV {
63 // Compute triangular decomposition and orthonormal matrix.
65 tau := make([]float64, n)
66 work := make([]float64, 1)
67 impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1)
68 work = make([]float64, int(work[0]))
69 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
70 impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
72 for i := 0; i < n; i++ {
73 for j := 0; j < n; j++ {
81 work := make([]float64, 2*n)
83 aDecomp := make([]float64, len(a))
85 dDecomp := make([]float64, len(d))
87 eDecomp := make([]float64, len(e))
89 impl.Dsteqr(compz, n, d, e, a, lda, work)
90 dAns := make([]float64, len(d))
93 var truth blas64.General
94 if compz == lapack.OriginalEV {
95 truth = blas64.General{
99 Data: make([]float64, n*n),
101 for i := 0; i < n; i++ {
102 for j := i; j < n; j++ {
104 truth.Data[i*truth.Stride+j] = v
105 truth.Data[j*truth.Stride+i] = v
109 truth = blas64.General{
113 Data: make([]float64, n*n),
115 for i := 0; i < n; i++ {
116 truth.Data[i*truth.Stride+i] = dCopy[i]
118 truth.Data[(i+1)*truth.Stride+i] = eCopy[i]
119 truth.Data[i*truth.Stride+i+1] = eCopy[i]
130 if !eigenDecompCorrect(d, truth, V) {
131 t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v",
132 compz == lapack.OriginalEV, n)
135 // Compare eigenvalues when not computing eigenvectors.
136 for i := range work {
137 work[i] = rnd.Float64()
139 impl.Dsteqr(lapack.None, n, dDecomp, eDecomp, aDecomp, lda, work)
140 if !floats.EqualApprox(d, dAns, 1e-8) {
141 t.Errorf("Eigenvalue mismatch when eigenvectors not computed")
148 // eigenDecompCorrect returns whether the eigen decomposition is correct.
151 // where the eigenvalues λ are stored in values, and the eigenvectors are stored
152 // in the columns of v.
153 func eigenDecompCorrect(values []float64, A, V blas64.General) bool {
155 for i := 0; i < n; i++ {
157 vector := make([]float64, n)
158 ans2 := make([]float64, n)
159 for j := range vector {
160 v := V.Data[j*V.Stride+i]
164 v := blas64.Vector{Inc: 1, Data: vector}
165 ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)}
166 blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1)
167 if !floats.EqualApprox(ans1.Data, ans2, 1e-8) {