OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / blas64 / conv_test.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package blas64
6
7 import (
8         "math"
9         "testing"
10
11         "gonum.org/v1/gonum/blas"
12 )
13
14 func newGeneralFrom(a GeneralCols) General {
15         t := General{
16                 Rows:   a.Rows,
17                 Cols:   a.Cols,
18                 Stride: a.Cols,
19                 Data:   make([]float64, a.Rows*a.Cols),
20         }
21         t.From(a)
22         return t
23 }
24
25 func (m General) dims() (r, c int)    { return m.Rows, m.Cols }
26 func (m General) at(i, j int) float64 { return m.Data[i*m.Stride+j] }
27
28 func newGeneralColsFrom(a General) GeneralCols {
29         t := GeneralCols{
30                 Rows:   a.Rows,
31                 Cols:   a.Cols,
32                 Stride: a.Rows,
33                 Data:   make([]float64, a.Rows*a.Cols),
34         }
35         t.From(a)
36         return t
37 }
38
39 func (m GeneralCols) dims() (r, c int)    { return m.Rows, m.Cols }
40 func (m GeneralCols) at(i, j int) float64 { return m.Data[i+j*m.Stride] }
41
42 type general interface {
43         dims() (r, c int)
44         at(i, j int) float64
45 }
46
47 func sameGeneral(a, b general) bool {
48         ar, ac := a.dims()
49         br, bc := b.dims()
50         if ar != br || ac != bc {
51                 return false
52         }
53         for i := 0; i < ar; i++ {
54                 for j := 0; j < ac; j++ {
55                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
56                                 return false
57                         }
58                 }
59         }
60         return true
61 }
62
63 var generalTests = []General{
64         {Rows: 2, Cols: 3, Stride: 3, Data: []float64{
65                 1, 2, 3,
66                 4, 5, 6,
67         }},
68         {Rows: 3, Cols: 2, Stride: 2, Data: []float64{
69                 1, 2,
70                 3, 4,
71                 5, 6,
72         }},
73         {Rows: 3, Cols: 3, Stride: 3, Data: []float64{
74                 1, 2, 3,
75                 4, 5, 6,
76                 7, 8, 9,
77         }},
78         {Rows: 2, Cols: 3, Stride: 5, Data: []float64{
79                 1, 2, 3, 0, 0,
80                 4, 5, 6, 0, 0,
81         }},
82         {Rows: 3, Cols: 2, Stride: 5, Data: []float64{
83                 1, 2, 0, 0, 0,
84                 3, 4, 0, 0, 0,
85                 5, 6, 0, 0, 0,
86         }},
87         {Rows: 3, Cols: 3, Stride: 5, Data: []float64{
88                 1, 2, 3, 0, 0,
89                 4, 5, 6, 0, 0,
90                 7, 8, 9, 0, 0,
91         }},
92 }
93
94 func TestConvertGeneral(t *testing.T) {
95         for _, test := range generalTests {
96                 colmajor := newGeneralColsFrom(test)
97                 if !sameGeneral(colmajor, test) {
98                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
99                                 colmajor, test)
100                 }
101                 rowmajor := newGeneralFrom(colmajor)
102                 if !sameGeneral(rowmajor, test) {
103                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
104                                 rowmajor, test)
105                 }
106         }
107 }
108
109 func newTriangularFrom(a TriangularCols) Triangular {
110         t := Triangular{
111                 N:      a.N,
112                 Stride: a.N,
113                 Data:   make([]float64, a.N*a.N),
114                 Diag:   a.Diag,
115                 Uplo:   a.Uplo,
116         }
117         t.From(a)
118         return t
119 }
120
121 func (m Triangular) n() int { return m.N }
122 func (m Triangular) at(i, j int) float64 {
123         if m.Diag == blas.Unit && i == j {
124                 return 1
125         }
126         if m.Uplo == blas.Lower && i < j && j < m.N {
127                 return 0
128         }
129         if m.Uplo == blas.Upper && i > j {
130                 return 0
131         }
132         return m.Data[i*m.Stride+j]
133 }
134 func (m Triangular) uplo() blas.Uplo { return m.Uplo }
135 func (m Triangular) diag() blas.Diag { return m.Diag }
136
137 func newTriangularColsFrom(a Triangular) TriangularCols {
138         t := TriangularCols{
139                 N:      a.N,
140                 Stride: a.N,
141                 Data:   make([]float64, a.N*a.N),
142                 Diag:   a.Diag,
143                 Uplo:   a.Uplo,
144         }
145         t.From(a)
146         return t
147 }
148
149 func (m TriangularCols) n() int { return m.N }
150 func (m TriangularCols) at(i, j int) float64 {
151         if m.Diag == blas.Unit && i == j {
152                 return 1
153         }
154         if m.Uplo == blas.Lower && i < j {
155                 return 0
156         }
157         if m.Uplo == blas.Upper && i > j && i < m.N {
158                 return 0
159         }
160         return m.Data[i+j*m.Stride]
161 }
162 func (m TriangularCols) uplo() blas.Uplo { return m.Uplo }
163 func (m TriangularCols) diag() blas.Diag { return m.Diag }
164
165 type triangular interface {
166         n() int
167         at(i, j int) float64
168         uplo() blas.Uplo
169         diag() blas.Diag
170 }
171
172 func sameTriangular(a, b triangular) bool {
173         an := a.n()
174         bn := b.n()
175         if an != bn {
176                 return false
177         }
178         for i := 0; i < an; i++ {
179                 for j := 0; j < an; j++ {
180                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
181                                 return false
182                         }
183                 }
184         }
185         return true
186 }
187
188 var triangularTests = []Triangular{
189         {N: 3, Stride: 3, Data: []float64{
190                 1, 2, 3,
191                 4, 5, 6,
192                 7, 8, 9,
193         }},
194         {N: 3, Stride: 5, Data: []float64{
195                 1, 2, 3, 0, 0,
196                 4, 5, 6, 0, 0,
197                 7, 8, 9, 0, 0,
198         }},
199 }
200
201 func TestConvertTriangular(t *testing.T) {
202         for _, test := range triangularTests {
203                 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower, blas.All} {
204                         for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} {
205                                 test.Uplo = uplo
206                                 test.Diag = diag
207                                 colmajor := newTriangularColsFrom(test)
208                                 if !sameTriangular(colmajor, test) {
209                                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
210                                                 colmajor, test)
211                                 }
212                                 rowmajor := newTriangularFrom(colmajor)
213                                 if !sameTriangular(rowmajor, test) {
214                                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
215                                                 rowmajor, test)
216                                 }
217                         }
218                 }
219         }
220 }
221
222 func newBandFrom(a BandCols) Band {
223         t := Band{
224                 Rows:   a.Rows,
225                 Cols:   a.Cols,
226                 KL:     a.KL,
227                 KU:     a.KU,
228                 Stride: a.KL + a.KU + 1,
229                 Data:   make([]float64, a.Rows*(a.KL+a.KU+1)),
230         }
231         for i := range t.Data {
232                 t.Data[i] = math.NaN()
233         }
234         t.From(a)
235         return t
236 }
237
238 func (m Band) dims() (r, c int) { return m.Rows, m.Cols }
239 func (m Band) at(i, j int) float64 {
240         pj := j + m.KL - i
241         if pj < 0 || m.KL+m.KU+1 <= pj {
242                 return 0
243         }
244         return m.Data[i*m.Stride+pj]
245 }
246 func (m Band) bandwidth() (kl, ku int) { return m.KL, m.KU }
247
248 func newBandColsFrom(a Band) BandCols {
249         t := BandCols{
250                 Rows:   a.Rows,
251                 Cols:   a.Cols,
252                 KL:     a.KL,
253                 KU:     a.KU,
254                 Stride: a.KL + a.KU + 1,
255                 Data:   make([]float64, a.Cols*(a.KL+a.KU+1)),
256         }
257         for i := range t.Data {
258                 t.Data[i] = math.NaN()
259         }
260         t.From(a)
261         return t
262 }
263
264 func (m BandCols) dims() (r, c int) { return m.Rows, m.Cols }
265 func (m BandCols) at(i, j int) float64 {
266         pj := i + m.KU - j
267         if pj < 0 || m.KL+m.KU+1 <= pj {
268                 return 0
269         }
270         return m.Data[j*m.Stride+pj]
271 }
272 func (m BandCols) bandwidth() (kl, ku int) { return m.KL, m.KU }
273
274 type band interface {
275         dims() (r, c int)
276         at(i, j int) float64
277         bandwidth() (kl, ku int)
278 }
279
280 func sameBand(a, b band) bool {
281         ar, ac := a.dims()
282         br, bc := b.dims()
283         if ar != br || ac != bc {
284                 return false
285         }
286         akl, aku := a.bandwidth()
287         bkl, bku := b.bandwidth()
288         if akl != bkl || aku != bku {
289                 return false
290         }
291         for i := 0; i < ar; i++ {
292                 for j := 0; j < ac; j++ {
293                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
294                                 return false
295                         }
296                 }
297         }
298         return true
299 }
300
301 var bandTests = []Band{
302         {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 1, Data: []float64{
303                 1,
304                 2,
305                 3,
306         }},
307         {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{
308                 1,
309                 2,
310                 3,
311         }},
312         {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{
313                 1,
314                 2,
315                 3,
316         }},
317         {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 2, Data: []float64{
318                 1, 2,
319                 3, 4,
320                 5, 6,
321         }},
322         {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 2, Data: []float64{
323                 1, 2,
324                 3, 4,
325                 5, 6,
326         }},
327         {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 3, Data: []float64{
328                 -1, 2, 3,
329                 4, 5, 6,
330                 7, 8, 9,
331         }},
332         {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 3, Data: []float64{
333                 -1, 2, 3,
334                 4, 5, 6,
335                 7, 8, -2,
336                 9, -3, -4,
337         }},
338         {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 4, Data: []float64{
339                 -2, -1, 3, 4,
340                 -3, 5, 6, 7,
341                 8, 9, 10, 11,
342         }},
343         {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 4, Data: []float64{
344                 -2, -1, 2, 3,
345                 -3, 4, 5, 6,
346                 7, 8, 9, -4,
347                 10, 11, -5, -6,
348         }},
349
350         {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 5, Data: []float64{
351                 1, 0, 0, 0, 0,
352                 2, 0, 0, 0, 0,
353                 3, 0, 0, 0, 0,
354         }},
355         {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{
356                 1, 0, 0, 0, 0,
357                 2, 0, 0, 0, 0,
358                 3, 0, 0, 0, 0,
359         }},
360         {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{
361                 1, 0, 0, 0, 0,
362                 2, 0, 0, 0, 0,
363                 3, 0, 0, 0, 0,
364         }},
365         {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 5, Data: []float64{
366                 1, 2, 0, 0, 0,
367                 3, 4, 0, 0, 0,
368                 5, 6, 0, 0, 0,
369         }},
370         {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 5, Data: []float64{
371                 1, 2, 0, 0, 0,
372                 3, 4, 0, 0, 0,
373                 5, 6, 0, 0, 0,
374         }},
375         {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 5, Data: []float64{
376                 -1, 2, 3, 0, 0,
377                 4, 5, 6, 0, 0,
378                 7, 8, 9, 0, 0,
379         }},
380         {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 5, Data: []float64{
381                 -1, 2, 3, 0, 0,
382                 4, 5, 6, 0, 0,
383                 7, 8, -2, 0, 0,
384                 9, -3, -4, 0, 0,
385         }},
386         {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 5, Data: []float64{
387                 -2, -1, 3, 4, 0,
388                 -3, 5, 6, 7, 0,
389                 8, 9, 10, 11, 0,
390         }},
391         {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 5, Data: []float64{
392                 -2, -1, 2, 3, 0,
393                 -3, 4, 5, 6, 0,
394                 7, 8, 9, -4, 0,
395                 10, 11, -5, -6, 0,
396         }},
397 }
398
399 func TestConvertBand(t *testing.T) {
400         for _, test := range bandTests {
401                 colmajor := newBandColsFrom(test)
402                 if !sameBand(colmajor, test) {
403                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
404                                 colmajor, test)
405                 }
406                 rowmajor := newBandFrom(colmajor)
407                 if !sameBand(rowmajor, test) {
408                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
409                                 rowmajor, test)
410                 }
411         }
412 }
413
414 func newTriangularBandFrom(a TriangularBandCols) TriangularBand {
415         t := TriangularBand{
416                 N:      a.N,
417                 K:      a.K,
418                 Stride: a.K + 1,
419                 Data:   make([]float64, a.N*(a.K+1)),
420                 Uplo:   a.Uplo,
421                 Diag:   a.Diag,
422         }
423         for i := range t.Data {
424                 t.Data[i] = math.NaN()
425         }
426         t.From(a)
427         return t
428 }
429
430 func (m TriangularBand) n() (n int) { return m.N }
431 func (m TriangularBand) at(i, j int) float64 {
432         if m.Diag == blas.Unit && i == j {
433                 return 1
434         }
435         b := Band{
436                 Rows: m.N, Cols: m.N,
437                 Stride: m.Stride,
438                 Data:   m.Data,
439         }
440         switch m.Uplo {
441         default:
442                 panic("blas64: bad BLAS uplo")
443         case blas.Upper:
444                 if i > j {
445                         return 0
446                 }
447                 b.KU = m.K
448         case blas.Lower:
449                 if i < j {
450                         return 0
451                 }
452                 b.KL = m.K
453         }
454         return b.at(i, j)
455 }
456 func (m TriangularBand) bandwidth() (k int) { return m.K }
457 func (m TriangularBand) uplo() blas.Uplo    { return m.Uplo }
458 func (m TriangularBand) diag() blas.Diag    { return m.Diag }
459
460 func newTriangularBandColsFrom(a TriangularBand) TriangularBandCols {
461         t := TriangularBandCols{
462                 N:      a.N,
463                 K:      a.K,
464                 Stride: a.K + 1,
465                 Data:   make([]float64, a.N*(a.K+1)),
466                 Uplo:   a.Uplo,
467                 Diag:   a.Diag,
468         }
469         for i := range t.Data {
470                 t.Data[i] = math.NaN()
471         }
472         t.From(a)
473         return t
474 }
475
476 func (m TriangularBandCols) n() (n int) { return m.N }
477 func (m TriangularBandCols) at(i, j int) float64 {
478         if m.Diag == blas.Unit && i == j {
479                 return 1
480         }
481         b := BandCols{
482                 Rows: m.N, Cols: m.N,
483                 Stride: m.Stride,
484                 Data:   m.Data,
485         }
486         switch m.Uplo {
487         default:
488                 panic("blas64: bad BLAS uplo")
489         case blas.Upper:
490                 if i > j {
491                         return 0
492                 }
493                 b.KU = m.K
494         case blas.Lower:
495                 if i < j {
496                         return 0
497                 }
498                 b.KL = m.K
499         }
500         return b.at(i, j)
501 }
502 func (m TriangularBandCols) bandwidth() (k int) { return m.K }
503 func (m TriangularBandCols) uplo() blas.Uplo    { return m.Uplo }
504 func (m TriangularBandCols) diag() blas.Diag    { return m.Diag }
505
506 type triangularBand interface {
507         n() (n int)
508         at(i, j int) float64
509         bandwidth() (k int)
510         uplo() blas.Uplo
511         diag() blas.Diag
512 }
513
514 func sameTriangularBand(a, b triangularBand) bool {
515         an := a.n()
516         bn := b.n()
517         if an != bn {
518                 return false
519         }
520         if a.uplo() != b.uplo() {
521                 return false
522         }
523         if a.diag() != b.diag() {
524                 return false
525         }
526         ak := a.bandwidth()
527         bk := b.bandwidth()
528         if ak != bk {
529                 return false
530         }
531         for i := 0; i < an; i++ {
532                 for j := 0; j < an; j++ {
533                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
534                                 return false
535                         }
536                 }
537         }
538         return true
539 }
540
541 var triangularBandTests = []TriangularBand{
542         {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{
543                 1,
544                 2,
545                 3,
546         }},
547         {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{
548                 1,
549                 2,
550                 3,
551         }},
552         {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{
553                 1, 2,
554                 3, 4,
555                 5, -1,
556         }},
557         {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{
558                 -1, 1,
559                 2, 3,
560                 4, 5,
561         }},
562         {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{
563                 1, 2, 3,
564                 4, 5, -1,
565                 6, -2, -3,
566         }},
567         {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{
568                 -2, -1, 1,
569                 -3, 2, 4,
570                 3, 5, 6,
571         }},
572
573         {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{
574                 1, 0, 0, 0, 0,
575                 2, 0, 0, 0, 0,
576                 3, 0, 0, 0, 0,
577         }},
578         {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{
579                 1, 0, 0, 0, 0,
580                 2, 0, 0, 0, 0,
581                 3, 0, 0, 0, 0,
582         }},
583         {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{
584                 1, 2, 0, 0, 0,
585                 3, 4, 0, 0, 0,
586                 5, -1, 0, 0, 0,
587         }},
588         {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{
589                 -1, 1, 0, 0, 0,
590                 2, 3, 0, 0, 0,
591                 4, 5, 0, 0, 0,
592         }},
593         {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{
594                 1, 2, 3, 0, 0,
595                 4, 5, -1, 0, 0,
596                 6, -2, -3, 0, 0,
597         }},
598         {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{
599                 -2, -1, 1, 0, 0,
600                 -3, 2, 4, 0, 0,
601                 3, 5, 6, 0, 0,
602         }},
603 }
604
605 func TestConvertTriBand(t *testing.T) {
606         for _, test := range triangularBandTests {
607                 colmajor := newTriangularBandColsFrom(test)
608                 if !sameTriangularBand(colmajor, test) {
609                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
610                                 colmajor, test)
611                 }
612                 rowmajor := newTriangularBandFrom(colmajor)
613                 if !sameTriangularBand(rowmajor, test) {
614                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
615                                 rowmajor, test)
616                 }
617         }
618 }