OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlarfg.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dlarfger interface {
18         Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64)
19 }
20
21 func DlarfgTest(t *testing.T, impl Dlarfger) {
22         rnd := rand.New(rand.NewSource(1))
23         for i, test := range []struct {
24                 alpha float64
25                 n     int
26                 x     []float64
27         }{
28                 {
29                         alpha: 4,
30                         n:     3,
31                 },
32                 {
33                         alpha: -2,
34                         n:     3,
35                 },
36                 {
37                         alpha: 0,
38                         n:     3,
39                 },
40                 {
41                         alpha: 1,
42                         n:     1,
43                 },
44                 {
45                         alpha: 1,
46                         n:     4,
47                         x:     []float64{4, 5, 6},
48                 },
49                 {
50                         alpha: 1,
51                         n:     4,
52                         x:     []float64{0, 0, 0},
53                 },
54                 {
55                         alpha: dlamchS,
56                         n:     4,
57                         x:     []float64{dlamchS, dlamchS, dlamchS},
58                 },
59         } {
60                 n := test.n
61                 incX := 1
62                 var x []float64
63                 if test.x == nil {
64                         x = make([]float64, n-1)
65                         for i := range x {
66                                 x[i] = rnd.Float64()
67                         }
68                 } else {
69                         if len(test.x) != n-1 {
70                                 panic("bad test")
71                         }
72                         x = make([]float64, n-1)
73                         copy(x, test.x)
74                 }
75                 xcopy := make([]float64, n-1)
76                 copy(xcopy, x)
77                 alpha := test.alpha
78                 beta, tau := impl.Dlarfg(n, alpha, x, incX)
79
80                 // Verify the returns and the values in v. Construct h and perform
81                 // the explicit multiplication.
82                 h := make([]float64, n*n)
83                 for i := 0; i < n; i++ {
84                         h[i*n+i] = 1
85                 }
86                 hmat := blas64.General{
87                         Rows:   n,
88                         Cols:   n,
89                         Stride: n,
90                         Data:   h,
91                 }
92                 v := make([]float64, n)
93                 copy(v[1:], x)
94                 v[0] = 1
95                 vVec := blas64.Vector{
96                         Inc:  1,
97                         Data: v,
98                 }
99                 blas64.Ger(-tau, vVec, vVec, hmat)
100                 eye := blas64.General{
101                         Rows:   n,
102                         Cols:   n,
103                         Stride: n,
104                         Data:   make([]float64, n*n),
105                 }
106                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
107                 iseye := true
108                 for i := 0; i < n; i++ {
109                         for j := 0; j < n; j++ {
110                                 if i == j {
111                                         if math.Abs(eye.Data[i*n+j]-1) > 1e-14 {
112                                                 iseye = false
113                                         }
114                                 } else {
115                                         if math.Abs(eye.Data[i*n+j]) > 1e-14 {
116                                                 iseye = false
117                                         }
118                                 }
119                         }
120                 }
121                 if !iseye {
122                         t.Errorf("H^T * H is not I %v", eye)
123                 }
124
125                 xVec := blas64.Vector{
126                         Inc:  1,
127                         Data: make([]float64, n),
128                 }
129                 xVec.Data[0] = test.alpha
130                 copy(xVec.Data[1:], xcopy)
131
132                 ans := make([]float64, n)
133                 ansVec := blas64.Vector{
134                         Inc:  1,
135                         Data: ans,
136                 }
137                 blas64.Gemv(blas.NoTrans, 1, hmat, xVec, 0, ansVec)
138                 if math.Abs(ans[0]-beta) > 1e-14 {
139                         t.Errorf("Case %v, beta mismatch. Want %v, got %v", i, ans[0], beta)
140                 }
141                 for i := 1; i < n; i++ {
142                         if math.Abs(ans[i]) > 1e-14 {
143                                 t.Errorf("Case %v, nonzero answer %v", i, ans)
144                                 break
145                         }
146                 }
147         }
148 }