OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgeql2.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15 )
16
17 type Dgeql2er interface {
18         Dgeql2(m, n int, a []float64, lda int, tau, work []float64)
19 }
20
21 func Dgeql2Test(t *testing.T, impl Dgeql2er) {
22         rnd := rand.New(rand.NewSource(1))
23         // TODO(btracey): Add tests for m < n.
24         for _, test := range []struct {
25                 m, n, lda int
26         }{
27                 {5, 5, 0},
28                 {5, 3, 0},
29                 {5, 4, 0},
30         } {
31                 m := test.m
32                 n := test.n
33                 lda := test.lda
34                 if lda == 0 {
35                         lda = n
36                 }
37                 a := make([]float64, m*lda)
38                 for i := range a {
39                         a[i] = rnd.NormFloat64()
40                 }
41                 tau := nanSlice(min(m, n))
42                 work := nanSlice(n)
43
44                 aCopy := make([]float64, len(a))
45                 copy(aCopy, a)
46                 impl.Dgeql2(m, n, a, lda, tau, work)
47
48                 k := min(m, n)
49                 // Construct Q.
50                 q := blas64.General{
51                         Rows:   m,
52                         Cols:   m,
53                         Stride: m,
54                         Data:   make([]float64, m*m),
55                 }
56                 for i := 0; i < m; i++ {
57                         q.Data[i*q.Stride+i] = 1
58                 }
59                 for i := 0; i < k; i++ {
60                         h := blas64.General{Rows: m, Cols: m, Stride: m, Data: make([]float64, m*m)}
61                         for j := 0; j < m; j++ {
62                                 h.Data[j*h.Stride+j] = 1
63                         }
64                         v := blas64.Vector{Inc: 1, Data: make([]float64, m)}
65                         v.Data[m-k+i] = 1
66                         for j := 0; j < m-k+i; j++ {
67                                 v.Data[j] = a[j*lda+n-k+i]
68                         }
69                         blas64.Ger(-tau[i], v, v, h)
70                         qTmp := blas64.General{Rows: q.Rows, Cols: q.Cols, Stride: q.Stride, Data: make([]float64, len(q.Data))}
71                         copy(qTmp.Data, q.Data)
72                         blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qTmp, 0, q)
73                 }
74                 if !isOrthonormal(q) {
75                         t.Errorf("Q is not orthonormal")
76                 }
77                 l := blas64.General{
78                         Rows:   m,
79                         Cols:   n,
80                         Stride: n,
81                         Data:   make([]float64, m*n),
82                 }
83                 if m >= n {
84                         for i := m - n; i < m; i++ {
85                                 for j := 0; j <= min(i-(m-n), n-1); j++ {
86                                         l.Data[i*l.Stride+j] = a[i*lda+j]
87                                 }
88                         }
89                 } else {
90                         panic("untested")
91                 }
92                 ans := blas64.General{Rows: m, Cols: n, Stride: lda, Data: make([]float64, len(a))}
93                 copy(ans.Data, a)
94
95                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, l, 0, ans)
96                 if !floats.EqualApprox(ans.Data, aCopy, 1e-10) {
97                         t.Errorf("Reconstruction mismatch: m = %v, n = %v", m, n)
98                 }
99         }
100 }