1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/blas/blas64"
17 type Dsytrder interface {
18 Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
20 Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
21 Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
24 func DsytrdTest(t *testing.T, impl Dsytrder) {
26 rnd := rand.New(rand.NewSource(1))
27 for tc, test := range []struct {
50 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
51 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
57 a := randomGeneral(n, n, lda, rnd)
58 for i := 1; i < n; i++ {
59 for j := 0; j < i; j++ {
60 a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
63 aCopy := cloneGeneral(a)
67 tau := nanSlice(n - 1)
74 work := make([]float64, 1)
75 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
76 lwork = (int(work[0]) + 1) / 2
79 work := make([]float64, 1)
80 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
83 work := make([]float64, lwork)
85 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
87 prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
90 if !generalOutsideAllNaN(a) {
91 t.Errorf("%v: out-of-range write to A", prefix)
94 // Extract Q by doing what Dorgtr does.
96 if uplo == blas.Upper {
97 for j := 0; j < n-1; j++ {
98 for i := 0; i < j; i++ {
99 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
101 q.Data[(n-1)*q.Stride+j] = 0
103 for i := 0; i < n-1; i++ {
104 q.Data[i*q.Stride+n-1] = 0
106 q.Data[(n-1)*q.Stride+n-1] = 1
108 work = make([]float64, n-1)
109 impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
112 for j := n - 1; j > 0; j-- {
114 for i := j + 1; i < n; i++ {
115 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
119 for i := 1; i < n; i++ {
120 q.Data[i*q.Stride] = 0
123 work = make([]float64, n-1)
124 impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
127 if !isOrthonormal(q) {
128 t.Errorf("%v: Q not orthogonal", prefix)
131 // Contruct symmetric tridiagonal T from d and e.
132 tMat := zeros(n, n, n)
133 for i := 0; i < n; i++ {
134 tMat.Data[i*tMat.Stride+i] = d[i]
136 if uplo == blas.Upper {
137 for j := 1; j < n; j++ {
138 tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
139 tMat.Data[j*tMat.Stride+j-1] = e[j-1]
142 for j := 0; j < n-1; j++ {
143 tMat.Data[(j+1)*tMat.Stride+j] = e[j]
144 tMat.Data[j*tMat.Stride+j+1] = e[j]
148 // Compute Q^T * A * Q.
149 tmp := zeros(n, n, n)
150 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
151 got := zeros(n, n, n)
152 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
155 if !equalApproxGeneral(got, tMat, tol) {
156 t.Errorf("%v: Q^T*A*Q != T", prefix)