OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / testblas / ztrsv.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testblas
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11         "gonum.org/v1/gonum/blas"
12 )
13
14 type Ztrsver interface {
15         Ztrsv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n int, a []complex128, lda int, x []complex128, incX int)
16
17         Ztrmver
18 }
19
20 func ZtrsvTest(t *testing.T, impl Ztrsver) {
21         rnd := rand.New(rand.NewSource(1))
22         for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
23                 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
24                         for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
25                                 for _, n := range []int{0, 1, 2, 3, 4, 10} {
26                                         for _, lda := range []int{max(1, n), n + 11} {
27                                                 for _, incX := range []int{-11, -3, -2, -1, 1, 2, 3, 7} {
28                                                         ztrsvTest(t, impl, uplo, trans, diag, n, lda, incX, rnd)
29                                                 }
30                                         }
31                                 }
32                         }
33                 }
34         }
35 }
36
37 func ztrsvTest(t *testing.T, impl Ztrsver, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, lda, incX int, rnd *rand.Rand) {
38         const tol = 1e-10
39
40         a := makeZGeneral(nil, n, n, lda)
41         if uplo == blas.Upper {
42                 for i := 0; i < n; i++ {
43                         for j := i; j < n; j++ {
44                                 re := rnd.NormFloat64()
45                                 im := rnd.NormFloat64()
46                                 a[i*lda+j] = complex(re, im)
47                         }
48                 }
49         } else {
50                 for i := 0; i < n; i++ {
51                         for j := 0; j <= i; j++ {
52                                 re := rnd.NormFloat64()
53                                 im := rnd.NormFloat64()
54                                 a[i*lda+j] = complex(re, im)
55                         }
56                 }
57         }
58         if diag == blas.Unit {
59                 for i := 0; i < n; i++ {
60                         a[i*lda+i] = znan
61                 }
62         }
63         aCopy := make([]complex128, len(a))
64         copy(aCopy, a)
65
66         xtest := make([]complex128, n)
67         for i := range xtest {
68                 re := rnd.NormFloat64()
69                 im := rnd.NormFloat64()
70                 xtest[i] = complex(re, im)
71         }
72         x := makeZVector(xtest, incX)
73         want := make([]complex128, len(x))
74         copy(want, x)
75
76         impl.Ztrmv(uplo, trans, diag, n, a, lda, x, incX)
77         impl.Ztrsv(uplo, trans, diag, n, a, lda, x, incX)
78
79         if !zsame(a, aCopy) {
80                 t.Errorf("Case uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v: unexpected modification of A", uplo, trans, diag, n, lda, incX)
81         }
82         if !zSameAtNonstrided(x, want, incX) {
83                 t.Errorf("Case uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v: unexpected modification of x\nwant %v\ngot  %v", uplo, trans, diag, n, lda, incX, want, x)
84         }
85         if !zEqualApproxAtStrided(x, want, incX, tol) {
86                 t.Errorf("Case uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v: unexpected result\nwant %v\ngot  %v", uplo, trans, diag, n, lda, incX, want, x)
87         }
88 }