OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgeqp3.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dgeqp3er interface {
18         Dlapmter
19         Dgeqp3(m, n int, a []float64, lda int, jpvt []int, tau, work []float64, lwork int)
20 }
21
22 func Dgeqp3Test(t *testing.T, impl Dgeqp3er) {
23         rnd := rand.New(rand.NewSource(1))
24         for c, test := range []struct {
25                 m, n, lda int
26         }{
27                 {1, 1, 0},
28                 {2, 2, 0},
29                 {3, 2, 0},
30                 {2, 3, 0},
31                 {1, 12, 0},
32                 {2, 6, 0},
33                 {3, 4, 0},
34                 {4, 3, 0},
35                 {6, 2, 0},
36                 {12, 1, 0},
37                 {1, 1, 20},
38                 {2, 2, 20},
39                 {3, 2, 20},
40                 {2, 3, 20},
41                 {1, 12, 20},
42                 {2, 6, 20},
43                 {3, 4, 20},
44                 {4, 3, 20},
45                 {6, 2, 20},
46                 {12, 1, 20},
47                 {129, 256, 0},
48                 {256, 129, 0},
49                 {129, 256, 266},
50                 {256, 129, 266},
51         } {
52                 n := test.n
53                 m := test.m
54                 lda := test.lda
55                 if lda == 0 {
56                         lda = test.n
57                 }
58                 const (
59                         all = iota
60                         some
61                         none
62                 )
63                 for _, free := range []int{all, some, none} {
64                         a := make([]float64, m*lda)
65                         for i := range a {
66                                 a[i] = rnd.Float64()
67                         }
68                         aCopy := make([]float64, len(a))
69                         copy(aCopy, a)
70                         jpvt := make([]int, n)
71                         for j := range jpvt {
72                                 switch free {
73                                 case all:
74                                         jpvt[j] = -1
75                                 case some:
76                                         jpvt[j] = rnd.Intn(2) - 1
77                                 case none:
78                                         jpvt[j] = 0
79                                 default:
80                                         panic("bad freedom")
81                                 }
82                         }
83                         k := min(m, n)
84                         tau := make([]float64, k)
85                         for i := range tau {
86                                 tau[i] = rnd.Float64()
87                         }
88                         work := make([]float64, 1)
89                         impl.Dgeqp3(m, n, a, lda, jpvt, tau, work, -1)
90                         lwork := int(work[0])
91                         work = make([]float64, lwork)
92                         for i := range work {
93                                 work[i] = rnd.Float64()
94                         }
95                         impl.Dgeqp3(m, n, a, lda, jpvt, tau, work, lwork)
96
97                         // Test that the QR factorization has completed successfully. Compute
98                         // Q based on the vectors.
99                         q := constructQ("QR", m, n, a, lda, tau)
100
101                         // Check that q is orthonormal
102                         for i := 0; i < m; i++ {
103                                 nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]})
104                                 if math.Abs(nrm-1) > 1e-13 {
105                                         t.Errorf("Case %v, q not normal", c)
106                                 }
107                                 for j := 0; j < i; j++ {
108                                         dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]})
109                                         if math.Abs(dot) > 1e-14 {
110                                                 t.Errorf("Case %v, q not orthogonal", c)
111                                         }
112                                 }
113                         }
114                         // Check that A * P = Q * R
115                         r := blas64.General{
116                                 Rows:   m,
117                                 Cols:   n,
118                                 Stride: n,
119                                 Data:   make([]float64, m*n),
120                         }
121                         for i := 0; i < m; i++ {
122                                 for j := i; j < n; j++ {
123                                         r.Data[i*n+j] = a[i*lda+j]
124                                 }
125                         }
126                         got := nanGeneral(m, n, lda)
127                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, got)
128
129                         want := blas64.General{Rows: m, Cols: n, Stride: lda, Data: aCopy}
130                         impl.Dlapmt(true, want.Rows, want.Cols, want.Data, want.Stride, jpvt)
131                         if !equalApproxGeneral(got, want, 1e-13) {
132                                 t.Errorf("Case %v,  Q*R != A*P\nQ*R=%v\nA*P=%v", c, got, want)
133                         }
134                 }
135         }
136 }