8 "golang.org/x/exp/rand"
10 "gonum.org/v1/gonum/blas"
11 "gonum.org/v1/gonum/floats"
14 func TestFlattenBanded(t *testing.T) {
15 for i, test := range []struct {
22 dense: [][]float64{{3}},
25 condensed: [][]float64{{3}},
33 condensed: [][]float64{
43 condensed: [][]float64{
57 condensed: [][]float64{
61 {math.NaN(), math.NaN()},
62 {math.NaN(), math.NaN()},
75 condensed: [][]float64{
78 {2, math.NaN(), math.NaN()},
79 {math.NaN(), math.NaN(), math.NaN()},
80 {math.NaN(), math.NaN(), math.NaN()},
93 condensed: [][]float64{
94 {math.NaN(), 3, 4, 6},
95 {1, 5, 8, math.NaN()},
96 {6, 2, math.NaN(), math.NaN()},
97 {7, math.NaN(), math.NaN(), math.NaN()},
98 {math.NaN(), math.NaN(), math.NaN(), math.NaN()},
111 condensed: [][]float64{
112 {math.NaN(), math.NaN(), 1, 2},
113 {math.NaN(), 3, 4, 5},
114 {6, 7, 8, math.NaN()},
115 {9, 10, math.NaN(), math.NaN()},
116 {11, math.NaN(), math.NaN(), math.NaN()},
129 condensed: [][]float64{
130 {math.NaN(), math.NaN(), 1},
134 {11, math.NaN(), math.NaN()},
145 condensed: [][]float64{
146 {math.NaN(), math.NaN(), 1},
152 condensed := flattenBanded(test.dense, test.ku, test.kl)
153 correct := flatten(test.condensed)
154 if !floats.Same(condensed, correct) {
155 t.Errorf("Case %v mismatch. Want %v, got %v.", i, correct, condensed)
160 func TestFlattenTriangular(t *testing.T) {
161 for i, test := range []struct {
173 ans: []float64{1, 2, 3, 4, 5, 6},
182 ans: []float64{1, 2, 3, 4, 5, 6},
185 a := flattenTriangular(test.a, test.ul)
186 if !floats.Equal(a, test.ans) {
187 t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
192 func TestPackUnpackAsHermitian(t *testing.T) {
193 rnd := rand.New(rand.NewSource(1))
194 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
195 for _, n := range []int{1, 2, 5, 50} {
196 for _, lda := range []int{max(1, n), n + 11} {
197 a := makeZGeneral(nil, n, n, lda)
198 for i := 0; i < n; i++ {
199 for j := i; j < n; j++ {
200 a[i*lda+j] = complex(rnd.NormFloat64(), rnd.NormFloat64())
202 a[j*lda+i] = cmplx.Conj(a[i*lda+j])
206 aCopy := make([]complex128, len(a))
209 ap := zPack(uplo, n, a, lda)
210 if !zsame(a, aCopy) {
211 t.Errorf("Case uplo=%v,n=%v,lda=%v: zPack modified a", uplo, n, lda)
214 apCopy := make([]complex128, len(ap))
217 art := zUnpackAsHermitian(uplo, n, ap)
218 if !zsame(ap, apCopy) {
219 t.Errorf("Case uplo=%v,n=%v,lda=%v: zUnpackAsHermitian modified ap", uplo, n, lda)
222 // Copy the round-tripped A into a matrix with the same stride
224 got := makeZGeneral(nil, n, n, lda)
225 for i := 0; i < n; i++ {
226 copy(got[i*lda:i*lda+n], art[i*n:i*n+n])
229 t.Errorf("Case uplo=%v,n=%v,lda=%v: zPack and zUnpackAsHermitian do not roundtrip", uplo, n, lda)