OSDN Git Service

git-svn-id: svn+ssh://svn.sourceforge.jp/svnroot/simplenn/trunk@8 dd34cd95-496f-4e97...
[simplenn/repo.git] / simplenn / src / main / java / jp / gr / java_conf / u6k / simplenn / SimpleNNApplet.java
1 \r
2 package jp.gr.java_conf.u6k.simplenn;\r
3 \r
4 import java.applet.Applet;\r
5 import java.awt.Button;\r
6 import java.awt.Color;\r
7 import java.awt.Graphics;\r
8 import java.awt.event.ActionEvent;\r
9 import java.awt.event.ActionListener;\r
10 import java.awt.event.MouseEvent;\r
11 import java.awt.event.MouseListener;\r
12 import java.awt.event.MouseMotionListener;\r
13 import java.io.BufferedReader;\r
14 import java.io.IOException;\r
15 import java.io.InputStreamReader;\r
16 \r
17 /**\r
18  * <p>\r
19  * バックプロパゲーション法で学習するニューラル・ネットワークのデモ・アプレットです。このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。\r
20  * </p>\r
21  * \r
22  * @version $Id$\r
23  * @see http://codezine.jp/a/article/aid/372.aspx\r
24  */\r
25 public class SimpleNNApplet extends Applet implements MouseListener, MouseMotionListener, ActionListener {\r
26 \r
27     Button    button1, button2, button3, button4;\r
28 \r
29     int       X0           = 10, X1 = 125;\r
30 \r
31     int       Y0           = 55, Y1 = 70, Y2 = 160, Y3 = 240, Y4 = 305;\r
32 \r
33     int       RX0          = 30, RX1 = 60, RX2 = 210, RX3 = 260;\r
34 \r
35     int       RY0          = 225, RY1 = 240;\r
36 \r
37     int       WIDTH        = 7;                                        // 入力データの幅\r
38 \r
39     int       HEIGHT       = 11;                                       // 入力データの高さ\r
40 \r
41     int       INPUT        = WIDTH * HEIGHT;                           // 入力層の数(入力データ数)\r
42 \r
43     int       HIDDEN       = 16;                                       // 隠れ層の数\r
44 \r
45     int       PATTERN      = 10;                                       // パターンの種類\r
46 \r
47     int       OUTPUT       = PATTERN;                                  // 出力層の数(出力データ数)\r
48 \r
49     int       OUTER_CYCLES = 100;                                      // 外部サイクル(一連のパターンの繰返し学習)の回数\r
50 \r
51     int       INNER_CYCLES = 100;                                      // 内部サイクル(同一パターンの繰返し学習)の回数\r
52 \r
53     float     ALPHA        = 1.2f;                                     // 学習の加速係数\r
54 \r
55     float     BETA         = 1.2f;                                     // シグモイド曲線の傾斜\r
56 \r
57     int[]     sample_in    = new int[INPUT];                           // 学習用入力\r
58 \r
59     int[]     written_in   = new int[INPUT];                           // 認識用手書き入力\r
60 \r
61     float[][] weight_ih    = new float[INPUT][HIDDEN];                 // 入力層と隠れ層の間の重み係数\r
62 \r
63     float[]   thresh_h     = new float[HIDDEN];                        // 隠れ層の閾値\r
64 \r
65     float[]   hidden_out   = new float[HIDDEN];                        // 隠れ層の出力\r
66 \r
67     float[][] weight_ho    = new float[HIDDEN][OUTPUT];                // 隠れ層と出力層の間の重み係数\r
68 \r
69     float[]   thresh_o     = new float[OUTPUT];                        // 出力層の閾値\r
70 \r
71     float[]   recog_out    = new float[OUTPUT];                        // 認識出力(出力層の出力)\r
72 \r
73     int[]     teach        = new int[PATTERN];                         // 教師信号\r
74 \r
75     boolean   learning_flag;                                           // 「学習モード」フラグ\r
76 \r
77     int[][]   sample_array;                                            // 学習用入力データの基となるパターン\r
78 \r
79     int[][]   teach_array  = new int[PATTERN][OUTPUT];                 // パターンと出力すべき教師信号の比較表\r
80 \r
81     int       x_new, y_new, x_old, y_old;                              // 手書き文字入力用座標\r
82 \r
83     public void init() {\r
84         // 学習用入力データの元となるパターンの読み込み\r
85         this.sample_array = new int[10][WIDTH * HEIGHT];\r
86         for (int i = 0; i <= 9; i++) {\r
87             BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream(i + ".txt")));\r
88             try {\r
89                 try {\r
90                     int j = 0;\r
91                     String line;\r
92                     while ((line = r.readLine()) != null) {\r
93                         for (char c : line.toCharArray()) {\r
94                             this.sample_array[i][j] = Integer.parseInt(Character.toString(c));\r
95                             j++;\r
96                         }\r
97                     }\r
98                 } finally {\r
99                     r.close();\r
100                 }\r
101             } catch (IOException e) {\r
102                 throw new RuntimeException(e);\r
103             }\r
104         }\r
105 \r
106         setBackground(Color.gray);\r
107 \r
108         // ボタンの設定\r
109         add(button1 = new Button("  再学習  "));\r
110         add(button2 = new Button(" 学習終了 "));\r
111         add(button3 = new Button("入力クリヤ"));\r
112         add(button4 = new Button("  認  識  "));\r
113         button1.addActionListener(this);\r
114         button2.addActionListener(this);\r
115         button3.addActionListener(this);\r
116         button4.addActionListener(this);\r
117 \r
118         // マウスの設定\r
119         addMouseListener(this);\r
120         addMouseMotionListener(this);\r
121 \r
122         // 教師信号の設定\r
123         for (int q = 0; q < PATTERN; q++)\r
124             for (int k = 0; k < OUTPUT; k++) {\r
125                 if (q == k)\r
126                     teach_array[q][k] = 1;\r
127                 else\r
128                     teach_array[q][k] = 0;\r
129             }\r
130 \r
131         // モードの初期設定\r
132         learning_flag = true;\r
133 \r
134     }\r
135 \r
136     // ------------------- ボタン関係のメソッド ------------------\r
137 \r
138     public void actionPerformed(ActionEvent ae) {\r
139 \r
140         if (ae.getSource() == button1) { // 「再学習」\r
141             learning_flag = true;\r
142             repaint();\r
143         }\r
144         if (ae.getSource() == button2) { // 「学習終了」\r
145             learning_flag = false;\r
146             repaint();\r
147         }\r
148         if (ae.getSource() == button3) { // 「入力クリヤ」\r
149             if (!learning_flag)\r
150                 repaint();\r
151         }\r
152         if (ae.getSource() == button4) { // 「認識」\r
153             if (!learning_flag)\r
154                 recognizeCharacter();\r
155         }\r
156 \r
157     }\r
158 \r
159     // ---------- マウス関係のメソッド(手書き文字入力)----------\r
160 \r
161     public void mousePressed(MouseEvent me) {\r
162         int x = me.getX();\r
163         int y = me.getY();\r
164         if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {\r
165             x_old = me.getX();\r
166             y_old = me.getY();\r
167             written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;\r
168         }\r
169     }\r
170 \r
171     public void mouseClicked(MouseEvent me) {\r
172     }\r
173 \r
174     public void mouseEntered(MouseEvent me) {\r
175     }\r
176 \r
177     public void mouseExited(MouseEvent me) {\r
178     }\r
179 \r
180     public void mouseReleased(MouseEvent me) {\r
181     }\r
182 \r
183     public void mouseDragged(MouseEvent me) {\r
184         int x = me.getX();\r
185         int y = me.getY();\r
186         if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {\r
187             Graphics g = getGraphics();\r
188             x_new = me.getX();\r
189             y_new = me.getY();\r
190             g.drawLine(x_old, y_old, x_new, y_new);\r
191             x_old = x_new;\r
192             y_old = y_new;\r
193             written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;\r
194         }\r
195 \r
196     }\r
197 \r
198     public void mouseMoved(MouseEvent me) {\r
199     }\r
200 \r
201     // ---------- 起動時およびrepaint()で呼び出されるメソッド ----------\r
202 \r
203     public void paint(Graphics g) {\r
204 \r
205         int i, j, k, p, q, r, x;\r
206 \r
207         String string;\r
208 \r
209         float outer_error; // 外部サイクルエラー累計\r
210         float inner_error; // 内部サイクルエラー累計\r
211         float temp_error; // 隠れ層の誤差の累計\r
212 \r
213         // 学習モードの背景\r
214         if (learning_flag) {\r
215             g.setColor(new Color(255, 255, 192));\r
216             g.fillRect(5, 35, 590, 460);\r
217             g.setColor(Color.black);\r
218             g.drawString("学習モード", 500, 55);\r
219         }\r
220 \r
221         // 認識モードの背景\r
222         else {\r
223             g.setColor(new Color(192, 255, 255));\r
224             g.fillRect(5, 35, 590, 460);\r
225             g.setColor(Color.black);\r
226             g.drawString("認識モード", 500, 55);\r
227         }\r
228 \r
229         // 学習用パターンの表示\r
230         g.drawString("使用している学習用パターン", X0, Y0);\r
231         for (q = 0; q < PATTERN; q++) {\r
232             x = 56 * q;\r
233             for (j = 0; j < HEIGHT; j++)\r
234                 for (i = 0; i < WIDTH; i++) {\r
235                     if (sample_array[q][WIDTH * j + i] == 1)\r
236                         g.setColor(Color.red);\r
237                     else\r
238                         g.setColor(Color.cyan);\r
239                     g.fillRect(X0 + x + 6 * i, Y1 + 6 * j, 5, 5);\r
240                 }\r
241         }\r
242         g.setColor(Color.black);\r
243 \r
244         // -------------------------------------------------------------------\r
245         // --------------------------- 学習モード ----------------------------\r
246         // -------------------------------------------------------------------\r
247         if (learning_flag) {\r
248 \r
249             // 閾値と重みの乱数設定\r
250             for (j = 0; j < HIDDEN; j++) {\r
251                 thresh_h[j] = (float) Math.random() - 0.5f;\r
252                 for (i = 0; i < INPUT; i++)\r
253                     weight_ih[i][j] = (float) Math.random() - 0.5f;\r
254             }\r
255             for (k = 0; k < OUTPUT; k++) {\r
256                 thresh_o[k] = (float) Math.random() - 0.5f;\r
257                 for (j = 0; j < HIDDEN; j++)\r
258                     weight_ho[j][k] = (float) Math.random() - 0.5f;\r
259             }\r
260 \r
261             // -------------------------- 学習 --------------------------\r
262 \r
263             for (p = 0; p < OUTER_CYCLES; p++) { // 外部サイクル\r
264 \r
265                 outer_error = 0.0f; // 外部二乗誤差のクリヤー\r
266 \r
267                 for (q = 0; q < PATTERN; q++) { // パターンの切り替え\r
268 \r
269                     // パターンに対応した入力と教師信号の設定\r
270                     sample_in = sample_array[q];\r
271                     teach = teach_array[q];\r
272 \r
273                     for (r = 0; r < INNER_CYCLES; r++) { // 内部サイクル\r
274 \r
275                         // 順方向演算\r
276                         forwardNeuralNet(sample_in, recog_out);\r
277 \r
278                         // 逆方向演算(バックプロパゲーション)\r
279                         backwardNeuralNet();\r
280 \r
281                     }\r
282 \r
283                     // 内部二乗誤差の計算\r
284                     inner_error = 0.0f; // 内部二乗誤差のクリヤー\r
285                     for (k = 0; k < OUTPUT; k++)\r
286                         inner_error += (teach[k] - recog_out[k]) * (teach[k] - recog_out[k]);\r
287 \r
288                     outer_error += inner_error; // 外部二乗誤差への累加算\r
289 \r
290                 }\r
291 \r
292                 // 外部サイクルの回数と外部二乗誤差の表示\r
293                 g.drawString("実行中の外部サイクルの回数と二乗誤差", X0, Y2);\r
294                 g.setColor(new Color(255, 255, 192));\r
295                 g.fillRect(X0 + 5, Y2 + 10, 200, 50); // 以前の表示を消去\r
296                 g.setColor(Color.black);\r
297                 g.drawString("OuterCycles=" + String.valueOf(p), X0 + 10, Y2 + 25);\r
298                 g.drawString("TotalSquaredError=" + String.valueOf(outer_error), X0 + 10, Y2 + 45);\r
299 \r
300             }\r
301 \r
302             // --------------------- 学習結果の確認 ---------------------\r
303 \r
304             g.drawString("学習結果の確認", X0, Y3);\r
305             for (k = 0; k < OUTPUT; k++) {\r
306                 g.drawString("Output", X1 + 45 * k, Y3 + 25);\r
307                 g.drawString("  [" + String.valueOf(k) + "]", X1 + 5 + 45 * k, Y3 + 40);\r
308             }\r
309 \r
310             for (q = 0; q < PATTERN; q++) {\r
311 \r
312                 // 入力パターンの設定\r
313                 sample_in = sample_array[q];\r
314 \r
315                 // 順方向演算\r
316                 forwardNeuralNet(sample_in, recog_out);\r
317 \r
318                 // 結果の表示\r
319                 g.setColor(Color.black);\r
320                 g.drawString("TestPattern[" + String.valueOf(q) + "]", X0 + 10, Y4 + 20 * q);\r
321                 for (k = 0; k < OUTPUT; k++) {\r
322                     if (recog_out[k] > 0.99) { // 99% より大は、赤で YES と表示\r
323                         g.setColor(Color.red);\r
324                         string = "YES";\r
325                     } else if (recog_out[k] < 0.01) { // 1% より小は、青で NO と表示\r
326                         g.setColor(Color.blue);\r
327                         string = "NO ";\r
328                     } else { // 1% 以上 99% 以下は、黒で ? と表示\r
329                         g.setColor(Color.black);\r
330                         string = " ? ";\r
331                     }\r
332                     g.drawString(string, X1 + 10 + 45 * k, Y4 + 20 * q);\r
333                 }\r
334 \r
335             }\r
336         }\r
337 \r
338         // -------------------------------------------------------------------\r
339         // --------------------------- 認識モード ----------------------------\r
340         // -------------------------------------------------------------------\r
341         else {\r
342             g.setColor(Color.black);\r
343             g.drawString("マウスで数字を描いて下さい", RX0, RY0);\r
344             g.drawRect(RX1 - 1, RY1 - 1, WIDTH * 10 + 2, HEIGHT * 10 + 2); // 外枠\r
345             g.setColor(Color.gray);\r
346             for (j = 1; j < HEIGHT; j++)\r
347                 g.drawLine(RX1, RY1 + 10 * j, RX1 + WIDTH * 10, RY1 + 10 * j); // 横方向区切り\r
348             for (i = 1; i < WIDTH; i++)\r
349                 g.drawLine(RX1 + 10 * i, RY1, RX1 + 10 * i, RY1 + HEIGHT * 10); // 縦方向区切り\r
350             for (i = 0; i < INPUT; i++)\r
351                 written_in[i] = 0; // 手書き入力データのクリヤ\r
352         }\r
353 \r
354     }\r
355 \r
356     // 順方向演算のメソッド\r
357     public void forwardNeuralNet(int[] input, float[] output) {\r
358 \r
359         float[] out = new float[OUTPUT];\r
360         float[] hidden = new float[HIDDEN];\r
361 \r
362         // 隠れ層出力の計算\r
363         for (int j = 0; j < HIDDEN; j++) {\r
364             hidden[j] = -thresh_h[j];\r
365             for (int i = 0; i < INPUT; i++)\r
366                 hidden[j] += input[i] * weight_ih[i][j];\r
367             hidden_out[j] = sigmoid(hidden[j]);\r
368         }\r
369 \r
370         // 出力層出力の計算\r
371         for (int k = 0; k < OUTPUT; k++) {\r
372             out[k] = -thresh_o[k];\r
373             for (int j = 0; j < HIDDEN; j++)\r
374                 out[k] += hidden_out[j] * weight_ho[j][k];\r
375             output[k] = sigmoid(out[k]);\r
376         }\r
377 \r
378     }\r
379 \r
380     // 逆方向演算のメソッド\r
381     public void backwardNeuralNet() {\r
382 \r
383         int i, j, k;\r
384 \r
385         float[] output_error = new float[OUTPUT]; // 出力層の誤差\r
386         float[] hidden_error = new float[HIDDEN]; // 隠れ層の誤差\r
387 \r
388         float temp_error;\r
389 \r
390         // 出力層の誤差の計算\r
391         for (k = 0; k < OUTPUT; k++)\r
392             output_error[k] = (teach[k] - recog_out[k]) * recog_out[k] * (1.0f - recog_out[k]);\r
393 \r
394         // 隠れ層の誤差の計算\r
395         for (j = 0; j < HIDDEN; j++) {\r
396             temp_error = 0.0f;\r
397             for (k = 0; k < OUTPUT; k++)\r
398                 temp_error += output_error[k] * weight_ho[j][k];\r
399             hidden_error[j] = hidden_out[j] * (1.0f - hidden_out[j]) * temp_error;\r
400         }\r
401 \r
402         // 重みの補正\r
403         for (k = 0; k < OUTPUT; k++)\r
404             for (j = 0; j < HIDDEN; j++)\r
405                 weight_ho[j][k] += ALPHA * output_error[k] * hidden_out[j];\r
406         for (j = 0; j < HIDDEN; j++)\r
407             for (i = 0; i < INPUT; i++)\r
408                 weight_ih[i][j] += ALPHA * hidden_error[j] * sample_in[i];\r
409 \r
410         // 閾値の補正\r
411         for (k = 0; k < OUTPUT; k++)\r
412             thresh_o[k] -= ALPHA * output_error[k];\r
413         for (j = 0; j < HIDDEN; j++)\r
414             thresh_h[j] -= ALPHA * hidden_error[j];\r
415 \r
416     }\r
417 \r
418     // Sigmoid関数を計算するメソッド\r
419     public float sigmoid(float x) {\r
420 \r
421         return 1.0f / (1.0f + (float) Math.exp(-BETA * x));\r
422 \r
423     }\r
424 \r
425     // 入力文字を認識するメソッド\r
426     public void recognizeCharacter() {\r
427 \r
428         Graphics g = getGraphics();\r
429         String string;\r
430 \r
431         // 順方向演算\r
432         forwardNeuralNet(written_in, recog_out);\r
433 \r
434         // 結果の表示\r
435         for (int k = 0; k < OUTPUT; k++) {\r
436             g.setColor(Color.black);\r
437             g.drawString(String.valueOf(k) + "である", RX2, RY1 + 20 * k);\r
438             if (recog_out[k] > 0.8f)\r
439                 g.setColor(Color.red);\r
440             else\r
441                 g.setColor(Color.black);\r
442 \r
443             g.fillRect(RX3, RY1 - 10 + 20 * k, (int) (200 * recog_out[k]), 10);\r
444             g.drawString(String.valueOf((int) (100 * recog_out[k] + 0.5f)) + "%", RX3 + (int) (200 * recog_out[k]) + 10, RY1 + 20 * k);\r
445         }\r
446 \r
447     }\r
448 \r
449 }\r