OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlags2.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15         "gonum.org/v1/gonum/floats"
16 )
17
18 type Dlags2er interface {
19         Dlags2(upper bool, a1, a2, a3, b1, b2, b3 float64) (csu, snu, csv, snv, csq, snq float64)
20 }
21
22 func Dlags2Test(t *testing.T, impl Dlags2er) {
23         rnd := rand.New(rand.NewSource(1))
24         for _, upper := range []bool{true, false} {
25                 for i := 0; i < 100; i++ {
26                         a1 := rnd.Float64()
27                         a2 := rnd.Float64()
28                         a3 := rnd.Float64()
29                         b1 := rnd.Float64()
30                         b2 := rnd.Float64()
31                         b3 := rnd.Float64()
32
33                         csu, snu, csv, snv, csq, snq := impl.Dlags2(upper, a1, a2, a3, b1, b2, b3)
34
35                         detU := det2x2(csu, snu, -snu, csu)
36                         if !floats.EqualWithinAbsOrRel(math.Abs(detU), 1, 1e-14, 1e-14) {
37                                 t.Errorf("U not orthogonal: det(U)=%v", detU)
38                         }
39                         detV := det2x2(csv, snv, -snv, csv)
40                         if !floats.EqualWithinAbsOrRel(math.Abs(detV), 1, 1e-14, 1e-14) {
41                                 t.Errorf("V not orthogonal: det(V)=%v", detV)
42                         }
43                         detQ := det2x2(csq, snq, -snq, csq)
44                         if !floats.EqualWithinAbsOrRel(math.Abs(detQ), 1, 1e-14, 1e-14) {
45                                 t.Errorf("Q not orthogonal: det(Q)=%v", detQ)
46                         }
47
48                         u := blas64.General{
49                                 Rows:   2,
50                                 Cols:   2,
51                                 Stride: 2,
52                                 Data:   []float64{csu, snu, -snu, csu},
53                         }
54                         v := blas64.General{
55                                 Rows:   2,
56                                 Cols:   2,
57                                 Stride: 2,
58                                 Data:   []float64{csv, snv, -snv, csv},
59                         }
60                         q := blas64.General{
61                                 Rows:   2,
62                                 Cols:   2,
63                                 Stride: 2,
64                                 Data:   []float64{csq, snq, -snq, csq},
65                         }
66
67                         a := blas64.General{Rows: 2, Cols: 2, Stride: 2}
68                         b := blas64.General{Rows: 2, Cols: 2, Stride: 2}
69                         if upper {
70                                 a.Data = []float64{a1, a2, 0, a3}
71                                 b.Data = []float64{b1, b2, 0, b3}
72                         } else {
73                                 a.Data = []float64{a1, 0, a2, a3}
74                                 b.Data = []float64{b1, 0, b2, b3}
75                         }
76
77                         tmp := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 4)}
78                         blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, a, 0, tmp)
79                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, a)
80                         blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, b, 0, tmp)
81                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, b)
82
83                         var gotA, gotB float64
84                         if upper {
85                                 gotA = a.Data[1]
86                                 gotB = b.Data[1]
87                         } else {
88                                 gotA = a.Data[2]
89                                 gotB = b.Data[2]
90                         }
91                         if !floats.EqualWithinAbsOrRel(gotA, 0, 1e-14, 1e-14) {
92                                 t.Errorf("unexpected non-zero value for zero triangle of U^T*A*Q: %v", gotA)
93                         }
94                         if !floats.EqualWithinAbsOrRel(gotB, 0, 1e-14, 1e-14) {
95                                 t.Errorf("unexpected non-zero value for zero triangle of V^T*B*Q: %v", gotB)
96                         }
97                 }
98         }
99 }
100
101 func det2x2(a, b, c, d float64) float64 { return a*d - b*c }