OSDN Git Service

アプレット・クラスに混入していたニューラル・ネットワークのコードを、とりあえず追い出しました。
[simplenn/repo.git] / simplenn / src / main / java / jp / gr / java_conf / u6k / simplenn / SimpleNNApplet.java
index 3e031a6..e00dafc 100644 (file)
@@ -1,5 +1,5 @@
 /*\r
- * Copyright (C) 2007 u6k.yu1@gmail.com, All Rights Reserved.\r
+ * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.\r
  *\r
  * Redistribution and use in source and binary forms, with or without\r
  * modification, are permitted provided that the following conditions\r
@@ -41,11 +41,8 @@ import java.awt.event.MouseEvent;
 import java.awt.event.MouseListener;\r
 import java.awt.event.MouseMotionListener;\r
 import java.io.BufferedReader;\r
-import java.io.BufferedWriter;\r
-import java.io.FileOutputStream;\r
 import java.io.IOException;\r
 import java.io.InputStreamReader;\r
-import java.io.OutputStreamWriter;\r
 \r
 /**\r
  * <p>\r
@@ -63,315 +60,266 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    X0              = 10;\r
+    private static final int X0           = 10;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    X1              = 125;\r
+    private static final int X1           = 125;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    Y0              = 55;\r
+    private static final int Y0           = 55;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    Y1              = 70;\r
+    private static final int Y1           = 70;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    Y2              = 160;\r
+    private static final int Y2           = 160;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    Y3              = 240;\r
+    private static final int Y3           = 240;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    Y4              = 305;\r
+    private static final int Y4           = 305;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RX0             = 30;\r
+    private static final int RX0          = 30;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RX1             = 60;\r
+    private static final int RX1          = 60;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RX2             = 210;\r
+    private static final int RX2          = 210;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RX3             = 260;\r
+    private static final int RX3          = 260;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RY0             = 225;\r
+    private static final int RY0          = 225;\r
 \r
     /**\r
      * <p>\r
      * 文字列とかの表示基準座標。\r
      * </p>\r
      */\r
-    private static final int    RY1             = 240;\r
+    private static final int RY1          = 240;\r
 \r
     /**\r
      * <p>\r
      * 入力データの幅。\r
      * </p>\r
      */\r
-    private static final int    WIDTH           = 7;\r
+    private static final int WIDTH        = 7;\r
 \r
     /**\r
      * <p>\r
      * 入力データの高さ。\r
      * </p>\r
      */\r
-    private static final int    HEIGHT          = 11;\r
+    private static final int HEIGHT       = 11;\r
 \r
     /**\r
      * <p>\r
      * 入力層の数(入力データ数)。\r
      * </p>\r
      */\r
-    private static final int    INPUT           = SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT;\r
+    private static final int INPUT        = SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT;\r
 \r
     /**\r
      * <p>\r
      * 隠れ層の数。\r
      * </p>\r
      */\r
-    private static final int    HIDDEN          = 16;\r
+    private static final int HIDDEN       = 16;\r
 \r
     /**\r
      * <p>\r
      * パターンの種類。\r
      * </p>\r
      */\r
-    private static final int    PATTERN         = 10;\r
+    private static final int PATTERN      = 10;\r
 \r
     /**\r
      * <p>\r
      * 出力層の数(出力データ数)。\r
      * </p>\r
      */\r
-    private static final int    OUTPUT          = SimpleNNApplet.PATTERN;\r
+    private static final int OUTPUT       = SimpleNNApplet.PATTERN;\r
 \r
     /**\r
      * <p>\r
      * 外部サイクル(一連のパターンの繰返し学習)の回数。\r
      * </p>\r
      */\r
-    private static final int    OUTER_CYCLES    = 100;\r
+    private static final int OUTER_CYCLES = 100;\r
 \r
     /**\r
      * <p>\r
      * 内部サイクル(同一パターンの繰返し学習)の回数。\r
      * </p>\r
      */\r
-    private static final int    INNER_CYCLES    = 100;\r
-\r
-    /**\r
-     * <p>\r
-     * 学習の加速係数。\r
-     * </p>\r
-     */\r
-    private static final double ALPHA           = 1.2;\r
-\r
-    /**\r
-     * <p>\r
-     * シグモイド曲線の傾斜。\r
-     * </p>\r
-     */\r
-    private static final double BETA            = 1.2;\r
-\r
-    /**\r
-     * <p>\r
-     * ネットワークを流れる値の上限(1)の半分。\r
-     * </p>\r
-     */\r
-    private static final double VALUE_HALF      = 0.5;\r
+    private static final int INNER_CYCLES = 100;\r
 \r
     /**\r
      * <p>\r
      * 「再学習」ボタン。\r
      * </p>\r
      */\r
