OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlaqr5.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "compress/gzip"
9         "encoding/json"
10         "fmt"
11         "log"
12         "math"
13         "os"
14         "path/filepath"
15         "testing"
16
17         "golang.org/x/exp/rand"
18
19         "gonum.org/v1/gonum/blas"
20         "gonum.org/v1/gonum/blas/blas64"
21 )
22
23 type Dlaqr5er interface {
24         Dlaqr5(wantt, wantz bool, kacc22 int, n, ktop, kbot, nshfts int, sr, si []float64, h []float64, ldh int, iloz, ihiz int, z []float64, ldz int, v []float64, ldv int, u []float64, ldu int, nh int, wh []float64, ldwh int, nv int, wv []float64, ldwv int)
25 }
26
27 type Dlaqr5test struct {
28         WantT          bool
29         N              int
30         NShifts        int
31         KTop, KBot     int
32         ShiftR, ShiftI []float64
33         H              []float64
34
35         HWant []float64
36         ZWant []float64
37 }
38
39 func Dlaqr5Test(t *testing.T, impl Dlaqr5er) {
40         // Test without using reference data.
41         rnd := rand.New(rand.NewSource(1))
42         for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 30} {
43                 for _, extra := range []int{0, 1, 20} {
44                         for _, kacc22 := range []int{0, 1, 2} {
45                                 for cas := 0; cas < 100; cas++ {
46                                         testDlaqr5(t, impl, n, extra, kacc22, rnd)
47                                 }
48                         }
49                 }
50         }
51
52         // Test using reference data computed by the reference netlib
53         // implementation.
54         file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlaqr5data.json.gz"))
55         if err != nil {
56                 log.Fatal(err)
57         }
58         defer file.Close()
59         r, err := gzip.NewReader(file)
60         if err != nil {
61                 log.Fatal(err)
62         }
63         defer r.Close()
64
65         var tests []Dlaqr5test
66         json.NewDecoder(r).Decode(&tests)
67         for _, test := range tests {
68                 wantt := test.WantT
69                 n := test.N
70                 nshfts := test.NShifts
71                 ktop := test.KTop
72                 kbot := test.KBot
73                 sr := test.ShiftR
74                 si := test.ShiftI
75
76                 for _, extra := range []int{0, 1, 10} {
77                         v := randomGeneral(nshfts/2, 3, 3+extra, rnd)
78                         u := randomGeneral(3*nshfts-3, 3*nshfts-3, 3*nshfts-3+extra, rnd)
79                         nh := n
80                         wh := randomGeneral(3*nshfts-3, n, n+extra, rnd)
81                         nv := n
82                         wv := randomGeneral(n, 3*nshfts-3, 3*nshfts-3+extra, rnd)
83
84                         h := nanGeneral(n, n, n+extra)
85
86                         for _, kacc22 := range []int{0, 1, 2} {
87                                 copyMatrix(n, n, h.Data, h.Stride, test.H)
88                                 z := eye(n, n+extra)
89
90                                 impl.Dlaqr5(wantt, true, kacc22,
91                                         n, ktop, kbot,
92                                         nshfts, sr, si,
93                                         h.Data, h.Stride,
94                                         0, n-1, z.Data, z.Stride,
95                                         v.Data, v.Stride,
96                                         u.Data, u.Stride,
97                                         nv, wv.Data, wv.Stride,
98                                         nh, wh.Data, wh.Stride)
99
100                                 prefix := fmt.Sprintf("wantt=%v, n=%v, nshfts=%v, ktop=%v, kbot=%v, extra=%v, kacc22=%v",
101                                         wantt, n, nshfts, ktop, kbot, extra, kacc22)
102                                 if !equalApprox(n, n, h.Data, h.Stride, test.HWant, 1e-13) {
103                                         t.Errorf("Case %v: unexpected matrix H\nh    =%v\nhwant=%v", prefix, h.Data, test.HWant)
104                                 }
105                                 if !equalApprox(n, n, z.Data, z.Stride, test.ZWant, 1e-13) {
106                                         t.Errorf("Case %v: unexpected matrix Z\nz    =%v\nzwant=%v", prefix, z.Data, test.ZWant)
107                                 }
108                         }
109                 }
110         }
111 }
112
113 func testDlaqr5(t *testing.T, impl Dlaqr5er, n, extra, kacc22 int, rnd *rand.Rand) {
114         wantt := true
115         wantz := true
116         nshfts := 2 * n
117         sr := make([]float64, nshfts)
118         si := make([]float64, nshfts)
119         for i := 0; i < n; i++ {
120                 re := rnd.NormFloat64()
121                 im := rnd.NormFloat64()
122                 sr[2*i], sr[2*i+1] = re, re
123                 si[2*i], si[2*i+1] = im, -im
124         }
125         ktop := rnd.Intn(n)
126         kbot := rnd.Intn(n)
127         if kbot < ktop {
128                 ktop, kbot = kbot, ktop
129         }
130
131         v := randomGeneral(nshfts/2, 3, 3+extra, rnd)
132         u := randomGeneral(3*nshfts-3, 3*nshfts-3, 3*nshfts-3+extra, rnd)
133         nh := n
134         wh := randomGeneral(3*nshfts-3, n, n+extra, rnd)
135         nv := n
136         wv := randomGeneral(n, 3*nshfts-3, 3*nshfts-3+extra, rnd)
137
138         h := randomHessenberg(n, n+extra, rnd)
139         if ktop > 0 {
140                 h.Data[ktop*h.Stride+ktop-1] = 0
141         }
142         if kbot < n-1 {
143                 h.Data[(kbot+1)*h.Stride+kbot] = 0
144         }
145         hCopy := h
146         hCopy.Data = make([]float64, len(h.Data))
147         copy(hCopy.Data, h.Data)
148
149         z := eye(n, n+extra)
150
151         impl.Dlaqr5(wantt, wantz, kacc22,
152                 n, ktop, kbot,
153                 nshfts, sr, si,
154                 h.Data, h.Stride,
155                 0, n-1, z.Data, z.Stride,
156                 v.Data, v.Stride,
157                 u.Data, u.Stride,
158                 nv, wv.Data, wv.Stride,
159                 nh, wh.Data, wh.Stride)
160
161         prefix := fmt.Sprintf("Case n=%v, extra=%v, kacc22=%v", n, extra, kacc22)
162
163         if !generalOutsideAllNaN(h) {
164                 t.Errorf("%v: out-of-range write to H\n%v", prefix, h.Data)
165         }
166         if !generalOutsideAllNaN(z) {
167                 t.Errorf("%v: out-of-range write to Z\n%v", prefix, z.Data)
168         }
169         if !generalOutsideAllNaN(u) {
170                 t.Errorf("%v: out-of-range write to U\n%v", prefix, u.Data)
171         }
172         if !generalOutsideAllNaN(v) {
173                 t.Errorf("%v: out-of-range write to V\n%v", prefix, v.Data)
174         }
175         if !generalOutsideAllNaN(wh) {
176                 t.Errorf("%v: out-of-range write to WH\n%v", prefix, wh.Data)
177         }
178         if !generalOutsideAllNaN(wv) {
179                 t.Errorf("%v: out-of-range write to WV\n%v", prefix, wv.Data)
180         }
181
182         for i := 0; i < n; i++ {
183                 for j := 0; j < i-1; j++ {
184                         if h.Data[i*h.Stride+j] != 0 {
185                                 t.Errorf("%v: H is not Hessenberg, H[%v,%v]!=0", prefix, i, j)
186                         }
187                 }
188         }
189         if !isOrthonormal(z) {
190                 t.Errorf("%v: Z is not orthogonal", prefix)
191         }
192         // Construct Z^T * HOrig * Z and check that it is equal to H from Dlaqr5.
193         hz := blas64.General{
194                 Rows:   n,
195                 Cols:   n,
196                 Stride: n,
197                 Data:   make([]float64, n*n),
198         }
199         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hCopy, z, 0, hz)
200         zhz := blas64.General{
201                 Rows:   n,
202                 Cols:   n,
203                 Stride: n,
204                 Data:   make([]float64, n*n),
205         }
206         blas64.Gemm(blas.Trans, blas.NoTrans, 1, z, hz, 0, zhz)
207         for i := 0; i < n; i++ {
208                 for j := 0; j < n; j++ {
209                         diff := zhz.Data[i*zhz.Stride+j] - h.Data[i*h.Stride+j]
210                         if math.Abs(diff) > 1e-13 {
211                                 t.Errorf("%v: Z^T*HOrig*Z and H are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
212                         }
213                 }
214         }
215 }