OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlaexc.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "math"
10         "math/cmplx"
11         "testing"
12
13         "golang.org/x/exp/rand"
14
15         "gonum.org/v1/gonum/blas"
16         "gonum.org/v1/gonum/blas/blas64"
17 )
18
19 type Dlaexcer interface {
20         Dlaexc(wantq bool, n int, t []float64, ldt int, q []float64, ldq int, j1, n1, n2 int, work []float64) bool
21 }
22
23 func DlaexcTest(t *testing.T, impl Dlaexcer) {
24         rnd := rand.New(rand.NewSource(1))
25
26         for _, wantq := range []bool{true, false} {
27                 for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
28                         for _, extra := range []int{0, 1, 11} {
29                                 for cas := 0; cas < 100; cas++ {
30                                         j1 := rnd.Intn(n)
31                                         n1 := min(rnd.Intn(3), n-j1)
32                                         n2 := min(rnd.Intn(3), n-j1-n1)
33                                         testDlaexc(t, impl, wantq, n, j1, n1, n2, extra, rnd)
34                                 }
35                         }
36                 }
37         }
38 }
39
40 func testDlaexc(t *testing.T, impl Dlaexcer, wantq bool, n, j1, n1, n2, extra int, rnd *rand.Rand) {
41         const tol = 1e-14
42
43         tmat := randomGeneral(n, n, n+extra, rnd)
44         // Zero out the lower triangle.
45         for i := 1; i < n; i++ {
46                 for j := 0; j < i; j++ {
47                         tmat.Data[i*tmat.Stride+j] = 0
48                 }
49         }
50         // Make any 2x2 diagonal block to be in Schur canonical form.
51         if n1 == 2 {
52                 // Diagonal elements equal.
53                 tmat.Data[(j1+1)*tmat.Stride+j1+1] = tmat.Data[j1*tmat.Stride+j1]
54                 // Off-diagonal elements of opposite sign.
55                 c := rnd.NormFloat64()
56                 if math.Signbit(c) == math.Signbit(tmat.Data[j1*tmat.Stride+j1+1]) {
57                         c *= -1
58                 }
59                 tmat.Data[(j1+1)*tmat.Stride+j1] = c
60         }
61         if n2 == 2 {
62                 // Diagonal elements equal.
63                 tmat.Data[(j1+n1+1)*tmat.Stride+j1+n1+1] = tmat.Data[(j1+n1)*tmat.Stride+j1+n1]
64                 // Off-diagonal elements of opposite sign.
65                 c := rnd.NormFloat64()
66                 if math.Signbit(c) == math.Signbit(tmat.Data[(j1+n1)*tmat.Stride+j1+n1+1]) {
67                         c *= -1
68                 }
69                 tmat.Data[(j1+n1+1)*tmat.Stride+j1+n1] = c
70         }
71         tmatCopy := cloneGeneral(tmat)
72         var q, qCopy blas64.General
73         if wantq {
74                 q = eye(n, n+extra)
75                 qCopy = cloneGeneral(q)
76         }
77         work := nanSlice(n)
78
79         ok := impl.Dlaexc(wantq, n, tmat.Data, tmat.Stride, q.Data, q.Stride, j1, n1, n2, work)
80
81         prefix := fmt.Sprintf("Case n=%v, j1=%v, n1=%v, n2=%v, wantq=%v, extra=%v", n, j1, n1, n2, wantq, extra)
82
83         if !generalOutsideAllNaN(tmat) {
84                 t.Errorf("%v: out-of-range write to T", prefix)
85         }
86         if wantq && !generalOutsideAllNaN(q) {
87                 t.Errorf("%v: out-of-range write to Q", prefix)
88         }
89
90         if !ok {
91                 if n1 == 1 && n2 == 1 {
92                         t.Errorf("%v: unexpected failure", prefix)
93                 } else {
94                         t.Logf("%v: Dlaexc returned false", prefix)
95                 }
96         }
97
98         if !ok || n1 == 0 || n2 == 0 || j1+n1 >= n {
99                 // Check that T is not modified.
100                 for i := 0; i < n; i++ {
101                         for j := 0; j < n; j++ {
102                                 if tmat.Data[i*tmat.Stride+j] != tmatCopy.Data[i*tmatCopy.Stride+j] {
103                                         t.Errorf("%v: ok == false but T[%v,%v] modified", prefix, i, j)
104                                 }
105                         }
106                 }
107                 if !wantq {
108                         return
109                 }
110                 // Check that Q is not modified.
111                 for i := 0; i < n; i++ {
112                         for j := 0; j < n; j++ {
113                                 if q.Data[i*q.Stride+j] != qCopy.Data[i*qCopy.Stride+j] {
114                                         t.Errorf("%v: ok == false but Q[%v,%v] modified", prefix, i, j)
115                                 }
116                         }
117                 }
118                 return
119         }
120
121         // Check that T is not modified outside of rows and columns [j1:j1+n1+n2].
122         for i := 0; i < n; i++ {
123                 if j1 <= i && i < j1+n1+n2 {
124                         continue
125                 }
126                 for j := 0; j < n; j++ {
127                         if j1 <= j && j < j1+n1+n2 {
128                                 continue
129                         }
130                         diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j]
131                         if diff != 0 {
132                                 t.Errorf("%v: unexpected modification of T[%v,%v]", prefix, i, j)
133                         }
134                 }
135         }
136
137         if n1 == 1 {
138                 // 1×1 blocks are swapped exactly.
139                 got := tmat.Data[(j1+n2)*tmat.Stride+j1+n2]
140                 want := tmatCopy.Data[j1*tmatCopy.Stride+j1]
141                 if want != got {
142                         t.Errorf("%v: unexpected value of T[%v,%v]. Want %v, got %v", prefix, j1+n2, j1+n2, want, got)
143                 }
144         } else {
145                 // Check that the swapped 2×2 block is in Schur canonical form.
146                 // The n1×n1 block is now located at T[j1+n2,j1+n2].
147                 a, b, c, d := extract2x2Block(tmat.Data[(j1+n2)*tmat.Stride+j1+n2:], tmat.Stride)
148                 if !isSchurCanonical(a, b, c, d) {
149                         t.Errorf("%v: 2×2 block at T[%v,%v] not in Schur canonical form", prefix, j1+n2, j1+n2)
150                 }
151                 ev1Got, ev2Got := schurBlockEigenvalues(a, b, c, d)
152
153                 // Check that the swapped 2×2 block has the same eigenvalues.
154                 // The n1×n1 block was originally located at T[j1,j1].
155                 a, b, c, d = extract2x2Block(tmatCopy.Data[j1*tmatCopy.Stride+j1:], tmatCopy.Stride)
156                 ev1Want, ev2Want := schurBlockEigenvalues(a, b, c, d)
157                 if cmplx.Abs(ev1Got-ev1Want) > tol {
158                         t.Errorf("%v: unexpected first eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
159                                 prefix, j1+n2, j1+n2, ev1Want, ev1Got)
160                 }
161                 if cmplx.Abs(ev2Got-ev2Want) > tol {
162                         t.Errorf("%v: unexpected second eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
163                                 prefix, j1+n2, j1+n2, ev2Want, ev2Got)
164                 }
165         }
166         if n2 == 1 {
167                 // 1×1 blocks are swapped exactly.
168                 got := tmat.Data[j1*tmat.Stride+j1]
169                 want := tmatCopy.Data[(j1+n1)*tmatCopy.Stride+j1+n1]
170                 if want != got {
171                         t.Errorf("%v: unexpected value of T[%v,%v]. Want %v, got %v", prefix, j1, j1, want, got)
172                 }
173         } else {
174                 // Check that the swapped 2×2 block is in Schur canonical form.
175                 // The n2×n2 block is now located at T[j1,j1].
176                 a, b, c, d := extract2x2Block(tmat.Data[j1*tmat.Stride+j1:], tmat.Stride)
177                 if !isSchurCanonical(a, b, c, d) {
178                         t.Errorf("%v: 2×2 block at T[%v,%v] not in Schur canonical form", prefix, j1, j1)
179                 }
180                 ev1Got, ev2Got := schurBlockEigenvalues(a, b, c, d)
181
182                 // Check that the swapped 2×2 block has the same eigenvalues.
183                 // The n2×n2 block was originally located at T[j1+n1,j1+n1].
184                 a, b, c, d = extract2x2Block(tmatCopy.Data[(j1+n1)*tmatCopy.Stride+j1+n1:], tmatCopy.Stride)
185                 ev1Want, ev2Want := schurBlockEigenvalues(a, b, c, d)
186                 if cmplx.Abs(ev1Got-ev1Want) > tol {
187                         t.Errorf("%v: unexpected first eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
188                                 prefix, j1, j1, ev1Want, ev1Got)
189                 }
190                 if cmplx.Abs(ev2Got-ev2Want) > tol {
191                         t.Errorf("%v: unexpected second eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
192                                 prefix, j1, j1, ev2Want, ev2Got)
193                 }
194         }
195
196         if !wantq {
197                 return
198         }
199
200         if !isOrthonormal(q) {
201                 t.Errorf("%v: Q is not orthogonal", prefix)
202         }
203         // Check that Q is unchanged outside of columns [j1:j1+n1+n2].
204         for i := 0; i < n; i++ {
205                 for j := 0; j < n; j++ {
206                         if j1 <= j && j < j1+n1+n2 {
207                                 continue
208                         }
209                         diff := q.Data[i*q.Stride+j] - qCopy.Data[i*qCopy.Stride+j]
210                         if diff != 0 {
211                                 t.Errorf("%v: unexpected modification of Q[%v,%v]", prefix, i, j)
212                         }
213                 }
214         }
215         // Check that Q^T TOrig Q == T.
216         tq := eye(n, n)
217         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmatCopy, q, 0, tq)
218         qtq := eye(n, n)
219         blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tq, 0, qtq)
220         for i := 0; i < n; i++ {
221                 for j := 0; j < n; j++ {
222                         diff := qtq.Data[i*qtq.Stride+j] - tmat.Data[i*tmat.Stride+j]
223                         if math.Abs(diff) > tol {
224                                 t.Errorf("%v: unexpected value of T[%v,%v]", prefix, i, j)
225                         }
226                 }
227         }
228 }