1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
13 "golang.org/x/exp/rand"
15 "gonum.org/v1/gonum/blas"
16 "gonum.org/v1/gonum/blas/blas64"
19 type Dlaexcer interface {
20 Dlaexc(wantq bool, n int, t []float64, ldt int, q []float64, ldq int, j1, n1, n2 int, work []float64) bool
23 func DlaexcTest(t *testing.T, impl Dlaexcer) {
24 rnd := rand.New(rand.NewSource(1))
26 for _, wantq := range []bool{true, false} {
27 for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
28 for _, extra := range []int{0, 1, 11} {
29 for cas := 0; cas < 100; cas++ {
31 n1 := min(rnd.Intn(3), n-j1)
32 n2 := min(rnd.Intn(3), n-j1-n1)
33 testDlaexc(t, impl, wantq, n, j1, n1, n2, extra, rnd)
40 func testDlaexc(t *testing.T, impl Dlaexcer, wantq bool, n, j1, n1, n2, extra int, rnd *rand.Rand) {
43 tmat := randomGeneral(n, n, n+extra, rnd)
44 // Zero out the lower triangle.
45 for i := 1; i < n; i++ {
46 for j := 0; j < i; j++ {
47 tmat.Data[i*tmat.Stride+j] = 0
50 // Make any 2x2 diagonal block to be in Schur canonical form.
52 // Diagonal elements equal.
53 tmat.Data[(j1+1)*tmat.Stride+j1+1] = tmat.Data[j1*tmat.Stride+j1]
54 // Off-diagonal elements of opposite sign.
55 c := rnd.NormFloat64()
56 if math.Signbit(c) == math.Signbit(tmat.Data[j1*tmat.Stride+j1+1]) {
59 tmat.Data[(j1+1)*tmat.Stride+j1] = c
62 // Diagonal elements equal.
63 tmat.Data[(j1+n1+1)*tmat.Stride+j1+n1+1] = tmat.Data[(j1+n1)*tmat.Stride+j1+n1]
64 // Off-diagonal elements of opposite sign.
65 c := rnd.NormFloat64()
66 if math.Signbit(c) == math.Signbit(tmat.Data[(j1+n1)*tmat.Stride+j1+n1+1]) {
69 tmat.Data[(j1+n1+1)*tmat.Stride+j1+n1] = c
71 tmatCopy := cloneGeneral(tmat)
72 var q, qCopy blas64.General
75 qCopy = cloneGeneral(q)
79 ok := impl.Dlaexc(wantq, n, tmat.Data, tmat.Stride, q.Data, q.Stride, j1, n1, n2, work)
81 prefix := fmt.Sprintf("Case n=%v, j1=%v, n1=%v, n2=%v, wantq=%v, extra=%v", n, j1, n1, n2, wantq, extra)
83 if !generalOutsideAllNaN(tmat) {
84 t.Errorf("%v: out-of-range write to T", prefix)
86 if wantq && !generalOutsideAllNaN(q) {
87 t.Errorf("%v: out-of-range write to Q", prefix)
91 if n1 == 1 && n2 == 1 {
92 t.Errorf("%v: unexpected failure", prefix)
94 t.Logf("%v: Dlaexc returned false", prefix)
98 if !ok || n1 == 0 || n2 == 0 || j1+n1 >= n {
99 // Check that T is not modified.
100 for i := 0; i < n; i++ {
101 for j := 0; j < n; j++ {
102 if tmat.Data[i*tmat.Stride+j] != tmatCopy.Data[i*tmatCopy.Stride+j] {
103 t.Errorf("%v: ok == false but T[%v,%v] modified", prefix, i, j)
110 // Check that Q is not modified.
111 for i := 0; i < n; i++ {
112 for j := 0; j < n; j++ {
113 if q.Data[i*q.Stride+j] != qCopy.Data[i*qCopy.Stride+j] {
114 t.Errorf("%v: ok == false but Q[%v,%v] modified", prefix, i, j)
121 // Check that T is not modified outside of rows and columns [j1:j1+n1+n2].
122 for i := 0; i < n; i++ {
123 if j1 <= i && i < j1+n1+n2 {
126 for j := 0; j < n; j++ {
127 if j1 <= j && j < j1+n1+n2 {
130 diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j]
132 t.Errorf("%v: unexpected modification of T[%v,%v]", prefix, i, j)
138 // 1×1 blocks are swapped exactly.
139 got := tmat.Data[(j1+n2)*tmat.Stride+j1+n2]
140 want := tmatCopy.Data[j1*tmatCopy.Stride+j1]
142 t.Errorf("%v: unexpected value of T[%v,%v]. Want %v, got %v", prefix, j1+n2, j1+n2, want, got)
145 // Check that the swapped 2×2 block is in Schur canonical form.
146 // The n1×n1 block is now located at T[j1+n2,j1+n2].
147 a, b, c, d := extract2x2Block(tmat.Data[(j1+n2)*tmat.Stride+j1+n2:], tmat.Stride)
148 if !isSchurCanonical(a, b, c, d) {
149 t.Errorf("%v: 2×2 block at T[%v,%v] not in Schur canonical form", prefix, j1+n2, j1+n2)
151 ev1Got, ev2Got := schurBlockEigenvalues(a, b, c, d)
153 // Check that the swapped 2×2 block has the same eigenvalues.
154 // The n1×n1 block was originally located at T[j1,j1].
155 a, b, c, d = extract2x2Block(tmatCopy.Data[j1*tmatCopy.Stride+j1:], tmatCopy.Stride)
156 ev1Want, ev2Want := schurBlockEigenvalues(a, b, c, d)
157 if cmplx.Abs(ev1Got-ev1Want) > tol {
158 t.Errorf("%v: unexpected first eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
159 prefix, j1+n2, j1+n2, ev1Want, ev1Got)
161 if cmplx.Abs(ev2Got-ev2Want) > tol {
162 t.Errorf("%v: unexpected second eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
163 prefix, j1+n2, j1+n2, ev2Want, ev2Got)
167 // 1×1 blocks are swapped exactly.
168 got := tmat.Data[j1*tmat.Stride+j1]
169 want := tmatCopy.Data[(j1+n1)*tmatCopy.Stride+j1+n1]
171 t.Errorf("%v: unexpected value of T[%v,%v]. Want %v, got %v", prefix, j1, j1, want, got)
174 // Check that the swapped 2×2 block is in Schur canonical form.
175 // The n2×n2 block is now located at T[j1,j1].
176 a, b, c, d := extract2x2Block(tmat.Data[j1*tmat.Stride+j1:], tmat.Stride)
177 if !isSchurCanonical(a, b, c, d) {
178 t.Errorf("%v: 2×2 block at T[%v,%v] not in Schur canonical form", prefix, j1, j1)
180 ev1Got, ev2Got := schurBlockEigenvalues(a, b, c, d)
182 // Check that the swapped 2×2 block has the same eigenvalues.
183 // The n2×n2 block was originally located at T[j1+n1,j1+n1].
184 a, b, c, d = extract2x2Block(tmatCopy.Data[(j1+n1)*tmatCopy.Stride+j1+n1:], tmatCopy.Stride)
185 ev1Want, ev2Want := schurBlockEigenvalues(a, b, c, d)
186 if cmplx.Abs(ev1Got-ev1Want) > tol {
187 t.Errorf("%v: unexpected first eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
188 prefix, j1, j1, ev1Want, ev1Got)
190 if cmplx.Abs(ev2Got-ev2Want) > tol {
191 t.Errorf("%v: unexpected second eigenvalue of 2×2 block at T[%v,%v]. Want %v, got %v",
192 prefix, j1, j1, ev2Want, ev2Got)
200 if !isOrthonormal(q) {
201 t.Errorf("%v: Q is not orthogonal", prefix)
203 // Check that Q is unchanged outside of columns [j1:j1+n1+n2].
204 for i := 0; i < n; i++ {
205 for j := 0; j < n; j++ {
206 if j1 <= j && j < j1+n1+n2 {
209 diff := q.Data[i*q.Stride+j] - qCopy.Data[i*qCopy.Stride+j]
211 t.Errorf("%v: unexpected modification of Q[%v,%v]", prefix, i, j)
215 // Check that Q^T TOrig Q == T.
217 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmatCopy, q, 0, tq)
219 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tq, 0, qtq)
220 for i := 0; i < n; i++ {
221 for j := 0; j < n; j++ {
222 diff := qtq.Data[i*qtq.Stride+j] - tmat.Data[i*tmat.Stride+j]
223 if math.Abs(diff) > tol {
224 t.Errorf("%v: unexpected value of T[%v,%v]", prefix, i, j)