-    private Button              button1;\r
+    private Button           button1;\r
 \r
     /**\r
      * <p>\r
      * 「学習終了」ボタン。\r
      * </p>\r
      */\r
-    private Button              button2;\r
+    private Button           button2;\r
 \r
     /**\r
      * <p>\r
      * 「入力クリア」ボタン。\r
      * </p>\r
      */\r
-    private Button              button3;\r
+    private Button           button3;\r
 \r
     /**\r
      * <p>\r
      * 「認識」ボタン。\r
      * </p>\r
      */\r
-    private Button              button4;\r
+    private Button           button4;\r
 \r
     /**\r
      * <p>\r
      * 「状態出力」ボタン。\r
      * </p>\r
      */\r
-    private Button              button5;\r
+    private Button           button5;\r
 \r
     /**\r
      * <p>\r
      * 学習用入力。\r
      * </p>\r
      */\r
-    private double[]            sampleIn        = new double[SimpleNNApplet.INPUT];\r
+    private double[]         sampleIn     = new double[SimpleNNApplet.INPUT];\r
 \r
     /**\r
      * <p>\r
      * 認識用手書き入力。\r
      * </p>\r
      */\r
-    private double[]            writtenIn       = new double[SimpleNNApplet.INPUT];\r
-\r
-    /**\r
-     * <p>\r
-     * 入力層と隠れ層の間の重み係数。\r
-     * </p>\r
-     */\r
-    private double[][]          weightInHidden  = new double[SimpleNNApplet.INPUT][SimpleNNApplet.HIDDEN];\r
-\r
-    /**\r
-     * <p>\r
-     * 隠れ層の閾値。\r
-     * </p>\r
-     */\r
-    private double[]            thresholdHidden = new double[SimpleNNApplet.HIDDEN];\r
-\r
-    /**\r
-     * <p>\r
-     * 隠れ層の出力。\r
-     * </p>\r
-     */\r
-    private double[]            hiddenOut       = new double[SimpleNNApplet.HIDDEN];\r
-\r
-    /**\r
-     * <p>\r
-     * 隠れ層と出力層の間の重み係数。\r
-     * </p>\r
-     */\r
-    private double[][]          weightHiddenOut = new double[SimpleNNApplet.HIDDEN][SimpleNNApplet.OUTPUT];\r
-\r
-    /**\r
-     * <p>\r
-     * 出力層の閾値。\r
-     * </p>\r
-     */\r
-    private double[]            thresholdOut    = new double[SimpleNNApplet.OUTPUT];\r
+    private double[]         writtenIn    = new double[SimpleNNApplet.INPUT];\r
 \r
     /**\r
      * <p>\r
      * 認識出力(出力層の出力)。\r
      * </p>\r
      */\r
-    private double[]            recognizeOut    = new double[SimpleNNApplet.OUTPUT];\r
+    private double[]         recognizeOut = new double[SimpleNNApplet.OUTPUT];\r
 \r
     /**\r
      * <p>\r
      * 教師信号。\r
      * </p>\r
      */\r
-    private double[]            teach           = new double[SimpleNNApplet.PATTERN];\r
+    private double[]         teach        = new double[SimpleNNApplet.PATTERN];\r
 \r
     /**\r
      * <p>\r
      * 「学習モード」フラグ。\r
      * </p>\r
      */\r
-    private boolean             learningFlag;\r
+    private boolean          learningFlag;\r
 \r
     /**\r
      * <p>\r
      * 学習用入力データの基となるパターン。\r
      * </p>\r
      */\r
-    private double[][]          sampleArray;\r
+    private double[][]       sampleArray;\r
 \r
     /**\r
      * <p>\r
      * パターンと出力すべき教師信号の比較表。\r
      * </p>\r
      */\r
-    private double[][]          teachArray      = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.OUTPUT];\r
+    private double[][]       teachArray   = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.OUTPUT];\r
 \r
     /**\r
      * <p>\r
      * 手書き文字入力用座標。\r
      * </p>\r
      */\r
-    private int                 xNew;\r
+    private int              xNew;\r
 \r
     /**\r
      * <p>\r
      * 手書き文字入力用座標。\r
      * </p>\r
      */\r
-    private int                 yNew;\r
+    private int              yNew;\r
 \r
     /**\r
      * <p>\r
      * 手書き文字入力用座標。\r
      * </p>\r
      */\r
-    private int                 xOld;\r
+    private int              xOld;\r
 \r
     /**\r
      * <p>\r
      * 手書き文字入力用座標。\r
      * </p>\r
      */\r
