OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlarft.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dlarfter interface {
19         Dgeqr2er
20         Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int)
21 }
22
23 func DlarftTest(t *testing.T, impl Dlarfter) {
24         rnd := rand.New(rand.NewSource(1))
25         for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
26                 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
27                         for _, test := range []struct {
28                                 m, n, ldv, ldt int
29                         }{
30                                 {6, 6, 0, 0},
31                                 {8, 6, 0, 0},
32                                 {6, 8, 0, 0},
33                                 {6, 6, 10, 15},
34                                 {8, 6, 10, 15},
35                                 {6, 8, 10, 15},
36                                 {6, 6, 15, 10},
37                                 {8, 6, 15, 10},
38                                 {6, 8, 15, 10},
39                         } {
40                                 // Generate a matrix
41                                 m := test.m
42                                 n := test.n
43                                 lda := n
44                                 if lda == 0 {
45                                         lda = n
46                                 }
47
48                                 a := make([]float64, m*lda)
49                                 for i := 0; i < m; i++ {
50                                         for j := 0; j < lda; j++ {
51                                                 a[i*lda+j] = rnd.Float64()
52                                         }
53                                 }
54                                 // Use dgeqr2 to find the v vectors
55                                 tau := make([]float64, n)
56                                 work := make([]float64, n)
57                                 impl.Dgeqr2(m, n, a, lda, tau, work)
58
59                                 // Construct H using these answers
60                                 vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise)
61                                 vMat := constructVMat(vMatTmp, store, direct)
62                                 v := vMat.Data
63                                 ldv := vMat.Stride
64
65                                 h := constructH(tau, vMat, store, direct)
66
67                                 k := min(m, n)
68                                 ldt := test.ldt
69                                 if ldt == 0 {
70                                         ldt = k
71                                 }
72                                 // Find T from the actual function
73                                 tm := make([]float64, k*ldt)
74                                 for i := range tm {
75                                         tm[i] = 100 + rnd.Float64()
76                                 }
77                                 // The v data has been put into a.
78                                 impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt)
79
80                                 tData := make([]float64, len(tm))
81                                 copy(tData, tm)
82                                 if direct == lapack.Forward {
83                                         // Zero out the lower traingular portion.
84                                         for i := 0; i < k; i++ {
85                                                 for j := 0; j < i; j++ {
86                                                         tData[i*ldt+j] = 0
87                                                 }
88                                         }
89                                 } else {
90                                         // Zero out the upper traingular portion.
91                                         for i := 0; i < k; i++ {
92                                                 for j := i + 1; j < k; j++ {
93                                                         tData[i*ldt+j] = 0
94                                                 }
95                                         }
96                                 }
97
98                                 T := blas64.General{
99                                         Rows:   k,
100                                         Cols:   k,
101                                         Stride: ldt,
102                                         Data:   tData,
103                                 }
104
105                                 vMatT := blas64.General{
106                                         Rows:   vMat.Cols,
107                                         Cols:   vMat.Rows,
108                                         Stride: vMat.Rows,
109                                         Data:   make([]float64, vMat.Cols*vMat.Rows),
110                                 }
111                                 for i := 0; i < vMat.Rows; i++ {
112                                         for j := 0; j < vMat.Cols; j++ {
113                                                 vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j]
114                                         }
115                                 }
116                                 var comp blas64.General
117                                 if store == lapack.ColumnWise {
118                                         // H = I - V * T * V^T
119                                         tmp := blas64.General{
120                                                 Rows:   T.Rows,
121                                                 Cols:   vMatT.Cols,
122                                                 Stride: vMatT.Cols,
123                                                 Data:   make([]float64, T.Rows*vMatT.Cols),
124                                         }
125                                         // T * V^T
126                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp)
127                                         comp = blas64.General{
128                                                 Rows:   vMat.Rows,
129                                                 Cols:   tmp.Cols,
130                                                 Stride: tmp.Cols,
131                                                 Data:   make([]float64, vMat.Rows*tmp.Cols),
132                                         }
133                                         // V * (T * V^T)
134                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp)
135                                 } else {
136                                         // H = I - V^T * T * V
137                                         tmp := blas64.General{
138                                                 Rows:   T.Rows,
139                                                 Cols:   vMat.Cols,
140                                                 Stride: vMat.Cols,
141                                                 Data:   make([]float64, T.Rows*vMat.Cols),
142                                         }
143                                         // T * V
144                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp)
145                                         comp = blas64.General{
146                                                 Rows:   vMatT.Rows,
147                                                 Cols:   tmp.Cols,
148                                                 Stride: tmp.Cols,
149                                                 Data:   make([]float64, vMatT.Rows*tmp.Cols),
150                                         }
151                                         // V^T * (T * V)
152                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp)
153                                 }
154                                 // I - V^T * T * V
155                                 for i := 0; i < comp.Rows; i++ {
156                                         for j := 0; j < comp.Cols; j++ {
157                                                 comp.Data[i*m+j] *= -1
158                                                 if i == j {
159                                                         comp.Data[i*m+j] += 1
160                                                 }
161                                         }
162                                 }
163                                 if !floats.EqualApprox(comp.Data, h.Data, 1e-14) {
164                                         t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data)
165                                 }
166                         }
167                 }
168         }
169 }