OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / solve.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "gonum.org/v1/gonum/blas"
9         "gonum.org/v1/gonum/blas/blas64"
10         "gonum.org/v1/gonum/lapack/lapack64"
11 )
12
13 // Solve finds a minimum-norm solution to a system of linear equations defined
14 // by the matrices a and b. If A is singular or near-singular, a Condition error
15 // is returned. See the documentation for Condition for more information.
16 //
17 // The minimization problem solved depends on the input parameters:
18 //  - if m >= n, find X such that ||A*X - B||_2 is minimized,
19 //  - if m < n, find the minimum norm solution of A * X = B.
20 // The solution matrix, X, is stored in-place into the receiver.
21 func (m *Dense) Solve(a, b Matrix) error {
22         ar, ac := a.Dims()
23         br, bc := b.Dims()
24         if ar != br {
25                 panic(ErrShape)
26         }
27         m.reuseAs(ac, bc)
28
29         // TODO(btracey): Add special cases for SymDense, etc.
30         aU, aTrans := untranspose(a)
31         bU, bTrans := untranspose(b)
32         switch rma := aU.(type) {
33         case RawTriangular:
34                 side := blas.Left
35                 tA := blas.NoTrans
36                 if aTrans {
37                         tA = blas.Trans
38                 }
39
40                 switch rm := bU.(type) {
41                 case RawMatrixer:
42                         if m != bU || bTrans {
43                                 if m == bU || m.checkOverlap(rm.RawMatrix()) {
44                                         tmp := getWorkspace(br, bc, false)
45                                         tmp.Copy(b)
46                                         m.Copy(tmp)
47                                         putWorkspace(tmp)
48                                         break
49                                 }
50                                 m.Copy(b)
51                         }
52                 default:
53                         if m != bU {
54                                 m.Copy(b)
55                         } else if bTrans {
56                                 // m and b share data so Copy cannot be used directly.
57                                 tmp := getWorkspace(br, bc, false)
58                                 tmp.Copy(b)
59                                 m.Copy(tmp)
60                                 putWorkspace(tmp)
61                         }
62                 }
63
64                 rm := rma.RawTriangular()
65                 blas64.Trsm(side, tA, 1, rm, m.mat)
66                 work := getFloats(3*rm.N, false)
67                 iwork := getInts(rm.N, false)
68                 cond := lapack64.Trcon(CondNorm, rm, work, iwork)
69                 putFloats(work)
70                 putInts(iwork)
71                 if cond > ConditionTolerance {
72                         return Condition(cond)
73                 }
74                 return nil
75         }
76
77         switch {
78         case ar == ac:
79                 if a == b {
80                         // x = I.
81                         if ar == 1 {
82                                 m.mat.Data[0] = 1
83                                 return nil
84                         }
85                         for i := 0; i < ar; i++ {
86                                 v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
87                                 zero(v)
88                                 v[i] = 1
89                         }
90                         return nil
91                 }
92                 var lu LU
93                 lu.Factorize(a)
94                 return lu.Solve(m, false, b)
95         case ar > ac:
96                 var qr QR
97                 qr.Factorize(a)
98                 return qr.Solve(m, false, b)
99         default:
100                 var lq LQ
101                 lq.Factorize(a)
102                 return lq.Solve(m, false, b)
103         }
104 }
105
106 // SolveVec finds a minimum-norm solution to a system of linear equations defined
107 // by the matrix a and the right-hand side column vector b. If A is singular or
108 // near-singular, a Condition error is returned. See the documentation for
109 // Dense.Solve for more information.
110 func (v *VecDense) SolveVec(a Matrix, b Vector) error {
111         if _, bc := b.Dims(); bc != 1 {
112                 panic(ErrShape)
113         }
114         _, c := a.Dims()
115
116         // The Solve implementation is non-trivial, so rather than duplicate the code,
117         // instead recast the VecDenses as Dense and call the matrix code.
118
119         if rv, ok := b.(RawVectorer); ok {
120                 bmat := rv.RawVector()
121                 if v != b {
122                         v.checkOverlap(bmat)
123                 }
124                 v.reuseAs(c)
125                 m := v.asDense()
126                 // We conditionally create bm as m when b and v are identical
127                 // to prevent the overlap detection code from identifying m
128                 // and bm as overlapping but not identical.
129                 bm := m
130                 if v != b {
131                         b := VecDense{mat: bmat, n: b.Len()}
132                         bm = b.asDense()
133                 }
134                 return m.Solve(a, bm)
135         }
136
137         v.reuseAs(c)
138         m := v.asDense()
139         return m.Solve(a, b)
140 }