OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / testblas / dtxmv.go
1 package testblas
2
3 import (
4         "testing"
5
6         "gonum.org/v1/gonum/blas"
7 )
8
9 type Dtxmver interface {
10         Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
11         Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
12         Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, x []float64, incX int)
13 }
14
15 type vec struct {
16         data []float64
17         inc  int
18 }
19
20 var cases = []struct {
21         n, k       int
22         ul         blas.Uplo
23         d          blas.Diag
24         ldab       int
25         tr, tb, tp []float64
26         ins        []vec
27         solNoTrans []float64
28         solTrans   []float64
29 }{
30         {
31                 n:    3,
32                 k:    1,
33                 ul:   blas.Upper,
34                 d:    blas.NonUnit,
35                 tr:   []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
36                 tb:   []float64{1, 2, 3, 4, 5, 0},
37                 ldab: 2,
38                 tp:   []float64{1, 2, 0, 3, 4, 5},
39                 ins: []vec{
40                         {[]float64{2, 3, 4}, 1},
41                         {[]float64{2, 1, 3, 1, 4}, 2},
42                         {[]float64{4, 1, 3, 1, 2}, -2},
43                 },
44                 solNoTrans: []float64{8, 25, 20},
45                 solTrans:   []float64{2, 13, 32},
46         },
47         {
48                 n:    3,
49                 k:    1,
50                 ul:   blas.Upper,
51                 d:    blas.Unit,
52                 tr:   []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
53                 tb:   []float64{1, 2, 3, 4, 5, 0},
54                 ldab: 2,
55                 tp:   []float64{1, 2, 0, 3, 4, 5},
56                 ins: []vec{
57                         {[]float64{2, 3, 4}, 1},
58                         {[]float64{2, 1, 3, 1, 4}, 2},
59                         {[]float64{4, 1, 3, 1, 2}, -2},
60                 },
61                 solNoTrans: []float64{8, 19, 4},
62                 solTrans:   []float64{2, 7, 16},
63         },
64         {
65                 n:    3,
66                 k:    1,
67                 ul:   blas.Lower,
68                 d:    blas.NonUnit,
69                 tr:   []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
70                 tb:   []float64{0, 1, 2, 3, 4, 5},
71                 ldab: 2,
72                 tp:   []float64{1, 2, 3, 0, 4, 5},
73                 ins: []vec{
74                         {[]float64{2, 3, 4}, 1},
75                         {[]float64{2, 1, 3, 1, 4}, 2},
76                         {[]float64{4, 1, 3, 1, 2}, -2},
77                 },
78                 solNoTrans: []float64{2, 13, 32},
79                 solTrans:   []float64{8, 25, 20},
80         },
81         {
82                 n:    3,
83                 k:    1,
84                 ul:   blas.Lower,
85                 d:    blas.Unit,
86                 tr:   []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
87                 tb:   []float64{0, 1, 2, 3, 4, 5},
88                 ldab: 2,
89                 tp:   []float64{1, 2, 3, 0, 4, 5},
90                 ins: []vec{
91                         {[]float64{2, 3, 4}, 1},
92                         {[]float64{2, 1, 3, 1, 4}, 2},
93                         {[]float64{4, 1, 3, 1, 2}, -2},
94                 },
95                 solNoTrans: []float64{2, 7, 16},
96                 solTrans:   []float64{8, 19, 4},
97         },
98 }
99
100 func DtxmvTest(t *testing.T, blasser Dtxmver) {
101
102         for nc, c := range cases {
103                 for nx, x := range c.ins {
104                         in := make([]float64, len(x.data))
105                         copy(in, x.data)
106                         blasser.Dtrmv(c.ul, blas.NoTrans, c.d, c.n, c.tr, c.n, in, x.inc)
107                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
108                                 t.Error("Wrong Dtrmv result for: NoTrans  in Case:", nc, "input:", nx)
109                         }
110
111                         in = make([]float64, len(x.data))
112                         copy(in, x.data)
113                         blasser.Dtrmv(c.ul, blas.Trans, c.d, c.n, c.tr, c.n, in, x.inc)
114                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
115                                 t.Error("Wrong Dtrmv result for: Trans in Case:", nc, "input:", nx)
116                         }
117                         in = make([]float64, len(x.data))
118                         copy(in, x.data)
119                         blasser.Dtbmv(c.ul, blas.NoTrans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
120                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
121                                 t.Error("Wrong Dtbmv result for: NoTrans  in Case:", nc, "input:", nx)
122                         }
123
124                         in = make([]float64, len(x.data))
125                         copy(in, x.data)
126                         blasser.Dtbmv(c.ul, blas.Trans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
127                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
128                                 t.Error("Wrong Dtbmv result for: Trans in Case:", nc, "input:", nx)
129                         }
130                         in = make([]float64, len(x.data))
131                         copy(in, x.data)
132                         blasser.Dtpmv(c.ul, blas.NoTrans, c.d, c.n, c.tp, in, x.inc)
133                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
134                                 t.Error("Wrong Dtpmv result for:  NoTrans  in Case:", nc, "input:", nx)
135                         }
136
137                         in = make([]float64, len(x.data))
138                         copy(in, x.data)
139                         blasser.Dtpmv(c.ul, blas.Trans, c.d, c.n, c.tp, in, x.inc)
140                         if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
141                                 t.Error("Wrong Dtpmv result for: Trans in Case:", nc, "input:", nx)
142                         }
143                 }
144         }
145 }