1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/blas/blas64"
17 type Dsytd2er interface {
18 Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64)
21 func Dsytd2Test(t *testing.T, impl Dsytd2er) {
22 rnd := rand.New(rand.NewSource(1))
23 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
24 for _, test := range []struct {
40 a := make([]float64, n*lda)
42 a[i] = rnd.NormFloat64()
44 aCopy := make([]float64, len(a))
47 d := make([]float64, n)
51 e := make([]float64, n-1)
55 tau := make([]float64, n-1)
60 impl.Dsytd2(uplo, n, a, lda, d, e, tau)
63 qMat := blas64.General{
67 Data: make([]float64, n*n),
69 qCopy := blas64.General{
73 Data: make([]float64, len(qMat.Data)),
76 for i := 0; i < n; i++ {
77 qMat.Data[i*qMat.Stride+i] = 1
79 for i := 0; i < n-1; i++ {
80 hMat := blas64.General{
84 Data: make([]float64, n*n),
87 for i := 0; i < n; i++ {
88 hMat.Data[i*hMat.Stride+i] = 1
91 if uplo == blas.Upper {
94 Data: make([]float64, n),
96 for j := 0; j < i; j++ {
97 vi.Data[j] = a[j*lda+i+1]
103 Data: make([]float64, n),
106 for j := i + 2; j < n; j++ {
107 vi.Data[j] = a[j*lda+i]
110 blas64.Ger(-tau[i], vi, vi, hMat)
111 copy(qCopy.Data, qMat.Data)
113 // Multiply q by the new h.
114 if uplo == blas.Upper {
115 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat)
117 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
120 // Check that Q is orthonormal
122 for i := 0; i < n; i++ {
123 for j := i; j < n; j++ {
125 blas64.Vector{Inc: 1, Data: qMat.Data[i*qMat.Stride:]},
126 blas64.Vector{Inc: 1, Data: qMat.Data[j*qMat.Stride:]},
129 if math.Abs(dot-1) > 1e-10 {
133 if math.Abs(dot) > 1e-10 {
140 t.Errorf("Q not orthonormal")
143 // Compute Q^T * A * Q.
144 aMat := blas64.General{
148 Data: make([]float64, len(a)),
151 for i := 0; i < n; i++ {
152 for j := i; j < n; j++ {
154 if uplo == blas.Lower {
157 aMat.Data[i*aMat.Stride+j] = v
158 aMat.Data[j*aMat.Stride+i] = v
162 tmp := blas64.General{
166 Data: make([]float64, n*n),
169 ans := blas64.General{
173 Data: make([]float64, n*n),
176 blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp)
177 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans)
180 tMat := blas64.General{
184 Data: make([]float64, n*n),
186 for i := 0; i < n-1; i++ {
187 tMat.Data[i*tMat.Stride+i] = d[i]
188 tMat.Data[i*tMat.Stride+i+1] = e[i]
189 tMat.Data[(i+1)*tMat.Stride+i] = e[i]
191 tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1]
194 for i := 0; i < n; i++ {
195 for j := 0; j < n; j++ {
196 if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > 1e-10 {
202 t.Errorf("Matrix answer mismatch")