OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlasr.go
diff --git a/vendor/gonum.org/v1/gonum/lapack/testlapack/dlasr.go b/vendor/gonum.org/v1/gonum/lapack/testlapack/dlasr.go
new file mode 100644 (file)
index 0000000..8b28c91
--- /dev/null
@@ -0,0 +1,149 @@
+// Copyright ©2015 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testlapack
+
+import (
+       "math"
+       "testing"
+
+       "golang.org/x/exp/rand"
+
+       "gonum.org/v1/gonum/blas"
+       "gonum.org/v1/gonum/blas/blas64"
+       "gonum.org/v1/gonum/floats"
+       "gonum.org/v1/gonum/lapack"
+)
+
+type Dlasrer interface {
+       Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
+}
+
+func DlasrTest(t *testing.T, impl Dlasrer) {
+       rnd := rand.New(rand.NewSource(1))
+       for _, side := range []blas.Side{blas.Left, blas.Right} {
+               for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
+                       for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
+                               for _, test := range []struct {
+                                       m, n, lda int
+                               }{
+                                       {5, 5, 0},
+                                       {5, 10, 0},
+                                       {10, 5, 0},
+
+                                       {5, 5, 20},
+                                       {5, 10, 20},
+                                       {10, 5, 20},
+                               } {
+                                       m := test.m
+                                       n := test.n
+                                       lda := test.lda
+                                       if lda == 0 {
+                                               lda = n
+                                       }
+                                       a := make([]float64, m*lda)
+                                       for i := range a {
+                                               a[i] = rnd.Float64()
+                                       }
+                                       var s, c []float64
+                                       if side == blas.Left {
+                                               s = make([]float64, m-1)
+                                               c = make([]float64, m-1)
+                                       } else {
+                                               s = make([]float64, n-1)
+                                               c = make([]float64, n-1)
+                                       }
+                                       for k := range s {
+                                               theta := rnd.Float64() * 2 * math.Pi
+                                               s[k] = math.Sin(theta)
+                                               c[k] = math.Cos(theta)
+                                       }
+                                       aCopy := make([]float64, len(a))
+                                       copy(a, aCopy)
+                                       impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
+
+                                       pSize := m
+                                       if side == blas.Right {
+                                               pSize = n
+                                       }
+                                       p := blas64.General{
+                                               Rows:   pSize,
+                                               Cols:   pSize,
+                                               Stride: pSize,
+                                               Data:   make([]float64, pSize*pSize),
+                                       }
+                                       pk := blas64.General{
+                                               Rows:   pSize,
+                                               Cols:   pSize,
+                                               Stride: pSize,
+                                               Data:   make([]float64, pSize*pSize),
+                                       }
+                                       ptmp := blas64.General{
+                                               Rows:   pSize,
+                                               Cols:   pSize,
+                                               Stride: pSize,
+                                               Data:   make([]float64, pSize*pSize),
+                                       }
+                                       for i := 0; i < pSize; i++ {
+                                               p.Data[i*p.Stride+i] = 1
+                                               ptmp.Data[i*p.Stride+i] = 1
+                                       }
+                                       // Compare to direct computation.
+                                       for k := range s {
+                                               for i := range p.Data {
+                                                       pk.Data[i] = 0
+                                               }
+                                               for i := 0; i < pSize; i++ {
+                                                       pk.Data[i*p.Stride+i] = 1
+                                               }
+                                               if pivot == lapack.Variable {
+                                                       pk.Data[k*p.Stride+k] = c[k]
+                                                       pk.Data[k*p.Stride+k+1] = s[k]
+                                                       pk.Data[(k+1)*p.Stride+k] = -s[k]
+                                                       pk.Data[(k+1)*p.Stride+k+1] = c[k]
+                                               } else if pivot == lapack.Top {
+                                                       pk.Data[0] = c[k]
+                                                       pk.Data[k+1] = s[k]
+                                                       pk.Data[(k+1)*p.Stride] = -s[k]
+                                                       pk.Data[(k+1)*p.Stride+k+1] = c[k]
+                                               } else {
+                                                       pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
+                                                       pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
+                                                       pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
+                                                       pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
+                                               }
+                                               if direct == lapack.Forward {
+                                                       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
+                                               } else {
+                                                       blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
+                                               }
+                                               copy(ptmp.Data, p.Data)
+                                       }
+
+                                       aMat := blas64.General{
+                                               Rows:   m,
+                                               Cols:   n,
+                                               Stride: lda,
+                                               Data:   make([]float64, m*lda),
+                                       }
+                                       copy(a, aCopy)
+                                       newA := blas64.General{
+                                               Rows:   m,
+                                               Cols:   n,
+                                               Stride: lda,
+                                               Data:   make([]float64, m*lda),
+                                       }
+                                       if side == blas.Left {
+                                               blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
+                                       } else {
+                                               blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
+                                       }
+                                       if !floats.EqualApprox(newA.Data, a, 1e-12) {
+                                               t.Errorf("A update mismatch")
+                                       }
+                               }
+                       }
+               }
+       }
+}