-    private int                 yOld;\r
+    private int              yOld;\r
+\r
+    /**\r
+     * <p>\r
+     * ニューラル・ネットワーク。\r
+     * </p>\r
+     */\r
+    private SimpleNN         simpleNN;\r
 \r
     /**\r
      * <p>\r
@@ -482,7 +430,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
             // 「状態出力」\r
             System.out.println("状態出力開始。");\r
             try {\r
-                this.resultOutput();\r
+                this.simpleNN.outputState();\r
             } catch (IOException e) {\r
                 e.printStackTrace();\r
             }\r
@@ -646,7 +594,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
             // 学習モード\r
 \r
             // 閾値と重みの乱数設定\r
-            this.initNetwork();\r
+            this.simpleNN = new SimpleNN();\r
 \r
             // -------------------------- 学習 --------------------------\r
             for (p = 0; p < SimpleNNApplet.OUTER_CYCLES; p++) {\r
@@ -666,16 +614,16 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
                         // 内部サイクル\r
 \r
                         // 順方向演算\r
-                        this.recognizeOut = this.forwardNeuralNet(this.sampleIn);\r
+                        this.recognizeOut = this.simpleNN.forwardNeuralNet(this.sampleIn);\r
 \r
                         // 逆方向演算(バックプロパゲーション)\r
-                        this.backwardNeuralNet(this.recognizeOut, this.teach);\r
+                        this.simpleNN.backwardNeuralNet(this.sampleIn, this.recognizeOut, this.teach);\r
                     }\r
 \r
                     // 内部二乗誤差の計算\r
 \r
                     // 内部二乗誤差のクリヤー\r
-                    innerError = this.calcError(this.recognizeOut, this.teach);\r
+                    innerError = this.simpleNN.calcError(this.recognizeOut, this.teach);\r
 \r
                     // 外部二乗誤差への累加算\r
                     outerError += innerError;\r
@@ -705,7 +653,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
                 this.sampleIn = this.sampleArray[q];\r
 \r
                 // 順方向演算\r
-                this.recognizeOut = this.forwardNeuralNet(this.sampleIn);\r
+                this.recognizeOut = this.simpleNN.forwardNeuralNet(this.sampleIn);\r
 \r
                 // 結果の表示\r
                 g.setColor(Color.black);\r
@@ -752,160 +700,6 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
 \r
     /**\r
      * <p>\r
-     * 順方向演算を行います。\r
-     * </p>\r
-     * \r
-     * @param input\r
-     *            入力データ。\r
-     * @return 演算結果。\r
-     */\r
-    public double[] forwardNeuralNet(double[] input) {\r
-        double[] out = new double[SimpleNNApplet.OUTPUT];\r
-        double[] hidden = new double[SimpleNNApplet.HIDDEN];\r
-\r
-        // 隠れ層出力の計算\r
-        for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-            hidden[j] = -this.thresholdHidden[j];\r
-            for (int i = 0; i < SimpleNNApplet.INPUT; i++) {\r
-                hidden[j] += input[i] * this.weightInHidden[i][j];\r
-            }\r
-            this.hiddenOut[j] = this.sigmoid(hidden[j]);\r
-        }\r
-\r
-        // 出力層出力の計算\r
-        for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-            out[k] = -this.thresholdOut[k];\r
-            for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-                out[k] += this.hiddenOut[j] * this.weightHiddenOut[j][k];\r
-            }\r
-            out[k] = this.sigmoid(out[k]);\r
-        }\r
-\r
-        return out;\r
-    }\r
-\r
-    /**\r
-     * <p>\r
-     * 逆方向演算を行います。\r
-     * </p>\r
-     * \r
-     * @param output\r
-     *            順方向演算の結果。\r
-     * @param teach\r
-     *            教師信号。\r
-     */\r
-    public void backwardNeuralNet(double[] output, double[] teach) {\r
-        int i;\r
-        int j;\r
-        int k;\r
-\r
-        // 出力層の誤差\r
-        double[] outputError = new double[SimpleNNApplet.OUTPUT];\r
-        // 隠れ層の誤差\r
-        double[] hiddenError = new double[SimpleNNApplet.HIDDEN];\r
-        double tempError;\r
-\r
-        // 出力層の誤差の計算\r
-        for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-            outputError[k] = (teach[k] - output[k]) * output[k] * (1.0 - output[k]);\r
-        }\r
-\r
-        // 隠れ層の誤差の計算\r
-        for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-            tempError = 0;\r
-            for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-                tempError += outputError[k] * this.weightHiddenOut[j][k];\r
-            }\r
-            hiddenError[j] = this.hiddenOut[j] * (1.0 - this.hiddenOut[j]) * tempError;\r
-        }\r
-\r
-        // 重みの補正\r
-        for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-            for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-                this.weightHiddenOut[j][k] += SimpleNNApplet.ALPHA * outputError[k] * this.hiddenOut[j];\r
-            }\r
-        }\r
-        for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-            for (i = 0; i < SimpleNNApplet.INPUT; i++) {\r
-                this.weightInHidden[i][j] += SimpleNNApplet.ALPHA * hiddenError[j] * this.sampleIn[i];\r
-            }\r
-        }\r
-\r
-        // 閾値の補正\r
-        for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-            this.thresholdOut[k] -= SimpleNNApplet.ALPHA * outputError[k];\r
-        }\r
-        for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-            this.thresholdHidden[j] -= SimpleNNApplet.ALPHA * hiddenError[j];\r
-        }\r
-    }\r
-\r
-    /**\r
-     * <p>\r
-     * ニューラル・ネットワークの状態(閾値、重み)を初期化します。\r
-     * </p>\r
-     */\r
-    public void initNetwork() {\r
-        // TODO ちゃんとランダムに初期化する。\r
-        try {\r
-            BufferedReader reader = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));\r
-            try {\r
-                for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-                    this.thresholdHidden[j] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
-                    for (int i = 0; i < SimpleNNApplet.INPUT; i++) {\r
-                        this.weightInHidden[i][j] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
-                    }\r
-                }\r
-                for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
-                    this.thresholdOut[k] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
-                    for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
-                        this.weightHiddenOut[j][k] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
-                    }\r
-                }\r
-            } finally {\r
-                reader.close();\r
-            }\r
-        } catch (IOException e) {\r
-            e.printStackTrace();\r
-        }\r
-    }\r
-\r
-    /**\r
-     * <p>\r
-     * 順方向演算の結果と教師信号とのずれを表す二乗誤差を算出します。\r
-     * </p>\r
-     * \r
-     * @param output\r
-     *            順方向演算の結果。\r
-     * @param teach\r
-     *            教師信号。\r
-     * @return 二乗誤差。\r
-     */\r
-    public double calcError(double[] output, double[] teach) {\r
-        double error = 0;\r
-\r
-        for (int i = 0; i < output.length; i++) {\r
-            error += (teach[i] - output[i]) * (teach[i] - output[i]);\r
-        }\r
-\r
-        return error;\r
-    }\r
-\r
-    /**\r
-     * <p>\r
-     * シグモイド関数です。\r
-     * </p>\r
-     * \r
-     * @param x\r
-     *            引数。\r
-     * @return 計算結果。\r
-     */\r
-    public double sigmoid(double x) {\r
-        return 1.0 / (1.0 + Math.exp(-SimpleNNApplet.BETA * x));\r
-    }\r
-\r
-    /**\r
-     * <p>\r
      * 手書き入力された文字を認識します。\r
      * </p>\r
      */\r
