OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dorg2l.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas/blas64"
14 )
15
16 type Dorg2ler interface {
17         Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64)
18         Dgeql2er
19 }
20
21 func Dorg2lTest(t *testing.T, impl Dorg2ler) {
22         rnd := rand.New(rand.NewSource(1))
23         for _, test := range []struct {
24                 m, n, k, lda int
25         }{
26                 {5, 4, 3, 0},
27                 {5, 4, 4, 0},
28                 {3, 3, 2, 0},
29                 {5, 5, 5, 0},
30         } {
31                 m := test.m
32                 n := test.n
33                 k := test.k
34                 lda := test.lda
35                 if lda == 0 {
36                         lda = n
37                 }
38
39                 a := make([]float64, m*lda)
40                 for i := range a {
41                         a[i] = rnd.NormFloat64()
42                 }
43                 tau := nanSlice(max(m, n))
44                 work := make([]float64, n)
45                 impl.Dgeql2(m, n, a, lda, tau, work)
46
47                 aCopy := make([]float64, len(a))
48                 copy(aCopy, a)
49                 impl.Dorg2l(m, n, k, a, lda, tau[n-k:], work)
50                 if !hasOrthonormalColumns(m, n, a, lda) {
51                         t.Errorf("Q is not orthonormal. m = %v, n = %v, k = %v", m, n, k)
52                 }
53         }
54 }
55
56 // hasOrthornormalColumns checks that the columns of a are orthonormal.
57 func hasOrthonormalColumns(m, n int, a []float64, lda int) bool {
58         for i := 0; i < n; i++ {
59                 for j := i; j < n; j++ {
60                         dot := blas64.Dot(m,
61                                 blas64.Vector{Inc: lda, Data: a[i:]},
62                                 blas64.Vector{Inc: lda, Data: a[j:]},
63                         )
64                         if i == j {
65                                 if math.Abs(dot-1) > 1e-10 {
66                                         return false
67                                 }
68                         } else {
69                                 if math.Abs(dot) > 1e-10 {
70                                         return false
71                                 }
72                         }
73                 }
74         }
75         return true
76 }