2 package jp.gr.java_conf.u6k.simplenn;
\r
4 import java.applet.Applet;
\r
5 import java.awt.Button;
\r
6 import java.awt.Color;
\r
7 import java.awt.Graphics;
\r
8 import java.awt.event.ActionEvent;
\r
9 import java.awt.event.ActionListener;
\r
10 import java.awt.event.MouseEvent;
\r
11 import java.awt.event.MouseListener;
\r
12 import java.awt.event.MouseMotionListener;
\r
16 * バックプロパゲーション法で学習するニューラル・ネットワークのデモ・アプレットです。このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。
\r
20 * @see http://codezine.jp/a/article/aid/372.aspx
\r
22 public class SimpleNNApplet extends Applet implements MouseListener, MouseMotionListener, ActionListener {
\r
24 Button button1, button2, button3, button4;
\r
26 int X0 = 10, X1 = 125;
\r
28 int Y0 = 55, Y1 = 70, Y2 = 160, Y3 = 240, Y4 = 305;
\r
30 int RX0 = 30, RX1 = 60, RX2 = 210, RX3 = 260;
\r
32 int RY0 = 225, RY1 = 240;
\r
34 int WIDTH = 7; // 入力データの幅
\r
36 int HEIGHT = 11; // 入力データの高さ
\r
38 int INPUT = WIDTH * HEIGHT; // 入力層の数(入力データ数)
\r
40 int HIDDEN = 16; // 隠れ層の数
\r
42 int PATTERN = 10; // パターンの種類
\r
44 int OUTPUT = PATTERN; // 出力層の数(出力データ数)
\r
46 int OUTER_CYCLES = 100; // 外部サイクル(一連のパターンの繰返し学習)の回数
\r
48 int INNER_CYCLES = 100; // 内部サイクル(同一パターンの繰返し学習)の回数
\r
50 float ALPHA = 1.2f; // 学習の加速係数
\r
52 float BETA = 1.2f; // シグモイド曲線の傾斜
\r
54 int[] sample_in = new int[INPUT]; // 学習用入力
\r
56 int[] written_in = new int[INPUT]; // 認識用手書き入力
\r
58 float[][] weight_ih = new float[INPUT][HIDDEN]; // 入力層と隠れ層の間の重み係数
\r
60 float[] thresh_h = new float[HIDDEN]; // 隠れ層の閾値
\r
62 float[] hidden_out = new float[HIDDEN]; // 隠れ層の出力
\r
64 float[][] weight_ho = new float[HIDDEN][OUTPUT]; // 隠れ層と出力層の間の重み係数
\r
66 float[] thresh_o = new float[OUTPUT]; // 出力層の閾値
\r
68 float[] recog_out = new float[OUTPUT]; // 認識出力(出力層の出力)
\r
70 int[] teach = new int[PATTERN]; // 教師信号
\r
72 boolean learning_flag; // 「学習モード」フラグ
\r
74 // 学習用入力データの基となるパターン
\r
75 int[][] sample_array = { { 0, 0, 1, 1, 1, 0, 0, // '0'
\r
76 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },
\r
78 { 0, 0, 0, 1, 0, 0, 0, // '1'
\r
79 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 },
\r
81 { 0, 0, 1, 1, 1, 0, 0, // '2'
\r
82 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 },
\r
84 { 0, 0, 1, 1, 1, 0, 0, // '3'
\r
85 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },
\r
87 { 0, 0, 0, 0, 1, 0, 0, // '4'
\r
88 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0 },
\r
90 { 1, 1, 1, 1, 1, 1, 1, // '5'
\r
91 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0 },
\r
93 { 0, 0, 0, 0, 1, 1, 0, // '6'
\r
94 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },
\r
96 { 1, 1, 1, 1, 1, 1, 1, // '7'
\r
97 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 },
\r
99 { 0, 0, 1, 1, 1, 0, 0, // '8'
\r
100 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0 },
\r
102 { 0, 1, 1, 1, 1, 1, 0, // '9'
\r
103 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0 } };
\r
105 int[][] teach_array = new int[PATTERN][OUTPUT]; // パターンと出力すべき教師信号の比較表
\r
107 int x_new, y_new, x_old, y_old; // 手書き文字入力用座標
\r
109 public void init() {
\r
111 setBackground(Color.gray);
\r
114 add(button1 = new Button(" 再学習 "));
\r
115 add(button2 = new Button(" 学習終了 "));
\r
116 add(button3 = new Button("入力クリヤ"));
\r
117 add(button4 = new Button(" 認 識 "));
\r
118 button1.addActionListener(this);
\r
119 button2.addActionListener(this);
\r
120 button3.addActionListener(this);
\r
121 button4.addActionListener(this);
\r
124 addMouseListener(this);
\r
125 addMouseMotionListener(this);
\r
128 for (int q = 0; q < PATTERN; q++)
\r
129 for (int k = 0; k < OUTPUT; k++) {
\r
131 teach_array[q][k] = 1;
\r
133 teach_array[q][k] = 0;
\r
137 learning_flag = true;
\r
141 // ------------------- ボタン関係のメソッド ------------------
\r
143 public void actionPerformed(ActionEvent ae) {
\r
145 if (ae.getSource() == button1) { // 「再学習」
\r
146 learning_flag = true;
\r
149 if (ae.getSource() == button2) { // 「学習終了」
\r
150 learning_flag = false;
\r
153 if (ae.getSource() == button3) { // 「入力クリヤ」
\r
154 if (!learning_flag)
\r
157 if (ae.getSource() == button4) { // 「認識」
\r
158 if (!learning_flag)
\r
159 recognizeCharacter();
\r
164 // ---------- マウス関係のメソッド(手書き文字入力)----------
\r
166 public void mousePressed(MouseEvent me) {
\r
169 if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {
\r
172 written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;
\r
176 public void mouseClicked(MouseEvent me) {
\r
179 public void mouseEntered(MouseEvent me) {
\r
182 public void mouseExited(MouseEvent me) {
\r
185 public void mouseReleased(MouseEvent me) {
\r
188 public void mouseDragged(MouseEvent me) {
\r
191 if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {
\r
192 Graphics g = getGraphics();
\r
195 g.drawLine(x_old, y_old, x_new, y_new);
\r
198 written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;
\r
203 public void mouseMoved(MouseEvent me) {
\r
206 // ---------- 起動時およびrepaint()で呼び出されるメソッド ----------
\r
208 public void paint(Graphics g) {
\r
210 int i, j, k, p, q, r, x;
\r
214 float outer_error; // 外部サイクルエラー累計
\r
215 float inner_error; // 内部サイクルエラー累計
\r
216 float temp_error; // 隠れ層の誤差の累計
\r
219 if (learning_flag) {
\r
220 g.setColor(new Color(255, 255, 192));
\r
221 g.fillRect(5, 35, 590, 460);
\r
222 g.setColor(Color.black);
\r
223 g.drawString("学習モード", 500, 55);
\r
228 g.setColor(new Color(192, 255, 255));
\r
229 g.fillRect(5, 35, 590, 460);
\r
230 g.setColor(Color.black);
\r
231 g.drawString("認識モード", 500, 55);
\r
235 g.drawString("使用している学習用パターン", X0, Y0);
\r
236 for (q = 0; q < PATTERN; q++) {
\r
238 for (j = 0; j < HEIGHT; j++)
\r
239 for (i = 0; i < WIDTH; i++) {
\r
240 if (sample_array[q][WIDTH * j + i] == 1)
\r
241 g.setColor(Color.red);
\r
243 g.setColor(Color.cyan);
\r
244 g.fillRect(X0 + x + 6 * i, Y1 + 6 * j, 5, 5);
\r
247 g.setColor(Color.black);
\r
249 // -------------------------------------------------------------------
\r
250 // --------------------------- 学習モード ----------------------------
\r
251 // -------------------------------------------------------------------
\r
252 if (learning_flag) {
\r
255 for (j = 0; j < HIDDEN; j++) {
\r
256 thresh_h[j] = (float) Math.random() - 0.5f;
\r
257 for (i = 0; i < INPUT; i++)
\r
258 weight_ih[i][j] = (float) Math.random() - 0.5f;
\r
260 for (k = 0; k < OUTPUT; k++) {
\r
261 thresh_o[k] = (float) Math.random() - 0.5f;
\r
262 for (j = 0; j < HIDDEN; j++)
\r
263 weight_ho[j][k] = (float) Math.random() - 0.5f;
\r
266 // -------------------------- 学習 --------------------------
\r
268 for (p = 0; p < OUTER_CYCLES; p++) { // 外部サイクル
\r
270 outer_error = 0.0f; // 外部二乗誤差のクリヤー
\r
272 for (q = 0; q < PATTERN; q++) { // パターンの切り替え
\r
274 // パターンに対応した入力と教師信号の設定
\r
275 sample_in = sample_array[q];
\r
276 teach = teach_array[q];
\r
278 for (r = 0; r < INNER_CYCLES; r++) { // 内部サイクル
\r
281 forwardNeuralNet(sample_in, recog_out);
\r
283 // 逆方向演算(バックプロパゲーション)
\r
284 backwardNeuralNet();
\r
289 inner_error = 0.0f; // 内部二乗誤差のクリヤー
\r
290 for (k = 0; k < OUTPUT; k++)
\r
291 inner_error += (teach[k] - recog_out[k]) * (teach[k] - recog_out[k]);
\r
293 outer_error += inner_error; // 外部二乗誤差への累加算
\r
297 // 外部サイクルの回数と外部二乗誤差の表示
\r
298 g.drawString("実行中の外部サイクルの回数と二乗誤差", X0, Y2);
\r
299 g.setColor(new Color(255, 255, 192));
\r
300 g.fillRect(X0 + 5, Y2 + 10, 200, 50); // 以前の表示を消去
\r
301 g.setColor(Color.black);
\r
302 g.drawString("OuterCycles=" + String.valueOf(p), X0 + 10, Y2 + 25);
\r
303 g.drawString("TotalSquaredError=" + String.valueOf(outer_error), X0 + 10, Y2 + 45);
\r
307 // --------------------- 学習結果の確認 ---------------------
\r
309 g.drawString("学習結果の確認", X0, Y3);
\r
310 for (k = 0; k < OUTPUT; k++) {
\r
311 g.drawString("Output", X1 + 45 * k, Y3 + 25);
\r
312 g.drawString(" [" + String.valueOf(k) + "]", X1 + 5 + 45 * k, Y3 + 40);
\r
315 for (q = 0; q < PATTERN; q++) {
\r
318 sample_in = sample_array[q];
\r
321 forwardNeuralNet(sample_in, recog_out);
\r
324 g.setColor(Color.black);
\r
325 g.drawString("TestPattern[" + String.valueOf(q) + "]", X0 + 10, Y4 + 20 * q);
\r
326 for (k = 0; k < OUTPUT; k++) {
\r
327 if (recog_out[k] > 0.99) { // 99% より大は、赤で YES と表示
\r
328 g.setColor(Color.red);
\r
330 } else if (recog_out[k] < 0.01) { // 1% より小は、青で NO と表示
\r
331 g.setColor(Color.blue);
\r
333 } else { // 1% 以上 99% 以下は、黒で ? と表示
\r
334 g.setColor(Color.black);
\r
337 g.drawString(string, X1 + 10 + 45 * k, Y4 + 20 * q);
\r
343 // -------------------------------------------------------------------
\r
344 // --------------------------- 認識モード ----------------------------
\r
345 // -------------------------------------------------------------------
\r
347 g.setColor(Color.black);
\r
348 g.drawString("マウスで数字を描いて下さい", RX0, RY0);
\r
349 g.drawRect(RX1 - 1, RY1 - 1, WIDTH * 10 + 2, HEIGHT * 10 + 2); // 外枠
\r
350 g.setColor(Color.gray);
\r
351 for (j = 1; j < HEIGHT; j++)
\r
352 g.drawLine(RX1, RY1 + 10 * j, RX1 + WIDTH * 10, RY1 + 10 * j); // 横方向区切り
\r
353 for (i = 1; i < WIDTH; i++)
\r
354 g.drawLine(RX1 + 10 * i, RY1, RX1 + 10 * i, RY1 + HEIGHT * 10); // 縦方向区切り
\r
355 for (i = 0; i < INPUT; i++)
\r
356 written_in[i] = 0; // 手書き入力データのクリヤ
\r
362 public void forwardNeuralNet(int[] input, float[] output) {
\r
364 float[] out = new float[OUTPUT];
\r
365 float[] hidden = new float[HIDDEN];
\r
368 for (int j = 0; j < HIDDEN; j++) {
\r
369 hidden[j] = -thresh_h[j];
\r
370 for (int i = 0; i < INPUT; i++)
\r
371 hidden[j] += input[i] * weight_ih[i][j];
\r
372 hidden_out[j] = sigmoid(hidden[j]);
\r
376 for (int k = 0; k < OUTPUT; k++) {
\r
377 out[k] = -thresh_o[k];
\r
378 for (int j = 0; j < HIDDEN; j++)
\r
379 out[k] += hidden_out[j] * weight_ho[j][k];
\r
380 output[k] = sigmoid(out[k]);
\r
386 public void backwardNeuralNet() {
\r
390 float[] output_error = new float[OUTPUT]; // 出力層の誤差
\r
391 float[] hidden_error = new float[HIDDEN]; // 隠れ層の誤差
\r
396 for (k = 0; k < OUTPUT; k++)
\r
397 output_error[k] = (teach[k] - recog_out[k]) * recog_out[k] * (1.0f - recog_out[k]);
\r
400 for (j = 0; j < HIDDEN; j++) {
\r
402 for (k = 0; k < OUTPUT; k++)
\r
403 temp_error += output_error[k] * weight_ho[j][k];
\r
404 hidden_error[j] = hidden_out[j] * (1.0f - hidden_out[j]) * temp_error;
\r
408 for (k = 0; k < OUTPUT; k++)
\r
409 for (j = 0; j < HIDDEN; j++)
\r
410 weight_ho[j][k] += ALPHA * output_error[k] * hidden_out[j];
\r
411 for (j = 0; j < HIDDEN; j++)
\r
412 for (i = 0; i < INPUT; i++)
\r
413 weight_ih[i][j] += ALPHA * hidden_error[j] * sample_in[i];
\r
416 for (k = 0; k < OUTPUT; k++)
\r
417 thresh_o[k] -= ALPHA * output_error[k];
\r
418 for (j = 0; j < HIDDEN; j++)
\r
419 thresh_h[j] -= ALPHA * hidden_error[j];
\r
423 // Sigmoid関数を計算するメソッド
\r
424 public float sigmoid(float x) {
\r
426 return 1.0f / (1.0f + (float) Math.exp(-BETA * x));
\r
431 public void recognizeCharacter() {
\r
433 Graphics g = getGraphics();
\r
437 forwardNeuralNet(written_in, recog_out);
\r
440 for (int k = 0; k < OUTPUT; k++) {
\r
441 g.setColor(Color.black);
\r
442 g.drawString(String.valueOf(k) + "である", RX2, RY1 + 20 * k);
\r
443 if (recog_out[k] > 0.8f)
\r
444 g.setColor(Color.red);
\r
446 g.setColor(Color.black);
\r
448 g.fillRect(RX3, RY1 - 10 + 20 * k, (int) (200 * recog_out[k]), 10);
\r
449 g.drawString(String.valueOf((int) (100 * recog_out[k] + 0.5f)) + "%", RX3 + (int) (200 * recog_out[k]) + 10, RY1 + 20 * k);
\r