OSDN Git Service

openJtalk結合
[nlite/nlite.git] / open_jtalk_lib / open_jtalk / mecab / src / learner.cpp
1 //  MeCab -- Yet Another Part-of-Speech and Morphological Analyzer
2 //
3 //
4 //  Copyright(C) 2001-2006 Taku Kudo <taku@chasen.org>
5 //  Copyright(C) 2004-2006 Nippon Telegraph and Telephone Corporation
6 #include <vector>
7 #include <string>
8 #include <fstream>
9 #include "param.h"
10 #include "common.h"
11 #include "lbfgs.h"
12 #include "utils.h"
13 #include "thread.h"
14 #include "learner_tagger.h"
15 #include "freelist.h"
16 #include "feature_index.h"
17 #include "string_buffer.h"
18
19 namespace {
20 double toLogProb(double f1, double f2) {
21   return std::log(1.0 * f1 / f2) - VERY_SMALL_LOGPROB;  // avoid 0
22 }
23 }
24
25 namespace MeCab {
26
27 #define DCONF(file) create_filename(dicdir, std::string(file)).c_str()
28
29 class HMMLearner {
30  public:
31   static int run(Param *param) {
32     DictionaryRewriter rewrite;
33
34
35     const std::string dicdir = param->get<std::string>("dicdir");
36     CHECK_DIE(param->load(DCONF(DICRC)))
37         << "no such file or directory: " << DCONF(DICRC);
38
39     CHECK_DIE(rewrite.open(DCONF(REWRITE_FILE)))
40         << "no such file or directory: " << DCONF(REWRITE_FILE);
41
42     const std::vector<std::string> files = param->rest_args();
43     if (files.size() != 2) {
44       std::cout << "Usage: " <<
45           param->program_name() << " corpus model" << std::endl;
46       return -1;
47     }
48
49     const std::string ifile = files[0];
50     const std::string model = files[1];
51
52     const bool text_only = param->get<bool>("text-only");
53     const bool em_hmm = param->get<bool>("em-hmm");
54     const std::string bos_feature = param->get<std::string>("bos-feature");
55
56     CHECK_DIE(!bos_feature.empty()) << "bos-feature is empty";
57
58     char line[BUF_SIZE];
59     char *col[8];
60     std::string word, feature;
61     std::string ufeature, lfeature, rfeature;
62     std::string plfeature, prfeature;
63     std::map<std::string, std::map<std::string, double> > emission;
64     std::map<std::string, std::map<std::string, double> > transition;
65
66     // corpus
67     if (em_hmm) {
68       std::ifstream ifs(ifile.c_str());
69       CHECK_DIE(ifs) << "no such file or directory: " << ifile;
70       size_t size = 0;
71       while (ifs.getline(line, sizeof(line))) {
72         if (std::strcmp("EOS", line) == 0) {
73           if (++size % 100 == 0)
74             std::cout << size << "... " << std::flush;
75           continue;
76         }
77
78         CHECK_DIE(tokenize(line, "\t", col, 4) == 4)
79             << "format error\n";
80         CHECK_DIE(std::strcmp("B", col[0]) == 0 &&
81                   std::strcmp("U", col[0]) == 0)
82             << "format error\n";
83         if (col[0][0] == 'B') {  // bigram
84           feature = col[1];
85           CHECK_DIE(rewrite.rewrite(feature,
86                                     &ufeature,
87                                     &lfeature,
88                                     &rfeature))
89               << "rewrite failed";
90           prfeature = rfeature;
91           feature = col[2];
92           CHECK_DIE(rewrite.rewrite(feature,
93                                     &ufeature,
94                                     &lfeature,
95                                     &rfeature))
96               << "rewrite failed";
97           plfeature = lfeature;
98           transition[prfeature][plfeature] += std::atof(col[3]);
99         } else {   // unigram
100           feature = col[2];
101           CHECK_DIE(rewrite.rewrite(feature,
102                                     &ufeature,
103                                     &lfeature,
104                                     &rfeature))
105               << "rewrite failed";
106           std::strncpy(line, ufeature.c_str(), sizeof(line));
107           size_t n = tokenize2(line, "\t ", col, 2);
108           CHECK_DIE(n == 2) << "format error in rewrite.def: " << ufeature;
109           ufeature = col[0];
110           word = col[1];
111           emission[ufeature][word] += atof(col[3]);
112         }
113         ++size;
114       }
115     } else {
116       std::ifstream ifs(ifile.c_str());
117       CHECK_DIE(ifs) << "no such file or directory: " << ifile;
118
119       CHECK_DIE(rewrite.rewrite(bos_feature,
120                                 &ufeature,
121                                 &plfeature,
122                                 &prfeature)) << "rewrite failed";
123
124       size_t size = 0;
125       while (ifs.getline(line, sizeof(line))) {
126         if (std::strcmp("EOS", line) == 0) {
127           if (++size % 100 == 0)
128             std::cout << size << "... " << std::flush;
129           feature = bos_feature;
130         } else {
131           CHECK_DIE(tokenize(line, "\t", col, 2) == 2)
132               << "format error\n";
133           feature = col[1];
134         }
135
136         CHECK_DIE(rewrite.rewrite(feature,
137                                   &ufeature,
138                                   &lfeature,
139                                   &rfeature))
140             << "rewrite failed";
141
142         std::strncpy(line, ufeature.c_str(), sizeof(line));
143         // unigram rule must contain ' '
144         const size_t n = tokenize2(line, "\t ", col, 2);
145         CHECK_DIE(n == 2) << "format error in rewrite.def: " << ufeature;
146         ufeature = col[0];
147         word = col[1];
148         transition[prfeature][lfeature] += 1.0;
149         emission[ufeature][word] += 1.0;
150         plfeature = lfeature;
151         prfeature = rfeature;
152       }
153     }
154
155     // dictionary
156     {
157       std::vector<std::string> dic;
158       enum_csv_dictionaries(dicdir.c_str(), &dic);
159
160       const double freq = param->get<double>("default-emission-freq");
161       CHECK_DIE(freq >= 0.0) << " default-emission-freq must be >= 0 "
162                              << freq;
163
164       for (std::vector<std::string>::const_iterator it = dic.begin();
165            it != dic.end(); ++it) {
166         std::cout << "reading " << *it << " ... " << std::flush;
167
168         std::ifstream ifs(it->c_str());
169         CHECK_DIE(ifs) << "no such file or directory: " << *it;
170
171         while (ifs.getline(line, sizeof(line))) {
172           CHECK_DIE(tokenizeCSV(line, col, 5) == 5) << "format error";
173           feature = col[4];
174           CHECK_DIE(rewrite.rewrite(feature,
175                                     &ufeature,
176                                     &lfeature,
177                                     &rfeature)) << "rewrite failed";
178           std::strncpy(line, ufeature.c_str(), sizeof(line));
179           const size_t n = tokenize2(line, "\t ", col, 2);
180           CHECK_DIE(n == 2) << "format error: " << ufeature;
181           ufeature = col[0];
182           word = col[1];
183           emission[ufeature][word] += freq;
184         }
185
186         std::cout << std::endl;
187       }
188     }
189
190     {
191       std::cout << std::endl;
192       std::string txtfile = model;
193       txtfile += ".txt";
194
195       std::ofstream ofs(txtfile.c_str());
196       CHECK_DIE(ofs) << "permission denied: " << model;
197
198       ofs.setf(std::ios::fixed, std::ios::floatfield);
199       ofs.precision(24);
200
201       // bigram
202       for (std::map<std::string, std::map<std::string, double> >
203                ::const_iterator
204                it = transition.begin();
205            it != transition.end(); ++it) {
206         double freq = 0.0;
207         for (std::map<std::string, double>::
208                  const_iterator it2 = it->second.begin();
209              it2 != it->second.end(); ++it2) {
210           freq += it2->second;
211         }
212
213         for (std::map<std::string, double>
214                  ::const_iterator it2 = it->second.begin();
215              it2 != it->second.end(); ++it2)
216           ofs << toLogProb(it2->second, freq) << '\t'
217               << 'B' << ':' << it->first << '/' << it2->first << std::endl;
218       }
219
220       // unigram
221       for (std::map<std::string, std::map<std::string, double> >
222                ::const_iterator
223                it = emission.begin();
224            it != emission.end(); ++it) {
225         double freq = 0.0;
226         for (std::map<std::string, double>
227                  ::const_iterator it2 = it->second.begin();
228              it2 != it->second.end(); ++it2)
229           freq += it2->second;
230
231         for (std::map<std::string, double>
232                  ::const_iterator it2 = it->second.begin();
233              it2 != it->second.end(); ++it2) {
234           std::string w = it2->first;
235           CHECK_DIE(escape_csv_element(&w));
236           ofs << toLogProb(it2->second, freq) << '\t'
237               << 'U' << ':' << it->first << ' ' << w << std::endl;
238         }
239       }
240
241       ofs.close();
242
243       if (!text_only) {
244         EncoderFeatureIndex feature_index;
245         CHECK_DIE(feature_index.convert(txtfile.c_str(), model.c_str()))
246             << "unexpected error in LBFGS routin";
247       }
248
249       std::cout << "Done!" << std::endl;
250     }
251
252     return 0;
253   }
254 };
255
256 class OLLearner {
257  public:
258   static int run(Param *param) {
259     const std::string dicdir = param->get<std::string>("dicdir");
260     CHECK_DIE(param->load(DCONF(DICRC)))
261         << "no such file or directory: " << DCONF(DICRC);
262
263     const std::vector<std::string> files = param->rest_args();
264     if (files.size() != 2) {
265       std::cout << "Usage: " <<
266           param->program_name() << " corpus model" << std::endl;
267       return -1;
268     }
269
270     const std::string ifile = files[0];
271     const std::string model = files[1];
272
273     const double C = param->get<double>("cost");
274     const bool   text_only = param->get<bool>("text-only");
275     const size_t eval_size = param->get<size_t>("eval-size");
276     const size_t unk_eval_size = param->get<size_t>("unk-eval-size");
277     const size_t iter = param->get<size_t>("iteration");
278     const size_t freq = param->get<size_t>("freq");
279
280     EncoderFeatureIndex feature_index;
281     LearnerTokenizer tokenizer;
282     FreeList<LearnerPath> path_freelist(PATH_FREELIST_SIZE);
283     std::vector<double> expected;
284     std::vector<double> observed;
285     std::vector<double> alpha;
286
287     std::cout.setf(std::ios::fixed, std::ios::floatfield);
288     std::cout.precision(5);
289
290     {
291       CHECK_DIE(C > 0) << "cost parameter is out of range: " << C;
292       CHECK_DIE(eval_size > 0) << "eval-size is out of range: " << eval_size;
293       CHECK_DIE(unk_eval_size > 0) <<
294           "unk-eval-size is out of range: " << unk_eval_size;
295       CHECK_DIE(tokenizer.open(*param)) << tokenizer.what();
296       CHECK_DIE(feature_index.open(*param)) << feature_index.what();
297       CHECK_DIE(iter >= 1 && iter <= 100) << "iteration should be <= 100";
298       CHECK_DIE(freq == 1) << "freq must be 1";
299     }
300
301     std::cout << "reading corpus ..." << std::flush;
302
303     EncoderLearnerTagger x;
304     for (size_t i = 0; i < 10; ++i) {
305       std::ifstream ifs(ifile.c_str());
306       CHECK_DIE(ifs) << "no such file or directory: " << ifile;
307       while (ifs) {
308         path_freelist.free();
309         tokenizer.clear();
310         std::fill(expected.begin(), expected.end(), 0.0);
311         std::fill(observed.begin(), observed.end(), 0.0);
312
313         CHECK_DIE(x.open(&tokenizer,
314                          &path_freelist,
315                          &feature_index,
316                          eval_size,
317                          unk_eval_size)) << x.what();
318         CHECK_DIE(x.read(&ifs, &observed)) << x.what();
319
320         if (x.empty()) {
321           continue;
322         }
323
324         alpha.resize(feature_index.size());
325         expected.resize(feature_index.size());
326         observed.resize(feature_index.size());
327         feature_index.set_alpha(&alpha[0]);
328
329         x.online_update(&expected[0]);
330
331         size_t micro_p = 0;
332         size_t micro_r = 0;
333         size_t micro_c = 0;
334         size_t err = x.eval(&micro_c, &micro_p, &micro_r);
335         std::cout << micro_p << " " << micro_r << " " << micro_c << " " << err << std::endl;
336
337         // gradient
338         double margin = 0.0;
339         double s = 0.0;
340         for (size_t k = 0; k < feature_index.size(); ++k) {
341           const double tmp = (observed[k] - expected[k]);
342           margin += alpha[k] * tmp;
343           s += tmp * tmp;
344         }
345
346         // Passive Aggressive I algorithm
347         if (s > 0.0) {
348           const double diff = _max(0.0, 10 - margin) / s;
349           if (diff > 0.0) {
350             for (size_t k = 0; k < feature_index.size(); ++k) {
351               alpha[k] += diff * (observed[k] - expected[k]);
352             }
353           }
354         }
355       }
356     }
357
358     std::cout << "\nDone! writing model file ... " << std::endl;
359
360     std::string txtfile = model;
361     txtfile += ".txt";
362
363     CHECK_DIE(feature_index.save(txtfile.c_str()))
364         << feature_index.what();
365
366     if (!text_only) {
367       CHECK_DIE(feature_index.convert(txtfile.c_str(), model.c_str()))
368           << feature_index.what();
369     }
370
371     return true;
372   }
373 };
374
375
376 #ifdef MECAB_USE_THREAD
377 class learner_thread: public thread {
378  public:
379   unsigned short start_i;
380   unsigned short thread_num;
381   size_t size;
382   size_t micro_p;
383   size_t micro_r;
384   size_t micro_c;
385   size_t err;
386   double f;
387   EncoderLearnerTagger **x;
388   std::vector<double> expected;
389   void run() {
390     micro_p = micro_r = micro_c = err = 0;
391     f = 0.0;
392     std::fill(expected.begin(), expected.end(), 0.0);
393     for (size_t i = start_i; i < size; i += thread_num) {
394       f += x[i]->gradient(&expected[0]);
395       err += x[i]->eval(&micro_c, &micro_p, &micro_r);
396     }
397   }
398 };
399 #endif
400
401 class CRFLearner {
402  public:
403   static int run(Param *param) {
404     const std::string dicdir = param->get<std::string>("dicdir");
405     CHECK_DIE(param->load(DCONF(DICRC)))
406         << "no such file or directory: " << DCONF(DICRC);
407
408     const std::vector<std::string> &files = param->rest_args();
409     if (files.size() != 2) {
410       std::cout << "Usage: " <<
411           param->program_name() << " corpus model" << std::endl;
412       return -1;
413     }
414
415     const std::string ifile = files[0];
416     const std::string model = files[1];
417
418     const double C = param->get<double>("cost");
419     const double eta = param->get<double>("eta");
420     const bool text_only = param->get<bool>("text-only");
421     const size_t eval_size = param->get<size_t>("eval-size");
422     const size_t unk_eval_size = param->get<size_t>("unk-eval-size");
423     const size_t freq = param->get<size_t>("freq");
424     const size_t thread_num = param->get<size_t>("thread");
425
426     EncoderFeatureIndex feature_index;
427     LearnerTokenizer tokenizer;
428     FreeList<LearnerPath> path_freelist(PATH_FREELIST_SIZE);
429     std::vector<double> expected;
430     std::vector<double> observed;
431     std::vector<double> alpha;
432     std::vector<EncoderLearnerTagger *> x_;
433
434     std::cout.setf(std::ios::fixed, std::ios::floatfield);
435     std::cout.precision(5);
436
437     std::ifstream ifs(ifile.c_str());
438     {
439       CHECK_DIE(C > 0) << "cost parameter is out of range: " << C;
440       CHECK_DIE(eta > 0) "eta is out of range: " << eta;
441       CHECK_DIE(eval_size > 0) << "eval-size is out of range: " << eval_size;
442       CHECK_DIE(unk_eval_size > 0) <<
443           "unk-eval-size is out of range: " << unk_eval_size;
444       CHECK_DIE(freq > 0) <<
445           "freq is out of range: " << unk_eval_size;
446       CHECK_DIE(thread_num > 0 && thread_num <= 512)
447           << "# thread is invalid: " << thread_num;
448       CHECK_DIE(tokenizer.open(*param)) << tokenizer.what();
449       CHECK_DIE(feature_index.open(*param)) << feature_index.what();
450       CHECK_DIE(ifs) << "no such file or directory: " << ifile;
451     }
452
453     std::cout << "reading corpus ..." << std::flush;
454
455     while (ifs) {
456       EncoderLearnerTagger *_x = new EncoderLearnerTagger();
457
458       CHECK_DIE(_x->open(&tokenizer, &path_freelist,
459                          &feature_index,
460                          eval_size,
461                          unk_eval_size))
462           << _x->what();
463
464       CHECK_DIE(_x->read(&ifs, &observed)) << _x->what();
465
466       if (!_x->empty())
467         x_.push_back(_x);
468       else
469         delete _x;
470
471       if (x_.size() % 100 == 0)
472         std::cout << x_.size() << "... " << std::flush;
473     }
474
475     feature_index.shrink(freq, &observed);
476     feature_index.clearcache();
477
478     int converge = 0;
479     double old_f = 0.0;
480     size_t psize = feature_index.size();
481     observed.resize(psize);
482     LBFGS lbfgs;
483
484     alpha.resize(psize);
485     expected.resize(psize);
486     std::fill(alpha.begin(), alpha.end(), 0.0);
487
488     feature_index.set_alpha(&alpha[0]);
489
490     std::cout << std::endl;
491     std::cout << "Number of sentences: " << x_.size() << std::endl;
492     std::cout << "Number of features:  " << psize     << std::endl;
493     std::cout << "eta:                 " << eta       << std::endl;
494     std::cout << "freq:                " << freq      << std::endl;
495 #ifdef MECAB_USE_THREAD
496     std::cout << "threads:             " << thread_num << std::endl;
497 #endif
498     std::cout << "C(sigma^2):          " << C          << std::endl
499               << std::endl;
500
501 #ifdef MECAB_USE_THREAD
502     std::vector<learner_thread> thread;
503     if (thread_num > 1) {
504       thread.resize(thread_num);
505       for (size_t i = 0; i < thread_num; ++i) {
506         thread[i].start_i = i;
507         thread[i].size = x_.size();
508         thread[i].thread_num = thread_num;
509         thread[i].x = &x_[0];
510         thread[i].expected.resize(expected.size());
511       }
512     }
513 #endif
514
515     for (size_t itr = 0; ;  ++itr) {
516       std::fill(expected.begin(), expected.end(), 0.0);
517
518       double f = 0.0;
519       size_t err = 0;
520       size_t micro_p = 0;
521       size_t micro_r = 0;
522       size_t micro_c = 0;
523
524 #ifdef MECAB_USE_THREAD
525       if (thread_num > 1) {
526         for (size_t i = 0; i < thread_num; ++i)
527           thread[i].start();
528
529         for (size_t i = 0; i < thread_num; ++i)
530           thread[i].join();
531
532         for (size_t i = 0; i < thread_num; ++i) {
533           f += thread[i].f;
534           err += thread[i].err;
535           micro_r += thread[i].micro_r;
536           micro_p += thread[i].micro_p;
537           micro_c += thread[i].micro_c;
538           for (size_t k = 0; k < psize; ++k)
539             expected[k] += thread[i].expected[k];
540         }
541       }
542       else
543 #endif
544       {
545         for (size_t i = 0; i < x_.size(); ++i) {
546           f += x_[i]->gradient(&expected[0]);
547           err += x_[i]->eval(&micro_c, &micro_p, &micro_r);
548         }
549       }
550
551       const double p = 1.0 * micro_c / micro_p;
552       const double r = 1.0 * micro_c / micro_r;
553       const double micro_f = 2 * p * r /(p + r);
554
555       for (size_t i = 0; i < psize; ++i) {
556         f += (alpha[i] * alpha[i]/(2.0 * C));
557         expected[i] = expected[i] - observed[i] + alpha[i]/C;
558       }
559
560       double diff = (itr == 0 ? 1.0 : std::fabs(1.0 *(old_f - f) )/old_f);
561       std::cout << "iter="    << itr
562                 << " err="    << 1.0 * err/x_.size()
563                 << " F="      << micro_f
564                 << " target=" << f
565                 << " diff="   << diff << std::endl;
566       old_f = f;
567
568       if (diff < eta)
569         converge++;
570       else
571         converge = 0;
572
573       if (converge == 3)
574         break;  // 3 is ad-hoc
575
576       int ret = lbfgs.optimize(psize, &alpha[0], f, &expected[0], false, C);
577
578       CHECK_DIE(ret > 0) << "unexpected error in LBFGS routin";
579     }
580
581     std::cout << "\nDone! writing model file ... " << std::endl;
582
583     std::string txtfile = model;
584     txtfile += ".txt";
585
586     CHECK_DIE(feature_index.save(txtfile.c_str()))
587         << feature_index.what();
588
589     if (!text_only) {
590       CHECK_DIE(feature_index.convert(txtfile.c_str(), model.c_str()))
591           << feature_index.what();
592     }
593
594     return 0;
595   }
596 };
597
598 class Learner {
599  public:
600   static bool run(int argc, char **argv) {
601     static const MeCab::Option long_options[] = {
602       { "dicdir",   'd',  ".",     "DIR",
603         "set DIR as dicdir(default \".\" )" },
604       { "cost",     'c',  "1.0",   "FLOAT",
605         "set FLOAT for cost C for constraints violatoin" },
606       { "training-algorithm",   'a',  "crf",
607         "(crf|hmm|oll)", "set training algorithm" },
608       { "em-hmm", 'E', 0, 0,       "use EM in HMM training (experimental)" },
609       { "freq",     'f',  "1",     "INT",
610         "set the frequency cut-off (default 1)" },
611       { "default-emission-freq", 'E',  "0.5",     "FLOAT",
612         "set the default emission frequency for HMM (default 0.5)" },
613       { "default-transition-freq", 'T',  "0.5",     "FLOAT",
614         "set the default transition frequency for HMM (default 0.0)" },
615       { "eta",      'e',  "0.001", "DIR",
616         "set FLOAT for tolerance of termination criterion" },
617       { "iteration", 'N', "10",    "INT",
618         "numer of iterations in online learning (default 1)" },
619       { "thread",   'p',  "1",     "INT",    "number of threads(default 1)" },
620       { "build",    'b',  0,  0,   "build binary model from text model"},
621       { "text-only", 'y',  0,  0,   "output text model only" },
622       { "version",  'v',  0,   0,  "show the version and exit"  },
623       { "help",     'h',  0,   0,  "show this help and exit."      },
624       { 0, 0, 0, 0 }
625     };
626
627     Param param;
628
629     if (!param.open(argc, argv, long_options)) {
630       std::cout << param.what() << "\n\n" <<  COPYRIGHT
631                 << "\ntry '--help' for more information." << std::endl;
632       return -1;
633     }
634
635     if (!param.help_version()) {
636       return 0;
637     }
638
639     // build mode
640     {
641       const bool build = param.get<bool>("build");
642       if (build) {
643         const std::vector<std::string> files = param.rest_args();
644         if (files.size() != 2) {
645           std::cout << "Usage: " <<
646               param.program_name() << " corpus model" << std::endl;
647           return -1;
648         }
649         const std::string ifile = files[0];
650         const std::string model = files[1];
651         EncoderFeatureIndex feature_index;
652         CHECK_DIE(feature_index.convert(ifile.c_str(), model.c_str()))
653             << feature_index.what();
654         return 0;
655       }
656     }
657
658     std::string type = param.get<std::string>("training-algorithm");
659     toLower(&type);
660     if (type == "crf") {
661       return CRFLearner::run(&param);
662     } else if (type == "hmm") {
663       return HMMLearner::run(&param);
664     } else if (type == "oll") {
665       return OLLearner::run(&param);
666     } else {
667       std::cerr << "unknown type: " << type << std::endl;
668       return -1;
669     }
670
671     return 0;
672   }
673 };
674 }
675
676 int mecab_cost_train(int argc, char **argv) {
677   return MeCab::Learner::run(argc, argv);
678 }