// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/lapack" ) type Dlarfter interface { Dgeqr2er Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int) } func DlarftTest(t *testing.T, impl Dlarfter) { rnd := rand.New(rand.NewSource(1)) for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} { for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { for _, test := range []struct { m, n, ldv, ldt int }{ {6, 6, 0, 0}, {8, 6, 0, 0}, {6, 8, 0, 0}, {6, 6, 10, 15}, {8, 6, 10, 15}, {6, 8, 10, 15}, {6, 6, 15, 10}, {8, 6, 15, 10}, {6, 8, 15, 10}, } { // Generate a matrix m := test.m n := test.n lda := n if lda == 0 { lda = n } a := make([]float64, m*lda) for i := 0; i < m; i++ { for j := 0; j < lda; j++ { a[i*lda+j] = rnd.Float64() } } // Use dgeqr2 to find the v vectors tau := make([]float64, n) work := make([]float64, n) impl.Dgeqr2(m, n, a, lda, tau, work) // Construct H using these answers vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise) vMat := constructVMat(vMatTmp, store, direct) v := vMat.Data ldv := vMat.Stride h := constructH(tau, vMat, store, direct) k := min(m, n) ldt := test.ldt if ldt == 0 { ldt = k } // Find T from the actual function tm := make([]float64, k*ldt) for i := range tm { tm[i] = 100 + rnd.Float64() } // The v data has been put into a. impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt) tData := make([]float64, len(tm)) copy(tData, tm) if direct == lapack.Forward { // Zero out the lower traingular portion. for i := 0; i < k; i++ { for j := 0; j < i; j++ { tData[i*ldt+j] = 0 } } } else { // Zero out the upper traingular portion. for i := 0; i < k; i++ { for j := i + 1; j < k; j++ { tData[i*ldt+j] = 0 } } } T := blas64.General{ Rows: k, Cols: k, Stride: ldt, Data: tData, } vMatT := blas64.General{ Rows: vMat.Cols, Cols: vMat.Rows, Stride: vMat.Rows, Data: make([]float64, vMat.Cols*vMat.Rows), } for i := 0; i < vMat.Rows; i++ { for j := 0; j < vMat.Cols; j++ { vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j] } } var comp blas64.General if store == lapack.ColumnWise { // H = I - V * T * V^T tmp := blas64.General{ Rows: T.Rows, Cols: vMatT.Cols, Stride: vMatT.Cols, Data: make([]float64, T.Rows*vMatT.Cols), } // T * V^T blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp) comp = blas64.General{ Rows: vMat.Rows, Cols: tmp.Cols, Stride: tmp.Cols, Data: make([]float64, vMat.Rows*tmp.Cols), } // V * (T * V^T) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp) } else { // H = I - V^T * T * V tmp := blas64.General{ Rows: T.Rows, Cols: vMat.Cols, Stride: vMat.Cols, Data: make([]float64, T.Rows*vMat.Cols), } // T * V blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp) comp = blas64.General{ Rows: vMatT.Rows, Cols: tmp.Cols, Stride: tmp.Cols, Data: make([]float64, vMatT.Rows*tmp.Cols), } // V^T * (T * V) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp) } // I - V^T * T * V for i := 0; i < comp.Rows; i++ { for j := 0; j < comp.Cols; j++ { comp.Data[i*m+j] *= -1 if i == j { comp.Data[i*m+j] += 1 } } } if !floats.EqualApprox(comp.Data, h.Data, 1e-14) { t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data) } } } } }