1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "gonum.org/v1/gonum/blas"
11 "gonum.org/v1/gonum/floats"
14 type Dpotf2er interface {
15 Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
18 func Dpotf2Test(t *testing.T, impl Dpotf2er) {
19 for _, test := range []struct {
33 {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
34 {0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
35 {0, 0, 3.582364210034111, 2.419258947036024},
36 {0, 0, 0, 3.401680257083044},
46 {2.82842712474619, 0.707106781186547},
47 {0, 1.870828693386971},
51 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
52 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
53 aT := transpose(test.a)
54 L := transpose(test.U)
55 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
56 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
60 func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
61 aFlat := flattenTri(a, stride, ul)
62 ansFlat := flattenTri(ans, stride, ul)
63 pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
65 t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
68 if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
69 t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat)
73 // flattenTri with a certain stride. stride must be >= dimension. Puts repeatable
74 // nonce values in non-accessed places
75 func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
81 upper := ul == blas.Upper
82 v := make([]float64, m*stride)
84 for i := 0; i < m; i++ {
85 for j := 0; j < stride; j++ {
86 if j >= n || (upper && j < i) || (!upper && j > i) {
87 // not accessed, so give a unique crazy number
92 v[i*stride+j] = a[i][j]
98 func transpose(a [][]float64) [][]float64 {
104 aNew := make([][]float64, m)
105 for i := 0; i < m; i++ {
106 aNew[i] = make([]float64, n)
108 for i := 0; i < m; i++ {
112 for j := 0; j < n; j++ {