OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dtrtri.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "math"
9         "testing"
10
11         "golang.org/x/exp/rand"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/blas/blas64"
15 )
16
17 type Dtrtrier interface {
18         Dtrconer
19         Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool
20 }
21
22 func DtrtriTest(t *testing.T, impl Dtrtrier) {
23         const tol = 1e-10
24         rnd := rand.New(rand.NewSource(1))
25         bi := blas64.Implementation()
26         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
27                 for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
28                         for _, test := range []struct {
29                                 n, lda int
30                         }{
31                                 {3, 0},
32                                 {70, 0},
33                                 {200, 0},
34                                 {3, 5},
35                                 {70, 92},
36                                 {200, 205},
37                         } {
38                                 n := test.n
39                                 lda := test.lda
40                                 if lda == 0 {
41                                         lda = n
42                                 }
43                                 a := make([]float64, n*lda)
44                                 for i := range a {
45                                         a[i] = rnd.Float64()
46                                 }
47                                 for i := 0; i < n; i++ {
48                                         // This keeps the matrices well conditioned.
49                                         a[i*lda+i] += float64(n)
50                                 }
51                                 aCopy := make([]float64, len(a))
52                                 copy(aCopy, a)
53                                 impl.Dtrtri(uplo, diag, n, a, lda)
54                                 if uplo == blas.Upper {
55                                         for i := 1; i < n; i++ {
56                                                 for j := 0; j < i; j++ {
57                                                         aCopy[i*lda+j] = 0
58                                                         a[i*lda+j] = 0
59                                                 }
60                                         }
61                                 } else {
62                                         for i := 0; i < n; i++ {
63                                                 for j := i + 1; j < n; j++ {
64                                                         aCopy[i*lda+j] = 0
65                                                         a[i*lda+j] = 0
66                                                 }
67                                         }
68                                 }
69                                 if diag == blas.Unit {
70                                         for i := 0; i < n; i++ {
71                                                 a[i*lda+i] = 1
72                                                 aCopy[i*lda+i] = 1
73                                         }
74                                 }
75                                 ans := make([]float64, len(a))
76                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
77                                 iseye := true
78                                 for i := 0; i < n; i++ {
79                                         for j := 0; j < n; j++ {
80                                                 if i == j {
81                                                         if math.Abs(ans[i*lda+i]-1) > tol {
82                                                                 iseye = false
83                                                                 break
84                                                         }
85                                                 } else {
86                                                         if math.Abs(ans[i*lda+j]) > tol {
87                                                                 iseye = false
88                                                                 break
89                                                         }
90                                                 }
91                                         }
92                                 }
93                                 if !iseye {
94                                         t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v",
95                                                 uplo == blas.Upper, diag == blas.Unit, n, lda)
96                                 }
97                         }
98                 }
99         }
100 }