OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dorgtr.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15 )
16
17 type Dorgtrer interface {
18         Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int)
19         Dsytrder
20 }
21
22 func DorgtrTest(t *testing.T, impl Dorgtrer) {
23         rnd := rand.New(rand.NewSource(1))
24         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
25                 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
26                         for _, test := range []struct {
27                                 n, lda int
28                         }{
29                                 {1, 0},
30                                 {2, 0},
31                                 {3, 0},
32                                 {6, 0},
33                                 {33, 0},
34                                 {100, 0},
35
36                                 {1, 3},
37                                 {2, 5},
38                                 {3, 7},
39                                 {6, 10},
40                                 {33, 50},
41                                 {100, 120},
42                         } {
43                                 n := test.n
44                                 lda := test.lda
45                                 if lda == 0 {
46                                         lda = n
47                                 }
48                                 a := make([]float64, n*lda)
49                                 for i := range a {
50                                         a[i] = rnd.NormFloat64()
51                                 }
52                                 aCopy := make([]float64, len(a))
53                                 copy(aCopy, a)
54
55                                 d := make([]float64, n)
56                                 e := make([]float64, n-1)
57                                 tau := make([]float64, n-1)
58                                 work := make([]float64, 1)
59                                 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1)
60                                 work = make([]float64, int(work[0]))
61                                 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
62
63                                 var lwork int
64                                 switch wl {
65                                 case minimumWork:
66                                         lwork = max(1, n-1)
67                                 case mediumWork:
68                                         work := make([]float64, 1)
69                                         impl.Dorgtr(uplo, n, a, lda, tau, work, -1)
70                                         lwork = (int(work[0]) + n - 1) / 2
71                                         lwork = max(1, lwork)
72                                 case optimumWork:
73                                         work := make([]float64, 1)
74                                         impl.Dorgtr(uplo, n, a, lda, tau, work, -1)
75                                         lwork = int(work[0])
76                                 }
77                                 work = nanSlice(lwork)
78
79                                 impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
80
81                                 q := blas64.General{
82                                         Rows:   n,
83                                         Cols:   n,
84                                         Stride: lda,
85                                         Data:   a,
86                                 }
87                                 tri := blas64.General{
88                                         Rows:   n,
89                                         Cols:   n,
90                                         Stride: n,
91                                         Data:   make([]float64, n*n),
92                                 }
93                                 for i := 0; i < n; i++ {
94                                         tri.Data[i*tri.Stride+i] = d[i]
95                                         if i != n-1 {
96                                                 tri.Data[i*tri.Stride+i+1] = e[i]
97                                                 tri.Data[(i+1)*tri.Stride+i] = e[i]
98                                         }
99                                 }
100
101                                 aMat := blas64.General{
102                                         Rows:   n,
103                                         Cols:   n,
104                                         Stride: n,
105                                         Data:   make([]float64, n*n),
106                                 }
107                                 if uplo == blas.Upper {
108                                         for i := 0; i < n; i++ {
109                                                 for j := i; j < n; j++ {
110                                                         v := aCopy[i*lda+j]
111                                                         aMat.Data[i*aMat.Stride+j] = v
112                                                         aMat.Data[j*aMat.Stride+i] = v
113                                                 }
114                                         }
115                                 } else {
116                                         for i := 0; i < n; i++ {
117                                                 for j := 0; j <= i; j++ {
118                                                         v := aCopy[i*lda+j]
119                                                         aMat.Data[i*aMat.Stride+j] = v
120                                                         aMat.Data[j*aMat.Stride+i] = v
121                                                 }
122                                         }
123                                 }
124
125                                 tmp := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)}
126                                 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, q, 0, tmp)
127
128                                 ans := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)}
129                                 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmp, 0, ans)
130
131                                 if !floats.EqualApprox(ans.Data, tri.Data, 1e-13) {
132                                         t.Errorf("Recombination mismatch. n = %v, isUpper = %v", n, uplo == blas.Upper)
133                                 }
134                         }
135                 }
136         }
137 }