OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dgetrs.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/blas"
13         "gonum.org/v1/gonum/blas/blas64"
14         "gonum.org/v1/gonum/floats"
15 )
16
17 type Dgetrser interface {
18         Dgetrfer
19         Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
20 }
21
22 func DgetrsTest(t *testing.T, impl Dgetrser) {
23         rnd := rand.New(rand.NewSource(1))
24         // TODO(btracey): Put more thought into creating more regularized matrices
25         // and what correct tolerances should be. Consider also seeding the random
26         // number in this test to make it more robust to code changes in other
27         // parts of the suite.
28         for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
29                 for _, test := range []struct {
30                         n, nrhs, lda, ldb int
31                         tol               float64
32                 }{
33                         {3, 3, 0, 0, 1e-12},
34                         {3, 5, 0, 0, 1e-12},
35                         {5, 3, 0, 0, 1e-12},
36
37                         {3, 3, 8, 10, 1e-12},
38                         {3, 5, 8, 10, 1e-12},
39                         {5, 3, 8, 10, 1e-12},
40
41                         {300, 300, 0, 0, 1e-8},
42                         {300, 500, 0, 0, 1e-8},
43                         {500, 300, 0, 0, 1e-6},
44
45                         {300, 300, 700, 600, 1e-8},
46                         {300, 500, 700, 600, 1e-8},
47                         {500, 300, 700, 600, 1e-6},
48                 } {
49                         n := test.n
50                         nrhs := test.nrhs
51                         lda := test.lda
52                         if lda == 0 {
53                                 lda = n
54                         }
55                         ldb := test.ldb
56                         if ldb == 0 {
57                                 ldb = nrhs
58                         }
59                         a := make([]float64, n*lda)
60                         for i := range a {
61                                 a[i] = rnd.Float64()
62                         }
63                         b := make([]float64, n*ldb)
64                         for i := range b {
65                                 b[i] = rnd.Float64()
66                         }
67                         aCopy := make([]float64, len(a))
68                         copy(aCopy, a)
69                         bCopy := make([]float64, len(b))
70                         copy(bCopy, b)
71
72                         ipiv := make([]int, n)
73                         for i := range ipiv {
74                                 ipiv[i] = rnd.Int()
75                         }
76
77                         // Compute the LU factorization.
78                         impl.Dgetrf(n, n, a, lda, ipiv)
79                         // Solve the system of equations given the result.
80                         impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb)
81
82                         // Check that the system of equations holds.
83                         A := blas64.General{
84                                 Rows:   n,
85                                 Cols:   n,
86                                 Stride: lda,
87                                 Data:   aCopy,
88                         }
89                         B := blas64.General{
90                                 Rows:   n,
91                                 Cols:   nrhs,
92                                 Stride: ldb,
93                                 Data:   bCopy,
94                         }
95                         X := blas64.General{
96                                 Rows:   n,
97                                 Cols:   nrhs,
98                                 Stride: ldb,
99                                 Data:   b,
100                         }
101                         tmp := blas64.General{
102                                 Rows:   n,
103                                 Cols:   nrhs,
104                                 Stride: ldb,
105                                 Data:   make([]float64, n*ldb),
106                         }
107                         copy(tmp.Data, bCopy)
108                         blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B)
109                         if !floats.EqualApprox(tmp.Data, bCopy, test.tol) {
110                                 t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb)
111                         }
112                 }
113         }
114 }