OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dormbr.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15         "gonum.org/v1/gonum/lapack"
16 )
17
18 type Dormbrer interface {
19         Dormbr(vect lapack.DecompUpdate, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
20         Dgebrder
21 }
22
23 func DormbrTest(t *testing.T, impl Dormbrer) {
24         rnd := rand.New(rand.NewSource(1))
25         bi := blas64.Implementation()
26         for _, vect := range []lapack.DecompUpdate{lapack.ApplyQ, lapack.ApplyP} {
27                 for _, side := range []blas.Side{blas.Left, blas.Right} {
28                         for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
29                                 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
30                                         for _, test := range []struct {
31                                                 m, n, k, lda, ldc int
32                                         }{
33                                                 {3, 4, 5, 0, 0},
34                                                 {3, 5, 4, 0, 0},
35                                                 {4, 3, 5, 0, 0},
36                                                 {4, 5, 3, 0, 0},
37                                                 {5, 3, 4, 0, 0},
38                                                 {5, 4, 3, 0, 0},
39
40                                                 {3, 4, 5, 10, 12},
41                                                 {3, 5, 4, 10, 12},
42                                                 {4, 3, 5, 10, 12},
43                                                 {4, 5, 3, 10, 12},
44                                                 {5, 3, 4, 10, 12},
45                                                 {5, 4, 3, 10, 12},
46
47                                                 {150, 140, 130, 0, 0},
48                                         } {
49                                                 m := test.m
50                                                 n := test.n
51                                                 k := test.k
52                                                 ldc := test.ldc
53                                                 if ldc == 0 {
54                                                         ldc = n
55                                                 }
56                                                 nq := n
57                                                 nw := m
58                                                 if side == blas.Left {
59                                                         nq = m
60                                                         nw = n
61                                                 }
62
63                                                 // Compute a decomposition.
64                                                 var ma, na int
65                                                 var a []float64
66                                                 if vect == lapack.ApplyQ {
67                                                         ma = nq
68                                                         na = k
69                                                 } else {
70                                                         ma = k
71                                                         na = nq
72                                                 }
73                                                 lda := test.lda
74                                                 if lda == 0 {
75                                                         lda = na
76                                                 }
77                                                 a = make([]float64, ma*lda)
78                                                 for i := range a {
79                                                         a[i] = rnd.NormFloat64()
80                                                 }
81                                                 nTau := min(nq, k)
82                                                 tauP := make([]float64, nTau)
83                                                 tauQ := make([]float64, nTau)
84                                                 d := make([]float64, nTau)
85                                                 e := make([]float64, nTau)
86
87                                                 work := make([]float64, 1)
88                                                 impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, -1)
89                                                 work = make([]float64, int(work[0]))
90                                                 impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, len(work))
91
92                                                 // Apply and compare update.
93                                                 c := make([]float64, m*ldc)
94                                                 for i := range c {
95                                                         c[i] = rnd.NormFloat64()
96                                                 }
97                                                 cCopy := make([]float64, len(c))
98                                                 copy(cCopy, c)
99
100                                                 var lwork int
101                                                 switch wl {
102                                                 case minimumWork:
103                                                         lwork = nw
104                                                 case optimumWork:
105                                                         impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
106                                                         lwork = int(work[0])
107                                                 case mediumWork:
108                                                         work := make([]float64, 1)
109                                                         impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
110                                                         lwork = (int(work[0]) + nw) / 2
111                                                 }
112                                                 lwork = max(1, lwork)
113                                                 work = make([]float64, lwork)
114
115                                                 if vect == lapack.ApplyQ {
116                                                         impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, lwork)
117                                                 } else {
118                                                         impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauP, c, ldc, work, lwork)
119                                                 }
120
121                                                 // Check that the multiplication was correct.
122                                                 cOrig := blas64.General{
123                                                         Rows:   m,
124                                                         Cols:   n,
125                                                         Stride: ldc,
126                                                         Data:   make([]float64, len(cCopy)),
127                                                 }
128                                                 copy(cOrig.Data, cCopy)
129                                                 cAns := blas64.General{
130                                                         Rows:   m,
131                                                         Cols:   n,
132                                                         Stride: ldc,
133                                                         Data:   make([]float64, len(cCopy)),
134                                                 }
135                                                 copy(cAns.Data, cCopy)
136                                                 nb := min(ma, na)
137                                                 var mulMat blas64.General
138                                                 if vect == lapack.ApplyQ {
139                                                         mulMat = constructQPBidiagonal(lapack.ApplyQ, ma, na, nb, a, lda, tauQ)
140                                                 } else {
141                                                         mulMat = constructQPBidiagonal(lapack.ApplyP, ma, na, nb, a, lda, tauP)
142                                                 }
143
144                                                 mulTrans := trans
145
146                                                 if side == blas.Left {
147                                                         bi.Dgemm(mulTrans, blas.NoTrans, m, n, m, 1, mulMat.Data, mulMat.Stride, cOrig.Data, cOrig.Stride, 0, cAns.Data, cAns.Stride)
148                                                 } else {
149                                                         bi.Dgemm(blas.NoTrans, mulTrans, m, n, n, 1, cOrig.Data, cOrig.Stride, mulMat.Data, mulMat.Stride, 0, cAns.Data, cAns.Stride)
150                                                 }
151
152                                                 if !floats.EqualApprox(cAns.Data, c, 1e-13) {
153                                                         isApplyQ := vect == lapack.ApplyQ
154                                                         isLeft := side == blas.Left
155                                                         isTrans := trans == blas.Trans
156
157                                                         t.Errorf("C mismatch. isApplyQ: %v, isLeft: %v, isTrans: %v, m = %v, n = %v, k = %v, lda = %v, ldc = %v",
158                                                                 isApplyQ, isLeft, isTrans, m, n, k, lda, ldc)
159                                                 }
160                                         }
161                                 }
162                         }
163                 }
164         }
165 }