OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / cblas128 / conv_symmetric_test.go
1 // Code generated by "go generate gonum.org/v1/gonum/blas”; DO NOT EDIT.
2
3 // Copyright ©2015 The Gonum Authors. All rights reserved.
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file.
6
7 package cblas128
8
9 import (
10         math "math/cmplx"
11         "testing"
12
13         "gonum.org/v1/gonum/blas"
14 )
15
16 func newSymmetricFrom(a SymmetricCols) Symmetric {
17         t := Symmetric{
18                 N:      a.N,
19                 Stride: a.N,
20                 Data:   make([]complex128, a.N*a.N),
21                 Uplo:   a.Uplo,
22         }
23         t.From(a)
24         return t
25 }
26
27 func (m Symmetric) n() int { return m.N }
28 func (m Symmetric) at(i, j int) complex128 {
29         if m.Uplo == blas.Lower && i < j && j < m.N {
30                 i, j = j, i
31         }
32         if m.Uplo == blas.Upper && i > j {
33                 i, j = j, i
34         }
35         return m.Data[i*m.Stride+j]
36 }
37 func (m Symmetric) uplo() blas.Uplo { return m.Uplo }
38
39 func newSymmetricColsFrom(a Symmetric) SymmetricCols {
40         t := SymmetricCols{
41                 N:      a.N,
42                 Stride: a.N,
43                 Data:   make([]complex128, a.N*a.N),
44                 Uplo:   a.Uplo,
45         }
46         t.From(a)
47         return t
48 }
49
50 func (m SymmetricCols) n() int { return m.N }
51 func (m SymmetricCols) at(i, j int) complex128 {
52         if m.Uplo == blas.Lower && i < j {
53                 i, j = j, i
54         }
55         if m.Uplo == blas.Upper && i > j && i < m.N {
56                 i, j = j, i
57         }
58         return m.Data[i+j*m.Stride]
59 }
60 func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo }
61
62 type symmetric interface {
63         n() int
64         at(i, j int) complex128
65         uplo() blas.Uplo
66 }
67
68 func sameSymmetric(a, b symmetric) bool {
69         an := a.n()
70         bn := b.n()
71         if an != bn {
72                 return false
73         }
74         if a.uplo() != b.uplo() {
75                 return false
76         }
77         for i := 0; i < an; i++ {
78                 for j := 0; j < an; j++ {
79                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
80                                 return false
81                         }
82                 }
83         }
84         return true
85 }
86
87 var symmetricTests = []Symmetric{
88         {N: 3, Stride: 3, Data: []complex128{
89                 1, 2, 3,
90                 4, 5, 6,
91                 7, 8, 9,
92         }},
93         {N: 3, Stride: 5, Data: []complex128{
94                 1, 2, 3, 0, 0,
95                 4, 5, 6, 0, 0,
96                 7, 8, 9, 0, 0,
97         }},
98 }
99
100 func TestConvertSymmetric(t *testing.T) {
101         for _, test := range symmetricTests {
102                 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
103                         test.Uplo = uplo
104                         colmajor := newSymmetricColsFrom(test)
105                         if !sameSymmetric(colmajor, test) {
106                                 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
107                                         colmajor, test)
108                         }
109                         rowmajor := newSymmetricFrom(colmajor)
110                         if !sameSymmetric(rowmajor, test) {
111                                 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
112                                         rowmajor, test)
113                         }
114                 }
115         }
116 }
117 func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand {
118         t := SymmetricBand{
119                 N:      a.N,
120                 K:      a.K,
121                 Stride: a.K + 1,
122                 Data:   make([]complex128, a.N*(a.K+1)),
123                 Uplo:   a.Uplo,
124         }
125         for i := range t.Data {
126                 t.Data[i] = math.NaN()
127         }
128         t.From(a)
129         return t
130 }
131
132 func (m SymmetricBand) n() (n int) { return m.N }
133 func (m SymmetricBand) at(i, j int) complex128 {
134         b := Band{
135                 Rows: m.N, Cols: m.N,
136                 Stride: m.Stride,
137                 Data:   m.Data,
138         }
139         switch m.Uplo {
140         default:
141                 panic("cblas128: bad BLAS uplo")
142         case blas.Upper:
143                 b.KU = m.K
144                 if i > j {
145                         i, j = j, i
146                 }
147         case blas.Lower:
148                 b.KL = m.K
149                 if i < j {
150                         i, j = j, i
151                 }
152         }
153         return b.at(i, j)
154 }
155 func (m SymmetricBand) bandwidth() (k int) { return m.K }
156 func (m SymmetricBand) uplo() blas.Uplo    { return m.Uplo }
157
158 func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols {
159         t := SymmetricBandCols{
160                 N:      a.N,
161                 K:      a.K,
162                 Stride: a.K + 1,
163                 Data:   make([]complex128, a.N*(a.K+1)),
164                 Uplo:   a.Uplo,
165         }
166         for i := range t.Data {
167                 t.Data[i] = math.NaN()
168         }
169         t.From(a)
170         return t
171 }
172
173 func (m SymmetricBandCols) n() (n int) { return m.N }
174 func (m SymmetricBandCols) at(i, j int) complex128 {
175         b := BandCols{
176                 Rows: m.N, Cols: m.N,
177                 Stride: m.Stride,
178                 Data:   m.Data,
179         }
180         switch m.Uplo {
181         default:
182                 panic("cblas128: bad BLAS uplo")
183         case blas.Upper:
184                 b.KU = m.K
185                 if i > j {
186                         i, j = j, i
187                 }
188         case blas.Lower:
189                 b.KL = m.K
190                 if i < j {
191                         i, j = j, i
192                 }
193         }
194         return b.at(i, j)
195 }
196 func (m SymmetricBandCols) bandwidth() (k int) { return m.K }
197 func (m SymmetricBandCols) uplo() blas.Uplo    { return m.Uplo }
198
199 type symmetricBand interface {
200         n() (n int)
201         at(i, j int) complex128
202         bandwidth() (k int)
203         uplo() blas.Uplo
204 }
205
206 func sameSymmetricBand(a, b symmetricBand) bool {
207         an := a.n()
208         bn := b.n()
209         if an != bn {
210                 return false
211         }
212         if a.uplo() != b.uplo() {
213                 return false
214         }
215         ak := a.bandwidth()
216         bk := b.bandwidth()
217         if ak != bk {
218                 return false
219         }
220         for i := 0; i < an; i++ {
221                 for j := 0; j < an; j++ {
222                         if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
223                                 return false
224                         }
225                 }
226         }
227         return true
228 }
229
230 var symmetricBandTests = []SymmetricBand{
231         {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []complex128{
232                 1,
233                 2,
234                 3,
235         }},
236         {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []complex128{
237                 1,
238                 2,
239                 3,
240         }},
241         {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []complex128{
242                 1, 2,
243                 3, 4,
244                 5, -1,
245         }},
246         {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []complex128{
247                 -1, 1,
248                 2, 3,
249                 4, 5,
250         }},
251         {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []complex128{
252                 1, 2, 3,
253                 4, 5, -1,
254                 6, -2, -3,
255         }},
256         {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []complex128{
257                 -2, -1, 1,
258                 -3, 2, 4,
259                 3, 5, 6,
260         }},
261
262         {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []complex128{
263                 1, 0, 0, 0, 0,
264                 2, 0, 0, 0, 0,
265                 3, 0, 0, 0, 0,
266         }},
267         {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []complex128{
268                 1, 0, 0, 0, 0,
269                 2, 0, 0, 0, 0,
270                 3, 0, 0, 0, 0,
271         }},
272         {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []complex128{
273                 1, 2, 0, 0, 0,
274                 3, 4, 0, 0, 0,
275                 5, -1, 0, 0, 0,
276         }},
277         {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []complex128{
278                 -1, 1, 0, 0, 0,
279                 2, 3, 0, 0, 0,
280                 4, 5, 0, 0, 0,
281         }},
282         {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []complex128{
283                 1, 2, 3, 0, 0,
284                 4, 5, -1, 0, 0,
285                 6, -2, -3, 0, 0,
286         }},
287         {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []complex128{
288                 -2, -1, 1, 0, 0,
289                 -3, 2, 4, 0, 0,
290                 3, 5, 6, 0, 0,
291         }},
292 }
293
294 func TestConvertSymBand(t *testing.T) {
295         for _, test := range symmetricBandTests {
296                 colmajor := newSymmetricBandColsFrom(test)
297                 if !sameSymmetricBand(colmajor, test) {
298                         t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
299                                 colmajor, test)
300                 }
301                 rowmajor := newSymmetricBandFrom(colmajor)
302                 if !sameSymmetricBand(rowmajor, test) {
303                         t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
304                                 rowmajor, test)
305                 }
306         }
307 }