// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "testing" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/floats" ) type Dpotf2er interface { Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) } func Dpotf2Test(t *testing.T, impl Dpotf2er) { for _, test := range []struct { a [][]float64 pos bool U [][]float64 }{ { a: [][]float64{ {23, 37, 34, 32}, {108, 71, 48, 48}, {109, 109, 67, 58}, {106, 107, 106, 63}, }, pos: true, U: [][]float64{ {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393}, {0, 3.387958215439679, -1.976308959006481, -1.026654004678691}, {0, 0, 3.582364210034111, 2.419258947036024}, {0, 0, 0, 3.401680257083044}, }, }, { a: [][]float64{ {8, 2}, {2, 4}, }, pos: true, U: [][]float64{ {2.82842712474619, 0.707106781186547}, {0, 1.870828693386971}, }, }, } { testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper) testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper) aT := transpose(test.a) L := transpose(test.U) testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower) testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower) } } func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) { aFlat := flattenTri(a, stride, ul) ansFlat := flattenTri(ans, stride, ul) pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride) if pos != testPos { t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos) return } if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) { t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat) } } // flattenTri with a certain stride. stride must be >= dimension. Puts repeatable // nonce values in non-accessed places func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 { m := len(a) n := len(a[0]) if stride < n { panic("bad stride") } upper := ul == blas.Upper v := make([]float64, m*stride) count := 1000.0 for i := 0; i < m; i++ { for j := 0; j < stride; j++ { if j >= n || (upper && j < i) || (!upper && j > i) { // not accessed, so give a unique crazy number v[i*stride+j] = count count++ continue } v[i*stride+j] = a[i][j] } } return v } func transpose(a [][]float64) [][]float64 { m := len(a) n := len(a[0]) if m != n { panic("not square") } aNew := make([][]float64, m) for i := 0; i < m; i++ { aNew[i] = make([]float64, n) } for i := 0; i < m; i++ { if len(a[i]) != n { panic("bad n size") } for j := 0; j < n; j++ { aNew[j][i] = a[i][j] } } return aNew }