2 package jp.gr.java_conf.u6k.simplenn;
\r
4 import java.applet.Applet;
\r
5 import java.awt.Button;
\r
6 import java.awt.Color;
\r
7 import java.awt.Graphics;
\r
8 import java.awt.event.ActionEvent;
\r
9 import java.awt.event.ActionListener;
\r
10 import java.awt.event.MouseEvent;
\r
11 import java.awt.event.MouseListener;
\r
12 import java.awt.event.MouseMotionListener;
\r
13 import java.io.BufferedReader;
\r
14 import java.io.IOException;
\r
15 import java.io.InputStreamReader;
\r
19 * バックプロパゲーション法で学習するニューラル・ネットワークのデモ・アプレットです。このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。
\r
23 * @see http://codezine.jp/a/article/aid/372.aspx
\r
25 public class SimpleNNApplet extends Applet implements MouseListener, MouseMotionListener, ActionListener {
\r
27 Button button1, button2, button3, button4;
\r
29 int X0 = 10, X1 = 125;
\r
31 int Y0 = 55, Y1 = 70, Y2 = 160, Y3 = 240, Y4 = 305;
\r
33 int RX0 = 30, RX1 = 60, RX2 = 210, RX3 = 260;
\r
35 int RY0 = 225, RY1 = 240;
\r
37 int WIDTH = 7; // 入力データの幅
\r
39 int HEIGHT = 11; // 入力データの高さ
\r
41 int INPUT = WIDTH * HEIGHT; // 入力層の数(入力データ数)
\r
43 int HIDDEN = 16; // 隠れ層の数
\r
45 int PATTERN = 10; // パターンの種類
\r
47 int OUTPUT = PATTERN; // 出力層の数(出力データ数)
\r
49 int OUTER_CYCLES = 100; // 外部サイクル(一連のパターンの繰返し学習)の回数
\r
51 int INNER_CYCLES = 100; // 内部サイクル(同一パターンの繰返し学習)の回数
\r
53 float ALPHA = 1.2f; // 学習の加速係数
\r
55 float BETA = 1.2f; // シグモイド曲線の傾斜
\r
57 int[] sample_in = new int[INPUT]; // 学習用入力
\r
59 int[] written_in = new int[INPUT]; // 認識用手書き入力
\r
61 float[][] weight_ih = new float[INPUT][HIDDEN]; // 入力層と隠れ層の間の重み係数
\r
63 float[] thresh_h = new float[HIDDEN]; // 隠れ層の閾値
\r
65 float[] hidden_out = new float[HIDDEN]; // 隠れ層の出力
\r
67 float[][] weight_ho = new float[HIDDEN][OUTPUT]; // 隠れ層と出力層の間の重み係数
\r
69 float[] thresh_o = new float[OUTPUT]; // 出力層の閾値
\r
71 float[] recog_out = new float[OUTPUT]; // 認識出力(出力層の出力)
\r
73 int[] teach = new int[PATTERN]; // 教師信号
\r
75 boolean learning_flag; // 「学習モード」フラグ
\r
77 int[][] sample_array; // 学習用入力データの基となるパターン
\r
79 int[][] teach_array = new int[PATTERN][OUTPUT]; // パターンと出力すべき教師信号の比較表
\r
81 int x_new, y_new, x_old, y_old; // 手書き文字入力用座標
\r
83 public void init() {
\r
84 // 学習用入力データの元となるパターンの読み込み
\r
85 this.sample_array = new int[10][WIDTH * HEIGHT];
\r
86 for (int i = 0; i <= 9; i++) {
\r
87 BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream(i + ".txt")));
\r
92 while ((line = r.readLine()) != null) {
\r
93 for (char c : line.toCharArray()) {
\r
94 this.sample_array[i][j] = Integer.parseInt(Character.toString(c));
\r
101 } catch (IOException e) {
\r
102 throw new RuntimeException(e);
\r
106 setBackground(Color.gray);
\r
109 add(button1 = new Button(" 再学習 "));
\r
110 add(button2 = new Button(" 学習終了 "));
\r
111 add(button3 = new Button("入力クリヤ"));
\r
112 add(button4 = new Button(" 認 識 "));
\r
113 button1.addActionListener(this);
\r
114 button2.addActionListener(this);
\r
115 button3.addActionListener(this);
\r
116 button4.addActionListener(this);
\r
119 addMouseListener(this);
\r
120 addMouseMotionListener(this);
\r
123 for (int q = 0; q < PATTERN; q++)
\r
124 for (int k = 0; k < OUTPUT; k++) {
\r
126 teach_array[q][k] = 1;
\r
128 teach_array[q][k] = 0;
\r
132 learning_flag = true;
\r
136 // ------------------- ボタン関係のメソッド ------------------
\r
138 public void actionPerformed(ActionEvent ae) {
\r
140 if (ae.getSource() == button1) { // 「再学習」
\r
141 learning_flag = true;
\r
144 if (ae.getSource() == button2) { // 「学習終了」
\r
145 learning_flag = false;
\r
148 if (ae.getSource() == button3) { // 「入力クリヤ」
\r
149 if (!learning_flag)
\r
152 if (ae.getSource() == button4) { // 「認識」
\r
153 if (!learning_flag)
\r
154 recognizeCharacter();
\r
159 // ---------- マウス関係のメソッド(手書き文字入力)----------
\r
161 public void mousePressed(MouseEvent me) {
\r
164 if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {
\r
167 written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;
\r
171 public void mouseClicked(MouseEvent me) {
\r
174 public void mouseEntered(MouseEvent me) {
\r
177 public void mouseExited(MouseEvent me) {
\r
180 public void mouseReleased(MouseEvent me) {
\r
183 public void mouseDragged(MouseEvent me) {
\r
186 if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {
\r
187 Graphics g = getGraphics();
\r
190 g.drawLine(x_old, y_old, x_new, y_new);
\r
193 written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;
\r
198 public void mouseMoved(MouseEvent me) {
\r
201 // ---------- 起動時およびrepaint()で呼び出されるメソッド ----------
\r
203 public void paint(Graphics g) {
\r
205 int i, j, k, p, q, r, x;
\r
209 float outer_error; // 外部サイクルエラー累計
\r
210 float inner_error; // 内部サイクルエラー累計
\r
211 float temp_error; // 隠れ層の誤差の累計
\r
214 if (learning_flag) {
\r
215 g.setColor(new Color(255, 255, 192));
\r
216 g.fillRect(5, 35, 590, 460);
\r
217 g.setColor(Color.black);
\r
218 g.drawString("学習モード", 500, 55);
\r
223 g.setColor(new Color(192, 255, 255));
\r
224 g.fillRect(5, 35, 590, 460);
\r
225 g.setColor(Color.black);
\r
226 g.drawString("認識モード", 500, 55);
\r
230 g.drawString("使用している学習用パターン", X0, Y0);
\r
231 for (q = 0; q < PATTERN; q++) {
\r
233 for (j = 0; j < HEIGHT; j++)
\r
234 for (i = 0; i < WIDTH; i++) {
\r
235 if (sample_array[q][WIDTH * j + i] == 1)
\r
236 g.setColor(Color.red);
\r
238 g.setColor(Color.cyan);
\r
239 g.fillRect(X0 + x + 6 * i, Y1 + 6 * j, 5, 5);
\r
242 g.setColor(Color.black);
\r
244 // -------------------------------------------------------------------
\r
245 // --------------------------- 学習モード ----------------------------
\r
246 // -------------------------------------------------------------------
\r
247 if (learning_flag) {
\r
250 for (j = 0; j < HIDDEN; j++) {
\r
251 thresh_h[j] = (float) Math.random() - 0.5f;
\r
252 for (i = 0; i < INPUT; i++)
\r
253 weight_ih[i][j] = (float) Math.random() - 0.5f;
\r
255 for (k = 0; k < OUTPUT; k++) {
\r
256 thresh_o[k] = (float) Math.random() - 0.5f;
\r
257 for (j = 0; j < HIDDEN; j++)
\r
258 weight_ho[j][k] = (float) Math.random() - 0.5f;
\r
261 // -------------------------- 学習 --------------------------
\r
263 for (p = 0; p < OUTER_CYCLES; p++) { // 外部サイクル
\r
265 outer_error = 0.0f; // 外部二乗誤差のクリヤー
\r
267 for (q = 0; q < PATTERN; q++) { // パターンの切り替え
\r
269 // パターンに対応した入力と教師信号の設定
\r
270 sample_in = sample_array[q];
\r
271 teach = teach_array[q];
\r
273 for (r = 0; r < INNER_CYCLES; r++) { // 内部サイクル
\r
276 forwardNeuralNet(sample_in, recog_out);
\r
278 // 逆方向演算(バックプロパゲーション)
\r
279 backwardNeuralNet();
\r
284 inner_error = 0.0f; // 内部二乗誤差のクリヤー
\r
285 for (k = 0; k < OUTPUT; k++)
\r
286 inner_error += (teach[k] - recog_out[k]) * (teach[k] - recog_out[k]);
\r
288 outer_error += inner_error; // 外部二乗誤差への累加算
\r
292 // 外部サイクルの回数と外部二乗誤差の表示
\r
293 g.drawString("実行中の外部サイクルの回数と二乗誤差", X0, Y2);
\r
294 g.setColor(new Color(255, 255, 192));
\r
295 g.fillRect(X0 + 5, Y2 + 10, 200, 50); // 以前の表示を消去
\r
296 g.setColor(Color.black);
\r
297 g.drawString("OuterCycles=" + String.valueOf(p), X0 + 10, Y2 + 25);
\r
298 g.drawString("TotalSquaredError=" + String.valueOf(outer_error), X0 + 10, Y2 + 45);
\r
302 // --------------------- 学習結果の確認 ---------------------
\r
304 g.drawString("学習結果の確認", X0, Y3);
\r
305 for (k = 0; k < OUTPUT; k++) {
\r
306 g.drawString("Output", X1 + 45 * k, Y3 + 25);
\r
307 g.drawString(" [" + String.valueOf(k) + "]", X1 + 5 + 45 * k, Y3 + 40);
\r
310 for (q = 0; q < PATTERN; q++) {
\r
313 sample_in = sample_array[q];
\r
316 forwardNeuralNet(sample_in, recog_out);
\r
319 g.setColor(Color.black);
\r
320 g.drawString("TestPattern[" + String.valueOf(q) + "]", X0 + 10, Y4 + 20 * q);
\r
321 for (k = 0; k < OUTPUT; k++) {
\r
322 if (recog_out[k] > 0.99) { // 99% より大は、赤で YES と表示
\r
323 g.setColor(Color.red);
\r
325 } else if (recog_out[k] < 0.01) { // 1% より小は、青で NO と表示
\r
326 g.setColor(Color.blue);
\r
328 } else { // 1% 以上 99% 以下は、黒で ? と表示
\r
329 g.setColor(Color.black);
\r
332 g.drawString(string, X1 + 10 + 45 * k, Y4 + 20 * q);
\r
338 // -------------------------------------------------------------------
\r
339 // --------------------------- 認識モード ----------------------------
\r
340 // -------------------------------------------------------------------
\r
342 g.setColor(Color.black);
\r
343 g.drawString("マウスで数字を描いて下さい", RX0, RY0);
\r
344 g.drawRect(RX1 - 1, RY1 - 1, WIDTH * 10 + 2, HEIGHT * 10 + 2); // 外枠
\r
345 g.setColor(Color.gray);
\r
346 for (j = 1; j < HEIGHT; j++)
\r
347 g.drawLine(RX1, RY1 + 10 * j, RX1 + WIDTH * 10, RY1 + 10 * j); // 横方向区切り
\r
348 for (i = 1; i < WIDTH; i++)
\r
349 g.drawLine(RX1 + 10 * i, RY1, RX1 + 10 * i, RY1 + HEIGHT * 10); // 縦方向区切り
\r
350 for (i = 0; i < INPUT; i++)
\r
351 written_in[i] = 0; // 手書き入力データのクリヤ
\r
357 public void forwardNeuralNet(int[] input, float[] output) {
\r
359 float[] out = new float[OUTPUT];
\r
360 float[] hidden = new float[HIDDEN];
\r
363 for (int j = 0; j < HIDDEN; j++) {
\r
364 hidden[j] = -thresh_h[j];
\r
365 for (int i = 0; i < INPUT; i++)
\r
366 hidden[j] += input[i] * weight_ih[i][j];
\r
367 hidden_out[j] = sigmoid(hidden[j]);
\r
371 for (int k = 0; k < OUTPUT; k++) {
\r
372 out[k] = -thresh_o[k];
\r
373 for (int j = 0; j < HIDDEN; j++)
\r
374 out[k] += hidden_out[j] * weight_ho[j][k];
\r
375 output[k] = sigmoid(out[k]);
\r
381 public void backwardNeuralNet() {
\r
385 float[] output_error = new float[OUTPUT]; // 出力層の誤差
\r
386 float[] hidden_error = new float[HIDDEN]; // 隠れ層の誤差
\r
391 for (k = 0; k < OUTPUT; k++)
\r
392 output_error[k] = (teach[k] - recog_out[k]) * recog_out[k] * (1.0f - recog_out[k]);
\r
395 for (j = 0; j < HIDDEN; j++) {
\r
397 for (k = 0; k < OUTPUT; k++)
\r
398 temp_error += output_error[k] * weight_ho[j][k];
\r
399 hidden_error[j] = hidden_out[j] * (1.0f - hidden_out[j]) * temp_error;
\r
403 for (k = 0; k < OUTPUT; k++)
\r
404 for (j = 0; j < HIDDEN; j++)
\r
405 weight_ho[j][k] += ALPHA * output_error[k] * hidden_out[j];
\r
406 for (j = 0; j < HIDDEN; j++)
\r
407 for (i = 0; i < INPUT; i++)
\r
408 weight_ih[i][j] += ALPHA * hidden_error[j] * sample_in[i];
\r
411 for (k = 0; k < OUTPUT; k++)
\r
412 thresh_o[k] -= ALPHA * output_error[k];
\r
413 for (j = 0; j < HIDDEN; j++)
\r
414 thresh_h[j] -= ALPHA * hidden_error[j];
\r
418 // Sigmoid関数を計算するメソッド
\r
419 public float sigmoid(float x) {
\r
421 return 1.0f / (1.0f + (float) Math.exp(-BETA * x));
\r
426 public void recognizeCharacter() {
\r
428 Graphics g = getGraphics();
\r
432 forwardNeuralNet(written_in, recog_out);
\r
435 for (int k = 0; k < OUTPUT; k++) {
\r
436 g.setColor(Color.black);
\r
437 g.drawString(String.valueOf(k) + "である", RX2, RY1 + 20 * k);
\r
438 if (recog_out[k] > 0.8f)
\r
439 g.setColor(Color.red);
\r
441 g.setColor(Color.black);
\r
443 g.fillRect(RX3, RY1 - 10 + 20 * k, (int) (200 * recog_out[k]), 10);
\r
444 g.drawString(String.valueOf((int) (100 * recog_out[k] + 0.5f)) + "%", RX3 + (int) (200 * recog_out[k]) + 10, RY1 + 20 * k);
\r