OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / testblas / dger.go
1 package testblas
2
3 import (
4         "math"
5         "testing"
6 )
7
8 type Dgerer interface {
9         Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
10 }
11
12 func DgerTest(t *testing.T, blasser Dgerer) {
13         for _, test := range []struct {
14                 name string
15                 a    [][]float64
16                 m    int
17                 n    int
18                 x    []float64
19                 y    []float64
20                 incX int
21                 incY int
22
23                 trueAns [][]float64
24         }{
25                 {
26                         name: "M gt N inc 1",
27                         m:    5,
28                         n:    3,
29                         a: [][]float64{
30                                 {1.3, 2.4, 3.5},
31                                 {2.6, 2.8, 3.3},
32                                 {-1.3, -4.3, -9.7},
33                                 {8, 9, -10},
34                                 {-12, -14, -6},
35                         },
36                         x:       []float64{-2, -3, 0, 1, 2},
37                         y:       []float64{-1.1, 5, 0},
38                         incX:    1,
39                         incY:    1,
40                         trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}, {6.9, 14, -10}, {-14.2, -4, -6}},
41                 },
42                 {
43                         name: "M eq N inc 1",
44                         m:    3,
45                         n:    3,
46                         a: [][]float64{
47                                 {1.3, 2.4, 3.5},
48                                 {2.6, 2.8, 3.3},
49                                 {-1.3, -4.3, -9.7},
50                         },
51                         x:       []float64{-2, -3, 0},
52                         y:       []float64{-1.1, 5, 0},
53                         incX:    1,
54                         incY:    1,
55                         trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}},
56                 },
57
58                 {
59                         name: "M lt N inc 1",
60                         m:    3,
61                         n:    6,
62                         a: [][]float64{
63                                 {1.3, 2.4, 3.5, 4.8, 1.11, -9},
64                                 {2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
65                                 {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
66                         },
67                         x:       []float64{-2, -3, 0},
68                         y:       []float64{-1.1, 5, 0, 9, 19, 22},
69                         incX:    1,
70                         incY:    1,
71                         trueAns: [][]float64{{3.5, -7.6, 3.5, -13.2, -36.89, -53}, {5.9, -12.2, 3.3, -30.4, -50.8, -74.7}, {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9}},
72                 },
73                 {
74                         name: "M gt N inc not 1",
75                         m:    5,
76                         n:    3,
77                         a: [][]float64{
78                                 {1.3, 2.4, 3.5},
79                                 {2.6, 2.8, 3.3},
80                                 {-1.3, -4.3, -9.7},
81                                 {8, 9, -10},
82                                 {-12, -14, -6},
83                         },
84                         x:       []float64{-2, -3, 0, 1, 2, 6, 0, 9, 7},
85                         y:       []float64{-1.1, 5, 0, 8, 7, -5, 7},
86                         incX:    2,
87                         incY:    3,
88                         trueAns: [][]float64{{3.5, -13.6, -10.5}, {2.6, 2.8, 3.3}, {-3.5, 11.7, 4.3}, {8, 9, -10}, {-19.700000000000003, 42, 43}},
89                 },
90                 {
91                         name: "M eq N inc not 1",
92                         m:    3,
93                         n:    3,
94                         a: [][]float64{
95                                 {1.3, 2.4, 3.5},
96                                 {2.6, 2.8, 3.3},
97                                 {-1.3, -4.3, -9.7},
98                         },
99                         x:       []float64{-2, -3, 0, 8, 7, -9, 7, -6, 12, 6, 6, 6, -11},
100                         y:       []float64{-1.1, 5, 0, 0, 9, 8, 6},
101                         incX:    4,
102                         incY:    3,
103                         trueAns: [][]float64{{3.5, 2.4, -8.5}, {-5.1, 2.8, 45.3}, {-14.5, -4.3, 62.3}},
104                 },
105                 {
106                         name: "M lt N inc not 1",
107                         m:    3,
108                         n:    6,
109                         a: [][]float64{
110                                 {1.3, 2.4, 3.5, 4.8, 1.11, -9},
111                                 {2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
112                                 {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
113                         },
114                         x:       []float64{-2, -3, 0, 0, 8, 0, 9, -3},
115                         y:       []float64{-1.1, 5, 0, 9, 19, 22, 11, -8.11, -9.22, 9.87, 7},
116                         incX:    3,
117                         incY:    2,
118                         trueAns: [][]float64{{3.5, 2.4, -34.5, -17.2, 19.55, -23}, {2.6, 2.8, 3.3, -3.4, 6.2, -8.7}, {-11.2, -4.3, 161.3, 95.9, -74.08, 71.9}},
119                 },
120                 {
121                         name:    "Y NaN element",
122                         m:       1,
123                         n:       1,
124                         a:       [][]float64{{1.3}},
125                         x:       []float64{1.3},
126                         y:       []float64{math.NaN()},
127                         incX:    1,
128                         incY:    1,
129                         trueAns: [][]float64{{math.NaN()}},
130                 },
131         } {
132                 // TODO: Add tests where a is longer
133                 // TODO: Add panic tests
134                 // TODO: Add negative increment tests
135
136                 x := sliceCopy(test.x)
137                 y := sliceCopy(test.y)
138
139                 a := sliceOfSliceCopy(test.a)
140
141                 // Test with row major
142                 alpha := 1.0
143                 aFlat := flatten(a)
144                 blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
145                 ans := unflatten(aFlat, test.m, test.n)
146                 dgercomp(t, x, test.x, y, test.y, ans, test.trueAns, test.name+" row maj")
147
148                 // Test with different alpha
149                 alpha = 4.0
150                 aFlat = flatten(a)
151                 blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
152                 ans = unflatten(aFlat, test.m, test.n)
153                 trueCopy := sliceOfSliceCopy(test.trueAns)
154                 for i := range trueCopy {
155                         for j := range trueCopy[i] {
156                                 trueCopy[i][j] = alpha*(trueCopy[i][j]-a[i][j]) + a[i][j]
157                         }
158                 }
159                 dgercomp(t, x, test.x, y, test.y, ans, trueCopy, test.name+" row maj alpha")
160         }
161 }
162
163 func dgercomp(t *testing.T, x, xCopy, y, yCopy []float64, ans [][]float64, trueAns [][]float64, name string) {
164         if !dSliceEqual(x, xCopy) {
165                 t.Errorf("case %v: x modified during call to dger\n%v\n%v", name, x, xCopy)
166         }
167         if !dSliceEqual(y, yCopy) {
168                 t.Errorf("case %v: y modified during call to dger\n%v\n%v", name, y, yCopy)
169         }
170
171         for i := range ans {
172                 if !dSliceTolEqual(ans[i], trueAns[i]) {
173                         t.Errorf("case %v: answer mismatch at %v. Expected %v, Found %v", name, i, trueAns, ans)
174                         break
175                 }
176         }
177 }