OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / gonum / dlasy2.go
1 // Copyright ©2016 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package gonum
6
7 import (
8         "math"
9
10         "gonum.org/v1/gonum/blas/blas64"
11 )
12
13 // Dlasy2 solves the Sylvester matrix equation where the matrices are of order 1
14 // or 2. It computes the unknown n1×n2 matrix X so that
15 //  TL*X   + sgn*X*TR   = scale*B,  if tranl == false and tranr == false,
16 //  TL^T*X + sgn*X*TR   = scale*B,  if tranl == true  and tranr == false,
17 //  TL*X   + sgn*X*TR^T = scale*B,  if tranl == false and tranr == true,
18 //  TL^T*X + sgn*X*TR^T = scale*B,  if tranl == true  and tranr == true,
19 // where TL is n1×n1, TR is n2×n2, B is n1×n2, and 1 <= n1,n2 <= 2.
20 //
21 // isgn must be 1 or -1, and n1 and n2 must be 0, 1, or 2, but these conditions
22 // are not checked.
23 //
24 // Dlasy2 returns three values, a scale factor that is chosen less than or equal
25 // to 1 to prevent the solution overflowing, the infinity norm of the solution,
26 // and an indicator of success. If ok is false, TL and TR have eigenvalues that
27 // are too close, so TL or TR is perturbed to get a non-singular equation.
28 //
29 // Dlasy2 is an internal routine. It is exported for testing purposes.
30 func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool) {
31         // TODO(vladimir-ch): Add input validation checks conditionally skipped
32         // using the build tag mechanism.
33
34         ok = true
35         // Quick return if possible.
36         if n1 == 0 || n2 == 0 {
37                 return scale, xnorm, ok
38         }
39
40         // Set constants to control overflow.
41         eps := dlamchP
42         smlnum := dlamchS / eps
43         sgn := float64(isgn)
44
45         if n1 == 1 && n2 == 1 {
46                 // 1×1 case: TL11*X + sgn*X*TR11 = B11.
47                 tau1 := tl[0] + sgn*tr[0]
48                 bet := math.Abs(tau1)
49                 if bet <= smlnum {
50                         tau1 = smlnum
51                         bet = smlnum
52                         ok = false
53                 }
54                 scale = 1
55                 gam := math.Abs(b[0])
56                 if smlnum*gam > bet {
57                         scale = 1 / gam
58                 }
59                 x[0] = b[0] * scale / tau1
60                 xnorm = math.Abs(x[0])
61                 return scale, xnorm, ok
62         }
63
64         if n1+n2 == 3 {
65                 // 1×2 or 2×1 case.
66                 var (
67                         smin float64
68                         tmp  [4]float64 // tmp is used as a 2×2 row-major matrix.
69                         btmp [2]float64
70                 )
71                 if n1 == 1 && n2 == 2 {
72                         // 1×2 case: TL11*[X11 X12] + sgn*[X11 X12]*op[TR11 TR12] = [B11 B12].
73                         //                                            [TR21 TR22]
74                         smin = math.Abs(tl[0])
75                         smin = math.Max(smin, math.Max(math.Abs(tr[0]), math.Abs(tr[1])))
76                         smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
77                         smin = math.Max(eps*smin, smlnum)
78                         tmp[0] = tl[0] + sgn*tr[0]
79                         tmp[3] = tl[0] + sgn*tr[ldtr+1]
80                         if tranr {
81                                 tmp[1] = sgn * tr[1]
82                                 tmp[2] = sgn * tr[ldtr]
83                         } else {
84                                 tmp[1] = sgn * tr[ldtr]
85                                 tmp[2] = sgn * tr[1]
86                         }
87                         btmp[0] = b[0]
88                         btmp[1] = b[1]
89                 } else {
90                         // 2×1 case: op[TL11 TL12]*[X11] + sgn*[X11]*TR11 = [B11].
91                         //             [TL21 TL22]*[X21]       [X21]        [B21]
92                         smin = math.Abs(tr[0])
93                         smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
94                         smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
95                         smin = math.Max(eps*smin, smlnum)
96                         tmp[0] = tl[0] + sgn*tr[0]
97                         tmp[3] = tl[ldtl+1] + sgn*tr[0]
98                         if tranl {
99                                 tmp[1] = tl[ldtl]
100                                 tmp[2] = tl[1]
101                         } else {
102                                 tmp[1] = tl[1]
103                                 tmp[2] = tl[ldtl]
104                         }
105                         btmp[0] = b[0]
106                         btmp[1] = b[ldb]
107                 }
108
109                 // Solve 2×2 system using complete pivoting.
110                 // Set pivots less than smin to smin.
111
112                 bi := blas64.Implementation()
113                 ipiv := bi.Idamax(len(tmp), tmp[:], 1)
114                 // Compute the upper triangular matrix [u11 u12].
115                 //                                     [  0 u22]
116                 u11 := tmp[ipiv]
117                 if math.Abs(u11) <= smin {
118                         ok = false
119                         u11 = smin
120                 }
121                 locu12 := [4]int{1, 0, 3, 2} // Index in tmp of the element on the same row as the pivot.
122                 u12 := tmp[locu12[ipiv]]
123                 locl21 := [4]int{2, 3, 0, 1} // Index in tmp of the element on the same column as the pivot.
124                 l21 := tmp[locl21[ipiv]] / u11
125                 locu22 := [4]int{3, 2, 1, 0} // Index in tmp of the remaining element.
126                 u22 := tmp[locu22[ipiv]] - l21*u12
127                 if math.Abs(u22) <= smin {
128                         ok = false
129                         u22 = smin
130                 }
131                 if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3.
132                         // The pivot was in the second row, swap the elements of
133                         // the right-hand side.
134                         btmp[0], btmp[1] = btmp[1], btmp[0]-l21*btmp[1]
135                 } else {
136                         btmp[1] -= l21 * btmp[0]
137                 }
138                 scale = 1
139                 if 2*smlnum*math.Abs(btmp[1]) > math.Abs(u22) || 2*smlnum*math.Abs(btmp[0]) > math.Abs(u11) {
140                         scale = 0.5 / math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
141                         btmp[0] *= scale
142                         btmp[1] *= scale
143                 }
144                 // Solve the system [u11 u12] [x21] = [ btmp[0] ].
145                 //                  [  0 u22] [x22]   [ btmp[1] ]
146                 x22 := btmp[1] / u22
147                 x21 := btmp[0]/u11 - (u12/u11)*x22
148                 if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3.
149                         // The pivot was in the second column, swap the elements
150                         // of the solution.
151                         x21, x22 = x22, x21
152                 }
153                 x[0] = x21
154                 if n1 == 1 {
155                         x[1] = x22
156                         xnorm = math.Abs(x[0]) + math.Abs(x[1])
157                 } else {
158                         x[ldx] = x22
159                         xnorm = math.Max(math.Abs(x[0]), math.Abs(x[ldx]))
160                 }
161                 return scale, xnorm, ok
162         }
163
164         // 2×2 case: op[TL11 TL12]*[X11 X12] + SGN*[X11 X12]*op[TR11 TR12] = [B11 B12].
165         //             [TL21 TL22] [X21 X22]       [X21 X22]   [TR21 TR22]   [B21 B22]
166         //
167         // Solve equivalent 4×4 system using complete pivoting.
168         // Set pivots less than smin to smin.
169
170         smin := math.Max(math.Abs(tr[0]), math.Abs(tr[1]))
171         smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
172         smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
173         smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
174         smin = math.Max(eps*smin, smlnum)
175
176         var t [4][4]float64
177         t[0][0] = tl[0] + sgn*tr[0]
178         t[1][1] = tl[0] + sgn*tr[ldtr+1]
179         t[2][2] = tl[ldtl+1] + sgn*tr[0]
180         t[3][3] = tl[ldtl+1] + sgn*tr[ldtr+1]
181         if tranl {
182                 t[0][2] = tl[ldtl]
183                 t[1][3] = tl[ldtl]
184                 t[2][0] = tl[1]
185                 t[3][1] = tl[1]
186         } else {
187                 t[0][2] = tl[1]
188                 t[1][3] = tl[1]
189                 t[2][0] = tl[ldtl]
190                 t[3][1] = tl[ldtl]
191         }
192         if tranr {
193                 t[0][1] = sgn * tr[1]
194                 t[1][0] = sgn * tr[ldtr]
195                 t[2][3] = sgn * tr[1]
196                 t[3][2] = sgn * tr[ldtr]
197         } else {
198                 t[0][1] = sgn * tr[ldtr]
199                 t[1][0] = sgn * tr[1]
200                 t[2][3] = sgn * tr[ldtr]
201                 t[3][2] = sgn * tr[1]
202         }
203
204         var btmp [4]float64
205         btmp[0] = b[0]
206         btmp[1] = b[1]
207         btmp[2] = b[ldb]
208         btmp[3] = b[ldb+1]
209
210         // Perform elimination.
211         var jpiv [4]int // jpiv records any column swaps for pivoting.
212         for i := 0; i < 3; i++ {
213                 var (
214                         xmax       float64
215                         ipsv, jpsv int
216                 )
217                 for ip := i; ip < 4; ip++ {
218                         for jp := i; jp < 4; jp++ {
219                                 if math.Abs(t[ip][jp]) >= xmax {
220                                         xmax = math.Abs(t[ip][jp])
221                                         ipsv = ip
222                                         jpsv = jp
223                                 }
224                         }
225                 }
226                 if ipsv != i {
227                         // The pivot is not in the top row of the unprocessed
228                         // block, swap rows ipsv and i of t and btmp.
229                         t[ipsv], t[i] = t[i], t[ipsv]
230                         btmp[ipsv], btmp[i] = btmp[i], btmp[ipsv]
231                 }
232                 if jpsv != i {
233                         // The pivot is not in the left column of the
234                         // unprocessed block, swap columns jpsv and i of t.
235                         for k := 0; k < 4; k++ {
236                                 t[k][jpsv], t[k][i] = t[k][i], t[k][jpsv]
237                         }
238                 }
239                 jpiv[i] = jpsv
240                 if math.Abs(t[i][i]) < smin {
241                         ok = false
242                         t[i][i] = smin
243                 }
244                 for k := i + 1; k < 4; k++ {
245                         t[k][i] /= t[i][i]
246                         btmp[k] -= t[k][i] * btmp[i]
247                         for j := i + 1; j < 4; j++ {
248                                 t[k][j] -= t[k][i] * t[i][j]
249                         }
250                 }
251         }
252         if math.Abs(t[3][3]) < smin {
253                 ok = false
254                 t[3][3] = smin
255         }
256         scale = 1
257         if 8*smlnum*math.Abs(btmp[0]) > math.Abs(t[0][0]) ||
258                 8*smlnum*math.Abs(btmp[1]) > math.Abs(t[1][1]) ||
259                 8*smlnum*math.Abs(btmp[2]) > math.Abs(t[2][2]) ||
260                 8*smlnum*math.Abs(btmp[3]) > math.Abs(t[3][3]) {
261
262                 maxbtmp := math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
263                 maxbtmp = math.Max(maxbtmp, math.Max(math.Abs(btmp[2]), math.Abs(btmp[3])))
264                 scale = 1 / 8 / maxbtmp
265                 btmp[0] *= scale
266                 btmp[1] *= scale
267                 btmp[2] *= scale
268                 btmp[3] *= scale
269         }
270         // Compute the solution of the upper triangular system t * tmp = btmp.
271         var tmp [4]float64
272         for i := 3; i >= 0; i-- {
273                 temp := 1 / t[i][i]
274                 tmp[i] = btmp[i] * temp
275                 for j := i + 1; j < 4; j++ {
276                         tmp[i] -= temp * t[i][j] * tmp[j]
277                 }
278         }
279         for i := 2; i >= 0; i-- {
280                 if jpiv[i] != i {
281                         tmp[i], tmp[jpiv[i]] = tmp[jpiv[i]], tmp[i]
282                 }
283         }
284         x[0] = tmp[0]
285         x[1] = tmp[1]
286         x[ldx] = tmp[2]
287         x[ldx+1] = tmp[3]
288         xnorm = math.Max(math.Abs(tmp[0])+math.Abs(tmp[1]), math.Abs(tmp[2])+math.Abs(tmp[3]))
289         return scale, xnorm, ok
290 }