1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
15 "gonum.org/v1/gonum/lapack"
18 type Dlarfber interface {
20 Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
21 store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
22 c []float64, ldc int, work []float64, ldwork int)
25 func DlarfbTest(t *testing.T, impl Dlarfber) {
26 rnd := rand.New(rand.NewSource(1))
27 for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
28 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
29 for _, side := range []blas.Side{blas.Left, blas.Right} {
30 for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
31 for cas, test := range []struct {
32 ma, na, cdim, lda, ldt, ldc int
41 {6, 6, 6, 12, 15, 30},
42 {6, 8, 10, 12, 15, 30},
43 {6, 10, 8, 12, 15, 30},
44 {8, 6, 10, 12, 15, 30},
45 {8, 10, 6, 12, 15, 30},
46 {10, 6, 8, 12, 15, 30},
47 {10, 8, 6, 12, 15, 30},
48 {6, 6, 6, 15, 12, 30},
49 {6, 8, 10, 15, 12, 30},
50 {6, 10, 8, 15, 12, 30},
51 {8, 6, 10, 15, 12, 30},
52 {8, 10, 6, 15, 12, 30},
53 {10, 6, 8, 15, 12, 30},
54 {10, 8, 6, 15, 12, 30},
56 // Generate a matrix for QR
63 a := make([]float64, ma*lda)
64 for i := 0; i < ma; i++ {
65 for j := 0; j < lda; j++ {
66 a[i*lda+j] = rnd.Float64()
71 // H is always ma x ma
72 var m, n, rowsWork int
75 panic("not implemented")
76 case side == blas.Left:
80 case side == blas.Right:
86 // Use dgeqr2 to find the v vectors
87 tau := make([]float64, na)
88 work := make([]float64, na)
89 impl.Dgeqr2(ma, k, a, lda, tau, work)
91 // Correct the v vectors based on the direct and store
92 vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
93 vMat := constructVMat(vMatTmp, store, direct)
97 // Use dlarft to find the t vector
102 tm := make([]float64, k*ldt)
104 impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
111 c := make([]float64, m*ldc)
112 for i := 0; i < m; i++ {
113 for j := 0; j < ldc; j++ {
114 c[i*ldc+j] = rnd.Float64()
117 cCopy := make([]float64, len(c))
121 work = make([]float64, rowsWork*k)
123 // Call Dlarfb with this information
124 impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
126 h := constructH(tau, vMat, store, direct)
128 cMat := blas64.General{
132 Data: make([]float64, m*ldc),
134 copy(cMat.Data, cCopy)
135 ans := blas64.General{
139 Data: make([]float64, m*ldc),
141 copy(ans.Data, cMat.Data)
144 panic("not implemented")
145 case side == blas.Left && trans == blas.NoTrans:
146 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
147 case side == blas.Left && trans == blas.Trans:
148 blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
149 case side == blas.Right && trans == blas.NoTrans:
150 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
151 case side == blas.Right && trans == blas.Trans:
152 blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
154 if !floats.EqualApprox(ans.Data, c, 1e-14) {
155 t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)