OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dpotf2.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "gonum.org/v1/gonum/blas"
11         "gonum.org/v1/gonum/floats"
12 )
13
14 type Dpotf2er interface {
15         Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
16 }
17
18 func Dpotf2Test(t *testing.T, impl Dpotf2er) {
19         for _, test := range []struct {
20                 a   [][]float64
21                 pos bool
22                 U   [][]float64
23         }{
24                 {
25                         a: [][]float64{
26                                 {23, 37, 34, 32},
27                                 {108, 71, 48, 48},
28                                 {109, 109, 67, 58},
29                                 {106, 107, 106, 63},
30                         },
31                         pos: true,
32                         U: [][]float64{
33                                 {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
34                                 {0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
35                                 {0, 0, 3.582364210034111, 2.419258947036024},
36                                 {0, 0, 0, 3.401680257083044},
37                         },
38                 },
39                 {
40                         a: [][]float64{
41                                 {8, 2},
42                                 {2, 4},
43                         },
44                         pos: true,
45                         U: [][]float64{
46                                 {2.82842712474619, 0.707106781186547},
47                                 {0, 1.870828693386971},
48                         },
49                 },
50         } {
51                 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
52                 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
53                 aT := transpose(test.a)
54                 L := transpose(test.U)
55                 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
56                 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
57         }
58 }
59
60 func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
61         aFlat := flattenTri(a, stride, ul)
62         ansFlat := flattenTri(ans, stride, ul)
63         pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
64         if pos != testPos {
65                 t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
66                 return
67         }
68         if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
69                 t.Errorf("Result mismatch: Want %v, Got  %v", ansFlat, aFlat)
70         }
71 }
72
73 // flattenTri  with a certain stride. stride must be >= dimension. Puts repeatable
74 // nonce values in non-accessed places
75 func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
76         m := len(a)
77         n := len(a[0])
78         if stride < n {
79                 panic("bad stride")
80         }
81         upper := ul == blas.Upper
82         v := make([]float64, m*stride)
83         count := 1000.0
84         for i := 0; i < m; i++ {
85                 for j := 0; j < stride; j++ {
86                         if j >= n || (upper && j < i) || (!upper && j > i) {
87                                 // not accessed, so give a unique crazy number
88                                 v[i*stride+j] = count
89                                 count++
90                                 continue
91                         }
92                         v[i*stride+j] = a[i][j]
93                 }
94         }
95         return v
96 }
97
98 func transpose(a [][]float64) [][]float64 {
99         m := len(a)
100         n := len(a[0])
101         if m != n {
102                 panic("not square")
103         }
104         aNew := make([][]float64, m)
105         for i := 0; i < m; i++ {
106                 aNew[i] = make([]float64, n)
107         }
108         for i := 0; i < m; i++ {
109                 if len(a[i]) != n {
110                         panic("bad n size")
111                 }
112                 for j := 0; j < n; j++ {
113                         aNew[j][i] = a[i][j]
114                 }
115         }
116         return aNew
117 }