1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
17 type Dorml2er interface {
19 Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
22 func Dorml2Test(t *testing.T, impl Dorml2er) {
23 rnd := rand.New(rand.NewSource(1))
24 // TODO(btracey): This test is not complete, because it
25 // doesn't test individual values of m, n, and k, instead only testing
26 // a specific subset of possible k values.
27 for _, side := range []blas.Side{blas.Left, blas.Right} {
28 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
29 for _, test := range []struct {
30 common, adim, cdim, lda, ldc int
52 var ma, na, mc, nc int
53 if side == blas.Left {
64 // Generate a random matrix
69 a := make([]float64, ma*lda)
77 // Compute random C matrix
78 c := make([]float64, mc*ldc)
85 tau := make([]float64, k)
86 work := make([]float64, 1)
87 impl.Dgelqf(ma, na, a, lda, tau, work, -1)
88 work = make([]float64, int(work[0]))
89 impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
91 // Build Q from result
92 q := constructQ("LQ", ma, na, a, lda, tau)
94 cMat := blas64.General{
98 Data: make([]float64, len(c)),
101 cMatCopy := blas64.General{
105 Data: make([]float64, len(cMat.Data)),
107 copy(cMatCopy.Data, cMat.Data)
111 case side == blas.Left && trans == blas.NoTrans:
112 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
113 case side == blas.Left && trans == blas.Trans:
114 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
115 case side == blas.Right && trans == blas.NoTrans:
116 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
117 case side == blas.Right && trans == blas.Trans:
118 blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
120 // Do Dorm2r ard compare
121 if side == blas.Left {
122 work = make([]float64, nc)
124 work = make([]float64, mc)
126 aCopy := make([]float64, len(a))
128 tauCopy := make([]float64, len(tau))
130 impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
131 if !floats.Equal(a, aCopy) {
132 t.Errorf("a changed in call")
134 if !floats.Equal(tau, tauCopy) {
135 t.Errorf("tau changed in call")
137 if !floats.EqualApprox(cMat.Data, c, 1e-14) {
138 isLeft := side == blas.Left
139 isTrans := trans == blas.Trans
140 t.Errorf("Multiplication mismatch. IsLeft = %v. IsTrans = %v", isLeft, isTrans)