1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
13 "golang.org/x/exp/rand"
15 "gonum.org/v1/gonum/blas"
16 "gonum.org/v1/gonum/blas/blas64"
17 "gonum.org/v1/gonum/floats"
20 func TestNewSymmetric(t *testing.T) {
21 for i, test := range []struct {
34 mat: blas64.Symmetric{
38 Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
44 sym := NewSymDense(test.n, test.data)
45 rows, cols := sym.Dims()
48 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n)
51 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n)
53 if !reflect.DeepEqual(sym, test.mat) {
54 t.Errorf("unexpected data slice for test %d: got: %v want: %v", i, sym, test.mat)
57 m := NewDense(test.n, test.n, test.data)
58 if !reflect.DeepEqual(sym.mat.Data, m.mat.Data) {
59 t.Errorf("unexpected data slice mismatch for test %d: got: %v want: %v", i, sym.mat.Data, m.mat.Data)
63 panicked, message := panics(func() { NewSymDense(3, []float64{1, 2}) })
64 if !panicked || message != ErrShape.Error() {
65 t.Error("expected panic for invalid data slice length")
69 func TestSymAtSet(t *testing.T) {
71 mat: blas64.Symmetric{
75 Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
79 rows, cols := sym.Dims()
81 // Check At out of bounds
82 for _, row := range []int{-1, rows, rows + 1} {
83 panicked, message := panics(func() { sym.At(row, 0) })
84 if !panicked || message != ErrRowAccess.Error() {
85 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
88 for _, col := range []int{-1, cols, cols + 1} {
89 panicked, message := panics(func() { sym.At(0, col) })
90 if !panicked || message != ErrColAccess.Error() {
91 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
95 // Check Set out of bounds
96 for _, row := range []int{-1, rows, rows + 1} {
97 panicked, message := panics(func() { sym.SetSym(row, 0, 1.2) })
98 if !panicked || message != ErrRowAccess.Error() {
99 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
102 for _, col := range []int{-1, cols, cols + 1} {
103 panicked, message := panics(func() { sym.SetSym(0, col, 1.2) })
104 if !panicked || message != ErrColAccess.Error() {
105 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
109 for _, st := range []struct {
113 {row: 1, col: 2, orig: 6, new: 15},
114 {row: 2, col: 1, orig: 15, new: 12},
116 if e := sym.At(st.row, st.col); e != st.orig {
117 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
119 if e := sym.At(st.col, st.row); e != st.orig {
120 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.col, st.row, e, st.orig)
122 sym.SetSym(st.row, st.col, st.new)
123 if e := sym.At(st.row, st.col); e != st.new {
124 t.Errorf("unexpected value for At(%d, %d) after SetSym(%[1]d, %[2]d, %[4]v): got: %[3]v want: %v", st.row, st.col, e, st.new)
126 if e := sym.At(st.col, st.row); e != st.new {
127 t.Errorf("unexpected value for At(%d, %d) after SetSym(%[2]d, %[1]d, %[4]v): got: %[3]v want: %v", st.col, st.row, e, st.new)
132 func TestSymAdd(t *testing.T) {
133 for _, test := range []struct {
144 a := NewSymDense(n, nil)
145 for i := range a.mat.Data {
146 a.mat.Data[i] = rand.Float64()
148 b := NewSymDense(n, nil)
149 for i := range a.mat.Data {
150 b.mat.Data[i] = rand.Float64()
155 // Check with new receiver
158 for i := 0; i < n; i++ {
159 for j := i; j < n; j++ {
161 if got := s.At(i, j); got != want {
162 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", i, j, got, want)
167 // Check with equal receiver
170 for i := 0; i < n; i++ {
171 for j := i; j < n; j++ {
173 if got := s.At(i, j); got != want {
174 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", i, j, got, want)
180 method := func(receiver, a, b Matrix) {
181 type addSymer interface {
182 AddSym(a, b Symmetric)
184 rd := receiver.(addSymer)
185 rd.AddSym(a.(Symmetric), b.(Symmetric))
187 denseComparison := func(receiver, a, b *Dense) {
190 testTwoInput(t, "AddSym", &SymDense{}, method, denseComparison, legalTypesSym, legalSizeSameSquare, 1e-14)
193 func TestCopy(t *testing.T) {
194 for _, test := range []struct {
205 a := NewSymDense(n, nil)
206 for i := range a.mat.Data {
207 a.mat.Data[i] = rand.Float64()
209 s := NewSymDense(n, nil)
211 for i := 0; i < n; i++ {
212 for j := i; j < n; j++ {
214 if got := s.At(i, j); got != want {
215 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", i, j, got, want)
222 // TODO(kortschak) Roll this into testOneInput when it exists.
223 // https://github.com/gonum/matrix/issues/171
224 func TestSymCopyPanic(t *testing.T) {
229 m := NewSymDense(1, nil)
230 panicked, message := panics(func() { n = m.CopySym(&a) })
232 t.Errorf("unexpected panic: %v", message)
235 t.Errorf("unexpected n: got: %d want: 0", n)
239 func TestSymRankOne(t *testing.T) {
240 for _, test := range []struct {
252 a := NewSymDense(n, nil)
253 for i := range a.mat.Data {
254 a.mat.Data[i] = rand.Float64()
256 x := make([]float64, n)
258 x[i] = rand.Float64()
261 xMat := NewDense(n, 1, x)
263 m.Mul(xMat, xMat.T())
267 // Check with new receiver
268 s := NewSymDense(n, nil)
269 s.SymRankOne(a, alpha, NewVecDense(len(x), x))
270 for i := 0; i < n; i++ {
271 for j := i; j < n; j++ {
273 if got := s.At(i, j); got != want {
274 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", i, j, got, want)
279 // Check with reused receiver
280 copy(s.mat.Data, a.mat.Data)
281 s.SymRankOne(s, alpha, NewVecDense(len(x), x))
282 for i := 0; i < n; i++ {
283 for j := i; j < n; j++ {
285 if got := s.At(i, j); got != want {
286 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", i, j, got, want)
293 method := func(receiver, a, b Matrix) {
294 type SymRankOner interface {
295 SymRankOne(a Symmetric, alpha float64, x Vector)
297 rd := receiver.(SymRankOner)
298 rd.SymRankOne(a.(Symmetric), alpha, b.(Vector))
300 denseComparison := func(receiver, a, b *Dense) {
303 tmp.Scale(alpha, &tmp)
304 receiver.Add(a, &tmp)
306 legalTypes := func(a, b Matrix) bool {
307 _, ok := a.(Symmetric)
314 legalSize := func(ar, ac, br, bc int) bool {
320 testTwoInput(t, "SymRankOne", &SymDense{}, method, denseComparison, legalTypes, legalSize, 1e-14)
323 func TestIssue250SymRankOne(t *testing.T) {
324 x := NewVecDense(5, []float64{1, 2, 3, 4, 5})
326 s1.SymRankOne(NewSymDense(5, nil), 1, x)
327 s2.SymRankOne(NewSymDense(5, nil), 1, x)
328 s2.SymRankOne(NewSymDense(5, nil), 1, x)
329 if !Equal(&s1, &s2) {
330 t.Error("unexpected result from repeat")
334 func TestRankTwo(t *testing.T) {
335 for _, test := range []struct {
347 a := NewSymDense(n, nil)
348 for i := range a.mat.Data {
349 a.mat.Data[i] = rand.Float64()
351 x := make([]float64, n)
352 y := make([]float64, n)
354 x[i] = rand.Float64()
355 y[i] = rand.Float64()
358 xMat := NewDense(n, 1, x)
359 yMat := NewDense(n, 1, y)
361 m.Mul(xMat, yMat.T())
363 tmp.Mul(yMat, xMat.T())
368 // Check with new receiver
369 s := NewSymDense(n, nil)
370 s.RankTwo(a, alpha, NewVecDense(len(x), x), NewVecDense(len(y), y))
371 for i := 0; i < n; i++ {
372 for j := i; j < n; j++ {
373 if !floats.EqualWithinAbsOrRel(s.At(i, j), m.At(i, j), 1e-14, 1e-14) {
374 t.Errorf("unexpected element value at (%d,%d): got: %f want: %f", i, j, m.At(i, j), s.At(i, j))
379 // Check with reused receiver
380 copy(s.mat.Data, a.mat.Data)
381 s.RankTwo(s, alpha, NewVecDense(len(x), x), NewVecDense(len(y), y))
382 for i := 0; i < n; i++ {
383 for j := i; j < n; j++ {
384 if !floats.EqualWithinAbsOrRel(s.At(i, j), m.At(i, j), 1e-14, 1e-14) {
385 t.Errorf("unexpected element value at (%d,%d): got: %f want: %f", i, j, m.At(i, j), s.At(i, j))
392 func TestSymRankK(t *testing.T) {
394 method := func(receiver, a, b Matrix) {
395 type SymRankKer interface {
396 SymRankK(a Symmetric, alpha float64, x Matrix)
398 rd := receiver.(SymRankKer)
399 rd.SymRankK(a.(Symmetric), alpha, b)
401 denseComparison := func(receiver, a, b *Dense) {
404 tmp.Scale(alpha, &tmp)
405 receiver.Add(a, &tmp)
407 legalTypes := func(a, b Matrix) bool {
408 _, ok := a.(Symmetric)
411 legalSize := func(ar, ac, br, bc int) bool {
417 testTwoInput(t, "SymRankK", &SymDense{}, method, denseComparison, legalTypes, legalSize, 1e-14)
420 func TestSymOuterK(t *testing.T) {
421 for _, f := range []float64{0.5, 1, 3} {
422 method := func(receiver, x Matrix) {
423 type SymOuterKer interface {
424 SymOuterK(alpha float64, x Matrix)
426 rd := receiver.(SymOuterKer)
429 denseComparison := func(receiver, x *Dense) {
430 receiver.Mul(x, x.T())
431 receiver.Scale(f, receiver)
433 testOneInput(t, "SymOuterK", &SymDense{}, method, denseComparison, isAnyType, isAnySize, 1e-14)
437 func TestIssue250SymOuterK(t *testing.T) {
438 x := NewVecDense(5, []float64{1, 2, 3, 4, 5})
443 if !Equal(&s1, &s2) {
444 t.Error("unexpected result from repeat")
448 func TestScaleSym(t *testing.T) {
449 for _, f := range []float64{0.5, 1, 3} {
450 method := func(receiver, a Matrix) {
451 type ScaleSymer interface {
452 ScaleSym(f float64, a Symmetric)
454 rd := receiver.(ScaleSymer)
455 rd.ScaleSym(f, a.(Symmetric))
457 denseComparison := func(receiver, a *Dense) {
460 testOneInput(t, "ScaleSym", &SymDense{}, method, denseComparison, legalTypeSym, isSquare, 1e-14)
464 func TestSubsetSym(t *testing.T) {
465 for _, test := range []struct {
471 a: NewSymDense(3, []float64{
477 ans: NewSymDense(2, []float64{
483 a: NewSymDense(3, []float64{
489 ans: NewSymDense(2, []float64{
495 a: NewSymDense(3, []float64{
500 dims: []int{1, 1, 1},
501 ans: NewSymDense(3, []float64{
509 s.SubsetSym(test.a, test.dims)
510 if !Equal(&s, test.ans) {
511 t.Errorf("SubsetSym mismatch dims %v\nGot:\n% v\nWant:\n% v\n", test.dims, s, test.ans)
517 for _, v := range dims {
522 method := func(receiver, a Matrix) {
523 type SubsetSymer interface {
524 SubsetSym(a Symmetric, set []int)
526 rd := receiver.(SubsetSymer)
527 rd.SubsetSym(a.(Symmetric), dims)
529 denseComparison := func(receiver, a *Dense) {
530 *receiver = *NewDense(len(dims), len(dims), nil)
532 for i := 0; i < sz; i++ {
533 for j := 0; j < sz; j++ {
534 receiver.Set(i, j, a.At(dims[i], dims[j]))
538 legalSize := func(ar, ac int) bool {
539 return ar == ac && ar > maxDim
542 testOneInput(t, "SubsetSym", &SymDense{}, method, denseComparison, legalTypeSym, legalSize, 0)
545 func TestViewGrowSquare(t *testing.T) {
546 // n is the size of the original SymDense.
547 // The first view uses start1, span1. The second view uses start2, span2 on
549 for _, test := range []struct {
550 n, start1, span1, start2, span2 int
559 s := NewSymDense(n, nil)
560 for i := 0; i < n; i++ {
561 for j := i; j < n; j++ {
562 s.SetSym(i, j, float64((i+1)*n+j+1))
566 // Take a subset and check the view matches.
567 start1 := test.start1
569 v := s.SliceSquare(start1, start1+span1).(*SymDense)
570 for i := 0; i < span1; i++ {
571 for j := i; j < span1; j++ {
572 if v.At(i, j) != s.At(start1+i, start1+j) {
573 t.Errorf("View mismatch")
578 start2 := test.start2
580 v2 := v.SliceSquare(start2, start2+span2).(*SymDense)
582 for i := 0; i < span2; i++ {
583 for j := i; j < span2; j++ {
584 if v2.At(i, j) != s.At(start1+start2+i, start1+start2+j) {
585 t.Errorf("Second view mismatch")
590 // Check that a write to the view is reflected in the original.
592 if s.At(start1+start2, start1+start2) != 1.2 {
593 t.Errorf("Write to view not reflected in original")
596 // Grow the matrix back to the original view
597 gn := n - start1 - start2
598 g := v2.GrowSquare(gn - v2.Symmetric()).(*SymDense)
601 for i := 0; i < gn; i++ {
602 for j := 0; j < gn; j++ {
603 if g.At(i, j) != s.At(start1+start2+i, start1+start2+j) {
604 t.Errorf("Grow mismatch")
606 fmt.Printf("g=\n% v\n", Formatted(g))
607 fmt.Printf("s=\n% v\n", Formatted(s))
613 // View g, then grow it and make sure all the elements were copied.
614 gv := g.SliceSquare(0, gn-1).(*SymDense)
616 gg := gv.GrowSquare(2)
617 for i := 0; i < gn; i++ {
618 for j := 0; j < gn; j++ {
619 if g.At(i, j) != gg.At(i, j) {
620 t.Errorf("Expand mismatch")
627 func TestPowPSD(t *testing.T) {
628 for cas, test := range []struct {
633 // Comparison with Matlab.
635 a: NewSymDense(2, []float64{10, 5, 5, 12}),
637 ans: NewSymDense(2, []float64{3.065533767740645, 0.776210486171016, 0.776210486171016, 3.376017962209052}),
640 a: NewSymDense(2, []float64{11, -1, -1, 8}),
642 ans: NewSymDense(2, []float64{3.312618742210524, -0.162963396980939, -0.162963396980939, 2.823728551267709}),
645 a: NewSymDense(2, []float64{10, 5, 5, 12}),
647 ans: NewSymDense(2, []float64{0.346372134547712, -0.079637515547296, -0.079637515547296, 0.314517128328794}),
650 a: NewSymDense(3, []float64{15, -1, -3, -1, 8, 6, -3, 6, 14}),
652 ans: NewSymDense(3, []float64{
653 5.051214323034288, -0.163162161893975, -0.612153996497505,
654 -0.163162161893976, 3.283474884617009, 1.432842761381493,
655 -0.612153996497505, 1.432842761381494, 4.695873060862573,
660 err := s.PowPSD(test.a, test.pow)
664 if !EqualApprox(&s, test.ans, 1e-10) {
665 t.Errorf("Case %d, pow mismatch", cas)
666 fmt.Println(Formatted(&s))
667 fmt.Println(Formatted(test.ans))
671 // Compare with Dense.Pow
672 rnd := rand.New(rand.NewSource(1))
673 for dim := 2; dim < 10; dim++ {
674 for pow := 2; pow < 6; pow++ {
675 a := NewDense(dim, dim, nil)
676 for i := 0; i < dim; i++ {
677 for j := 0; j < dim; j++ {
678 a.Set(i, j, rnd.Float64())
685 sym.PowPSD(&mat, float64(pow))
690 if !EqualApprox(&sym, &dense, 1e-10) {
691 t.Errorf("Dim %d: pow mismatch", dim)