@@ -913,7 +707,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
         Graphics g = this.getGraphics();\r
 \r
         // 順方向演算\r
-        this.recognizeOut = this.forwardNeuralNet(this.writtenIn);\r
+        this.recognizeOut = this.simpleNN.forwardNeuralNet(this.writtenIn);\r
 \r
         // 結果の表示\r
         for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
@@ -930,32 +724,4 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
         }\r
     }\r
 \r
-    private void resultOutput() throws IOException {\r
-        BufferedWriter w = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("C:/result.txt")));\r
-        try {\r
-            for (int i = 0; i < this.weightInHidden.length; i++) {\r
-                for (int j = 0; j < this.weightInHidden[i].length; j++) {\r
-                    w.write("weightInHidden[" + i + "][" + j + "]=" + this.weightInHidden[i][j]);\r
-                    w.newLine();\r
-                }\r
-            }\r
-            for (int i = 0; i < this.thresholdHidden.length; i++) {\r
-                w.write("thresholdHidden[" + i + "]=" + this.thresholdHidden[i]);\r
-                w.newLine();\r
-            }\r
-            for (int i = 0; i < this.weightHiddenOut.length; i++) {\r
-                for (int j = 0; j < this.weightHiddenOut[i].length; j++) {\r
-                    w.write("weightHiddenOut[" + i + "][" + j + "]=" + this.weightHiddenOut[i][j]);\r
-                    w.newLine();\r
-                }\r
-            }\r
-            for (int i = 0; i < this.thresholdOut.length; i++) {\r
-                w.write("thresholdOut[" + i + "]=" + this.thresholdOut[i]);\r
-                w.newLine();\r
-            }\r
-        } finally {\r
-            w.close();\r
-        }\r
-    }\r
-\r
 }\r