OSDN Git Service

git-svn-id: svn+ssh://svn.sourceforge.jp/svnroot/simplenn/trunk@18 dd34cd95-496f...
authoru6k <u6k@dd34cd95-496f-4e97-851c-65f59ff6709a>
Sun, 20 Apr 2008 18:01:42 +0000 (18:01 +0000)
committeru6k <u6k@dd34cd95-496f-4e97-851c-65f59ff6709a>
Sun, 20 Apr 2008 18:01:42 +0000 (18:01 +0000)
simplenn/src/main/java/jp/gr/java_conf/u6k/simplenn/SimpleNN.java
simplenn/src/main/java/jp/gr/java_conf/u6k/simplenn/SimpleNNApplet.java

index 2bf37c7..5ffb654 100644 (file)
@@ -53,114 +53,106 @@ public final class SimpleNN {
 \r
     /**\r
      * <p>\r
-     * 入力データの幅。\r
-     * </p>\r
-     */\r
-    private static final int    WIDTH           = 7;\r
-\r
-    /**\r
-     * <p>\r
-     * 入力データの高さ。\r
-     * </p>\r
-     */\r
-    private static final int    HEIGHT          = 11;\r
-\r
-    /**\r
-     * <p>\r
-     * 入力層の数(入力データ数)。\r
-     * </p>\r
-     */\r
-    private static final int    INPUT           = SimpleNN.WIDTH * SimpleNN.HEIGHT;\r
-\r
-    /**\r
-     * <p>\r
-     * 隠れ層の数。\r
-     * </p>\r
-     */\r
-    private static final int    HIDDEN          = 16;\r
-\r
-    /**\r
-     * <p>\r
-     * パターンの種類。\r
-     * </p>\r
-     */\r
-    private static final int    PATTERN         = 10;\r
-\r
-    /**\r
-     * <p>\r
-     * 出力層の数(出力データ数)。\r
-     * </p>\r
-     */\r
-    private static final int    OUTPUT          = SimpleNN.PATTERN;\r
-\r
-    /**\r
-     * <p>\r
-     * 学習の加速係数。\r
-     * </p>\r
-     */\r
-    private static final double ALPHA           = 1.2;\r
-\r
-    /**\r
-     * <p>\r
      * シグモイド曲線の傾斜。\r
      * </p>\r
      */\r
-    private static final double BETA            = 1.2;\r
+    private static final double BETA       = 1.2;\r
 \r
     /**\r
      * <p>\r
      * ネットワークを流れる値の上限(1)の半分。\r
      * </p>\r
      */\r
-    private static final double VALUE_HALF      = 0.5;\r
+    private static final double VALUE_HALF = 0.5;\r
 \r
     /**\r
      * <p>\r
      * 入力層と隠れ層の間の重み係数。\r
      * </p>\r
      */\r
-    private double[][]          weightInHidden  = new double[SimpleNN.INPUT][SimpleNN.HIDDEN];\r
+    private double[][]          weightInHidden;\r
 \r
     /**\r
      * <p>\r
      * 隠れ層の閾値。\r
      * </p>\r
      */\r
-    private double[]            thresholdHidden = new double[SimpleNN.HIDDEN];\r
+    private double[]            thresholdHidden;\r
 \r
     /**\r
      * <p>\r
      * 隠れ層と出力層の間の重み係数。\r
      * </p>\r
      */\r
-    private double[][]          weightHiddenOut = new double[SimpleNN.HIDDEN][SimpleNN.OUTPUT];\r
+    private double[][]          weightHiddenOut;\r
 \r
     /**\r
      * <p>\r
      * 出力層の閾値。\r
      * </p>\r
      */\r
-    private double[]            thresholdOut    = new double[SimpleNN.OUTPUT];\r
+    private double[]            thresholdOut;\r
+\r
+    /**\r
+     * <p>\r
+     * 学習係数(learning coefficient)。\r
+     * </p>\r
+     */\r
+    private double              learningCoefficient;\r
 \r
     /**\r
      * <p>\r
      * ニューラル・ネットワークの状態(閾値、重み)を初期化します。\r
      * </p>\r
+     * \r
+     * @param inputNumber\r
+     *            入力層のニューロン数。\r
+     * @param hiddenNumber\r
+     *            隠れ層のニューロン数。\r
+     * @param outputNumber\r
+     *            出力層のニューロン数。\r
+     * @param learningCoefficient\r
+     *            学習係数。\r
+     * @throws IllegalArgumentException\r
+     *             inputNumber引数、hiddenNumber引数、outputNumber引数、learningCoefficient引数が0以下の場合。\r
      */\r
-    public SimpleNN() {\r
+    public SimpleNN(int inputNumber, int hiddenNumber, int outputNumber, double learningCoefficient) {\r
+        /*\r
+         * 引数を確認します。\r
+         */\r
+        if (inputNumber <= 0) {\r
+            throw new IllegalArgumentException("inputNumber <= 0");\r
+        }\r
+        if (hiddenNumber <= 0) {\r
+            throw new IllegalArgumentException("hiddenNumber <= 0");\r
+        }\r
+        if (outputNumber <= 0) {\r
+            throw new IllegalArgumentException("outputNumber <= 0");\r
+        }\r
+        if (learningCoefficient <= 0) {\r
+            throw new IllegalArgumentException("learningCoefficient <= 0");\r
+        }\r
+\r
+        /*\r
+         * ニューラル・ネットワークの状態を初期化します。\r
+         */\r
+        this.thresholdHidden = new double[hiddenNumber];\r
+        this.weightInHidden = new double[inputNumber][hiddenNumber];\r
+        this.thresholdOut = new double[outputNumber];\r
+        this.weightHiddenOut = new double[hiddenNumber][outputNumber];\r
         // TODO ちゃんとランダムに初期化する。\r
         try {\r
             BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));\r
             try {\r
-                for (int i = 0; i < SimpleNN.HIDDEN; i++) {\r
+                for (int i = 0; i < hiddenNumber; i++) {\r
                     this.thresholdHidden[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
-                    for (int j = 0; j < SimpleNN.INPUT; j++) {\r
+                    for (int j = 0; j < inputNumber; j++) {\r
                         this.weightInHidden[j][i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
                     }\r
                 }\r
-                for (int i = 0; i < SimpleNN.OUTPUT; i++) {\r
+                for (int i = 0; i < outputNumber; i++) {\r
                     this.thresholdOut[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
-                    for (int j = 0; j < SimpleNN.HIDDEN; j++) {\r
+                    for (int j = 0; j < hiddenNumber; j++) {\r
                         this.weightHiddenOut[j][i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
                     }\r
                 }\r
@@ -170,6 +162,8 @@ public final class SimpleNN {
         } catch (IOException e) {\r
             e.printStackTrace();\r
         }\r
+\r
+        this.learningCoefficient = learningCoefficient;\r
     }\r
 \r
     /**\r
@@ -183,27 +177,43 @@ public final class SimpleNN {
      *            順方向演算の結果を格納する配列。\r
      * @param hiddenOutput\r
      *            順方向演算の過程の隠れ層出力を格納する配列。\r
+     * @throws NullPointerException\r
+     *             input引数、output引数、hiddenOutput引数がnullの場合。\r
      */\r
     public void forwardNeuralNet(double[] input, double[] output, double[] hiddenOutput) {\r
-        double[] out = new double[SimpleNN.OUTPUT];\r
-        double[] hidden = new double[SimpleNN.HIDDEN];\r
+        /*\r
+         * 引数を確認します。\r
+         */\r
+        if (input == null) {\r
+            throw new NullPointerException("input == null");\r
+        }\r
+        if (output == null) {\r
+            throw new NullPointerException("output == null");\r
+        }\r
+        if (hiddenOutput == null) {\r
+            throw new NullPointerException("hiddenOutput == null");\r
+        }\r
 \r
-        // 隠れ層出力の計算\r
-        for (int i = 0; i < SimpleNN.HIDDEN; i++) {\r
-            hidden[i] = -this.thresholdHidden[i];\r
-            for (int j = 0; j < SimpleNN.INPUT; j++) {\r
-                hidden[i] += input[j] * this.weightInHidden[j][i];\r
+        /*\r
+         * 隠れ層の出力を計算します。\r
+         */\r
+        for (int i = 0; i < hiddenOutput.length; i++) {\r
+            hiddenOutput[i] = -this.thresholdHidden[i];\r
+            for (int j = 0; j < input.length; j++) {\r
+                hiddenOutput[i] += input[j] * this.weightInHidden[j][i];\r
             }\r
-            hiddenOutput[i] = this.sigmoid(hidden[i]);\r
+            hiddenOutput[i] = this.sigmoid(hiddenOutput[i]);\r
         }\r
 \r
-        // 出力層出力の計算\r
-        for (int i = 0; i < SimpleNN.OUTPUT; i++) {\r
-            out[i] = -this.thresholdOut[i];\r
-            for (int j = 0; j < SimpleNN.HIDDEN; j++) {\r
-                out[i] += hiddenOutput[j] * this.weightHiddenOut[j][i];\r
+        /*\r
+         * 出力層の出力を計算します。\r
+         */\r
+        for (int i = 0; i < output.length; i++) {\r
+            output[i] = -this.thresholdOut[i];\r
+            for (int j = 0; j < hiddenOutput.length; j++) {\r
+                output[i] += hiddenOutput[j] * this.weightHiddenOut[j][i];\r
             }\r
-            output[i] = this.sigmoid(out[i]);\r
+            output[i] = this.sigmoid(output[i]);\r
         }\r
     }\r
 \r
@@ -220,51 +230,74 @@ public final class SimpleNN {
      *            順方向演算の過程の隠れ層出力を格納する配列。\r
      * @param teach\r
      *            教師信号。\r
+     * @throws NullPointerException\r
+     *             input引数、output引数、hiddenOutput引数、teach引数がnullの場合。\r
      */\r
     public void backwardNeuralNet(double[] input, double[] output, double[] hiddenOutput, double[] teach) {\r
-        // 出力層の誤差\r
-        double[] outputError = new double[SimpleNN.OUTPUT];\r
-        // 隠れ層の誤差\r
-        double[] hiddenError = new double[SimpleNN.HIDDEN];\r
+        /*\r
+         * 引数を確認します。\r
+         */\r
+        if (input == null) {\r
+            throw new NullPointerException("input == null");\r
+        }\r
+        if (output == null) {\r
+            throw new NullPointerException("output == null");\r
+        }\r
+        if (hiddenOutput == null) {\r
+            throw new NullPointerException("hiddenOutput == null");\r
+        }\r
+        if (teach == null) {\r
+            throw new NullPointerException("teach == null");\r
+        }\r
 \r
-        // 出力層の誤差の計算\r
-        for (int i = 0; i < SimpleNN.OUTPUT; i++) {\r
+        /*\r
+         * 出力層の誤差を計算します。\r
+         */\r
+        double[] outputError = new double[output.length];\r
+        for (int i = 0; i < outputError.length; i++) {\r
             outputError[i] = (teach[i] - output[i]) * output[i] * (1.0 - output[i]);\r
         }\r
 \r
-        // 隠れ層の誤差の計算\r
-        for (int i = 0; i < SimpleNN.HIDDEN; i++) {\r
+        /*\r
+         * 隠れ層の誤差を計算します。\r
+         */\r
+        double[] hiddenError = new double[hiddenOutput.length];\r
+        for (int i = 0; i < hiddenError.length; i++) {\r
             double err = 0;\r
-            for (int j = 0; j < SimpleNN.OUTPUT; j++) {\r
+            for (int j = 0; j < output.length; j++) {\r
                 err += outputError[j] * this.weightHiddenOut[i][j];\r
             }\r
             hiddenError[i] = hiddenOutput[i] * (1.0 - hiddenOutput[i]) * err;\r
         }\r
 \r
-        // 重みの補正\r
-        for (int i = 0; i < SimpleNN.OUTPUT; i++) {\r
-            for (int j = 0; j < SimpleNN.HIDDEN; j++) {\r
-                this.weightHiddenOut[j][i] += SimpleNN.ALPHA * outputError[i] * hiddenOutput[j];\r
+        /*\r
+         * 重みを補正します。\r
+         */\r
+        for (int i = 0; i < outputError.length; i++) {\r
+            for (int j = 0; j < hiddenOutput.length; j++) {\r
+                this.weightHiddenOut[j][i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];\r
             }\r
         }\r
-        for (int i = 0; i < SimpleNN.HIDDEN; i++) {\r
-            for (int j = 0; j < SimpleNN.INPUT; j++) {\r
-                this.weightInHidden[j][i] += SimpleNN.ALPHA * hiddenError[i] * input[j];\r
+        for (int i = 0; i < hiddenError.length; i++) {\r
+            for (int j = 0; j < input.length; j++) {\r
+                this.weightInHidden[j][i] += this.learningCoefficient * hiddenError[i] * input[j];\r
             }\r
         }\r
 \r
-        // 閾値の補正\r
-        for (int i = 0; i < SimpleNN.OUTPUT; i++) {\r
-            this.thresholdOut[i] -= SimpleNN.ALPHA * outputError[i];\r
+        /*\r
+         * 閾値を補正します。\r
+         */\r
+        for (int i = 0; i < this.thresholdOut.length; i++) {\r
+            this.thresholdOut[i] -= this.learningCoefficient * outputError[i];\r
         }\r
-        for (int i = 0; i < SimpleNN.HIDDEN; i++) {\r
-            this.thresholdHidden[i] -= SimpleNN.ALPHA * hiddenError[i];\r
+        for (int i = 0; i < this.thresholdHidden.length; i++) {\r
+            this.thresholdHidden[i] -= this.learningCoefficient * hiddenError[i];\r
         }\r
     }\r
 \r
     /**\r
      * <p>\r
-     * 順方向演算の結果と教師信号とのずれを表す二乗誤差を算出します。\r
+     * 順方向演算の結果と教師信号とのずれを表す二乗誤差を計算します。\r
      * </p>\r
      * \r
      * @param output\r
@@ -272,10 +305,24 @@ public final class SimpleNN {
      * @param teach\r
      *            教師信号。\r
      * @return 二乗誤差。\r
+     * @throws NullPointerException\r
+     *             output引数、teach引数がnullの場合。\r
      */\r
     public double calcError(double[] output, double[] teach) {\r
-        double error = 0;\r
+        /*\r
+         * 引数を確認します。\r
+         */\r
+        if (output == null) {\r
+            throw new NullPointerException("output == null");\r
+        }\r
+        if (teach == null) {\r
+            throw new NullPointerException("teach == null");\r
+        }\r
 \r
+        /*\r
+         * 二乗誤差を計算します。\r
+         */\r
+        double error = 0;\r
         for (int i = 0; i < output.length; i++) {\r
             error += (teach[i] - output[i]) * (teach[i] - output[i]);\r
         }\r
index 8c08f3a..42d27c2 100644 (file)
@@ -596,7 +596,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
             // 学習モード\r
 \r
             // 閾値と重みの乱数設定\r
-            this.simpleNN = new SimpleNN();\r
+            this.simpleNN = new SimpleNN(SimpleNNApplet.INPUT, SimpleNNApplet.HIDDEN, SimpleNNApplet.OUTPUT, 1.2);\r
 \r
             // -------------------------- 学習 --------------------------\r
             for (p = 0; p < SimpleNNApplet.OUTER_CYCLES; p++) {\r