OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgels.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15 )
16
17 type Dgelser interface {
18         Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
19 }
20
21 func DgelsTest(t *testing.T, impl Dgelser) {
22         rnd := rand.New(rand.NewSource(1))
23         for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
24                 for _, test := range []struct {
25                         m, n, nrhs, lda, ldb int
26                 }{
27                         {3, 4, 5, 0, 0},
28                         {3, 5, 4, 0, 0},
29                         {4, 3, 5, 0, 0},
30                         {4, 5, 3, 0, 0},
31                         {5, 3, 4, 0, 0},
32                         {5, 4, 3, 0, 0},
33                         {3, 4, 5, 10, 20},
34                         {3, 5, 4, 10, 20},
35                         {4, 3, 5, 10, 20},
36                         {4, 5, 3, 10, 20},
37                         {5, 3, 4, 10, 20},
38                         {5, 4, 3, 10, 20},
39                         {3, 4, 5, 20, 10},
40                         {3, 5, 4, 20, 10},
41                         {4, 3, 5, 20, 10},
42                         {4, 5, 3, 20, 10},
43                         {5, 3, 4, 20, 10},
44                         {5, 4, 3, 20, 10},
45                         {200, 300, 400, 0, 0},
46                         {200, 400, 300, 0, 0},
47                         {300, 200, 400, 0, 0},
48                         {300, 400, 200, 0, 0},
49                         {400, 200, 300, 0, 0},
50                         {400, 300, 200, 0, 0},
51                         {200, 300, 400, 500, 600},
52                         {200, 400, 300, 500, 600},
53                         {300, 200, 400, 500, 600},
54                         {300, 400, 200, 500, 600},
55                         {400, 200, 300, 500, 600},
56                         {400, 300, 200, 500, 600},
57                         {200, 300, 400, 600, 500},
58                         {200, 400, 300, 600, 500},
59                         {300, 200, 400, 600, 500},
60                         {300, 400, 200, 600, 500},
61                         {400, 200, 300, 600, 500},
62                         {400, 300, 200, 600, 500},
63                 } {
64                         m := test.m
65                         n := test.n
66                         nrhs := test.nrhs
67
68                         lda := test.lda
69                         if lda == 0 {
70                                 lda = n
71                         }
72                         a := make([]float64, m*lda)
73                         for i := range a {
74                                 a[i] = rnd.Float64()
75                         }
76                         aCopy := make([]float64, len(a))
77                         copy(aCopy, a)
78
79                         // Size of b is the same trans or no trans, because the number of rows
80                         // has to be the max of (m,n).
81                         mb := max(m, n)
82                         nb := nrhs
83                         ldb := test.ldb
84                         if ldb == 0 {
85                                 ldb = nb
86                         }
87                         b := make([]float64, mb*ldb)
88                         for i := range b {
89                                 b[i] = rnd.Float64()
90                         }
91                         bCopy := make([]float64, len(b))
92                         copy(bCopy, b)
93
94                         // Find optimal work length.
95                         work := make([]float64, 1)
96                         impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
97
98                         // Perform linear solve
99                         work = make([]float64, int(work[0]))
100                         lwork := len(work)
101                         for i := range work {
102                                 work[i] = rnd.Float64()
103                         }
104                         impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
105
106                         // Check that the answer is correct by comparing to the normal equations.
107                         aMat := blas64.General{
108                                 Rows:   m,
109                                 Cols:   n,
110                                 Stride: lda,
111                                 Data:   make([]float64, len(aCopy)),
112                         }
113                         copy(aMat.Data, aCopy)
114                         szAta := n
115                         if trans == blas.Trans {
116                                 szAta = m
117                         }
118                         aTA := blas64.General{
119                                 Rows:   szAta,
120                                 Cols:   szAta,
121                                 Stride: szAta,
122                                 Data:   make([]float64, szAta*szAta),
123                         }
124
125                         // Compute A^T * A if notrans and A * A^T otherwise.
126                         if trans == blas.NoTrans {
127                                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
128                         } else {
129                                 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
130                         }
131
132                         // Multiply by X.
133                         X := blas64.General{
134                                 Rows:   szAta,
135                                 Cols:   nrhs,
136                                 Stride: ldb,
137                                 Data:   b,
138                         }
139                         ans := blas64.General{
140                                 Rows:   aTA.Rows,
141                                 Cols:   X.Cols,
142                                 Stride: X.Cols,
143                                 Data:   make([]float64, aTA.Rows*X.Cols),
144                         }
145                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
146
147                         B := blas64.General{
148                                 Rows:   szAta,
149                                 Cols:   nrhs,
150                                 Stride: ldb,
151                                 Data:   make([]float64, len(bCopy)),
152                         }
153
154                         copy(B.Data, bCopy)
155                         var ans2 blas64.General
156                         if trans == blas.NoTrans {
157                                 ans2 = blas64.General{
158                                         Rows:   aMat.Cols,
159                                         Cols:   B.Cols,
160                                         Stride: B.Cols,
161                                         Data:   make([]float64, aMat.Cols*B.Cols),
162                                 }
163                         } else {
164                                 ans2 = blas64.General{
165                                         Rows:   aMat.Rows,
166                                         Cols:   B.Cols,
167                                         Stride: B.Cols,
168                                         Data:   make([]float64, aMat.Rows*B.Cols),
169                                 }
170                         }
171
172                         // Compute A^T B if Trans or A * B otherwise
173                         if trans == blas.NoTrans {
174                                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
175                         } else {
176                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
177                         }
178                         if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
179                                 t.Errorf("Normal equations not satisfied")
180                         }
181                 }
182         }
183 }