1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
12 "golang.org/x/exp/rand"
14 "gonum.org/v1/gonum/blas"
15 "gonum.org/v1/gonum/blas/blas64"
18 type Dlatrder interface {
19 Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int)
22 func DlatrdTest(t *testing.T, impl Dlatrder) {
23 rnd := rand.New(rand.NewSource(1))
24 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
25 for _, test := range []struct {
45 a := make([]float64, n*lda)
47 a[i] = rnd.NormFloat64()
50 e := make([]float64, n-1)
54 tau := make([]float64, n-1)
58 w := make([]float64, n*ldw)
63 aCopy := make([]float64, len(a))
66 impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw)
74 Data: make([]float64, n*ldq),
76 for i := 0; i < n; i++ {
79 if uplo == blas.Upper {
80 for i := n - 1; i >= n-nb; i-- {
85 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
87 for j := 0; j < n; j++ {
92 Data: make([]float64, n),
94 for j := 0; j < i-1; j++ {
95 v.Data[j] = a[j*lda+i]
99 blas64.Ger(-tau[i-1], v, v, h)
101 qTmp := blas64.General{
102 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
104 copy(qTmp.Data, q.Data)
105 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
108 for i := 0; i < nb; i++ {
113 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
115 for j := 0; j < n; j++ {
120 Data: make([]float64, n),
123 for j := i + 2; j < n; j++ {
124 v.Data[j] = a[j*lda+i]
126 blas64.Ger(-tau[i], v, v, h)
128 qTmp := blas64.General{
129 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
131 copy(qTmp.Data, q.Data)
132 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
135 errStr := fmt.Sprintf("isUpper = %v, n = %v, nb = %v", uplo == blas.Upper, n, nb)
136 if !isOrthonormal(q) {
137 t.Errorf("Q not orthonormal. %s", errStr)
139 aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy})
140 if !dlatrdCheckDecomposition(t, uplo, n, nb, e, tau, a, lda, aGen, q) {
141 t.Errorf("Decomposition mismatch. %s", errStr)
147 // dlatrdCheckDecomposition checks that the first nb rows have been successfully
149 func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
150 // Compute Q^T * A * Q.
151 tmp := blas64.General{
155 Data: make([]float64, n*n),
158 ans := blas64.General{
162 Data: make([]float64, n*n),
165 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
166 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)
169 if uplo == blas.Upper {
170 for i := n - 1; i >= n-nb; i-- {
171 for j := 0; j < n; j++ {
172 v := ans.Data[i*ans.Stride+j]
175 if math.Abs(v-a[i*lda+j]) > 1e-10 {
179 if math.Abs(a[i*lda+j]-1) > 1e-10 {
182 if math.Abs(v-e[i]) > 1e-10 {
187 if math.Abs(v) > 1e-10 {
194 for i := 0; i < nb; i++ {
195 for j := 0; j < n; j++ {
196 v := ans.Data[i*ans.Stride+j]
199 if math.Abs(v-a[i*lda+j]) > 1e-10 {
204 if math.Abs(a[i*lda+j]-1) > 1e-10 {
207 if math.Abs(v-e[i-1]) > 1e-10 {
211 if math.Abs(v) > 1e-10 {
221 // genFromSym constructs a (symmetric) general matrix from the data in the
223 // TODO(btracey): Replace other constructions of this with a call to this function.
224 func genFromSym(a blas64.Symmetric) blas64.General {
232 Data: make([]float64, n*n),
235 for i := 0; i < n; i++ {
236 for j := i; j < n; j++ {
238 if uplo == blas.Lower {