OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlahr2.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "compress/gzip"
9         "encoding/json"
10         "fmt"
11         "log"
12         "math"
13         "os"
14         "path/filepath"
15         "testing"
16
17         "golang.org/x/exp/rand"
18
19         "gonum.org/v1/gonum/blas"
20         "gonum.org/v1/gonum/blas/blas64"
21         "gonum.org/v1/gonum/floats"
22 )
23
24 type Dlahr2er interface {
25         Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int)
26 }
27
28 type Dlahr2test struct {
29         N, K, NB int
30         A        []float64
31
32         AWant   []float64
33         TWant   []float64
34         YWant   []float64
35         TauWant []float64
36 }
37
38 func Dlahr2Test(t *testing.T, impl Dlahr2er) {
39         rnd := rand.New(rand.NewSource(1))
40         for _, test := range []struct {
41                 n, k, nb int
42         }{
43                 {3, 0, 3},
44                 {3, 1, 2},
45                 {3, 1, 1},
46
47                 {5, 0, 5},
48                 {5, 1, 4},
49                 {5, 1, 3},
50                 {5, 1, 2},
51                 {5, 1, 1},
52                 {5, 2, 3},
53                 {5, 2, 2},
54                 {5, 2, 1},
55                 {5, 3, 2},
56                 {5, 3, 1},
57
58                 {7, 3, 4},
59                 {7, 3, 3},
60                 {7, 3, 2},
61                 {7, 3, 1},
62
63                 {10, 0, 10},
64                 {10, 1, 9},
65                 {10, 1, 5},
66                 {10, 1, 1},
67                 {10, 5, 5},
68                 {10, 5, 3},
69                 {10, 5, 1},
70         } {
71                 for cas := 0; cas < 100; cas++ {
72                         for _, extraStride := range []int{0, 1, 10} {
73                                 n := test.n
74                                 k := test.k
75                                 nb := test.nb
76
77                                 a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd)
78                                 aCopy := a
79                                 aCopy.Data = make([]float64, len(a.Data))
80                                 copy(aCopy.Data, a.Data)
81                                 tmat := nanTriangular(blas.Upper, nb, nb+extraStride)
82                                 y := nanGeneral(n, nb, nb+extraStride)
83                                 tau := nanSlice(nb)
84
85                                 impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride)
86
87                                 prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride)
88
89                                 if !generalOutsideAllNaN(a) {
90                                         t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
91                                 }
92                                 if !triangularOutsideAllNaN(tmat) {
93                                         t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data)
94                                 }
95                                 if !generalOutsideAllNaN(y) {
96                                         t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data)
97                                 }
98
99                                 // Check that A[:k,:] and A[:,nb:] blocks were not modified.
100                                 for i := 0; i < n; i++ {
101                                         for j := 0; j < n-k+1; j++ {
102                                                 if i >= k && j < nb {
103                                                         continue
104                                                 }
105                                                 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
106                                                         t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j)
107                                                 }
108                                         }
109                                 }
110
111                                 // Check that all elements of tau were assigned.
112                                 for i, v := range tau {
113                                         if math.IsNaN(v) {
114                                                 t.Errorf("%v: tau[%v] not assigned", prefix, i)
115                                         }
116                                 }
117
118                                 // Extract V from a.
119                                 v := blas64.General{
120                                         Rows:   n - k + 1,
121                                         Cols:   nb,
122                                         Stride: nb,
123                                         Data:   make([]float64, (n-k+1)*nb),
124                                 }
125                                 for j := 0; j < v.Cols; j++ {
126                                         v.Data[(j+1)*v.Stride+j] = 1
127                                         for i := j + 2; i < v.Rows; i++ {
128                                                 v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j]
129                                         }
130                                 }
131
132                                 // VT = V.
133                                 vt := v
134                                 vt.Data = make([]float64, len(v.Data))
135                                 copy(vt.Data, v.Data)
136                                 // VT = V * T.
137                                 blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt)
138                                 // YWant = A * V * T.
139                                 ywant := blas64.General{
140                                         Rows:   n,
141                                         Cols:   nb,
142                                         Stride: nb,
143                                         Data:   make([]float64, n*nb),
144                                 }
145                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant)
146
147                                 // Compare Y and YWant.
148                                 for i := 0; i < n; i++ {
149                                         for j := 0; j < nb; j++ {
150                                                 diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j])
151                                                 if diff > 1e-14 {
152                                                         t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff)
153                                                 }
154                                         }
155                                 }
156
157                                 // Construct Q directly from the first nb columns of a.
158                                 q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau)
159                                 if !isOrthonormal(q) {
160                                         t.Errorf("%v: Q is not orthogonal", prefix)
161                                 }
162                                 // Construct Q as the product Q = I - V*T*V^T.
163                                 qwant := blas64.General{
164                                         Rows:   n - k + 1,
165                                         Cols:   n - k + 1,
166                                         Stride: n - k + 1,
167                                         Data:   make([]float64, (n-k+1)*(n-k+1)),
168                                 }
169                                 for i := 0; i < qwant.Rows; i++ {
170                                         qwant.Data[i*qwant.Stride+i] = 1
171                                 }
172                                 blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant)
173                                 if !isOrthonormal(qwant) {
174                                         t.Errorf("%v: Q = I - V*T*V^T is not orthogonal", prefix)
175                                 }
176
177                                 // Compare Q and QWant. Note that since Q is
178                                 // (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we
179                                 // ignore the first row and column of QWant.
180                                 for i := 0; i < n-k; i++ {
181                                         for j := 0; j < n-k; j++ {
182                                                 diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1])
183                                                 if diff > 1e-14 {
184                                                         t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff)
185                                                 }
186                                         }
187                                 }
188                         }
189                 }
190         }
191
192         // Go runs tests from the source directory, so unfortunately we need to
193         // include the "../testlapack" part.
194         file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlahr2data.json.gz"))
195         if err != nil {
196                 log.Fatal(err)
197         }
198         defer file.Close()
199         r, err := gzip.NewReader(file)
200         if err != nil {
201                 log.Fatal(err)
202         }
203         defer r.Close()
204
205         var tests []Dlahr2test
206         json.NewDecoder(r).Decode(&tests)
207         for _, test := range tests {
208                 tau := make([]float64, len(test.TauWant))
209                 for _, ldex := range []int{0, 1, 20} {
210                         n := test.N
211                         k := test.K
212                         nb := test.NB
213
214                         lda := n - k + 1 + ldex
215                         a := make([]float64, (n-1)*lda+n-k+1)
216                         copyMatrix(n, n-k+1, a, lda, test.A)
217
218                         ldt := nb + ldex
219                         tmat := make([]float64, (nb-1)*ldt+nb)
220
221                         ldy := nb + ldex
222                         y := make([]float64, (n-1)*ldy+nb)
223
224                         impl.Dlahr2(n, k, nb, a, lda, tau, tmat, ldt, y, ldy)
225
226                         prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, ldex)
227                         if !equalApprox(n, n-k+1, a, lda, test.AWant, 1e-14) {
228                                 t.Errorf("%v: unexpected matrix A\n got=%v\nwant=%v", prefix, a, test.AWant)
229                         }
230                         if !equalApproxTriangular(true, nb, tmat, ldt, test.TWant, 1e-14) {
231                                 t.Errorf("%v: unexpected matrix T\n got=%v\nwant=%v", prefix, tmat, test.TWant)
232                         }
233                         if !equalApprox(n, nb, y, ldy, test.YWant, 1e-14) {
234                                 t.Errorf("%v: unexpected matrix Y\n got=%v\nwant=%v", prefix, y, test.YWant)
235                         }
236                         if !floats.EqualApprox(tau, test.TauWant, 1e-14) {
237                                 t.Errorf("%v: unexpected slice tau\n got=%v\nwant=%v", prefix, tau, test.TauWant)
238                         }
239                 }
240         }
241 }