OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / gonum / dlatrs.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package gonum
6
7 import (
8         "math"
9
10         "gonum.org/v1/gonum/blas"
11         "gonum.org/v1/gonum/blas/blas64"
12 )
13
14 // Dlatrs solves a triangular system of equations scaled to prevent overflow. It
15 // solves
16 //  A * x = scale * b if trans == blas.NoTrans
17 //  A^T * x = scale * b if trans == blas.Trans
18 // where the scale s is set for numeric stability.
19 //
20 // A is an n×n triangular matrix. On entry, the slice x contains the values of
21 // of b, and on exit it contains the solution vector x.
22 //
23 // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
24 // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
25 // than or equal to the infinity norm, and greater than or equal to the one-norm
26 // otherwise. If normin == false, then cnorm is treated as an output, and is set
27 // to contain the 1-norm of the off-diagonal part of the j^th column of A.
28 //
29 // Dlatrs is an internal routine. It is exported for testing purposes.
30 func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
31         if uplo != blas.Upper && uplo != blas.Lower {
32                 panic(badUplo)
33         }
34         if trans != blas.Trans && trans != blas.NoTrans {
35                 panic(badTrans)
36         }
37         if diag != blas.Unit && diag != blas.NonUnit {
38                 panic(badDiag)
39         }
40         upper := uplo == blas.Upper
41         noTrans := trans == blas.NoTrans
42         nonUnit := diag == blas.NonUnit
43
44         if n < 0 {
45                 panic(nLT0)
46         }
47         checkMatrix(n, n, a, lda)
48         checkVector(n, x, 1)
49         checkVector(n, cnorm, 1)
50
51         if n == 0 {
52                 return 0
53         }
54         smlnum := dlamchS / dlamchP
55         bignum := 1 / smlnum
56         scale = 1
57         bi := blas64.Implementation()
58         if !normin {
59                 if upper {
60                         cnorm[0] = 0
61                         for j := 1; j < n; j++ {
62                                 cnorm[j] = bi.Dasum(j, a[j:], lda)
63                         }
64                 } else {
65                         for j := 0; j < n-1; j++ {
66                                 cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
67                         }
68                         cnorm[n-1] = 0
69                 }
70         }
71         // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
72         imax := bi.Idamax(n, cnorm, 1)
73         tmax := cnorm[imax]
74         var tscal float64
75         if tmax <= bignum {
76                 tscal = 1
77         } else {
78                 tscal = 1 / (smlnum * tmax)
79                 bi.Dscal(n, tscal, cnorm, 1)
80         }
81
82         // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
83         j := bi.Idamax(n, x, 1)
84         xmax := math.Abs(x[j])
85         xbnd := xmax
86         var grow float64
87         var jfirst, jlast, jinc int
88         if noTrans {
89                 if upper {
90                         jfirst = n - 1
91                         jlast = -1
92                         jinc = -1
93                 } else {
94                         jfirst = 0
95                         jlast = n
96                         jinc = 1
97                 }
98                 // Compute the growth in A * x = b.
99                 if tscal != 1 {
100                         grow = 0
101                         goto Solve
102                 }
103                 if nonUnit {
104                         grow = 1 / math.Max(xbnd, smlnum)
105                         xbnd = grow
106                         for j := jfirst; j != jlast; j += jinc {
107                                 if grow <= smlnum {
108                                         goto Solve
109                                 }
110                                 tjj := math.Abs(a[j*lda+j])
111                                 xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
112                                 if tjj+cnorm[j] >= smlnum {
113                                         grow *= tjj / (tjj + cnorm[j])
114                                 } else {
115                                         grow = 0
116                                 }
117                         }
118                         grow = xbnd
119                 } else {
120                         grow = math.Min(1, 1/math.Max(xbnd, smlnum))
121                         for j := jfirst; j != jlast; j += jinc {
122                                 if grow <= smlnum {
123                                         goto Solve
124                                 }
125                                 grow *= 1 / (1 + cnorm[j])
126                         }
127                 }
128         } else {
129                 if upper {
130                         jfirst = 0
131                         jlast = n
132                         jinc = 1
133                 } else {
134                         jfirst = n - 1
135                         jlast = -1
136                         jinc = -1
137                 }
138                 if tscal != 1 {
139                         grow = 0
140                         goto Solve
141                 }
142                 if nonUnit {
143                         grow = 1 / (math.Max(xbnd, smlnum))
144                         xbnd = grow
145                         for j := jfirst; j != jlast; j += jinc {
146                                 if grow <= smlnum {
147                                         goto Solve
148                                 }
149                                 xj := 1 + cnorm[j]
150                                 grow = math.Min(grow, xbnd/xj)
151                                 tjj := math.Abs(a[j*lda+j])
152                                 if xj > tjj {
153                                         xbnd *= tjj / xj
154                                 }
155                         }
156                         grow = math.Min(grow, xbnd)
157                 } else {
158                         grow = math.Min(1, 1/math.Max(xbnd, smlnum))
159                         for j := jfirst; j != jlast; j += jinc {
160                                 if grow <= smlnum {
161                                         goto Solve
162                                 }
163                                 xj := 1 + cnorm[j]
164                                 grow /= xj
165                         }
166                 }
167         }
168
169 Solve:
170         if grow*tscal > smlnum {
171                 // Use the Level 2 BLAS solve if the reciprocal of the bound on
172                 // elements of X is not too small.
173                 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
174                 if tscal != 1 {
175                         bi.Dscal(n, 1/tscal, cnorm, 1)
176                 }
177                 return scale
178         }
179
180         // Use a Level 1 BLAS solve, scaling intermediate results.
181         if xmax > bignum {
182                 scale = bignum / xmax
183                 bi.Dscal(n, scale, x, 1)
184                 xmax = bignum
185         }
186         if noTrans {
187                 for j := jfirst; j != jlast; j += jinc {
188                         xj := math.Abs(x[j])
189                         var tjj, tjjs float64
190                         if nonUnit {
191                                 tjjs = a[j*lda+j] * tscal
192                         } else {
193                                 tjjs = tscal
194                                 if tscal == 1 {
195                                         goto Skip1
196                                 }
197                         }
198                         tjj = math.Abs(tjjs)
199                         if tjj > smlnum {
200                                 if tjj < 1 {
201                                         if xj > tjj*bignum {
202                                                 rec := 1 / xj
203                                                 bi.Dscal(n, rec, x, 1)
204                                                 scale *= rec
205                                                 xmax *= rec
206                                         }
207                                 }
208                                 x[j] /= tjjs
209                                 xj = math.Abs(x[j])
210                         } else if tjj > 0 {
211                                 if xj > tjj*bignum {
212                                         rec := (tjj * bignum) / xj
213                                         if cnorm[j] > 1 {
214                                                 rec /= cnorm[j]
215                                         }
216                                         bi.Dscal(n, rec, x, 1)
217                                         scale *= rec
218                                         xmax *= rec
219                                 }
220                                 x[j] /= tjjs
221                                 xj = math.Abs(x[j])
222                         } else {
223                                 for i := 0; i < n; i++ {
224                                         x[i] = 0
225                                 }
226                                 x[j] = 1
227                                 xj = 1
228                                 scale = 0
229                                 xmax = 0
230                         }
231                 Skip1:
232                         if xj > 1 {
233                                 rec := 1 / xj
234                                 if cnorm[j] > (bignum-xmax)*rec {
235                                         rec *= 0.5
236                                         bi.Dscal(n, rec, x, 1)
237                                         scale *= rec
238                                 }
239                         } else if xj*cnorm[j] > bignum-xmax {
240                                 bi.Dscal(n, 0.5, x, 1)
241                                 scale *= 0.5
242                         }
243                         if upper {
244                                 if j > 0 {
245                                         bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
246                                         i := bi.Idamax(j, x, 1)
247                                         xmax = math.Abs(x[i])
248                                 }
249                         } else {
250                                 if j < n-1 {
251                                         bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
252                                         i := j + bi.Idamax(n-j-1, x[j+1:], 1)
253                                         xmax = math.Abs(x[i])
254                                 }
255                         }
256                 }
257         } else {
258                 for j := jfirst; j != jlast; j += jinc {
259                         xj := math.Abs(x[j])
260                         uscal := tscal
261                         rec := 1 / math.Max(xmax, 1)
262                         var tjjs float64
263                         if cnorm[j] > (bignum-xj)*rec {
264                                 rec *= 0.5
265                                 if nonUnit {
266                                         tjjs = a[j*lda+j] * tscal
267                                 } else {
268                                         tjjs = tscal
269                                 }
270                                 tjj := math.Abs(tjjs)
271                                 if tjj > 1 {
272                                         rec = math.Min(1, rec*tjj)
273                                         uscal /= tjjs
274                                 }
275                                 if rec < 1 {
276                                         bi.Dscal(n, rec, x, 1)
277                                         scale *= rec
278                                         xmax *= rec
279                                 }
280                         }
281                         var sumj float64
282                         if uscal == 1 {
283                                 if upper {
284                                         sumj = bi.Ddot(j, a[j:], lda, x, 1)
285                                 } else if j < n-1 {
286                                         sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
287                                 }
288                         } else {
289                                 if upper {
290                                         for i := 0; i < j; i++ {
291                                                 sumj += (a[i*lda+j] * uscal) * x[i]
292                                         }
293                                 } else if j < n {
294                                         for i := j + 1; i < n; i++ {
295                                                 sumj += (a[i*lda+j] * uscal) * x[i]
296                                         }
297                                 }
298                         }
299                         if uscal == tscal {
300                                 x[j] -= sumj
301                                 xj := math.Abs(x[j])
302                                 var tjjs float64
303                                 if nonUnit {
304                                         tjjs = a[j*lda+j] * tscal
305                                 } else {
306                                         tjjs = tscal
307                                         if tscal == 1 {
308                                                 goto Skip2
309                                         }
310                                 }
311                                 tjj := math.Abs(tjjs)
312                                 if tjj > smlnum {
313                                         if tjj < 1 {
314                                                 if xj > tjj*bignum {
315                                                         rec = 1 / xj
316                                                         bi.Dscal(n, rec, x, 1)
317                                                         scale *= rec
318                                                         xmax *= rec
319                                                 }
320                                         }
321                                         x[j] /= tjjs
322                                 } else if tjj > 0 {
323                                         if xj > tjj*bignum {
324                                                 rec = (tjj * bignum) / xj
325                                                 bi.Dscal(n, rec, x, 1)
326                                                 scale *= rec
327                                                 xmax *= rec
328                                         }
329                                         x[j] /= tjjs
330                                 } else {
331                                         for i := 0; i < n; i++ {
332                                                 x[i] = 0
333                                         }
334                                         x[j] = 1
335                                         scale = 0
336                                         xmax = 0
337                                 }
338                         } else {
339                                 x[j] = x[j]/tjjs - sumj
340                         }
341                 Skip2:
342                         xmax = math.Max(xmax, math.Abs(x[j]))
343                 }
344         }
345         scale /= tscal
346         if tscal != 1 {
347                 bi.Dscal(n, 1/tscal, cnorm, 1)
348         }
349         return scale
350 }