1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/blas/blas64"
15 "gonum.org/v1/gonum/floats"
18 type Dlags2er interface {
19 Dlags2(upper bool, a1, a2, a3, b1, b2, b3 float64) (csu, snu, csv, snv, csq, snq float64)
22 func Dlags2Test(t *testing.T, impl Dlags2er) {
23 rnd := rand.New(rand.NewSource(1))
24 for _, upper := range []bool{true, false} {
25 for i := 0; i < 100; i++ {
33 csu, snu, csv, snv, csq, snq := impl.Dlags2(upper, a1, a2, a3, b1, b2, b3)
35 detU := det2x2(csu, snu, -snu, csu)
36 if !floats.EqualWithinAbsOrRel(math.Abs(detU), 1, 1e-14, 1e-14) {
37 t.Errorf("U not orthogonal: det(U)=%v", detU)
39 detV := det2x2(csv, snv, -snv, csv)
40 if !floats.EqualWithinAbsOrRel(math.Abs(detV), 1, 1e-14, 1e-14) {
41 t.Errorf("V not orthogonal: det(V)=%v", detV)
43 detQ := det2x2(csq, snq, -snq, csq)
44 if !floats.EqualWithinAbsOrRel(math.Abs(detQ), 1, 1e-14, 1e-14) {
45 t.Errorf("Q not orthogonal: det(Q)=%v", detQ)
52 Data: []float64{csu, snu, -snu, csu},
58 Data: []float64{csv, snv, -snv, csv},
64 Data: []float64{csq, snq, -snq, csq},
67 a := blas64.General{Rows: 2, Cols: 2, Stride: 2}
68 b := blas64.General{Rows: 2, Cols: 2, Stride: 2}
70 a.Data = []float64{a1, a2, 0, a3}
71 b.Data = []float64{b1, b2, 0, b3}
73 a.Data = []float64{a1, 0, a2, a3}
74 b.Data = []float64{b1, 0, b2, b3}
77 tmp := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 4)}
78 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, a, 0, tmp)
79 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, a)
80 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, b, 0, tmp)
81 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, b)
83 var gotA, gotB float64
91 if !floats.EqualWithinAbsOrRel(gotA, 0, 1e-14, 1e-14) {
92 t.Errorf("unexpected non-zero value for zero triangle of U^T*A*Q: %v", gotA)
94 if !floats.EqualWithinAbsOrRel(gotB, 0, 1e-14, 1e-14) {
95 t.Errorf("unexpected non-zero value for zero triangle of V^T*B*Q: %v", gotB)
101 func det2x2(a, b, c, d float64) float64 { return a*d - b*c }