1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/floats"
15 func TestGSVD(t *testing.T) {
17 rnd := rand.New(rand.NewSource(1))
18 for _, test := range []struct {
37 for trial := 0; trial < 10; trial++ {
38 a := NewDense(m, n, nil)
39 for i := range a.mat.Data {
40 a.mat.Data[i] = rnd.NormFloat64()
42 aCopy := DenseCopyOf(a)
44 b := NewDense(p, n, nil)
45 for i := range b.mat.Data {
46 b.mat.Data[i] = rnd.NormFloat64()
48 bCopy := DenseCopyOf(b)
50 // Test Full decomposition.
52 ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ)
54 t.Errorf("GSVD factorization failed")
57 t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
60 t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
62 c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd)
63 var ansU, ansV, d1R, d2R Dense
64 ansU.Product(u.T(), a, q)
65 ansV.Product(v.T(), b, q)
66 d1R.Mul(sigma1, zeroR)
67 d2R.Mul(sigma2, zeroR)
68 if !EqualApprox(&ansU, &d1R, tol) {
69 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nU^T * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f",
70 Formatted(&ansU), Formatted(&d1R))
72 if !EqualApprox(&ansV, &d2R, tol) {
73 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nV^T * B *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f",
74 Formatted(&d2R), Formatted(&ansV))
77 // Check C^2 + S^2 = I.
79 d := c[i]*c[i] + s[i]*s[i]
80 if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
81 t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d)
85 // Test None decomposition.
86 ok = gsvd.Factorize(a, b, GSVDNone)
88 t.Errorf("GSVD factorization failed")
91 t.Errorf("A changed during call to GSVD with GSVDNone")
94 t.Errorf("B changed during call to GSVD with GSVDNone")
96 cNone := gsvd.ValuesA(nil)
97 if !floats.EqualApprox(c, cNone, tol) {
98 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
100 sNone := gsvd.ValuesB(nil)
101 if !floats.EqualApprox(s, sNone, tol) {
102 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
108 func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) {
109 s1 = gsvd.SigmaATo(nil)
110 s2 = gsvd.SigmaBTo(nil)
111 zR = gsvd.ZeroRTo(nil)
115 c = gsvd.ValuesA(nil)
116 s = gsvd.ValuesB(nil)
117 return c, s, s1, s2, zR, u, v, q