OSDN Git Service

Initial commit
[wordring-tm/wordring-tm.git] / third_party / mecab-0.996 / src / lbfgs.cpp
1 //   MeCab: Yet Another Part-of-Speech and Morphological Analyzer
2 //
3 //
4 //   lbfgs.c was ported from the FORTRAN code of lbfgs.m to C
5 //   using f2c converter
6 //
7 //   http://www.ece.northwestern.edu/~nocedal/lbfgs.html
8 //
9 //   Software for Large-scale Unconstrained Optimization
10 //   L-BFGS is a limited-memory quasi-Newton code for unconstrained
11 //   optimization.
12 //   The code has been developed at the Optimization Technology Center,
13 //   a joint venture of Argonne National Laboratory and Northwestern University.
14 //
15 //   Authors
16 //   Jorge Nocedal
17 //
18 //   References
19 //   - J. Nocedal. Updating Quasi-Newton Matrices with Limited Storage(1980),
20 //   Mathematics of Computation 35, pp. 773-782.
21 //   - D.C. Liu and J. Nocedal. On the Limited Memory Method for
22 //   Large Scale Optimization(1989),
23 //   Mathematical Programming B, 45, 3, pp. 503-528.
24 #include <cmath>
25 #include <iostream>
26 #include <numeric>
27 #include "lbfgs.h"
28 #include "common.h"
29
30 namespace {
31 static const double ftol = 1e-4;
32 static const double xtol = 1e-16;
33 static const double eps  = 1e-7;
34 static const double lb3_1_gtol = 0.9;
35 static const double lb3_1_stpmin = 1e-20;
36 static const double lb3_1_stpmax = 1e20;
37 static const int lb3_1_mp = 6;
38 static const int lb3_1_lp = 6;
39
40 inline double sigma(double x) {
41   if (x > 0) {
42     return 1.0;
43   } else if (x < 0) {
44     return -1.0;
45   }
46   return 0.0;
47 }
48
49 inline double pi(double x, double y) {
50   return sigma(x) == sigma(y) ? x : 0.0;
51 }
52
53 inline void daxpy_(int n, double da, const double *dx, double *dy) {
54   for (int i = 0; i < n; ++i) {
55     dy[i] += da * dx[i];
56   }
57 }
58
59 inline double ddot_(int size, const double *dx, const double *dy) {
60   return std::inner_product(dx, dx + size, dy, 0.0);
61 }
62
63 void mcstep(double *stx, double *fx, double *dx,
64             double *sty, double *fy, double *dy,
65             double *stp, double fp, double dp,
66             int *brackt,
67             double stpmin, double stpmax,
68             int *info) {
69   bool bound = true;
70   double p, q, s, d1, d2, d3, r, gamma, theta, stpq, stpc, stpf;
71   *info = 0;
72
73   if (*brackt && ((*stp <= std::min(*stx, *sty) ||
74                    *stp >= std::max(*stx, *sty)) ||
75                   *dx * (*stp - *stx) >= 0.0 || stpmax < stpmin)) {
76     return;
77   }
78
79   double sgnd = dp * (*dx / std::abs(*dx));
80
81   if (fp > *fx) {
82     *info = 1;
83     bound = true;
84     theta =(*fx - fp) * 3 / (*stp - *stx) + *dx + dp;
85     d1 = std::abs(theta);
86     d2 = std::abs(*dx);
87     d1 = std::max(d1, d2);
88     d2 = std::abs(dp);
89     s = std::max(d1, d2);
90     d1 = theta / s;
91     gamma = s * std::sqrt(d1 * d1 - *dx / s *(dp / s));
92     if (*stp < *stx) {
93       gamma = -gamma;
94     }
95     p = gamma - *dx + theta;
96     q = gamma - *dx + gamma + dp;
97     r = p / q;
98     stpc = *stx + r * (*stp - *stx);
99     stpq = *stx + *dx / ((*fx - fp) /
100                          (*stp - *stx) + *dx) / 2 * (*stp - *stx);
101     if ((d1 = stpc - *stx, std::abs(d1)) < (d2 = stpq - *stx, std::abs(d2))) {
102       stpf = stpc;
103     } else {
104       stpf = stpc + (stpq - stpc) / 2;
105     }
106     *brackt = true;
107   } else if (sgnd < 0.0) {
108     *info = 2;
109     bound = false;
110     theta = (*fx - fp) * 3 / (*stp - *stx) + *dx + dp;
111     d1 = std::abs(theta);
112     d2 = std::abs(*dx);
113     d1 = std::max(d1, d2);
114     d2 = std::abs(dp);
115     s = std::max(d1, d2);
116     d1 = theta / s;
117     gamma = s * std::sqrt(d1 * d1 - *dx / s * (dp / s));
118     if (*stp > *stx) {
119       gamma = -gamma;
120     }
121     p = gamma - dp + theta;
122     q = gamma - dp + gamma + *dx;
123     r = p / q;
124     stpc = *stp + r *(*stx - *stp);
125     stpq = *stp + dp /(dp - *dx) * (*stx - *stp);
126     if ((d1 = stpc - *stp, std::abs(d1)) > (d2 = stpq - *stp, std::abs(d2))) {
127       stpf = stpc;
128     } else {
129       stpf = stpq;
130     }
131     *brackt = true;
132   } else if (std::abs(dp) < std::abs(*dx)) {
133     *info = 3;
134     bound = true;
135     theta = (*fx - fp) * 3 / (*stp - *stx) + *dx + dp;
136     d1 = std::abs(theta);
137     d2 = std::abs(*dx);
138     d1 = std::max(d1, d2);
139     d2 = std::abs(dp);
140     s = std::max(d1, d2);
141     d3 = theta / s;
142     d1 = 0.0;
143     d2 = d3 * d3 - *dx / s *(dp / s);
144     gamma = s * std::sqrt((std::max(d1, d2)));
145     if (*stp > *stx) {
146       gamma = -gamma;
147     }
148     p = gamma - dp + theta;
149     q = gamma + (*dx - dp) + gamma;
150     r = p / q;
151     if (r < 0.0 && gamma != 0.0) {
152       stpc = *stp + r *(*stx - *stp);
153     } else if (*stp > *stx) {
154       stpc = stpmax;
155     } else {
156       stpc = stpmin;
157     }
158     stpq = *stp + dp /(dp - *dx) * (*stx - *stp);
159     if (*brackt) {
160       if ((d1 = *stp - stpc, std::abs(d1)) <
161           (d2 = *stp - stpq, std::abs(d2))) {
162         stpf = stpc;
163       } else {
164         stpf = stpq;
165       }
166     } else {
167       if ((d1 = *stp - stpc, std::abs(d1)) >
168           (d2 = *stp - stpq, std::abs(d2))) {
169         stpf = stpc;
170       } else {
171         stpf = stpq;
172       }
173     }
174   } else {
175     *info = 4;
176     bound = false;
177     if (*brackt) {
178       theta =(fp - *fy) * 3 / (*sty - *stp) + *dy + dp;
179       d1 = std::abs(theta);
180       d2 = std::abs(*dy);
181       d1 = std::max(d1, d2);
182       d2 = std::abs(dp);
183       s = std::max(d1, d2);
184       d1 = theta / s;
185       gamma = s * std::sqrt(d1 * d1 - *dy / s * (dp / s));
186       if (*stp > *sty) {
187         gamma = -gamma;
188       }
189       p = gamma - dp + theta;
190       q = gamma - dp + gamma + *dy;
191       r = p / q;
192       stpc = *stp + r * (*sty - *stp);
193       stpf = stpc;
194     } else if (*stp > *stx) {
195       stpf = stpmax;
196     } else {
197       stpf = stpmin;
198     }
199   }
200
201   if (fp > *fx) {
202     *sty = *stp;
203     *fy = fp;
204     *dy = dp;
205   } else {
206     if (sgnd < 0.0) {
207       *sty = *stx;
208       *fy = *fx;
209       *dy = *dx;
210     }
211     *stx = *stp;
212     *fx = fp;
213     *dx = dp;
214   }
215
216   stpf = std::min(stpmax, stpf);
217   stpf = std::max(stpmin, stpf);
218   *stp = stpf;
219   if (*brackt && bound) {
220     if (*sty > *stx) {
221       d1 = *stx + (*sty - *stx) * 0.66;
222       *stp = std::min(d1, *stp);
223     } else {
224       d1 = *stx + (*sty - *stx) * 0.66;
225       *stp = std::max(d1, *stp);
226     }
227   }
228
229   return;
230 }
231 }
232
233 namespace MeCab {
234
235 class LBFGS::Mcsrch {
236  private:
237   int infoc, stage1, brackt;
238   double finit, dginit, dgtest, width, width1;
239   double stx, fx, dgx, sty, fy, dgy, stmin, stmax;
240
241  public:
242   Mcsrch():
243       infoc(0),
244       stage1(0),
245       brackt(0),
246       finit(0.0), dginit(0.0), dgtest(0.0), width(0.0), width1(0.0),
247       stx(0.0), fx(0.0), dgx(0.0), sty(0.0), fy(0.0), dgy(0.0),
248       stmin(0.0), stmax(0.0) {}
249
250   void mcsrch(int size,
251               double *x,
252               double f, const double *g, double *s,
253               double *stp,
254               int *info, int *nfev, double *wa, bool orthant, double C) {
255     const double p5 = 0.5;
256     const double p66 = 0.66;
257     const double xtrapf = 4.0;
258     const int maxfev = 20;
259
260     /* Parameter adjustments */
261     --wa;
262     --s;
263     --g;
264     --x;
265
266     if (*info == -1) {
267       goto L45;
268     }
269     infoc = 1;
270
271     if (size <= 0 || *stp <= 0.0) {
272       return;
273     }
274
275     dginit = ddot_(size, &g[1], &s[1]);
276     if (dginit >= 0.0) {
277       return;
278     }
279
280     brackt = false;
281     stage1 = true;
282     *nfev = 0;
283     finit = f;
284     dgtest = ftol * dginit;
285     width = lb3_1_stpmax - lb3_1_stpmin;
286     width1 = width / p5;
287     for (int j = 1; j <= size; ++j) {
288       wa[j] = x[j];
289     }
290
291     stx = 0.0;
292     fx = finit;
293     dgx = dginit;
294     sty = 0.0;
295     fy = finit;
296     dgy = dginit;
297
298     while (true) {
299       if (brackt) {
300         stmin = std::min(stx, sty);
301         stmax = std::max(stx, sty);
302       } else {
303         stmin = stx;
304         stmax = *stp + xtrapf * (*stp - stx);
305       }
306
307       *stp = std::max(*stp, lb3_1_stpmin);
308       *stp = std::min(*stp, lb3_1_stpmax);
309
310       if ((brackt && ((*stp <= stmin || *stp >= stmax) ||
311                       *nfev >= maxfev - 1 || infoc == 0)) ||
312           (brackt && (stmax - stmin <= xtol * stmax))) {
313         *stp = stx;
314       }
315
316       if (orthant) {
317         for (int j = 1; j <= size; ++j) {
318           double grad_neg = 0.0;
319           double grad_pos = 0.0;
320           double grad = 0.0;
321           if (wa[j] == 0.0) {
322             grad_neg = g[j] - 1.0 / C;
323             grad_pos = g[j] + 1.0 / C;
324           } else {
325             grad_pos = grad_neg = g[j] + 1.0 * sigma(wa[j]) / C;
326           }
327           if (grad_neg > 0.0) {
328             grad = grad_neg;
329           } else if (grad_pos < 0.0) {
330             grad = grad_pos;
331           } else {
332             grad = 0.0;
333           }
334           const double p = pi(s[j], -grad);
335           const double xi = wa[j] == 0.0 ? sigma(-grad) : sigma(wa[j]);
336           x[j] = pi(wa[j] + *stp * p, xi);
337         }
338       } else {
339         for (int j = 1; j <= size; ++j) {
340           x[j] = wa[j] + *stp * s[j];
341         }
342       }
343       *info = -1;
344       return;
345
346    L45:
347       *info = 0;
348       ++(*nfev);
349       double dg = ddot_(size, &g[1], &s[1]);
350       double ftest1 = finit + *stp * dgtest;
351
352       if (brackt && ((*stp <= stmin || *stp >= stmax) || infoc == 0)) {
353         *info = 6;
354       }
355       if (*stp == lb3_1_stpmax && f <= ftest1 && dg <= dgtest) {
356         *info = 5;
357       }
358       if (*stp == lb3_1_stpmin && (f > ftest1 || dg >= dgtest)) {
359         *info = 4;
360       }
361       if (*nfev >= maxfev) {
362         *info = 3;
363       }
364       if (brackt && stmax - stmin <= xtol * stmax) {
365         *info = 2;
366       }
367       if (f <= ftest1 && std::abs(dg) <= lb3_1_gtol * (-dginit)) {
368         *info = 1;
369       }
370
371       if (*info != 0) {
372         return;
373       }
374
375       if (stage1 && f <= ftest1 && dg >= std::min(ftol, lb3_1_gtol) * dginit) {
376         stage1 = false;
377       }
378
379       if (stage1 && f <= fx && f > ftest1) {
380         double fm = f - *stp * dgtest;
381         double fxm = fx - stx * dgtest;
382         double fym = fy - sty * dgtest;
383         double dgm = dg - dgtest;
384         double dgxm = dgx - dgtest;
385         double dgym = dgy - dgtest;
386         mcstep(&stx, &fxm, &dgxm, &sty, &fym, &dgym, stp, fm, dgm, &brackt,
387                stmin, stmax, &infoc);
388         fx = fxm + stx * dgtest;
389         fy = fym + sty * dgtest;
390         dgx = dgxm + dgtest;
391         dgy = dgym + dgtest;
392       } else {
393         mcstep(&stx, &fx, &dgx, &sty, &fy, &dgy, stp, f, dg, &brackt,
394                stmin, stmax, &infoc);
395       }
396
397       if (brackt) {
398         double d1 = 0.0;
399         if ((d1 = sty - stx, std::abs(d1)) >= p66 * width1) {
400           *stp = stx + p5 * (sty - stx);
401         }
402         width1 = width;
403         width = (d1 = sty - stx, std::abs(d1));
404       }
405     }
406
407     return;
408   }
409 };
410
411 void LBFGS::clear() {
412   iflag_ = iscn = nfev = iycn = point = npt =
413       iter = info = ispt = isyt = iypt = 0;
414   stp = stp1 = 0.0;
415   diag_.clear();
416   w_.clear();
417   delete mcsrch_;
418   mcsrch_ = 0;
419 }
420
421 void LBFGS::lbfgs_optimize(int size,
422                            int msize,
423                            double *x,
424                            double f,
425                            const double *g,
426                            double *diag,
427                            double *w,
428                            bool orthant,
429                            double C,
430                            int *iflag) {
431   double yy = 0.0;
432   double ys = 0.0;
433   int bound = 0;
434   int cp = 0;
435
436   --diag;
437   --g;
438   --x;
439   --w;
440
441   if (!mcsrch_) {
442     mcsrch_ = new Mcsrch;
443   }
444
445   if (*iflag == 1) {
446     goto L172;
447   }
448   if (*iflag == 2) {
449     goto L100;
450   }
451
452   // initialization
453   if (*iflag == 0) {
454     point = 0;
455     for (int i = 1; i <= size; ++i) {
456       diag[i] = 1.0;
457     }
458     ispt = size + (msize << 1);
459     iypt = ispt + size * msize;
460     for (int i = 1; i <= size; ++i) {
461       w[ispt + i] = -g[i] * diag[i];
462     }
463     stp1 = 1.0 / std::sqrt(ddot_(size, &g[1], &g[1]));
464   }
465
466   // MAIN ITERATION LOOP
467   while (true) {
468     ++iter;
469     info = 0;
470     if (iter == 1) goto L165;
471     if (iter > size) bound = size;
472
473     // COMPUTE -H*G USING THE FORMULA GIVEN IN: Nocedal, J. 1980,
474     // "Updating quasi-Newton matrices with limited storage",
475     // Mathematics of Computation, Vol.24, No.151, pp. 773-782.
476     ys = ddot_(size, &w[iypt + npt + 1], &w[ispt + npt + 1]);
477     yy = ddot_(size, &w[iypt + npt + 1], &w[iypt + npt + 1]);
478     for (int i = 1; i <= size; ++i) {
479       diag[i] = ys / yy;
480     }
481
482  L100:
483     cp = point;
484     if (point == 0) cp = msize;
485     w[size + cp] = 1.0 / ys;
486
487     for (int i = 1; i <= size; ++i) {
488       w[i] = -g[i];
489     }
490
491     bound = std::min(iter - 1, msize);
492
493     cp = point;
494     for (int i = 1; i <= bound; ++i) {
495       --cp;
496       if (cp == -1) cp = msize - 1;
497       double sq = ddot_(size, &w[ispt + cp * size + 1], &w[1]);
498       int inmc = size + msize + cp + 1;
499       iycn = iypt + cp * size;
500       w[inmc] = w[size + cp + 1] * sq;
501       double d = -w[inmc];
502       daxpy_(size, d, &w[iycn + 1], &w[1]);
503     }
504
505     for (int i = 1; i <= size; ++i) {
506       w[i] = diag[i] * w[i];
507     }
508
509     for (int i = 1; i <= bound; ++i) {
510       double yr = ddot_(size, &w[iypt + cp * size + 1], &w[1]);
511       double beta = w[size + cp + 1] * yr;
512       int inmc = size + msize + cp + 1;
513       beta = w[inmc] - beta;
514       iscn = ispt + cp * size;
515       daxpy_(size, beta, &w[iscn + 1], &w[1]);
516       ++cp;
517       if (cp == msize) {
518         cp = 0;
519       }
520     }
521
522     // STORE THE NEW SEARCH DIRECTION
523     for (int i = 1; i <= size; ++i) {
524       w[ispt + point * size + i] = w[i];
525     }
526
527  L165:
528     // OBTAIN THE ONE-DIMENSIONAL MINIMIZER OF THE FUNCTION
529     // BY USING THE LINE SEARCH ROUTINE MCSRCH
530     nfev = 0;
531     stp = 1.0;
532     if (iter == 1) {
533       stp = stp1;
534     }
535     for (int i = 1; i <= size; ++i) {
536       w[i] = g[i];
537     }
538
539  L172:
540     mcsrch_->mcsrch(size, &x[1], f, &g[1], &w[ispt + point * size + 1],
541                     &stp, &info, &nfev, &diag[1], orthant, C);
542     if (info == -1) {
543       *iflag = 1;  // next value
544       return;
545     }
546     if (info != 1) {
547       std::cerr << "The line search routine mcsrch failed: error code:"
548                 << info << std::endl;
549       *iflag = -1;
550       return;
551     }
552
553     // COMPUTE THE NEW STEP AND GRADIENT CHANGE
554     npt = point * size;
555     for (int i = 1; i <= size; ++i) {
556       w[ispt + npt + i] = stp * w[ispt + npt + i];
557       w[iypt + npt + i] = g[i] - w[i];
558     }
559     ++point;
560     if (point == msize) {
561       point = 0;
562     }
563
564     double gnorm = std::sqrt(ddot_(size, &g[1], &g[1]));
565     double xnorm = std::max(1.0, std::sqrt(ddot_(size, &x[1], &x[1])));
566     if (gnorm / xnorm <= eps) {
567       *iflag = 0;  // OK terminated
568       return;
569     }
570   }
571 }
572 }