1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
17 type Dgetrser interface {
19 Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
22 func DgetrsTest(t *testing.T, impl Dgetrser) {
23 rnd := rand.New(rand.NewSource(1))
24 // TODO(btracey): Put more thought into creating more regularized matrices
25 // and what correct tolerances should be. Consider also seeding the random
26 // number in this test to make it more robust to code changes in other
27 // parts of the suite.
28 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
29 for _, test := range []struct {
41 {300, 300, 0, 0, 1e-8},
42 {300, 500, 0, 0, 1e-8},
43 {500, 300, 0, 0, 1e-6},
45 {300, 300, 700, 600, 1e-8},
46 {300, 500, 700, 600, 1e-8},
47 {500, 300, 700, 600, 1e-6},
59 a := make([]float64, n*lda)
63 b := make([]float64, n*ldb)
67 aCopy := make([]float64, len(a))
69 bCopy := make([]float64, len(b))
72 ipiv := make([]int, n)
77 // Compute the LU factorization.
78 impl.Dgetrf(n, n, a, lda, ipiv)
79 // Solve the system of equations given the result.
80 impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb)
82 // Check that the system of equations holds.
101 tmp := blas64.General{
105 Data: make([]float64, n*ldb),
107 copy(tmp.Data, bCopy)
108 blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B)
109 if !floats.EqualApprox(tmp.Data, bCopy, test.tol) {
110 t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb)