1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/blas/blas64"
15 "gonum.org/v1/gonum/floats"
16 "gonum.org/v1/gonum/lapack"
19 type Dlasrer interface {
20 Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
23 func DlasrTest(t *testing.T, impl Dlasrer) {
24 rnd := rand.New(rand.NewSource(1))
25 for _, side := range []blas.Side{blas.Left, blas.Right} {
26 for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
27 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
28 for _, test := range []struct {
45 a := make([]float64, m*lda)
50 if side == blas.Left {
51 s = make([]float64, m-1)
52 c = make([]float64, m-1)
54 s = make([]float64, n-1)
55 c = make([]float64, n-1)
58 theta := rnd.Float64() * 2 * math.Pi
59 s[k] = math.Sin(theta)
60 c[k] = math.Cos(theta)
62 aCopy := make([]float64, len(a))
64 impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
67 if side == blas.Right {
74 Data: make([]float64, pSize*pSize),
80 Data: make([]float64, pSize*pSize),
82 ptmp := blas64.General{
86 Data: make([]float64, pSize*pSize),
88 for i := 0; i < pSize; i++ {
89 p.Data[i*p.Stride+i] = 1
90 ptmp.Data[i*p.Stride+i] = 1
92 // Compare to direct computation.
94 for i := range p.Data {
97 for i := 0; i < pSize; i++ {
98 pk.Data[i*p.Stride+i] = 1
100 if pivot == lapack.Variable {
101 pk.Data[k*p.Stride+k] = c[k]
102 pk.Data[k*p.Stride+k+1] = s[k]
103 pk.Data[(k+1)*p.Stride+k] = -s[k]
104 pk.Data[(k+1)*p.Stride+k+1] = c[k]
105 } else if pivot == lapack.Top {
108 pk.Data[(k+1)*p.Stride] = -s[k]
109 pk.Data[(k+1)*p.Stride+k+1] = c[k]
111 pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
112 pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
113 pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
114 pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
116 if direct == lapack.Forward {
117 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
119 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
121 copy(ptmp.Data, p.Data)
124 aMat := blas64.General{
128 Data: make([]float64, m*lda),
131 newA := blas64.General{
135 Data: make([]float64, m*lda),
137 if side == blas.Left {
138 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
140 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
142 if !floats.EqualApprox(newA.Data, a, 1e-12) {
143 t.Errorf("A update mismatch")