--- /dev/null
+// Copyright ©2013 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mat
+
+import (
+ "testing"
+
+ "golang.org/x/exp/rand"
+)
+
+func TestLUD(t *testing.T) {
+ for _, n := range []int{1, 5, 10, 11, 50} {
+ a := NewDense(n, n, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < n; j++ {
+ a.Set(i, j, rand.NormFloat64())
+ }
+ }
+ var want Dense
+ want.Clone(a)
+
+ var lu LU
+ lu.Factorize(a)
+
+ l := lu.LTo(nil)
+ u := lu.UTo(nil)
+ var p Dense
+ pivot := lu.Pivot(nil)
+ p.Permutation(n, pivot)
+ var got Dense
+ got.Product(&p, l, u)
+ if !EqualApprox(&got, &want, 1e-12) {
+ t.Errorf("PLU does not equal original matrix.\nWant: %v\n Got: %v", want, got)
+ }
+ }
+}
+
+func TestLURankOne(t *testing.T) {
+ for _, pivoting := range []bool{true} {
+ for _, n := range []int{3, 10, 50} {
+ // Construct a random LU factorization
+ lu := &LU{}
+ lu.lu = NewDense(n, n, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < n; j++ {
+ lu.lu.Set(i, j, rand.Float64())
+ }
+ }
+ lu.pivot = make([]int, n)
+ for i := range lu.pivot {
+ lu.pivot[i] = i
+ }
+ if pivoting {
+ // For each row, randomly swap with itself or a row after (like is done)
+ // in the actual LU factorization.
+ for i := range lu.pivot {
+ idx := i + rand.Intn(n-i)
+ lu.pivot[i], lu.pivot[idx] = lu.pivot[idx], lu.pivot[i]
+ }
+ }
+ // Apply a rank one update. Ensure the update magnitude is larger than
+ // the equal tolerance.
+ alpha := rand.Float64() + 1
+ x := NewVecDense(n, nil)
+ y := NewVecDense(n, nil)
+ for i := 0; i < n; i++ {
+ x.setVec(i, rand.Float64()+1)
+ y.setVec(i, rand.Float64()+1)
+ }
+ a := luReconstruct(lu)
+ a.RankOne(a, alpha, x, y)
+
+ var luNew LU
+ luNew.RankOne(lu, alpha, x, y)
+ lu.RankOne(lu, alpha, x, y)
+
+ aR1New := luReconstruct(&luNew)
+ aR1 := luReconstruct(lu)
+
+ if !Equal(aR1, aR1New) {
+ t.Error("Different answer when new receiver")
+ }
+ if !EqualApprox(aR1, a, 1e-10) {
+ t.Errorf("Rank one mismatch, pivot %v.\nWant: %v\nGot:%v\n", pivoting, a, aR1)
+ }
+ }
+ }
+}
+
+// luReconstruct reconstructs the original A matrix from an LU decomposition.
+func luReconstruct(lu *LU) *Dense {
+ var L, U TriDense
+ lu.LTo(&L)
+ lu.UTo(&U)
+ var P Dense
+ pivot := lu.Pivot(nil)
+ P.Permutation(len(pivot), pivot)
+
+ var a Dense
+ a.Mul(&L, &U)
+ a.Mul(&P, &a)
+ return &a
+}
+
+func TestSolveLU(t *testing.T) {
+ for _, test := range []struct {
+ n, bc int
+ }{
+ {5, 5},
+ {5, 10},
+ {10, 5},
+ } {
+ n := test.n
+ bc := test.bc
+ a := NewDense(n, n, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < n; j++ {
+ a.Set(i, j, rand.NormFloat64())
+ }
+ }
+ b := NewDense(n, bc, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < bc; j++ {
+ b.Set(i, j, rand.NormFloat64())
+ }
+ }
+ var lu LU
+ lu.Factorize(a)
+ var x Dense
+ if err := lu.Solve(&x, false, b); err != nil {
+ continue
+ }
+ var got Dense
+ got.Mul(a, &x)
+ if !EqualApprox(&got, b, 1e-12) {
+ t.Errorf("Solve mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
+ }
+ }
+ // TODO(btracey): Add testOneInput test when such a function exists.
+}
+
+func TestSolveLUCond(t *testing.T) {
+ for _, test := range []*Dense{
+ NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
+ } {
+ m, _ := test.Dims()
+ var lu LU
+ lu.Factorize(test)
+ b := NewDense(m, 2, nil)
+ var x Dense
+ if err := lu.Solve(&x, false, b); err == nil {
+ t.Error("No error for near-singular matrix in matrix solve.")
+ }
+
+ bvec := NewVecDense(m, nil)
+ var xvec VecDense
+ if err := lu.SolveVec(&xvec, false, bvec); err == nil {
+ t.Error("No error for near-singular matrix in matrix solve.")
+ }
+ }
+}
+
+func TestSolveLUVec(t *testing.T) {
+ for _, n := range []int{5, 10} {
+ a := NewDense(n, n, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < n; j++ {
+ a.Set(i, j, rand.NormFloat64())
+ }
+ }
+ b := NewVecDense(n, nil)
+ for i := 0; i < n; i++ {
+ b.SetVec(i, rand.NormFloat64())
+ }
+ var lu LU
+ lu.Factorize(a)
+ var x VecDense
+ if err := lu.SolveVec(&x, false, b); err != nil {
+ continue
+ }
+ var got VecDense
+ got.MulVec(a, &x)
+ if !EqualApprox(&got, b, 1e-12) {
+ t.Errorf("Solve mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
+ }
+ }
+ // TODO(btracey): Add testOneInput test when such a function exists.
+}