OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlasr.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15         "gonum.org/v1/gonum/floats"
16         "gonum.org/v1/gonum/lapack"
17 )
18
19 type Dlasrer interface {
20         Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
21 }
22
23 func DlasrTest(t *testing.T, impl Dlasrer) {
24         rnd := rand.New(rand.NewSource(1))
25         for _, side := range []blas.Side{blas.Left, blas.Right} {
26                 for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
27                         for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
28                                 for _, test := range []struct {
29                                         m, n, lda int
30                                 }{
31                                         {5, 5, 0},
32                                         {5, 10, 0},
33                                         {10, 5, 0},
34
35                                         {5, 5, 20},
36                                         {5, 10, 20},
37                                         {10, 5, 20},
38                                 } {
39                                         m := test.m
40                                         n := test.n
41                                         lda := test.lda
42                                         if lda == 0 {
43                                                 lda = n
44                                         }
45                                         a := make([]float64, m*lda)
46                                         for i := range a {
47                                                 a[i] = rnd.Float64()
48                                         }
49                                         var s, c []float64
50                                         if side == blas.Left {
51                                                 s = make([]float64, m-1)
52                                                 c = make([]float64, m-1)
53                                         } else {
54                                                 s = make([]float64, n-1)
55                                                 c = make([]float64, n-1)
56                                         }
57                                         for k := range s {
58                                                 theta := rnd.Float64() * 2 * math.Pi
59                                                 s[k] = math.Sin(theta)
60                                                 c[k] = math.Cos(theta)
61                                         }
62                                         aCopy := make([]float64, len(a))
63                                         copy(a, aCopy)
64                                         impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
65
66                                         pSize := m
67                                         if side == blas.Right {
68                                                 pSize = n
69                                         }
70                                         p := blas64.General{
71                                                 Rows:   pSize,
72                                                 Cols:   pSize,
73                                                 Stride: pSize,
74                                                 Data:   make([]float64, pSize*pSize),
75                                         }
76                                         pk := blas64.General{
77                                                 Rows:   pSize,
78                                                 Cols:   pSize,
79                                                 Stride: pSize,
80                                                 Data:   make([]float64, pSize*pSize),
81                                         }
82                                         ptmp := blas64.General{
83                                                 Rows:   pSize,
84                                                 Cols:   pSize,
85                                                 Stride: pSize,
86                                                 Data:   make([]float64, pSize*pSize),
87                                         }
88                                         for i := 0; i < pSize; i++ {
89                                                 p.Data[i*p.Stride+i] = 1
90                                                 ptmp.Data[i*p.Stride+i] = 1
91                                         }
92                                         // Compare to direct computation.
93                                         for k := range s {
94                                                 for i := range p.Data {
95                                                         pk.Data[i] = 0
96                                                 }
97                                                 for i := 0; i < pSize; i++ {
98                                                         pk.Data[i*p.Stride+i] = 1
99                                                 }
100                                                 if pivot == lapack.Variable {
101                                                         pk.Data[k*p.Stride+k] = c[k]
102                                                         pk.Data[k*p.Stride+k+1] = s[k]
103                                                         pk.Data[(k+1)*p.Stride+k] = -s[k]
104                                                         pk.Data[(k+1)*p.Stride+k+1] = c[k]
105                                                 } else if pivot == lapack.Top {
106                                                         pk.Data[0] = c[k]
107                                                         pk.Data[k+1] = s[k]
108                                                         pk.Data[(k+1)*p.Stride] = -s[k]
109                                                         pk.Data[(k+1)*p.Stride+k+1] = c[k]
110                                                 } else {
111                                                         pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
112                                                         pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
113                                                         pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
114                                                         pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
115                                                 }
116                                                 if direct == lapack.Forward {
117                                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
118                                                 } else {
119                                                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
120                                                 }
121                                                 copy(ptmp.Data, p.Data)
122                                         }
123
124                                         aMat := blas64.General{
125                                                 Rows:   m,
126                                                 Cols:   n,
127                                                 Stride: lda,
128                                                 Data:   make([]float64, m*lda),
129                                         }
130                                         copy(a, aCopy)
131                                         newA := blas64.General{
132                                                 Rows:   m,
133                                                 Cols:   n,
134                                                 Stride: lda,
135                                                 Data:   make([]float64, m*lda),
136                                         }
137                                         if side == blas.Left {
138                                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
139                                         } else {
140                                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
141                                         }
142                                         if !floats.EqualApprox(newA.Data, a, 1e-12) {
143                                                 t.Errorf("A update mismatch")
144                                         }
145                                 }
146                         }
147                 }
148         }
149 }