OSDN Git Service

git-svn-id: svn+ssh://svn.sourceforge.jp/svnroot/simplenn/trunk@20 dd34cd95-496f...
authoru6k <u6k@dd34cd95-496f-4e97-851c-65f59ff6709a>
Mon, 21 Apr 2008 04:00:57 +0000 (04:00 +0000)
committeru6k <u6k@dd34cd95-496f-4e97-851c-65f59ff6709a>
Mon, 21 Apr 2008 04:00:57 +0000 (04:00 +0000)
simplenn/src/main/java/jp/gr/java_conf/u6k/simplenn/SimpleNN.java
simplenn/src/main/java/jp/gr/java_conf/u6k/simplenn/SimpleNNApplet.java

index 5ffb654..0e7339b 100644 (file)
@@ -67,6 +67,27 @@ public final class SimpleNN {
 \r
     /**\r
      * <p>\r
+     * 入力層のニューロン数。\r
+     * </p>\r
+     */\r
+    private int                 inputNumber;\r
+\r
+    /**\r
+     * <p>\r
+     * 隠れ層のニューロン数。\r
+     * </p>\r
+     */\r
+    private int                 hiddenNumber;\r
+\r
+    /**\r
+     * <p>\r
+     * 出力層のニューロン数。\r
+     * </p>\r
+     */\r
+    private int                 outputNumber;\r
+\r
+    /**\r
+     * <p>\r
      * 入力層と隠れ層の間の重み係数。\r
      * </p>\r
      */\r
@@ -163,11 +184,79 @@ public final class SimpleNN {
             e.printStackTrace();\r
         }\r
 \r
+        this.inputNumber = inputNumber;\r
+        this.hiddenNumber = hiddenNumber;\r
+        this.outputNumber = outputNumber;\r
         this.learningCoefficient = learningCoefficient;\r
     }\r
 \r
     /**\r
      * <p>\r
+     * 入力データと教師信号を用いて、ニューラル・ネットワークの状態を更新します(学習します)。\r
+     * </p>\r
+     * \r
+     * @param input\r
+     *            入力データ。\r
+     * @param teach\r
+     *            教師信号。\r
+     * @throws NullPointerException\r
+     *             input引数、teach引数がnullの場合。\r
+     * @throws IllegalArgumentException\r
+     *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
+     */\r
+    public void learn(double[] input, double[] teach) {\r
+        double[] output = new double[this.outputNumber];\r
+        double[] hiddenOutput = new double[this.hiddenNumber];\r
+\r
+        this.calcForward(input, output, hiddenOutput);\r
+        this.calcBackward(input, output, hiddenOutput, teach);\r
+    }\r
+\r
+    /**\r
+     * <p>\r
+     * 入力データをニューラル・ネットワークを用いて計算します。\r
+     * </p>\r
+     * \r
+     * @param input\r
+     *            入力データ。\r
+     * @return 計算結果の出力データ。\r
+     * @throws NullPointerException\r
+     *             input引数がnullの場合。\r
+     * @throws IllegalArgumentException\r
+     *             input引数の配列要素数が入力層のニューロン数と異なる場合。\r
+     */\r
+    public double[] calculate(double[] input) {\r
+        double[] output = new double[this.outputNumber];\r
+\r
+        this.calcForward(input, output, new double[this.hiddenNumber]);\r
+\r
+        return output;\r
+    }\r
+\r
+    /**\r
+     * <p>\r
+     * 入力データから導き出される出力データと教師信号とのずれを表す、二乗誤差を算出します。\r
+     * </p>\r
+     * \r
+     * @param input\r
+     *            入力データ。\r
+     * @param teach\r
+     *            教師信号。\r
+     * @return 二乗誤差。\r
+     * @throws NullPointerException\r
+     *             input引数、teach引数がnullの場合。\r
+     * @throws IllegalArgumentException\r
+     *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
+     */\r
+    public double reportError(double[] input, double[] teach) {\r
+        double[] output = this.calculate(input);\r
+        double err = this.calcError(output, teach);\r
+\r
+        return err;\r
+    }\r
+\r
+    /**\r
+     * <p>\r
      * 順方向演算を行います。\r
      * </p>\r
      * \r
@@ -179,8 +268,10 @@ public final class SimpleNN {
      *            順方向演算の過程の隠れ層出力を格納する配列。\r
      * @throws NullPointerException\r
      *             input引数、output引数、hiddenOutput引数がnullの場合。\r
+     * @throws IllegalArgumentException\r
+     *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。\r
      */\r
