OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dorgql.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dorgqler interface {
18         Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
19
20         Dlarfger
21 }
22
23 func DorgqlTest(t *testing.T, impl Dorgqler) {
24         const tol = 1e-14
25
26         type Dorg2ler interface {
27                 Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64)
28         }
29         dorg2ler, hasDorg2l := impl.(Dorg2ler)
30
31         rnd := rand.New(rand.NewSource(1))
32         for _, m := range []int{0, 1, 2, 3, 4, 5, 7, 10, 15, 30, 50, 150} {
33                 for _, extra := range []int{0, 11} {
34                         for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
35                                 var k int
36                                 if m >= 129 {
37                                         // For large matrices make sure that k
38                                         // is large enough to trigger blocked
39                                         // path.
40                                         k = 129 + rnd.Intn(m-129+1)
41                                 } else {
42                                         k = rnd.Intn(m + 1)
43                                 }
44                                 n := k + rnd.Intn(m-k+1)
45                                 if m == 0 || n == 0 {
46                                         m = 0
47                                         n = 0
48                                         k = 0
49                                 }
50
51                                 // Generate k elementary reflectors in the last
52                                 // k columns of A.
53                                 a := nanGeneral(m, n, n+extra)
54                                 tau := make([]float64, k)
55                                 for l := 0; l < k; l++ {
56                                         jj := m - k + l
57                                         v := randomSlice(jj, rnd)
58                                         _, tau[l] = impl.Dlarfg(len(v)+1, rnd.NormFloat64(), v, 1)
59                                         j := n - k + l
60                                         for i := 0; i < jj; i++ {
61                                                 a.Data[i*a.Stride+j] = v[i]
62                                         }
63                                 }
64                                 aCopy := cloneGeneral(a)
65
66                                 // Compute the full matrix Q by forming the
67                                 // Householder reflectors explicitly.
68                                 q := eye(m, m)
69                                 qCopy := eye(m, m)
70                                 for l := 0; l < k; l++ {
71                                         h := eye(m, m)
72                                         jj := m - k + l
73                                         j := n - k + l
74                                         v := blas64.Vector{1, make([]float64, m)}
75                                         for i := 0; i < jj; i++ {
76                                                 v.Data[i] = a.Data[i*a.Stride+j]
77                                         }
78                                         v.Data[jj] = 1
79                                         blas64.Ger(-tau[l], v, v, h)
80                                         copy(qCopy.Data, q.Data)
81                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
82                                 }
83                                 // View the last n columns of Q as 'want'.
84                                 want := blas64.General{
85                                         Rows:   m,
86                                         Cols:   n,
87                                         Stride: q.Stride,
88                                         Data:   q.Data[m-n:],
89                                 }
90
91                                 var lwork int
92                                 switch wl {
93                                 case minimumWork:
94                                         lwork = max(1, n)
95                                 case mediumWork:
96                                         work := make([]float64, 1)
97                                         impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1)
98                                         lwork = (int(work[0]) + n) / 2
99                                         lwork = max(1, lwork)
100                                 case optimumWork:
101                                         work := make([]float64, 1)
102                                         impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1)
103                                         lwork = int(work[0])
104                                 }
105                                 work := make([]float64, lwork)
106
107                                 // Compute the last n columns of Q by a call to
108                                 // Dorgql.
109                                 impl.Dorgql(m, n, k, a.Data, a.Stride, tau, work, len(work))
110
111                                 prefix := fmt.Sprintf("Case m=%v,n=%v,k=%v,wl=%v", m, n, k, wl)
112                                 if !generalOutsideAllNaN(a) {
113                                         t.Errorf("%v: out-of-range write to A", prefix)
114                                 }
115                                 if !equalApproxGeneral(want, a, tol) {
116                                         t.Errorf("%v: unexpected Q", prefix)
117                                 }
118
119                                 // Compute the last n columns of Q by a call to
120                                 // Dorg2l and check that we get the same result.
121                                 if !hasDorg2l {
122                                         continue
123                                 }
124                                 dorg2ler.Dorg2l(m, n, k, aCopy.Data, aCopy.Stride, tau, work)
125                                 if !equalApproxGeneral(aCopy, a, tol) {
126                                         t.Errorf("%v: mismatch between Dorgql and Dorg2l", prefix)
127                                 }
128                         }
129                 }
130         }
131 }