1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
17 type Dormr2er interface {
18 Dgerqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
19 Dormr2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
22 func Dormr2Test(t *testing.T, impl Dormr2er) {
23 rnd := rand.New(rand.NewSource(1))
24 for _, side := range []blas.Side{blas.Left, blas.Right} {
25 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
26 for _, test := range []struct {
27 common, adim, cdim, lda, ldc int
51 if side == blas.Left {
59 // Generate a random matrix
64 a := make([]float64, ma*lda)
72 // Compute random C matrix
73 c := make([]float64, mc*ldc)
80 tau := make([]float64, k)
81 work := make([]float64, 1)
82 impl.Dgerqf(ma, na, a, lda, tau, work, -1)
83 work = make([]float64, int(work[0]))
84 impl.Dgerqf(ma, na, a, lda, tau, work, len(work))
86 // Build Q from result
87 q := constructQ("RQ", ma, na, a, lda, tau)
89 cMat := blas64.General{
93 Data: make([]float64, len(c)),
96 cMatCopy := blas64.General{
100 Data: make([]float64, len(cMat.Data)),
102 copy(cMatCopy.Data, cMat.Data)
106 case side == blas.Left && trans == blas.NoTrans:
107 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
108 case side == blas.Left && trans == blas.Trans:
109 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
110 case side == blas.Right && trans == blas.NoTrans:
111 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
112 case side == blas.Right && trans == blas.Trans:
113 blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
115 // Do Dorm2r ard compare
116 if side == blas.Left {
117 work = make([]float64, nc)
119 work = make([]float64, mc)
121 aCopy := make([]float64, len(a))
123 tauCopy := make([]float64, len(tau))
125 impl.Dormr2(side, trans, mc, nc, k, a[(ma-k)*lda:], lda, tau, c, ldc, work)
126 if !floats.Equal(a, aCopy) {
127 t.Errorf("a changed in call")
129 if !floats.Equal(tau, tauCopy) {
130 t.Errorf("tau changed in call")
132 if !floats.EqualApprox(cMat.Data, c, 1e-14) {
133 t.Errorf("Multiplication mismatch.\n Want %v \n got %v.", cMat.Data, c)