-    public void forwardNeuralNet(double[] input, double[] output, double[] hiddenOutput) {\r
+    private void calcForward(double[] input, double[] output, double[] hiddenOutput) {\r
         /*\r
          * 引数を確認します。\r
          */\r
@@ -193,6 +284,15 @@ public final class SimpleNN {
         if (hiddenOutput == null) {\r
             throw new NullPointerException("hiddenOutput == null");\r
         }\r
+        if (input.length != this.inputNumber) {\r
+            throw new IllegalArgumentException("input.length != inputNumber");\r
+        }\r
+        if (output.length != this.outputNumber) {\r
+            throw new IllegalArgumentException("output.length != outputNumber");\r
+        }\r
+        if (hiddenOutput.length != this.hiddenNumber) {\r
+            throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");\r
+        }\r
 \r
         /*\r
          * 隠れ層の出力を計算します。\r
@@ -232,8 +332,10 @@ public final class SimpleNN {
      *            教師信号。\r
      * @throws NullPointerException\r
      *             input引数、output引数、hiddenOutput引数、teach引数がnullの場合。\r
+     * @throws IllegalArgumentException\r
+     *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
      */\r
-    public void backwardNeuralNet(double[] input, double[] output, double[] hiddenOutput, double[] teach) {\r
+    private void calcBackward(double[] input, double[] output, double[] hiddenOutput, double[] teach) {\r
         /*\r
          * 引数を確認します。\r
          */\r
@@ -249,6 +351,18 @@ public final class SimpleNN {
         if (teach == null) {\r
             throw new NullPointerException("teach == null");\r
         }\r
+        if (input.length != this.inputNumber) {\r
+            throw new IllegalArgumentException("input.length != inputNumber");\r
+        }\r
+        if (output.length != this.outputNumber) {\r
+            throw new IllegalArgumentException("output.length != outputNumber");\r
+        }\r
+        if (hiddenOutput.length != this.hiddenNumber) {\r
+            throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");\r
+        }\r
+        if (teach.length != this.outputNumber) {\r
+            throw new IllegalArgumentException("teach.length != outputNumber");\r
+        }\r
 \r
         /*\r
          * 出力層の誤差を計算します。\r
@@ -308,7 +422,7 @@ public final class SimpleNN {
      * @throws NullPointerException\r
      *             output引数、teach引数がnullの場合。\r
      */\r
-    public double calcError(double[] output, double[] teach) {\r
+    private double calcError(double[] output, double[] teach) {\r
         /*\r
          * 引数を確認します。\r
          */\r
@@ -339,7 +453,7 @@ public final class SimpleNN {
      *            引数。\r
      * @return 計算結果。\r
      */\r
-    public double sigmoid(double x) {\r
+    private double sigmoid(double x) {\r
         return 1.0 / (1.0 + Math.exp(-SimpleNN.BETA * x));\r
     }\r
 \r
index 42d27c2..06c40e4 100644 (file)
@@ -615,17 +615,11 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
                     for (r = 0; r < SimpleNNApplet.INNER_CYCLES; r++) {\r
                         // 内部サイクル\r
 \r
-                        // 順方向演算\r
-                        this.simpleNN.forwardNeuralNet(this.sampleIn, this.recognizeOut, hiddenOutput);\r
-\r
-                        // 逆方向演算(バックプロパゲーション)\r
-                        this.simpleNN.backwardNeuralNet(this.sampleIn, this.recognizeOut, hiddenOutput, this.teach);\r
+                        this.simpleNN.learn(this.sampleIn, this.teach);\r
                     }\r
 \r
                     // 内部二乗誤差の計算\r
-\r
-                    // 内部二乗誤差のクリヤー\r
-                    innerError = this.simpleNN.calcError(this.recognizeOut, this.teach);\r
+                    innerError = this.simpleNN.reportError(this.sampleIn, this.teach);\r
 \r
                     // 外部二乗誤差への累加算\r
                     outerError += innerError;\r
@@ -655,7 +649,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
                 this.sampleIn = this.sampleArray[q];\r
 \r
                 // 順方向演算\r
-                this.simpleNN.forwardNeuralNet(this.sampleIn, this.recognizeOut, hiddenOutput);\r
+                this.recognizeOut = this.simpleNN.calculate(this.sampleIn);\r
 \r
                 // 結果の表示\r
                 g.setColor(Color.black);\r
@@ -709,7 +703,7 @@ public final class SimpleNNApplet extends Applet implements MouseListener, Mouse
         Graphics g = this.getGraphics();\r
 \r
         // 順方向演算\r
-        this.simpleNN.forwardNeuralNet(this.writtenIn, this.recognizeOut, new double[SimpleNNApplet.HIDDEN]);\r
+        this.recognizeOut = this.simpleNN.calculate(this.sampleIn);\r
 \r
         // 結果の表示\r
         for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r