// Copyright ©2016 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "compress/gzip" "encoding/json" "fmt" "log" "math" "os" "path/filepath" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" "gonum.org/v1/gonum/floats" ) type Dlahr2er interface { Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int) } type Dlahr2test struct { N, K, NB int A []float64 AWant []float64 TWant []float64 YWant []float64 TauWant []float64 } func Dlahr2Test(t *testing.T, impl Dlahr2er) { rnd := rand.New(rand.NewSource(1)) for _, test := range []struct { n, k, nb int }{ {3, 0, 3}, {3, 1, 2}, {3, 1, 1}, {5, 0, 5}, {5, 1, 4}, {5, 1, 3}, {5, 1, 2}, {5, 1, 1}, {5, 2, 3}, {5, 2, 2}, {5, 2, 1}, {5, 3, 2}, {5, 3, 1}, {7, 3, 4}, {7, 3, 3}, {7, 3, 2}, {7, 3, 1}, {10, 0, 10}, {10, 1, 9}, {10, 1, 5}, {10, 1, 1}, {10, 5, 5}, {10, 5, 3}, {10, 5, 1}, } { for cas := 0; cas < 100; cas++ { for _, extraStride := range []int{0, 1, 10} { n := test.n k := test.k nb := test.nb a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd) aCopy := a aCopy.Data = make([]float64, len(a.Data)) copy(aCopy.Data, a.Data) tmat := nanTriangular(blas.Upper, nb, nb+extraStride) y := nanGeneral(n, nb, nb+extraStride) tau := nanSlice(nb) impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride) prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride) if !generalOutsideAllNaN(a) { t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data) } if !triangularOutsideAllNaN(tmat) { t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data) } if !generalOutsideAllNaN(y) { t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data) } // Check that A[:k,:] and A[:,nb:] blocks were not modified. for i := 0; i < n; i++ { for j := 0; j < n-k+1; j++ { if i >= k && j < nb { continue } if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j) } } } // Check that all elements of tau were assigned. for i, v := range tau { if math.IsNaN(v) { t.Errorf("%v: tau[%v] not assigned", prefix, i) } } // Extract V from a. v := blas64.General{ Rows: n - k + 1, Cols: nb, Stride: nb, Data: make([]float64, (n-k+1)*nb), } for j := 0; j < v.Cols; j++ { v.Data[(j+1)*v.Stride+j] = 1 for i := j + 2; i < v.Rows; i++ { v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j] } } // VT = V. vt := v vt.Data = make([]float64, len(v.Data)) copy(vt.Data, v.Data) // VT = V * T. blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt) // YWant = A * V * T. ywant := blas64.General{ Rows: n, Cols: nb, Stride: nb, Data: make([]float64, n*nb), } blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant) // Compare Y and YWant. for i := 0; i < n; i++ { for j := 0; j < nb; j++ { diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j]) if diff > 1e-14 { t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff) } } } // Construct Q directly from the first nb columns of a. q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau) if !isOrthonormal(q) { t.Errorf("%v: Q is not orthogonal", prefix) } // Construct Q as the product Q = I - V*T*V^T. qwant := blas64.General{ Rows: n - k + 1, Cols: n - k + 1, Stride: n - k + 1, Data: make([]float64, (n-k+1)*(n-k+1)), } for i := 0; i < qwant.Rows; i++ { qwant.Data[i*qwant.Stride+i] = 1 } blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant) if !isOrthonormal(qwant) { t.Errorf("%v: Q = I - V*T*V^T is not orthogonal", prefix) } // Compare Q and QWant. Note that since Q is // (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we // ignore the first row and column of QWant. for i := 0; i < n-k; i++ { for j := 0; j < n-k; j++ { diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1]) if diff > 1e-14 { t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff) } } } } } } // Go runs tests from the source directory, so unfortunately we need to // include the "../testlapack" part. file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlahr2data.json.gz")) if err != nil { log.Fatal(err) } defer file.Close() r, err := gzip.NewReader(file) if err != nil { log.Fatal(err) } defer r.Close() var tests []Dlahr2test json.NewDecoder(r).Decode(&tests) for _, test := range tests { tau := make([]float64, len(test.TauWant)) for _, ldex := range []int{0, 1, 20} { n := test.N k := test.K nb := test.NB lda := n - k + 1 + ldex a := make([]float64, (n-1)*lda+n-k+1) copyMatrix(n, n-k+1, a, lda, test.A) ldt := nb + ldex tmat := make([]float64, (nb-1)*ldt+nb) ldy := nb + ldex y := make([]float64, (n-1)*ldy+nb) impl.Dlahr2(n, k, nb, a, lda, tau, tmat, ldt, y, ldy) prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, ldex) if !equalApprox(n, n-k+1, a, lda, test.AWant, 1e-14) { t.Errorf("%v: unexpected matrix A\n got=%v\nwant=%v", prefix, a, test.AWant) } if !equalApproxTriangular(true, nb, tmat, ldt, test.TWant, 1e-14) { t.Errorf("%v: unexpected matrix T\n got=%v\nwant=%v", prefix, tmat, test.TWant) } if !equalApprox(n, nb, y, ldy, test.YWant, 1e-14) { t.Errorf("%v: unexpected matrix Y\n got=%v\nwant=%v", prefix, y, test.YWant) } if !floats.EqualApprox(tau, test.TauWant, 1e-14) { t.Errorf("%v: unexpected slice tau\n got=%v\nwant=%v", prefix, tau, test.TauWant) } } } }