OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / hogsvd_test.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11 )
12
13 func TestHOGSVD(t *testing.T) {
14         const tol = 1e-10
15         rnd := rand.New(rand.NewSource(1))
16         for cas, test := range []struct {
17                 r, c int
18         }{
19                 {5, 3},
20                 {5, 5},
21                 {150, 150},
22                 {200, 150},
23
24                 // Calculating A_i*A_j^T and A_j*A_i^T fails for wide matrices.
25                 {3, 5},
26         } {
27                 r := test.r
28                 c := test.c
29                 for n := 3; n < 6; n++ {
30                         data := make([]Matrix, n)
31                         dataCopy := make([]*Dense, n)
32                         for trial := 0; trial < 10; trial++ {
33                                 for i := range data {
34                                         d := NewDense(r, c, nil)
35                                         for j := range d.mat.Data {
36                                                 d.mat.Data[j] = rnd.Float64()
37                                         }
38                                         data[i] = d
39                                         dataCopy[i] = DenseCopyOf(d)
40                                 }
41
42                                 var gsvd HOGSVD
43                                 ok := gsvd.Factorize(data...)
44                                 if r >= c {
45                                         if !ok {
46                                                 t.Errorf("HOGSVD factorization failed for %d %d×%d matrices: %v", n, r, c, gsvd.Err())
47                                                 continue
48                                         }
49                                 } else {
50                                         if ok {
51                                                 t.Errorf("HOGSVD factorization unexpectedly succeeded for for %d %d×%d matrices", n, r, c)
52                                         }
53                                         continue
54                                 }
55                                 for i := range data {
56                                         if !Equal(data[i], dataCopy[i]) {
57                                                 t.Errorf("A changed during call to HOGSVD.Factorize")
58                                         }
59                                 }
60                                 u, s, v := extractHOGSVD(&gsvd)
61                                 for i, want := range data {
62                                         var got Dense
63                                         sigma := NewDense(c, c, nil)
64                                         for j := 0; j < c; j++ {
65                                                 sigma.Set(j, j, s[i][j])
66                                         }
67
68                                         got.Product(u[i], sigma, v.T())
69                                         if !EqualApprox(&got, want, tol) {
70                                                 t.Errorf("test %d n=%d trial %d: unexpected answer\nU_%[4]d * S_%[4]d * V^T:\n% 0.2f\nD_%d:\n% 0.2f",
71                                                         cas, n, trial, i, Formatted(&got, Excerpt(5)), i, Formatted(want, Excerpt(5)))
72                                         }
73                                 }
74                         }
75                 }
76         }
77 }
78
79 func extractHOGSVD(gsvd *HOGSVD) (u []*Dense, s [][]float64, v *Dense) {
80         u = make([]*Dense, gsvd.Len())
81         s = make([][]float64, gsvd.Len())
82         for i := 0; i < gsvd.Len(); i++ {
83                 u[i] = gsvd.UTo(nil, i)
84                 s[i] = gsvd.Values(nil, i)
85         }
86         v = gsvd.VTo(nil)
87         return u, s, v
88 }