+import java.applet.Applet;\r
+import java.awt.Button;\r
+import java.awt.Color;\r
+import java.awt.Graphics;\r
+import java.awt.event.ActionEvent;\r
+import java.awt.event.ActionListener;\r
+import java.awt.event.MouseEvent;\r
+import java.awt.event.MouseListener;\r
+import java.awt.event.MouseMotionListener;\r
+\r
+/**\r
+ * <p>\r
+ * バックプロパゲーション法で学習するニューラル・ネットワークのデモ・アプレットです。このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。\r
+ * </p>\r
+ * \r
+ * @version $Id$\r
+ * @see http://codezine.jp/a/article/aid/372.aspx\r
+ */\r
+public class SimpleNNApplet extends Applet implements MouseListener, MouseMotionListener, ActionListener {\r
+\r
+ Button button1, button2, button3, button4;\r
+\r
+ int X0 = 10, X1 = 125;\r
+\r
+ int Y0 = 55, Y1 = 70, Y2 = 160, Y3 = 240, Y4 = 305;\r
+\r
+ int RX0 = 30, RX1 = 60, RX2 = 210, RX3 = 260;\r
+\r
+ int RY0 = 225, RY1 = 240;\r
+\r
+ int WIDTH = 7; // 入力データの幅\r
+\r
+ int HEIGHT = 11; // 入力データの高さ\r
+\r
+ int INPUT = WIDTH * HEIGHT; // 入力層の数(入力データ数)\r
+\r
+ int HIDDEN = 16; // 隠れ層の数\r
+\r
+ int PATTERN = 10; // パターンの種類\r
+\r
+ int OUTPUT = PATTERN; // 出力層の数(出力データ数)\r
+\r
+ int OUTER_CYCLES = 100; // 外部サイクル(一連のパターンの繰返し学習)の回数\r
+\r
+ int INNER_CYCLES = 100; // 内部サイクル(同一パターンの繰返し学習)の回数\r
+\r
+ float ALPHA = 1.2f; // 学習の加速係数\r
+\r
+ float BETA = 1.2f; // シグモイド曲線の傾斜\r
+\r
+ int[] sample_in = new int[INPUT]; // 学習用入力\r
+\r
+ int[] written_in = new int[INPUT]; // 認識用手書き入力\r
+\r
+ float[][] weight_ih = new float[INPUT][HIDDEN]; // 入力層と隠れ層の間の重み係数\r
+\r
+ float[] thresh_h = new float[HIDDEN]; // 隠れ層の閾値\r
+\r
+ float[] hidden_out = new float[HIDDEN]; // 隠れ層の出力\r
+\r
+ float[][] weight_ho = new float[HIDDEN][OUTPUT]; // 隠れ層と出力層の間の重み係数\r
+\r
+ float[] thresh_o = new float[OUTPUT]; // 出力層の閾値\r
+\r
+ float[] recog_out = new float[OUTPUT]; // 認識出力(出力層の出力)\r
+\r
+ int[] teach = new int[PATTERN]; // 教師信号\r
+\r
+ boolean learning_flag; // 「学習モード」フラグ\r
+\r
+ // 学習用入力データの基となるパターン\r
+ int[][] sample_array = { { 0, 0, 1, 1, 1, 0, 0, // '0'\r
+ 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },\r
+\r
+ { 0, 0, 0, 1, 0, 0, 0, // '1'\r
+ 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 },\r
+\r
+ { 0, 0, 1, 1, 1, 0, 0, // '2'\r
+ 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 },\r
+\r
+ { 0, 0, 1, 1, 1, 0, 0, // '3'\r
+ 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },\r
+\r
+ { 0, 0, 0, 0, 1, 0, 0, // '4'\r
+ 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0 },\r
+\r
+ { 1, 1, 1, 1, 1, 1, 1, // '5'\r
+ 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0 },\r
+\r
+ { 0, 0, 0, 0, 1, 1, 0, // '6'\r
+ 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0 },\r
+\r
+ { 1, 1, 1, 1, 1, 1, 1, // '7'\r
+ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 },\r
+\r
+ { 0, 0, 1, 1, 1, 0, 0, // '8'\r
+ 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0 },\r
+\r
+ { 0, 1, 1, 1, 1, 1, 0, // '9'\r
+ 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0 } };\r
+\r
+ int[][] teach_array = new int[PATTERN][OUTPUT]; // パターンと出力すべき教師信号の比較表\r
+\r
+ int x_new, y_new, x_old, y_old; // 手書き文字入力用座標\r
+\r
+ public void init() {\r
+\r
+ setBackground(Color.gray);\r
+\r
+ // ボタンの設定\r
+ add(button1 = new Button(" 再学習 "));\r
+ add(button2 = new Button(" 学習終了 "));\r
+ add(button3 = new Button("入力クリヤ"));\r
+ add(button4 = new Button(" 認 識 "));\r
+ button1.addActionListener(this);\r
+ button2.addActionListener(this);\r
+ button3.addActionListener(this);\r
+ button4.addActionListener(this);\r
+\r
+ // マウスの設定\r
+ addMouseListener(this);\r
+ addMouseMotionListener(this);\r
+\r
+ // 教師信号の設定\r
+ for (int q = 0; q < PATTERN; q++)\r
+ for (int k = 0; k < OUTPUT; k++) {\r
+ if (q == k)\r
+ teach_array[q][k] = 1;\r
+ else\r
+ teach_array[q][k] = 0;\r
+ }\r
+\r
+ // モードの初期設定\r
+ learning_flag = true;\r
+\r
+ }\r
+\r
+ // ------------------- ボタン関係のメソッド ------------------\r
+\r
+ public void actionPerformed(ActionEvent ae) {\r
+\r
+ if (ae.getSource() == button1) { // 「再学習」\r
+ learning_flag = true;\r
+ repaint();\r
+ }\r
+ if (ae.getSource() == button2) { // 「学習終了」\r
+ learning_flag = false;\r
+ repaint();\r
+ }\r
+ if (ae.getSource() == button3) { // 「入力クリヤ」\r
+ if (!learning_flag)\r
+ repaint();\r
+ }\r
+ if (ae.getSource() == button4) { // 「認識」\r
+ if (!learning_flag)\r
+ recognizeCharacter();\r
+ }\r
+\r
+ }\r
+\r
+ // ---------- マウス関係のメソッド(手書き文字入力)----------\r
+\r
+ public void mousePressed(MouseEvent me) {\r
+ int x = me.getX();\r
+ int y = me.getY();\r
+ if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {\r
+ x_old = me.getX();\r
+ y_old = me.getY();\r
+ written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;\r
+ }\r
+ }\r
+\r
+ public void mouseClicked(MouseEvent me) {\r
+ }\r
+\r
+ public void mouseEntered(MouseEvent me) {\r
+ }\r
+\r
+ public void mouseExited(MouseEvent me) {\r
+ }\r
+\r
+ public void mouseReleased(MouseEvent me) {\r
+ }\r
+\r
+ public void mouseDragged(MouseEvent me) {\r
+ int x = me.getX();\r
+ int y = me.getY();\r
+ if (!learning_flag && x >= RX1 && x <= RX1 + WIDTH * 10 && y >= RY1 && y <= RY1 + HEIGHT * 10) {\r
+ Graphics g = getGraphics();\r
+ x_new = me.getX();\r
+ y_new = me.getY();\r
+ g.drawLine(x_old, y_old, x_new, y_new);\r
+ x_old = x_new;\r
+ y_old = y_new;\r
+ written_in[(y_old - RY1) / 10 * WIDTH + (x_old - RX1) / 10] = 1;\r
+ }\r
+\r
+ }\r
+\r
+ public void mouseMoved(MouseEvent me) {\r
+ }\r
+\r
+ // ---------- 起動時およびrepaint()で呼び出されるメソッド ----------\r
+\r
+ public void paint(Graphics g) {\r
+\r
+ int i, j, k, p, q, r, x;\r
+\r
+ String string;\r
+\r
+ float outer_error; // 外部サイクルエラー累計\r
+ float inner_error; // 内部サイクルエラー累計\r
+ float temp_error; // 隠れ層の誤差の累計\r
+\r
+ // 学習モードの背景\r
+ if (learning_flag) {\r
+ g.setColor(new Color(255, 255, 192));\r
+ g.fillRect(5, 35, 590, 460);\r
+ g.setColor(Color.black);\r
+ g.drawString("学習モード", 500, 55);\r
+ }\r
+\r
+ // 認識モードの背景\r
+ else {\r
+ g.setColor(new Color(192, 255, 255));\r
+ g.fillRect(5, 35, 590, 460);\r
+ g.setColor(Color.black);\r
+ g.drawString("認識モード", 500, 55);\r
+ }\r
+\r
+ // 学習用パターンの表示\r
+ g.drawString("使用している学習用パターン", X0, Y0);\r
+ for (q = 0; q < PATTERN; q++) {\r
+ x = 56 * q;\r
+ for (j = 0; j < HEIGHT; j++)\r
+ for (i = 0; i < WIDTH; i++) {\r
+ if (sample_array[q][WIDTH * j + i] == 1)\r
+ g.setColor(Color.red);\r
+ else\r
+ g.setColor(Color.cyan);\r
+ g.fillRect(X0 + x + 6 * i, Y1 + 6 * j, 5, 5);\r
+ }\r
+ }\r
+ g.setColor(Color.black);\r
+\r
+ // -------------------------------------------------------------------\r
+ // --------------------------- 学習モード ----------------------------\r
+ // -------------------------------------------------------------------\r
+ if (learning_flag) {\r
+\r
+ // 閾値と重みの乱数設定\r
+ for (j = 0; j < HIDDEN; j++) {\r
+ thresh_h[j] = (float) Math.random() - 0.5f;\r
+ for (i = 0; i < INPUT; i++)\r
+ weight_ih[i][j] = (float) Math.random() - 0.5f;\r
+ }\r
+ for (k = 0; k < OUTPUT; k++) {\r
+ thresh_o[k] = (float) Math.random() - 0.5f;\r
+ for (j = 0; j < HIDDEN; j++)\r
+ weight_ho[j][k] = (float) Math.random() - 0.5f;\r
+ }\r
+\r
+ // -------------------------- 学習 --------------------------\r
+\r
+ for (p = 0; p < OUTER_CYCLES; p++) { // 外部サイクル\r
+\r
+ outer_error = 0.0f; // 外部二乗誤差のクリヤー\r
+\r
+ for (q = 0; q < PATTERN; q++) { // パターンの切り替え\r
+\r
+ // パターンに対応した入力と教師信号の設定\r
+ sample_in = sample_array[q];\r
+ teach = teach_array[q];\r
+\r
+ for (r = 0; r < INNER_CYCLES; r++) { // 内部サイクル\r
+\r
+ // 順方向演算\r
+ forwardNeuralNet(sample_in, recog_out);\r
+\r
+ // 逆方向演算(バックプロパゲーション)\r
+ backwardNeuralNet();\r
+\r
+ }\r
+\r
+ // 内部二乗誤差の計算\r
+ inner_error = 0.0f; // 内部二乗誤差のクリヤー\r
+ for (k = 0; k < OUTPUT; k++)\r
+ inner_error += (teach[k] - recog_out[k]) * (teach[k] - recog_out[k]);\r
+\r
+ outer_error += inner_error; // 外部二乗誤差への累加算\r
+\r
+ }\r
+\r
+ // 外部サイクルの回数と外部二乗誤差の表示\r
+ g.drawString("実行中の外部サイクルの回数と二乗誤差", X0, Y2);\r
+ g.setColor(new Color(255, 255, 192));\r
+ g.fillRect(X0 + 5, Y2 + 10, 200, 50); // 以前の表示を消去\r
+ g.setColor(Color.black);\r
+ g.drawString("OuterCycles=" + String.valueOf(p), X0 + 10, Y2 + 25);\r
+ g.drawString("TotalSquaredError=" + String.valueOf(outer_error), X0 + 10, Y2 + 45);\r
+\r
+ }\r
+\r
+ // --------------------- 学習結果の確認 ---------------------\r
+\r
+ g.drawString("学習結果の確認", X0, Y3);\r
+ for (k = 0; k < OUTPUT; k++) {\r
+ g.drawString("Output", X1 + 45 * k, Y3 + 25);\r
+ g.drawString(" [" + String.valueOf(k) + "]", X1 + 5 + 45 * k, Y3 + 40);\r
+ }\r
+\r
+ for (q = 0; q < PATTERN; q++) {\r
+\r
+ // 入力パターンの設定\r
+ sample_in = sample_array[q];\r
+\r
+ // 順方向演算\r
+ forwardNeuralNet(sample_in, recog_out);\r
+\r
+ // 結果の表示\r
+ g.setColor(Color.black);\r
+ g.drawString("TestPattern[" + String.valueOf(q) + "]", X0 + 10, Y4 + 20 * q);\r
+ for (k = 0; k < OUTPUT; k++) {\r
+ if (recog_out[k] > 0.99) { // 99% より大は、赤で YES と表示\r
+ g.setColor(Color.red);\r
+ string = "YES";\r
+ } else if (recog_out[k] < 0.01) { // 1% より小は、青で NO と表示\r
+ g.setColor(Color.blue);\r
+ string = "NO ";\r
+ } else { // 1% 以上 99% 以下は、黒で ? と表示\r
+ g.setColor(Color.black);\r
+ string = " ? ";\r
+ }\r
+ g.drawString(string, X1 + 10 + 45 * k, Y4 + 20 * q);\r
+ }\r
+\r
+ }\r
+ }\r
+\r
+ // -------------------------------------------------------------------\r
+ // --------------------------- 認識モード ----------------------------\r
+ // -------------------------------------------------------------------\r
+ else {\r
+ g.setColor(Color.black);\r
+ g.drawString("マウスで数字を描いて下さい", RX0, RY0);\r
+ g.drawRect(RX1 - 1, RY1 - 1, WIDTH * 10 + 2, HEIGHT * 10 + 2); // 外枠\r
+ g.setColor(Color.gray);\r
+ for (j = 1; j < HEIGHT; j++)\r
+ g.drawLine(RX1, RY1 + 10 * j, RX1 + WIDTH * 10, RY1 + 10 * j); // 横方向区切り\r
+ for (i = 1; i < WIDTH; i++)\r
+ g.drawLine(RX1 + 10 * i, RY1, RX1 + 10 * i, RY1 + HEIGHT * 10); // 縦方向区切り\r
+ for (i = 0; i < INPUT; i++)\r
+ written_in[i] = 0; // 手書き入力データのクリヤ\r
+ }\r
+\r
+ }\r
+\r
+ // 順方向演算のメソッド\r
+ public void forwardNeuralNet(int[] input, float[] output) {\r
+\r
+ float[] out = new float[OUTPUT];\r
+ float[] hidden = new float[HIDDEN];\r
+\r
+ // 隠れ層出力の計算\r
+ for (int j = 0; j < HIDDEN; j++) {\r
+ hidden[j] = -thresh_h[j];\r
+ for (int i = 0; i < INPUT; i++)\r
+ hidden[j] += input[i] * weight_ih[i][j];\r
+ hidden_out[j] = sigmoid(hidden[j]);\r
+ }\r
+\r
+ // 出力層出力の計算\r
+ for (int k = 0; k < OUTPUT; k++) {\r
+ out[k] = -thresh_o[k];\r
+ for (int j = 0; j < HIDDEN; j++)\r
+ out[k] += hidden_out[j] * weight_ho[j][k];\r
+ output[k] = sigmoid(out[k]);\r
+ }\r
+\r
+ }\r
+\r
+ // 逆方向演算のメソッド\r
+ public void backwardNeuralNet() {\r
+\r
+ int i, j, k;\r
+\r
+ float[] output_error = new float[OUTPUT]; // 出力層の誤差\r
+ float[] hidden_error = new float[HIDDEN]; // 隠れ層の誤差\r
+\r
+ float temp_error;\r
+\r
+ // 出力層の誤差の計算\r
+ for (k = 0; k < OUTPUT; k++)\r
+ output_error[k] = (teach[k] - recog_out[k]) * recog_out[k] * (1.0f - recog_out[k]);\r
+\r
+ // 隠れ層の誤差の計算\r
+ for (j = 0; j < HIDDEN; j++) {\r
+ temp_error = 0.0f;\r
+ for (k = 0; k < OUTPUT; k++)\r
+ temp_error += output_error[k] * weight_ho[j][k];\r
+ hidden_error[j] = hidden_out[j] * (1.0f - hidden_out[j]) * temp_error;\r
+ }\r
+\r
+ // 重みの補正\r
+ for (k = 0; k < OUTPUT; k++)\r
+ for (j = 0; j < HIDDEN; j++)\r
+ weight_ho[j][k] += ALPHA * output_error[k] * hidden_out[j];\r
+ for (j = 0; j < HIDDEN; j++)\r
+ for (i = 0; i < INPUT; i++)\r
+ weight_ih[i][j] += ALPHA * hidden_error[j] * sample_in[i];\r
+\r
+ // 閾値の補正\r
+ for (k = 0; k < OUTPUT; k++)\r
+ thresh_o[k] -= ALPHA * output_error[k];\r
+ for (j = 0; j < HIDDEN; j++)\r
+ thresh_h[j] -= ALPHA * hidden_error[j];\r
+\r
+ }\r
+\r
+ // Sigmoid関数を計算するメソッド\r
+ public float sigmoid(float x) {\r
+\r
+ return 1.0f / (1.0f + (float) Math.exp(-BETA * x));\r
+\r
+ }\r
+\r
+ // 入力文字を認識するメソッド\r
+ public void recognizeCharacter() {\r
+\r
+ Graphics g = getGraphics();\r
+ String string;\r
+\r
+ // 順方向演算\r
+ forwardNeuralNet(written_in, recog_out);\r
+\r
+ // 結果の表示\r
+ for (int k = 0; k < OUTPUT; k++) {\r
+ g.setColor(Color.black);\r
+ g.drawString(String.valueOf(k) + "である", RX2, RY1 + 20 * k);\r
+ if (recog_out[k] > 0.8f)\r
+ g.setColor(Color.red);\r
+ else\r
+ g.setColor(Color.black);\r
+\r
+ g.fillRect(RX3, RY1 - 10 + 20 * k, (int) (200 * recog_out[k]), 10);\r
+ g.drawString(String.valueOf((int) (100 * recog_out[k] + 0.5f)) + "%", RX3 + (int) (200 * recog_out[k]) + 10, RY1 + 20 * k);\r
+ }\r