OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / symband_test.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "reflect"
9         "testing"
10
11         "gonum.org/v1/gonum/blas"
12         "gonum.org/v1/gonum/blas/blas64"
13 )
14
15 func TestNewSymBand(t *testing.T) {
16         for i, test := range []struct {
17                 data  []float64
18                 n     int
19                 k     int
20                 mat   *SymBandDense
21                 dense *Dense
22         }{
23                 {
24                         data: []float64{
25                                 1, 2, 3,
26                                 4, 5, 6,
27                                 7, 8, 9,
28                                 10, 11, 12,
29                                 13, 14, -1,
30                                 15, -1, -1,
31                         },
32                         n: 6,
33                         k: 2,
34                         mat: &SymBandDense{
35                                 mat: blas64.SymmetricBand{
36                                         N:      6,
37                                         K:      2,
38                                         Stride: 3,
39                                         Uplo:   blas.Upper,
40                                         Data: []float64{
41                                                 1, 2, 3,
42                                                 4, 5, 6,
43                                                 7, 8, 9,
44                                                 10, 11, 12,
45                                                 13, 14, -1,
46                                                 15, -1, -1,
47                                         },
48                                 },
49                         },
50                         dense: NewDense(6, 6, []float64{
51                                 1, 2, 3, 0, 0, 0,
52                                 2, 4, 5, 6, 0, 0,
53                                 3, 5, 7, 8, 9, 0,
54                                 0, 6, 8, 10, 11, 12,
55                                 0, 0, 9, 11, 13, 14,
56                                 0, 0, 0, 12, 14, 15,
57                         }),
58                 },
59         } {
60                 band := NewSymBandDense(test.n, test.k, test.data)
61                 rows, cols := band.Dims()
62
63                 if rows != test.n {
64                         t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n)
65                 }
66                 if cols != test.n {
67                         t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n)
68                 }
69                 if !reflect.DeepEqual(band, test.mat) {
70                         t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat)
71                 }
72                 if !Equal(band, test.mat) {
73                         t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat)
74                 }
75                 if !Equal(band, test.dense) {
76                         t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense))
77                 }
78         }
79 }
80
81 func TestNewDiagonal(t *testing.T) {
82         for i, test := range []struct {
83                 data  []float64
84                 n     int
85                 mat   *SymBandDense
86                 dense *Dense
87         }{
88                 {
89                         data: []float64{1, 2, 3, 4, 5, 6},
90                         n:    6,
91                         mat: &SymBandDense{
92                                 mat: blas64.SymmetricBand{
93                                         N:      6,
94                                         Stride: 1,
95                                         Uplo:   blas.Upper,
96                                         Data:   []float64{1, 2, 3, 4, 5, 6},
97                                 },
98                         },
99                         dense: NewDense(6, 6, []float64{
100                                 1, 0, 0, 0, 0, 0,
101                                 0, 2, 0, 0, 0, 0,
102                                 0, 0, 3, 0, 0, 0,
103                                 0, 0, 0, 4, 0, 0,
104                                 0, 0, 0, 0, 5, 0,
105                                 0, 0, 0, 0, 0, 6,
106                         }),
107                 },
108         } {
109                 band := NewDiagonal(test.n, test.data)
110                 rows, cols := band.Dims()
111
112                 if rows != test.n {
113                         t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n)
114                 }
115                 if cols != test.n {
116                         t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n)
117                 }
118                 if !reflect.DeepEqual(band, test.mat) {
119                         t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat)
120                 }
121                 if !Equal(band, test.mat) {
122                         t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat)
123                 }
124                 if !Equal(band, test.dense) {
125                         t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense))
126                 }
127         }
128 }
129
130 func TestSymBandAtSet(t *testing.T) {
131         // 1  2  3  0  0  0
132         // 2  4  5  6  0  0
133         // 3  5  7  8  9  0
134         // 0  6  8 10 11 12
135         // 0  0  9 11 13 14
136         // 0  0  0 12 14 16
137         band := NewSymBandDense(6, 2, []float64{
138                 1, 2, 3,
139                 4, 5, 6,
140                 7, 8, 9,
141                 10, 11, 12,
142                 13, 14, -1,
143                 16, -1, -1,
144         })
145
146         rows, cols := band.Dims()
147         kl, ku := band.Bandwidth()
148
149         // Explicitly test all indexes.
150         want := bandImplicit{rows, cols, kl, ku, func(i, j int) float64 {
151                 if i > j {
152                         i, j = j, i
153                 }
154                 return float64(i*ku + j + 1)
155         }}
156         for i := 0; i < 6; i++ {
157                 for j := 0; j < 6; j++ {
158                         if band.At(i, j) != want.At(i, j) {
159                                 t.Errorf("unexpected value for band.At(%d, %d): got:%v want:%v", i, j, band.At(i, j), want.At(i, j))
160                         }
161                 }
162         }
163         // Do that same thing via a call to Equal.
164         if !Equal(band, want) {
165                 t.Errorf("unexpected value via mat.Equal:\ngot:\n% v\nwant:\n% v", Formatted(band), Formatted(want))
166         }
167
168         // Check At out of bounds
169         for _, row := range []int{-1, rows, rows + 1} {
170                 panicked, message := panics(func() { band.At(row, 0) })
171                 if !panicked || message != ErrRowAccess.Error() {
172                         t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
173                 }
174         }
175         for _, col := range []int{-1, cols, cols + 1} {
176                 panicked, message := panics(func() { band.At(0, col) })
177                 if !panicked || message != ErrColAccess.Error() {
178                         t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
179                 }
180         }
181
182         // Check Set out of bounds
183         for _, row := range []int{-1, rows, rows + 1} {
184                 panicked, message := panics(func() { band.SetSymBand(row, 0, 1.2) })
185                 if !panicked || message != ErrRowAccess.Error() {
186                         t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
187                 }
188         }
189         for _, col := range []int{-1, cols, cols + 1} {
190                 panicked, message := panics(func() { band.SetSymBand(0, col, 1.2) })
191                 if !panicked || message != ErrColAccess.Error() {
192                         t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
193                 }
194         }
195
196         for _, st := range []struct {
197                 row, col int
198         }{
199                 {row: 0, col: 3},
200                 {row: 0, col: 4},
201                 {row: 0, col: 5},
202                 {row: 1, col: 4},
203                 {row: 1, col: 5},
204                 {row: 2, col: 5},
205                 {row: 3, col: 0},
206                 {row: 4, col: 1},
207                 {row: 5, col: 2},
208         } {
209                 panicked, message := panics(func() { band.SetSymBand(st.row, st.col, 1.2) })
210                 if !panicked || message != ErrBandSet.Error() {
211                         t.Errorf("expected panic for %+v %s", st, message)
212                 }
213         }
214
215         for _, st := range []struct {
216                 row, col  int
217                 orig, new float64
218         }{
219                 {row: 1, col: 2, orig: 5, new: 15},
220                 {row: 2, col: 3, orig: 8, new: 15},
221         } {
222                 if e := band.At(st.row, st.col); e != st.orig {
223                         t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
224                 }
225                 band.SetSymBand(st.row, st.col, st.new)
226                 if e := band.At(st.row, st.col); e != st.new {
227                         t.Errorf("unexpected value for At(%d, %d) after SetSymBand(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e)
228                 }
229         }
230 }