OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / gsvd_test.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "testing"
9
10         "golang.org/x/exp/rand"
11
12         "gonum.org/v1/gonum/floats"
13 )
14
15 func TestGSVD(t *testing.T) {
16         const tol = 1e-10
17         rnd := rand.New(rand.NewSource(1))
18         for _, test := range []struct {
19                 m, p, n int
20         }{
21                 {5, 3, 5},
22                 {5, 3, 3},
23                 {3, 3, 5},
24                 {5, 5, 5},
25                 {5, 5, 3},
26                 {3, 5, 5},
27                 {150, 150, 150},
28                 {200, 150, 150},
29                 {150, 150, 200},
30                 {150, 200, 150},
31                 {200, 200, 150},
32                 {150, 200, 200},
33         } {
34                 m := test.m
35                 p := test.p
36                 n := test.n
37                 for trial := 0; trial < 10; trial++ {
38                         a := NewDense(m, n, nil)
39                         for i := range a.mat.Data {
40                                 a.mat.Data[i] = rnd.NormFloat64()
41                         }
42                         aCopy := DenseCopyOf(a)
43
44                         b := NewDense(p, n, nil)
45                         for i := range b.mat.Data {
46                                 b.mat.Data[i] = rnd.NormFloat64()
47                         }
48                         bCopy := DenseCopyOf(b)
49
50                         // Test Full decomposition.
51                         var gsvd GSVD
52                         ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ)
53                         if !ok {
54                                 t.Errorf("GSVD factorization failed")
55                         }
56                         if !Equal(a, aCopy) {
57                                 t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
58                         }
59                         if !Equal(b, bCopy) {
60                                 t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
61                         }
62                         c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd)
63                         var ansU, ansV, d1R, d2R Dense
64                         ansU.Product(u.T(), a, q)
65                         ansV.Product(v.T(), b, q)
66                         d1R.Mul(sigma1, zeroR)
67                         d2R.Mul(sigma2, zeroR)
68                         if !EqualApprox(&ansU, &d1R, tol) {
69                                 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nU^T * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f",
70                                         Formatted(&ansU), Formatted(&d1R))
71                         }
72                         if !EqualApprox(&ansV, &d2R, tol) {
73                                 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nV^T * B  *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f",
74                                         Formatted(&d2R), Formatted(&ansV))
75                         }
76
77                         // Check C^2 + S^2 = I.
78                         for i := range c {
79                                 d := c[i]*c[i] + s[i]*s[i]
80                                 if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
81                                         t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d)
82                                 }
83                         }
84
85                         // Test None decomposition.
86                         ok = gsvd.Factorize(a, b, GSVDNone)
87                         if !ok {
88                                 t.Errorf("GSVD factorization failed")
89                         }
90                         if !Equal(a, aCopy) {
91                                 t.Errorf("A changed during call to GSVD with GSVDNone")
92                         }
93                         if !Equal(b, bCopy) {
94                                 t.Errorf("B changed during call to GSVD with GSVDNone")
95                         }
96                         cNone := gsvd.ValuesA(nil)
97                         if !floats.EqualApprox(c, cNone, tol) {
98                                 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
99                         }
100                         sNone := gsvd.ValuesB(nil)
101                         if !floats.EqualApprox(s, sNone, tol) {
102                                 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
103                         }
104                 }
105         }
106 }
107
108 func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) {
109         s1 = gsvd.SigmaATo(nil)
110         s2 = gsvd.SigmaBTo(nil)
111         zR = gsvd.ZeroRTo(nil)
112         u = gsvd.UTo(nil)
113         v = gsvd.VTo(nil)
114         q = gsvd.QTo(nil)
115         c = gsvd.ValuesA(nil)
116         s = gsvd.ValuesB(nil)
117         return c, s, s1, s2, zR, u, v, q
118 }