OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dsytd2.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dsytd2er interface {
18         Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64)
19 }
20
21 func Dsytd2Test(t *testing.T, impl Dsytd2er) {
22         rnd := rand.New(rand.NewSource(1))
23         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
24                 for _, test := range []struct {
25                         n, lda int
26                 }{
27                         {3, 0},
28                         {4, 0},
29                         {5, 0},
30
31                         {3, 10},
32                         {4, 10},
33                         {5, 10},
34                 } {
35                         n := test.n
36                         lda := test.lda
37                         if lda == 0 {
38                                 lda = n
39                         }
40                         a := make([]float64, n*lda)
41                         for i := range a {
42                                 a[i] = rnd.NormFloat64()
43                         }
44                         aCopy := make([]float64, len(a))
45                         copy(aCopy, a)
46
47                         d := make([]float64, n)
48                         for i := range d {
49                                 d[i] = math.NaN()
50                         }
51                         e := make([]float64, n-1)
52                         for i := range e {
53                                 e[i] = math.NaN()
54                         }
55                         tau := make([]float64, n-1)
56                         for i := range tau {
57                                 tau[i] = math.NaN()
58                         }
59
60                         impl.Dsytd2(uplo, n, a, lda, d, e, tau)
61
62                         // Construct Q
63                         qMat := blas64.General{
64                                 Rows:   n,
65                                 Cols:   n,
66                                 Stride: n,
67                                 Data:   make([]float64, n*n),
68                         }
69                         qCopy := blas64.General{
70                                 Rows:   n,
71                                 Cols:   n,
72                                 Stride: n,
73                                 Data:   make([]float64, len(qMat.Data)),
74                         }
75                         // Set Q to I.
76                         for i := 0; i < n; i++ {
77                                 qMat.Data[i*qMat.Stride+i] = 1
78                         }
79                         for i := 0; i < n-1; i++ {
80                                 hMat := blas64.General{
81                                         Rows:   n,
82                                         Cols:   n,
83                                         Stride: n,
84                                         Data:   make([]float64, n*n),
85                                 }
86                                 // Set H to I.
87                                 for i := 0; i < n; i++ {
88                                         hMat.Data[i*hMat.Stride+i] = 1
89                                 }
90                                 var vi blas64.Vector
91                                 if uplo == blas.Upper {
92                                         vi = blas64.Vector{
93                                                 Inc:  1,
94                                                 Data: make([]float64, n),
95                                         }
96                                         for j := 0; j < i; j++ {
97                                                 vi.Data[j] = a[j*lda+i+1]
98                                         }
99                                         vi.Data[i] = 1
100                                 } else {
101                                         vi = blas64.Vector{
102                                                 Inc:  1,
103                                                 Data: make([]float64, n),
104                                         }
105                                         vi.Data[i+1] = 1
106                                         for j := i + 2; j < n; j++ {
107                                                 vi.Data[j] = a[j*lda+i]
108                                         }
109                                 }
110                                 blas64.Ger(-tau[i], vi, vi, hMat)
111                                 copy(qCopy.Data, qMat.Data)
112
113                                 // Multiply q by the new h.
114                                 if uplo == blas.Upper {
115                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat)
116                                 } else {
117                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
118                                 }
119                         }
120                         // Check that Q is orthonormal
121                         othonormal := true
122                         for i := 0; i < n; i++ {
123                                 for j := i; j < n; j++ {
124                                         dot := blas64.Dot(n,
125                                                 blas64.Vector{Inc: 1, Data: qMat.Data[i*qMat.Stride:]},
126                                                 blas64.Vector{Inc: 1, Data: qMat.Data[j*qMat.Stride:]},
127                                         )
128                                         if i == j {
129                                                 if math.Abs(dot-1) > 1e-10 {
130                                                         othonormal = false
131                                                 }
132                                         } else {
133                                                 if math.Abs(dot) > 1e-10 {
134                                                         othonormal = false
135                                                 }
136                                         }
137                                 }
138                         }
139                         if !othonormal {
140                                 t.Errorf("Q not orthonormal")
141                         }
142
143                         // Compute Q^T * A * Q.
144                         aMat := blas64.General{
145                                 Rows:   n,
146                                 Cols:   n,
147                                 Stride: n,
148                                 Data:   make([]float64, len(a)),
149                         }
150
151                         for i := 0; i < n; i++ {
152                                 for j := i; j < n; j++ {
153                                         v := aCopy[i*lda+j]
154                                         if uplo == blas.Lower {
155                                                 v = aCopy[j*lda+i]
156                                         }
157                                         aMat.Data[i*aMat.Stride+j] = v
158                                         aMat.Data[j*aMat.Stride+i] = v
159                                 }
160                         }
161
162                         tmp := blas64.General{
163                                 Rows:   n,
164                                 Cols:   n,
165                                 Stride: n,
166                                 Data:   make([]float64, n*n),
167                         }
168
169                         ans := blas64.General{
170                                 Rows:   n,
171                                 Cols:   n,
172                                 Stride: n,
173                                 Data:   make([]float64, n*n),
174                         }
175
176                         blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp)
177                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans)
178
179                         // Compare with T.
180                         tMat := blas64.General{
181                                 Rows:   n,
182                                 Cols:   n,
183                                 Stride: n,
184                                 Data:   make([]float64, n*n),
185                         }
186                         for i := 0; i < n-1; i++ {
187                                 tMat.Data[i*tMat.Stride+i] = d[i]
188                                 tMat.Data[i*tMat.Stride+i+1] = e[i]
189                                 tMat.Data[(i+1)*tMat.Stride+i] = e[i]
190                         }
191                         tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1]
192
193                         same := true
194                         for i := 0; i < n; i++ {
195                                 for j := 0; j < n; j++ {
196                                         if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > 1e-10 {
197                                                 same = false
198                                         }
199                                 }
200                         }
201                         if !same {
202                                 t.Errorf("Matrix answer mismatch")
203                         }
204                 }
205         }
206 }