OSDN Git Service

SVNから移行
[nmecab/NMeCabRepo2.git] / trunk / src / LibNMeCab / Core / Viterbi.cs
1 //  MeCab -- Yet Another Part-of-Speech and Morphological Analyzer
2 //
3 //  Copyright(C) 2001-2006 Taku Kudo <taku@chasen.org>
4 //  Copyright(C) 2004-2006 Nippon Telegraph and Telephone Corporation
5 using System;
6 using System.Collections.Generic;
7 using System.Text;
8 using System.Text.RegularExpressions;
9
10 namespace NMeCab.Core
11 {
12     public class Viterbi : IDisposable
13     {
14         #region InnerClass
15
16         private class ThreadData
17         {
18             public MeCabNode EosNode;
19             public MeCabNode BosNode;
20             public MeCabNode[] EndNodeList;
21             public MeCabNode[] BeginNodeList;
22             public float Z;
23         }
24
25         #endregion
26
27         #region Field/Property
28
29         private readonly Tokenizer tokenizer = new Tokenizer();
30         private readonly Connector connector = new Connector();
31
32         private MeCabLatticeLevel level;
33         private float theta;
34         private int costFactor;
35
36         public float Theta
37         {
38             get { return this.theta * this.costFactor; }
39             set { this.theta = value / this.costFactor; }
40         }
41
42         public unsafe MeCabLatticeLevel LatticeLevel
43         {
44             get
45             {
46                 return this.level;
47             }
48             set
49             {
50                 this.level = value;
51                 this.connect = this.ConnectNomal;
52                 this.analyze = this.DoViterbi;
53                 if (value >= MeCabLatticeLevel.One)
54                     this.connect = this.ConnectWithAllPath;
55                 if (value >= MeCabLatticeLevel.Two)
56                     this.analyze = this.ForwardBackward;
57             }
58         }
59
60         public bool Partial { get; set; }
61
62         public bool AllMorphs
63         {
64             get
65             {
66                 return this.buildLattice == this.BuildAllLattice;
67             }
68             set
69             {
70                 if (value)
71                     this.buildLattice = this.BuildAllLattice;
72                 else
73                     this.buildLattice = this.BuildBestLattice;
74             }
75         }
76
77         #endregion
78
79         #region Open/Clear
80
81         public void Open(MeCabParam param)
82         {
83             tokenizer.Open(param);
84             connector.Open(param);
85
86             this.costFactor = param.CostFactor;
87             this.Theta = param.Theta;
88             this.LatticeLevel = param.LatticeLevel;
89             this.Partial = param.Partial;
90             this.AllMorphs = param.AllMorphs;
91         }
92
93 #if NeedId
94         public void Clear()
95         {
96             this.tokenizer.Clear();
97         }
98 #endif
99
100         #endregion
101
102         #region AnalyzeStart
103
104         public unsafe MeCabNode Analyze(char* str, int len)
105         {
106 #if NeedId
107             this.Clear();
108 #endif
109
110             ThreadData work = new ThreadData()
111             {
112                 EndNodeList = new MeCabNode[len + 4],
113                 BeginNodeList = new MeCabNode[len + 4]
114             };
115
116             if (this.Partial)
117             {
118                 string newStr = this.InitConstraints(str, len, work);
119                 fixed (char* pNewStr = newStr)
120                 {
121                     this.analyze(pNewStr, newStr.Length, work);
122                     return this.buildLattice(work);
123                 }
124             }
125
126             this.analyze(str, len, work);
127             return this.buildLattice(work);
128         }
129
130         #endregion
131
132         #region Analyze
133
134         private unsafe delegate void AnalyzeAction(char* str, int len, ThreadData work);
135
136         private AnalyzeAction analyze;
137
138         private unsafe void ForwardBackward(char* sentence, int len, ThreadData work)
139         {
140             this.DoViterbi(sentence, len, work);
141
142             work.EndNodeList[0].Alpha = 0f;
143             for (int pos = 0; pos <= len; pos++)
144                 for (MeCabNode node = work.BeginNodeList[pos]; node != null; node = node.BNext)
145                     this.CalcAlpha(node, this.theta);
146
147             work.BeginNodeList[len].Beta = 0f;
148             for (int pos = len; pos >= 0; pos--)
149                 for (MeCabNode node = work.EndNodeList[pos]; node != null; node = node.ENext)
150                     this.CalcBeta(node, this.theta);
151
152             work.Z = work.BeginNodeList[len].Alpha; // alpha of EOS
153
154             for (int pos = 0; pos <= len; pos++)
155                 for (MeCabNode node = work.BeginNodeList[pos]; node != null; node = node.BNext)
156                     node.Prob = (float)Math.Exp(node.Alpha + node.Beta - work.Z);
157
158         }
159
160         private void CalcAlpha(MeCabNode n, double beta)
161         {
162             n.Alpha = 0f;
163             for (MeCabPath path = n.LPath; path != null; path = path.LNext)
164             {
165                 n.Alpha = (float)Utils.LogSumExp(n.Alpha,
166                                                  -beta * path.Cost + path.LNode.Alpha,
167                                                  path == n.LPath);
168             }
169         }
170
171         private void CalcBeta(MeCabNode n, double beta)
172         {
173             n.Beta = 0f;
174             for (MeCabPath path = n.RPath; path != null; path = path.RNext)
175             {
176                 n.Beta = (float)Utils.LogSumExp(n.Beta,
177                                                 -beta * path.Cost + path.RNode.Beta,
178                                                 path == n.RPath);
179             }
180         }
181
182         private unsafe void DoViterbi(char* sentence, int len, ThreadData work)
183         {
184             work.BosNode = this.tokenizer.GetBosNode();
185             work.BosNode.Length = len;
186
187             char* begin = sentence;
188             char* end = begin + len;
189             work.BosNode.Surface = new string(begin, 0, len);
190             work.EndNodeList[0] = work.BosNode;
191
192             for (int pos = 0; pos < len; pos++)
193             {
194                 if (work.EndNodeList[pos] != null)
195                 {
196                     MeCabNode rNode = tokenizer.Lookup(begin + pos, end);
197                     rNode = this.FilterNode(rNode, pos, work);
198                     rNode.BPos = pos;
199                     rNode.EPos = pos + rNode.RLength;
200                     work.BeginNodeList[pos] = rNode;
201                     this.connect(pos, rNode, work);
202                 }
203             }
204
205             work.EosNode = tokenizer.GetEosNode();
206             work.EosNode.Surface = end->ToString();
207             work.BeginNodeList[len] = work.EosNode;
208             for (int pos = len; pos >= 0; pos--)
209             {
210                 if (work.EndNodeList[pos] != null)
211                 {
212                     this.connect(pos, work.EosNode, work);
213                     break;
214                 }
215             }
216         }
217
218         #endregion
219
220         #region Connect
221
222         private delegate void ConnectAction(int pos, MeCabNode rNode, ThreadData work);
223
224         private ConnectAction connect;
225
226         private void ConnectWithAllPath(int pos, MeCabNode rNode, ThreadData work)
227         {
228             for (; rNode != null; rNode = rNode.BNext)
229             {
230                 long bestCost = int.MaxValue; // 2147483647
231
232                 MeCabNode bestNode = null;
233
234                 for (MeCabNode lNode = work.EndNodeList[pos]; lNode != null; lNode = lNode.ENext)
235                 {
236                     int lCost = this.connector.Cost(lNode, rNode); // local cost
237                     long cost = lNode.Cost + lCost;
238
239                     if (cost < bestCost)
240                     {
241                         bestNode = lNode;
242                         bestCost = cost;
243                     }
244
245                     MeCabPath path = new MeCabPath()
246                     {
247                         Cost = lCost,
248                         RNode = rNode,
249                         LNode = lNode,
250                         LNext = rNode.LPath,
251                         RNext = lNode.RPath
252                     };
253                     rNode.LPath = path;
254                     lNode.RPath = path;
255                 }
256
257                 if (bestNode == null) throw new ArgumentException("too long sentence.");
258
259                 rNode.Prev = bestNode;
260                 rNode.Next = null;
261                 rNode.Cost = bestCost;
262                 int x = rNode.RLength + pos;
263                 rNode.ENext = work.EndNodeList[x];
264                 work.EndNodeList[x] = rNode;
265             }
266         }
267
268         private void ConnectNomal(int pos, MeCabNode rNode, ThreadData work)
269         {
270             for (; rNode != null; rNode = rNode.BNext)
271             {
272                 long bestCost = int.MaxValue; // 2147483647
273
274                 MeCabNode bestNode = null;
275
276                 for (MeCabNode lNode = work.EndNodeList[pos]; lNode != null; lNode = lNode.ENext)
277                 {
278                     long cost = lNode.Cost + this.connector.Cost(lNode, rNode);
279
280                     if (cost < bestCost)
281                     {
282                         bestNode = lNode;
283                         bestCost = cost;
284                     }
285                 }
286
287                 if (bestNode == null) throw new MeCabException("too long sentence.");
288
289                 rNode.Prev = bestNode;
290                 rNode.Next = null;
291                 rNode.Cost = bestCost;
292                 int x = rNode.RLength + pos;
293                 rNode.ENext = work.EndNodeList[x];
294                 work.EndNodeList[x] = rNode;
295             }
296         }
297
298         #endregion
299
300         #region Lattice
301
302         private delegate MeCabNode BuildLatticeFunc(ThreadData work);
303
304         private BuildLatticeFunc buildLattice;
305
306         private MeCabNode BuildAllLattice(ThreadData work)
307         {
308             if (this.BuildBestLattice(work) == null) return null;
309
310             MeCabNode prev = work.BosNode;
311
312             for (int pos = 0; pos < work.BeginNodeList.Length; pos++)
313             {
314                 for (MeCabNode node = work.BeginNodeList[pos]; node != null; node = node.BNext)
315                 {
316                     prev.Next = node;
317                     node.Prev = prev;
318                     prev = node;
319                     for (MeCabPath path = node.LPath; path != null; path = path.LNext)
320                     {
321                         path.Prob = (float)(path.LNode.Alpha
322                                             - this.theta * path.Cost
323                                             + path.RNode.Beta - work.Z);
324                     }
325                 }
326             }
327
328             return work.BosNode;
329         }
330
331         private MeCabNode BuildBestLattice(ThreadData work)
332         {
333             MeCabNode node = work.EosNode;
334             for (MeCabNode prevNode; node.Prev != null; )
335             {
336                 node.IsBest = true;
337                 prevNode = node.Prev;
338                 prevNode.Next = node;
339                 node = prevNode;
340             }
341             return work.BosNode;
342         }
343
344         #endregion
345
346         #region Partial
347
348         private unsafe string InitConstraints(char* sentence, int sentenceLen, ThreadData work)
349         {
350             string str = new string(sentence, 0, sentenceLen);
351             StringBuilder os = new StringBuilder();
352             os.Append(' ');
353             int pos = 0;
354
355             foreach (string line in str.Split('\r', '\n'))
356             {
357                 if (line == "") continue;
358                 if (line == "EOS") break;
359
360                 string[] column = line.Split('\t');
361                 os.Append(column[0]).Append(' ');
362                 int len = column[0].Length;
363
364                 if (column.Length == 2)
365                 {
366                     if (column[1] == "\0") throw new ArgumentException("use \\t as separator");
367                     MeCabNode c = this.tokenizer.GetNewNode();
368                     c.Surface = column[0];
369                     c.Feature = column[1];
370                     c.Length = len;
371                     c.RLength = len + 1;
372                     c.BNext = null;
373                     c.WCost = 0;
374                     work.BeginNodeList[pos] = c;
375                 }
376
377                 pos += len + 1;
378             }
379
380             return os.ToString();
381         }
382
383         private MeCabNode FilterNode(MeCabNode node, int pos, ThreadData work)
384         {
385             if (!this.Partial) return node;
386
387             MeCabNode c = work.BeginNodeList[pos];
388             if (c == null) return node;
389             bool wild = (c.Feature == "*");
390
391             MeCabNode prev = null;
392             MeCabNode result = null;
393
394             for (MeCabNode n = node; n != null; n = n.BNext)
395             {
396                 if (c.Surface == n.Surface
397                     && (wild || this.PartialMatch(c.Feature, n.Feature)))
398                 {
399                     if (prev != null)
400                     {
401                         prev.BNext = n;
402                         prev = n;
403                     }
404                     else
405                     {
406                         result = n;
407                         prev = result;
408                     }
409                 }
410             }
411             if (result == null) result = c;
412             if (prev != null) prev.BNext = null;
413
414             return result;
415         }
416
417         private bool PartialMatch(string f1, string f2)
418         {
419             string[] c1 = f1.Split(',');
420             string[] c2 = f2.Split(',');
421
422             int n = Math.Min(c1.Length, c2.Length);
423
424             for (int i = 0; i < n; i++)
425                 if (c1[i] != "*" && c2[i] != "*" && c1[i] != c2[i]) return false;
426
427             return true;
428         }
429
430         #endregion
431
432         #region Dispose
433
434         private bool disposed;
435
436         /// <summary>
437         /// 使用中のリソースを開放する
438         /// </summary>
439         public void Dispose()
440         {
441             this.Dispose(true);
442             GC.SuppressFinalize(this);
443         }
444
445         protected virtual void Dispose(bool disposing)
446         {
447             if (disposed) return;
448
449             if (disposing)
450             {
451                 this.tokenizer.Dispose(); //Nullチェック不要
452                 this.connector.Dispose(); //Nullチェック不要
453             }
454
455             this.disposed = true;
456         }
457
458         ~Viterbi()
459         {
460             this.Dispose(false);
461         }
462
463         #endregion
464     }
465 }