OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / blas64 / conv_symmetric_test.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package blas64
6
7 import (
8         "math"
9         "testing"
10
11         "gonum.org/v1/gonum/blas"
12 )
13
14 func newSymmetricFrom(a SymmetricCols) Symmetric {
15         t := Symmetric{
16                 N:      a.N,
17                 Stride: a.N,
18                 Data:   make([]float64, a.N*a.N),
19                 Uplo:   a.Uplo,
20         }
21         t.From(a)
22         return t
23 }
24
25 func (m Symmetric) n() int { return m.N }
26 func (m Symmetric) at(i, j int) float64 {
27         if m.Uplo == blas.Lower && i < j && j < m.N {
28                 i, j = j, i
29         }
30         if m.Uplo == blas.Upper && i > j {
31                 i, j = j, i
32         }
33         return m.Data[i*m.Stride+j]
34 }
35 func (m Symmetric) uplo() blas.Uplo { return m.Uplo }
36
37 func newSymmetricColsFrom(a Symmetric) SymmetricCols {
38         t := SymmetricCols{
39                 N:      a.N,
40                 Stride: a.N,
41                 Data:   make([]float64, a.N*a.N),
42                 Uplo:   a.Uplo,
43         }
44         t.From(a)
45         return t
46 }
47
48 func (m SymmetricCols) n() int { return m.N }
49 func (m SymmetricCols) at(i, j int) float64 {
50         if m.Uplo == blas.Lower && i < j {
51                 i, j = j, i
52         }
53         if m.Uplo == blas.Upper && i > j && i < m.N {
54                 i, j = j, i
55         }
56         return m.Data[i+j*m.Stride]
57 }
58 func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo }
59
60 type symmetric interface {
61         n() int
62         at(i, j int) float64
63         uplo() blas.Uplo
64 }
65
66 func sameSymmetric(a, b symmetric) bool {
67         an := a.n()
68         bn := b.n()
69         if an != bn {
70                 return false
71         }
72         if a.uplo() != b.uplo() {
73                 return false
74         }
75         for i := 0; i < an; i++ {
76                 for j := 0; j < an; j++ {
77                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
78                                 return false
79                         }
80                 }
81         }
82         return true
83 }
84
85 var symmetricTests = []Symmetric{
86         {N: 3, Stride: 3, Data: []float64{
87                 1, 2, 3,
88                 4, 5, 6,
89                 7, 8, 9,
90         }},
91         {N: 3, Stride: 5, Data: []float64{
92                 1, 2, 3, 0, 0,
93                 4, 5, 6, 0, 0,
94                 7, 8, 9, 0, 0,
95         }},
96 }
97
98 func TestConvertSymmetric(t *testing.T) {
99         for _, test := range symmetricTests {
100                 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
101                         test.Uplo = uplo
102                         colmajor := newSymmetricColsFrom(test)
103                         if !sameSymmetric(colmajor, test) {
104                                 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
105                                         colmajor, test)
106                         }
107                         rowmajor := newSymmetricFrom(colmajor)
108                         if !sameSymmetric(rowmajor, test) {
109                                 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
110                                         rowmajor, test)
111                         }
112                 }
113         }
114 }
115 func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand {
116         t := SymmetricBand{
117                 N:      a.N,
118                 K:      a.K,
119                 Stride: a.K + 1,
120                 Data:   make([]float64, a.N*(a.K+1)),
121                 Uplo:   a.Uplo,
122         }
123         for i := range t.Data {
124                 t.Data[i] = math.NaN()
125         }
126         t.From(a)
127         return t
128 }
129
130 func (m SymmetricBand) n() (n int) { return m.N }
131 func (m SymmetricBand) at(i, j int) float64 {
132         b := Band{
133                 Rows: m.N, Cols: m.N,
134                 Stride: m.Stride,
135                 Data:   m.Data,
136         }
137         switch m.Uplo {
138         default:
139                 panic("blas64: bad BLAS uplo")
140         case blas.Upper:
141                 b.KU = m.K
142                 if i > j {
143                         i, j = j, i
144                 }
145         case blas.Lower:
146                 b.KL = m.K
147                 if i < j {
148                         i, j = j, i
149                 }
150         }
151         return b.at(i, j)
152 }
153 func (m SymmetricBand) bandwidth() (k int) { return m.K }
154 func (m SymmetricBand) uplo() blas.Uplo    { return m.Uplo }
155
156 func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols {
157         t := SymmetricBandCols{
158                 N:      a.N,
159                 K:      a.K,
160                 Stride: a.K + 1,
161                 Data:   make([]float64, a.N*(a.K+1)),
162                 Uplo:   a.Uplo,
163         }
164         for i := range t.Data {
165                 t.Data[i] = math.NaN()
166         }
167         t.From(a)
168         return t
169 }
170
171 func (m SymmetricBandCols) n() (n int) { return m.N }
172 func (m SymmetricBandCols) at(i, j int) float64 {
173         b := BandCols{
174                 Rows: m.N, Cols: m.N,
175                 Stride: m.Stride,
176                 Data:   m.Data,
177         }
178         switch m.Uplo {
179         default:
180                 panic("blas64: bad BLAS uplo")
181         case blas.Upper:
182                 b.KU = m.K
183                 if i > j {
184                         i, j = j, i
185                 }
186         case blas.Lower:
187                 b.KL = m.K
188                 if i < j {
189                         i, j = j, i
190                 }
191         }
192         return b.at(i, j)
193 }
194 func (m SymmetricBandCols) bandwidth() (k int) { return m.K }
195 func (m SymmetricBandCols) uplo() blas.Uplo    { return m.Uplo }
196
197 type symmetricBand interface {
198         n() (n int)
199         at(i, j int) float64
200         bandwidth() (k int)
201         uplo() blas.Uplo
202 }
203
204 func sameSymmetricBand(a, b symmetricBand) bool {
205         an := a.n()
206         bn := b.n()
207         if an != bn {
208                 return false
209         }
210         if a.uplo() != b.uplo() {
211                 return false
212         }
213         ak := a.bandwidth()
214         bk := b.bandwidth()
215         if ak != bk {
216                 return false
217         }
218         for i := 0; i < an; i++ {
219                 for j := 0; j < an; j++ {
220                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
221                                 return false
222                         }
223                 }
224         }
225         return true
226 }
227
228 var symmetricBandTests = []SymmetricBand{
229         {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{
230                 1,
231                 2,
232                 3,
233         }},
234         {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{
235                 1,
236                 2,
237                 3,
238         }},
239         {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{
240                 1, 2,
241                 3, 4,
242                 5, -1,
243         }},
244         {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{
245                 -1, 1,
246                 2, 3,
247                 4, 5,
248         }},
249         {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{
250                 1, 2, 3,
251                 4, 5, -1,
252                 6, -2, -3,
253         }},
254         {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{
255                 -2, -1, 1,
256                 -3, 2, 4,
257                 3, 5, 6,
258         }},
259
260         {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{
261                 1, 0, 0, 0, 0,
262                 2, 0, 0, 0, 0,
263                 3, 0, 0, 0, 0,
264         }},
265         {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{
266                 1, 0, 0, 0, 0,
267                 2, 0, 0, 0, 0,
268                 3, 0, 0, 0, 0,
269         }},
270         {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{
271                 1, 2, 0, 0, 0,
272                 3, 4, 0, 0, 0,
273                 5, -1, 0, 0, 0,
274         }},
275         {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{
276                 -1, 1, 0, 0, 0,
277                 2, 3, 0, 0, 0,
278                 4, 5, 0, 0, 0,
279         }},
280         {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{
281                 1, 2, 3, 0, 0,
282                 4, 5, -1, 0, 0,
283                 6, -2, -3, 0, 0,
284         }},
285         {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{
286                 -2, -1, 1, 0, 0,
287                 -3, 2, 4, 0, 0,
288                 3, 5, 6, 0, 0,
289         }},
290 }
291
292 func TestConvertSymBand(t *testing.T) {
293         for _, test := range symmetricBandTests {
294                 colmajor := newSymmetricBandColsFrom(test)
295                 if !sameSymmetricBand(colmajor, test) {
296                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
297                                 colmajor, test)
298                 }
299                 rowmajor := newSymmetricBandFrom(colmajor)
300                 if !sameSymmetricBand(rowmajor, test) {
301                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
302                                 rowmajor, test)
303                 }
304         }
305 }