OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgebal.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dgebaler interface {
19         Dgebal(job lapack.Job, n int, a []float64, lda int, scale []float64) (int, int)
20 }
21
22 func DgebalTest(t *testing.T, impl Dgebaler) {
23         rnd := rand.New(rand.NewSource(1))
24
25         for _, job := range []lapack.Job{lapack.None, lapack.Permute, lapack.Scale, lapack.PermuteScale} {
26                 for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 10, 18, 31, 53, 100} {
27                         for _, extra := range []int{0, 11} {
28                                 for cas := 0; cas < 100; cas++ {
29                                         a := unbalancedSparseGeneral(n, n, n+extra, 2*n, rnd)
30                                         testDgebal(t, impl, job, a)
31                                 }
32                         }
33                 }
34         }
35 }
36
37 func testDgebal(t *testing.T, impl Dgebaler, job lapack.Job, a blas64.General) {
38         const tol = 1e-14
39
40         n := a.Rows
41         extra := a.Stride - n
42
43         var scale []float64
44         if n > 0 {
45                 scale = nanSlice(n)
46         }
47
48         want := cloneGeneral(a)
49
50         ilo, ihi := impl.Dgebal(job, n, a.Data, a.Stride, scale)
51
52         prefix := fmt.Sprintf("Case job=%v, n=%v, extra=%v", job, n, extra)
53
54         if !generalOutsideAllNaN(a) {
55                 t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
56         }
57
58         if n == 0 {
59                 if ilo != 0 {
60                         t.Errorf("%v: unexpected ilo when n=0. Want 0, got %v", prefix, ilo)
61                 }
62                 if ihi != -1 {
63                         t.Errorf("%v: unexpected ihi when n=0. Want -1, got %v", prefix, ihi)
64                 }
65                 return
66         }
67
68         if job == lapack.None {
69                 if ilo != 0 {
70                         t.Errorf("%v: unexpected ilo when job=None. Want 0, got %v", prefix, ilo)
71                 }
72                 if ihi != n-1 {
73                         t.Errorf("%v: unexpected ihi when job=None. Want %v, got %v", prefix, n-1, ihi)
74                 }
75                 k := -1
76                 for i := range scale {
77                         if scale[i] != 1 {
78                                 k = i
79                                 break
80                         }
81                 }
82                 if k != -1 {
83                         t.Errorf("%v: unexpected scale[%v] when job=None. Want 1, got %v", prefix, k, scale[k])
84                 }
85                 if !equalApproxGeneral(a, want, 0) {
86                         t.Errorf("%v: unexpected modification of A when job=None", prefix)
87                 }
88                 return
89         }
90
91         if ilo < 0 || ihi < ilo || n <= ihi {
92                 t.Errorf("%v: invalid ordering of ilo=%v and ihi=%v", prefix, ilo, ihi)
93         }
94
95         if ilo >= 2 && !isUpperTriangular(blas64.General{ilo - 1, ilo - 1, a.Stride, a.Data}) {
96                 t.Errorf("%v: T1 is not upper triangular", prefix)
97         }
98         m := n - ihi - 1 // Order of T2.
99         k := ihi + 1
100         if m >= 2 && !isUpperTriangular(blas64.General{m, m, a.Stride, a.Data[k*a.Stride+k:]}) {
101                 t.Errorf("%v: T2 is not upper triangular", prefix)
102         }
103
104         if job == lapack.Permute || job == lapack.PermuteScale {
105                 // Check that all rows in [ilo:ihi+1] have at least one nonzero
106                 // off-diagonal element.
107                 zeroRow := -1
108                 for i := ilo; i <= ihi; i++ {
109                         onlyZeros := true
110                         for j := ilo; j <= ihi; j++ {
111                                 if i != j && a.Data[i*a.Stride+j] != 0 {
112                                         onlyZeros = false
113                                         break
114                                 }
115                         }
116                         if onlyZeros {
117                                 zeroRow = i
118                                 break
119                         }
120                 }
121                 if zeroRow != -1 && ilo != ihi {
122                         t.Errorf("%v: row %v has only zero off-diagonal elements, ilo=%v, ihi=%v", prefix, zeroRow, ilo, ihi)
123                 }
124                 // Check that all columns in [ilo:ihi+1] have at least one nonzero
125                 // off-diagonal element.
126                 zeroCol := -1
127                 for j := ilo; j <= ihi; j++ {
128                         onlyZeros := true
129                         for i := ilo; i <= ihi; i++ {
130                                 if i != j && a.Data[i*a.Stride+j] != 0 {
131                                         onlyZeros = false
132                                         break
133                                 }
134                         }
135                         if onlyZeros {
136                                 zeroCol = j
137                                 break
138                         }
139                 }
140                 if zeroCol != -1 && ilo != ihi {
141                         t.Errorf("%v: column %v has only zero off-diagonal elements, ilo=%v, ihi=%v", prefix, zeroCol, ilo, ihi)
142                 }
143
144                 // Create the permutation matrix P.
145                 p := eye(n, n)
146                 for j := n - 1; j > ihi; j-- {
147                         blas64.Swap(n,
148                                 blas64.Vector{p.Stride, p.Data[j:]},
149                                 blas64.Vector{p.Stride, p.Data[int(scale[j]):]})
150                 }
151                 for j := 0; j < ilo; j++ {
152                         blas64.Swap(n,
153                                 blas64.Vector{p.Stride, p.Data[j:]},
154                                 blas64.Vector{p.Stride, p.Data[int(scale[j]):]})
155                 }
156                 // Compute P^T*A*P and store into want.
157                 ap := zeros(n, n, n)
158                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, want, p, 0, ap)
159                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, p, ap, 0, want)
160         }
161         if job == lapack.Scale || job == lapack.PermuteScale {
162                 // Modify want by D and D^{-1}.
163                 d := eye(n, n)
164                 dinv := eye(n, n)
165                 for i := ilo; i <= ihi; i++ {
166                         d.Data[i*d.Stride+i] = scale[i]
167                         dinv.Data[i*dinv.Stride+i] = 1 / scale[i]
168                 }
169                 ad := zeros(n, n, n)
170                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, want, d, 0, ad)
171                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, dinv, ad, 0, want)
172         }
173         if !equalApproxGeneral(want, a, tol) {
174                 t.Errorf("%v: unexpected value of A, ilo=%v, ihi=%v", prefix, ilo, ihi)
175         }
176 }