OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgerqf.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dgerqfer interface {
18         Dgerqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
19 }
20
21 func DgerqfTest(t *testing.T, impl Dgerqfer) {
22         const tol = 1e-13
23
24         rnd := rand.New(rand.NewSource(1))
25         for c, test := range []struct {
26                 m, n, lda int
27         }{
28                 {1, 1, 0},
29                 {2, 2, 0},
30                 {3, 2, 0},
31                 {2, 3, 0},
32                 {1, 12, 0},
33                 {2, 6, 0},
34                 {3, 4, 0},
35                 {4, 3, 0},
36                 {6, 2, 0},
37                 {12, 1, 0},
38                 {200, 180, 0},
39                 {180, 200, 0},
40                 {200, 200, 0},
41                 {1, 1, 20},
42                 {2, 2, 20},
43                 {3, 2, 20},
44                 {2, 3, 20},
45                 {1, 12, 20},
46                 {2, 6, 20},
47                 {3, 4, 20},
48                 {4, 3, 20},
49                 {6, 2, 20},
50                 {12, 1, 20},
51                 {200, 180, 220},
52                 {180, 200, 220},
53                 {200, 200, 220},
54         } {
55                 n := test.n
56                 m := test.m
57                 lda := test.lda
58                 if lda == 0 {
59                         lda = test.n
60                 }
61                 a := make([]float64, m*lda)
62                 for i := range a {
63                         a[i] = rnd.Float64()
64                 }
65                 aCopy := make([]float64, len(a))
66                 copy(aCopy, a)
67                 k := min(m, n)
68                 tau := make([]float64, k)
69                 for i := range tau {
70                         tau[i] = rnd.Float64()
71                 }
72                 work := []float64{0}
73                 impl.Dgerqf(m, n, a, lda, tau, work, -1)
74                 lwkopt := int(work[0])
75                 for _, wk := range []struct {
76                         name   string
77                         length int
78                 }{
79                         {name: "short", length: m},
80                         {name: "medium", length: lwkopt - 1},
81                         {name: "long", length: lwkopt},
82                 } {
83                         if wk.length < max(1, m) {
84                                 continue
85                         }
86                         lwork := wk.length
87                         work = make([]float64, lwork)
88                         for i := range work {
89                                 work[i] = rnd.Float64()
90                         }
91                         copy(a, aCopy)
92                         impl.Dgerqf(m, n, a, lda, tau, work, lwork)
93
94                         // Test that the RQ factorization has completed successfully. Compute
95                         // Q based on the vectors.
96                         q := constructQ("RQ", m, n, a, lda, tau)
97
98                         // Check that q is orthonormal
99                         for i := 0; i < q.Rows; i++ {
100                                 nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]})
101                                 if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 {
102                                         t.Errorf("Case %v, q not normal", c)
103                                 }
104                                 for j := 0; j < i; j++ {
105                                         dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]})
106                                         if math.IsNaN(dot) || math.Abs(dot) > 1e-14 {
107                                                 t.Errorf("Case %v, q not orthogonal", c)
108                                         }
109                                 }
110                         }
111                         // Check that A = R * Q
112                         r := blas64.General{
113                                 Rows:   m,
114                                 Cols:   n,
115                                 Stride: n,
116                                 Data:   make([]float64, m*n),
117                         }
118                         for i := 0; i < m; i++ {
119                                 off := m - n
120                                 for j := max(0, i-off); j < n; j++ {
121                                         r.Data[i*r.Stride+j] = a[i*lda+j]
122                                 }
123                         }
124
125                         got := blas64.General{
126                                 Rows:   m,
127                                 Cols:   n,
128                                 Stride: lda,
129                                 Data:   make([]float64, m*lda),
130                         }
131                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, r, q, 0, got)
132                         want := blas64.General{
133                                 Rows:   m,
134                                 Cols:   n,
135                                 Stride: lda,
136                                 Data:   aCopy,
137                         }
138                         if !equalApproxGeneral(got, want, tol) {
139                                 t.Errorf("Case %d, R*Q != a %s\ngot: %+v\nwant:%+v", c, wk.name, got, want)
140                         }
141                 }
142         }
143 }