OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dsyev.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dsyever interface {
19         Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool)
20 }
21
22 func DsyevTest(t *testing.T, impl Dsyever) {
23         rnd := rand.New(rand.NewSource(1))
24         for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
25                 for _, test := range []struct {
26                         n, lda int
27                 }{
28                         {1, 0},
29                         {2, 0},
30                         {5, 0},
31                         {10, 0},
32                         {100, 0},
33
34                         {1, 5},
35                         {2, 5},
36                         {5, 10},
37                         {10, 20},
38                         {100, 110},
39                 } {
40                         for cas := 0; cas < 10; cas++ {
41                                 n := test.n
42                                 lda := test.lda
43                                 if lda == 0 {
44                                         lda = n
45                                 }
46                                 a := make([]float64, n*lda)
47                                 for i := range a {
48                                         a[i] = rnd.NormFloat64()
49                                 }
50                                 aCopy := make([]float64, len(a))
51                                 copy(aCopy, a)
52                                 w := make([]float64, n)
53                                 for i := range w {
54                                         w[i] = rnd.NormFloat64()
55                                 }
56
57                                 work := make([]float64, 1)
58                                 impl.Dsyev(lapack.ComputeEV, uplo, n, a, lda, w, work, -1)
59                                 work = make([]float64, int(work[0]))
60                                 impl.Dsyev(lapack.ComputeEV, uplo, n, a, lda, w, work, len(work))
61
62                                 // Check that the decomposition is correct
63                                 orig := blas64.General{
64                                         Rows:   n,
65                                         Cols:   n,
66                                         Stride: n,
67                                         Data:   make([]float64, n*n),
68                                 }
69                                 if uplo == blas.Upper {
70                                         for i := 0; i < n; i++ {
71                                                 for j := i; j < n; j++ {
72                                                         v := aCopy[i*lda+j]
73                                                         orig.Data[i*orig.Stride+j] = v
74                                                         orig.Data[j*orig.Stride+i] = v
75                                                 }
76                                         }
77                                 } else {
78                                         for i := 0; i < n; i++ {
79                                                 for j := 0; j <= i; j++ {
80                                                         v := aCopy[i*lda+j]
81                                                         orig.Data[i*orig.Stride+j] = v
82                                                         orig.Data[j*orig.Stride+i] = v
83                                                 }
84                                         }
85                                 }
86
87                                 V := blas64.General{
88                                         Rows:   n,
89                                         Cols:   n,
90                                         Stride: lda,
91                                         Data:   a,
92                                 }
93
94                                 if !eigenDecompCorrect(w, orig, V) {
95                                         t.Errorf("Decomposition mismatch")
96                                 }
97
98                                 // Check that the decomposition is correct when the eigenvectors
99                                 // are not computed.
100                                 wAns := make([]float64, len(w))
101                                 copy(wAns, w)
102                                 copy(a, aCopy)
103                                 for i := range w {
104                                         w[i] = rnd.Float64()
105                                 }
106                                 for i := range work {
107                                         work[i] = rnd.Float64()
108                                 }
109                                 impl.Dsyev(lapack.None, uplo, n, a, lda, w, work, len(work))
110                                 if !floats.EqualApprox(w, wAns, 1e-8) {
111                                         t.Errorf("Eigenvalue mismatch when vectors not computed")
112                                 }
113                         }
114                 }
115         }
116 }