/*\r
- * Copyright (C) 2007 u6k.yu1@gmail.com, All Rights Reserved.\r
+ * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.\r
*\r
* Redistribution and use in source and binary forms, with or without\r
* modification, are permitted provided that the following conditions\r
import java.awt.event.MouseListener;\r
import java.awt.event.MouseMotionListener;\r
import java.io.BufferedReader;\r
-import java.io.BufferedWriter;\r
-import java.io.FileOutputStream;\r
import java.io.IOException;\r
import java.io.InputStreamReader;\r
-import java.io.OutputStreamWriter;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int X0 = 10;\r
+ private static final int X0 = 10;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int X1 = 125;\r
+ private static final int X1 = 125;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int Y0 = 55;\r
+ private static final int Y0 = 55;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int Y1 = 70;\r
+ private static final int Y1 = 70;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int Y2 = 160;\r
+ private static final int Y2 = 160;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int Y3 = 240;\r
+ private static final int Y3 = 240;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int Y4 = 305;\r
+ private static final int Y4 = 305;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RX0 = 30;\r
+ private static final int RX0 = 30;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RX1 = 60;\r
+ private static final int RX1 = 60;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RX2 = 210;\r
+ private static final int RX2 = 210;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RX3 = 260;\r
+ private static final int RX3 = 260;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RY0 = 225;\r
+ private static final int RY0 = 225;\r
\r
/**\r
* <p>\r
* 文字列とかの表示基準座標。\r
* </p>\r
*/\r
- private static final int RY1 = 240;\r
+ private static final int RY1 = 240;\r
\r
/**\r
* <p>\r
* 入力データの幅。\r
* </p>\r
*/\r
- private static final int WIDTH = 7;\r
+ private static final int WIDTH = 7;\r
\r
/**\r
* <p>\r
* 入力データの高さ。\r
* </p>\r
*/\r
- private static final int HEIGHT = 11;\r
+ private static final int HEIGHT = 11;\r
\r
/**\r
* <p>\r
* 入力層の数(入力データ数)。\r
* </p>\r
*/\r
- private static final int INPUT = SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT;\r
+ private static final int INPUT = SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT;\r
\r
/**\r
* <p>\r
* 隠れ層の数。\r
* </p>\r
*/\r
- private static final int HIDDEN = 16;\r
+ private static final int HIDDEN = 16;\r
\r
/**\r
* <p>\r
* パターンの種類。\r
* </p>\r
*/\r
- private static final int PATTERN = 10;\r
+ private static final int PATTERN = 10;\r
\r
/**\r
* <p>\r
* 出力層の数(出力データ数)。\r
* </p>\r
*/\r
- private static final int OUTPUT = SimpleNNApplet.PATTERN;\r
+ private static final int OUTPUT = SimpleNNApplet.PATTERN;\r
\r
/**\r
* <p>\r
* 外部サイクル(一連のパターンの繰返し学習)の回数。\r
* </p>\r
*/\r
- private static final int OUTER_CYCLES = 100;\r
+ private static final int OUTER_CYCLES = 100;\r
\r
/**\r
* <p>\r
* 内部サイクル(同一パターンの繰返し学習)の回数。\r
* </p>\r
*/\r
- private static final int INNER_CYCLES = 100;\r
-\r
- /**\r
- * <p>\r
- * 学習の加速係数。\r
- * </p>\r
- */\r
- private static final double ALPHA = 1.2;\r
-\r
- /**\r
- * <p>\r
- * シグモイド曲線の傾斜。\r
- * </p>\r
- */\r
- private static final double BETA = 1.2;\r
-\r
- /**\r
- * <p>\r
- * ネットワークを流れる値の上限(1)の半分。\r
- * </p>\r
- */\r
- private static final double VALUE_HALF = 0.5;\r
+ private static final int INNER_CYCLES = 100;\r
\r
/**\r
* <p>\r
* 「再学習」ボタン。\r
* </p>\r
*/\r
- private Button button1;\r
+ private Button button1;\r
\r
/**\r
* <p>\r
* 「学習終了」ボタン。\r
* </p>\r
*/\r
- private Button button2;\r
+ private Button button2;\r
\r
/**\r
* <p>\r
* 「入力クリア」ボタン。\r
* </p>\r
*/\r
- private Button button3;\r
+ private Button button3;\r
\r
/**\r
* <p>\r
* 「認識」ボタン。\r
* </p>\r
*/\r
- private Button button4;\r
+ private Button button4;\r
\r
/**\r
* <p>\r
* 「状態出力」ボタン。\r
* </p>\r
*/\r
- private Button button5;\r
+ private Button button5;\r
\r
/**\r
* <p>\r
* 学習用入力。\r
* </p>\r
*/\r
- private double[] sampleIn = new double[SimpleNNApplet.INPUT];\r
+ private double[] sampleIn = new double[SimpleNNApplet.INPUT];\r
\r
/**\r
* <p>\r
* 認識用手書き入力。\r
* </p>\r
*/\r
- private double[] writtenIn = new double[SimpleNNApplet.INPUT];\r
-\r
- /**\r
- * <p>\r
- * 入力層と隠れ層の間の重み係数。\r
- * </p>\r
- */\r
- private double[][] weightInHidden = new double[SimpleNNApplet.INPUT][SimpleNNApplet.HIDDEN];\r
-\r
- /**\r
- * <p>\r
- * 隠れ層の閾値。\r
- * </p>\r
- */\r
- private double[] thresholdHidden = new double[SimpleNNApplet.HIDDEN];\r
-\r
- /**\r
- * <p>\r
- * 隠れ層の出力。\r
- * </p>\r
- */\r
- private double[] hiddenOut = new double[SimpleNNApplet.HIDDEN];\r
-\r
- /**\r
- * <p>\r
- * 隠れ層と出力層の間の重み係数。\r
- * </p>\r
- */\r
- private double[][] weightHiddenOut = new double[SimpleNNApplet.HIDDEN][SimpleNNApplet.OUTPUT];\r
-\r
- /**\r
- * <p>\r
- * 出力層の閾値。\r
- * </p>\r
- */\r
- private double[] thresholdOut = new double[SimpleNNApplet.OUTPUT];\r
+ private double[] writtenIn = new double[SimpleNNApplet.INPUT];\r
\r
/**\r
* <p>\r
* 認識出力(出力層の出力)。\r
* </p>\r
*/\r
- private double[] recognizeOut = new double[SimpleNNApplet.OUTPUT];\r
+ private double[] recognizeOut = new double[SimpleNNApplet.OUTPUT];\r
\r
/**\r
* <p>\r
* 教師信号。\r
* </p>\r
*/\r
- private double[] teach = new double[SimpleNNApplet.PATTERN];\r
+ private double[] teach = new double[SimpleNNApplet.PATTERN];\r
\r
/**\r
* <p>\r
* 「学習モード」フラグ。\r
* </p>\r
*/\r
- private boolean learningFlag;\r
+ private boolean learningFlag;\r
\r
/**\r
* <p>\r
* 学習用入力データの基となるパターン。\r
* </p>\r
*/\r
- private double[][] sampleArray;\r
+ private double[][] sampleArray;\r
\r
/**\r
* <p>\r
* パターンと出力すべき教師信号の比較表。\r
* </p>\r
*/\r
- private double[][] teachArray = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.OUTPUT];\r
+ private double[][] teachArray = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.OUTPUT];\r
\r
/**\r
* <p>\r
* 手書き文字入力用座標。\r
* </p>\r
*/\r
- private int xNew;\r
+ private int xNew;\r
\r
/**\r
* <p>\r
* 手書き文字入力用座標。\r
* </p>\r
*/\r
- private int yNew;\r
+ private int yNew;\r
\r
/**\r
* <p>\r
* 手書き文字入力用座標。\r
* </p>\r
*/\r
- private int xOld;\r
+ private int xOld;\r
\r
/**\r
* <p>\r
* 手書き文字入力用座標。\r
* </p>\r
*/\r
- private int yOld;\r
+ private int yOld;\r
+\r
+ /**\r
+ * <p>\r
+ * ニューラル・ネットワーク。\r
+ * </p>\r
+ */\r
+ private SimpleNN simpleNN;\r
\r
/**\r
* <p>\r
// 「状態出力」\r
System.out.println("状態出力開始。");\r
try {\r
- this.resultOutput();\r
+ this.simpleNN.outputState();\r
} catch (IOException e) {\r
e.printStackTrace();\r
}\r
// 学習モード\r
\r
// 閾値と重みの乱数設定\r
- this.initNetwork();\r
+ this.simpleNN = new SimpleNN();\r
\r
// -------------------------- 学習 --------------------------\r
for (p = 0; p < SimpleNNApplet.OUTER_CYCLES; p++) {\r
// 内部サイクル\r
\r
// 順方向演算\r
- this.recognizeOut = this.forwardNeuralNet(this.sampleIn);\r
+ this.recognizeOut = this.simpleNN.forwardNeuralNet(this.sampleIn);\r
\r
// 逆方向演算(バックプロパゲーション)\r
- this.backwardNeuralNet(this.recognizeOut, this.teach);\r
+ this.simpleNN.backwardNeuralNet(this.sampleIn, this.recognizeOut, this.teach);\r
}\r
\r
// 内部二乗誤差の計算\r
\r
// 内部二乗誤差のクリヤー\r
- innerError = this.calcError(this.recognizeOut, this.teach);\r
+ innerError = this.simpleNN.calcError(this.recognizeOut, this.teach);\r
\r
// 外部二乗誤差への累加算\r
outerError += innerError;\r
this.sampleIn = this.sampleArray[q];\r
\r
// 順方向演算\r
- this.recognizeOut = this.forwardNeuralNet(this.sampleIn);\r
+ this.recognizeOut = this.simpleNN.forwardNeuralNet(this.sampleIn);\r
\r
// 結果の表示\r
g.setColor(Color.black);\r
\r
/**\r
* <p>\r
- * 順方向演算を行います。\r
- * </p>\r
- * \r
- * @param input\r
- * 入力データ。\r
- * @return 演算結果。\r
- */\r
- public double[] forwardNeuralNet(double[] input) {\r
- double[] out = new double[SimpleNNApplet.OUTPUT];\r
- double[] hidden = new double[SimpleNNApplet.HIDDEN];\r
-\r
- // 隠れ層出力の計算\r
- for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- hidden[j] = -this.thresholdHidden[j];\r
- for (int i = 0; i < SimpleNNApplet.INPUT; i++) {\r
- hidden[j] += input[i] * this.weightInHidden[i][j];\r
- }\r
- this.hiddenOut[j] = this.sigmoid(hidden[j]);\r
- }\r
-\r
- // 出力層出力の計算\r
- for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- out[k] = -this.thresholdOut[k];\r
- for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- out[k] += this.hiddenOut[j] * this.weightHiddenOut[j][k];\r
- }\r
- out[k] = this.sigmoid(out[k]);\r
- }\r
-\r
- return out;\r
- }\r
-\r
- /**\r
- * <p>\r
- * 逆方向演算を行います。\r
- * </p>\r
- * \r
- * @param output\r
- * 順方向演算の結果。\r
- * @param teach\r
- * 教師信号。\r
- */\r
- public void backwardNeuralNet(double[] output, double[] teach) {\r
- int i;\r
- int j;\r
- int k;\r
-\r
- // 出力層の誤差\r
- double[] outputError = new double[SimpleNNApplet.OUTPUT];\r
- // 隠れ層の誤差\r
- double[] hiddenError = new double[SimpleNNApplet.HIDDEN];\r
- double tempError;\r
-\r
- // 出力層の誤差の計算\r
- for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- outputError[k] = (teach[k] - output[k]) * output[k] * (1.0 - output[k]);\r
- }\r
-\r
- // 隠れ層の誤差の計算\r
- for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- tempError = 0;\r
- for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- tempError += outputError[k] * this.weightHiddenOut[j][k];\r
- }\r
- hiddenError[j] = this.hiddenOut[j] * (1.0 - this.hiddenOut[j]) * tempError;\r
- }\r
-\r
- // 重みの補正\r
- for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- this.weightHiddenOut[j][k] += SimpleNNApplet.ALPHA * outputError[k] * this.hiddenOut[j];\r
- }\r
- }\r
- for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- for (i = 0; i < SimpleNNApplet.INPUT; i++) {\r
- this.weightInHidden[i][j] += SimpleNNApplet.ALPHA * hiddenError[j] * this.sampleIn[i];\r
- }\r
- }\r
-\r
- // 閾値の補正\r
- for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- this.thresholdOut[k] -= SimpleNNApplet.ALPHA * outputError[k];\r
- }\r
- for (j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- this.thresholdHidden[j] -= SimpleNNApplet.ALPHA * hiddenError[j];\r
- }\r
- }\r
-\r
- /**\r
- * <p>\r
- * ニューラル・ネットワークの状態(閾値、重み)を初期化します。\r
- * </p>\r
- */\r
- public void initNetwork() {\r
- // TODO ちゃんとランダムに初期化する。\r
- try {\r
- BufferedReader reader = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));\r
- try {\r
- for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- this.thresholdHidden[j] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
- for (int i = 0; i < SimpleNNApplet.INPUT; i++) {\r
- this.weightInHidden[i][j] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
- }\r
- }\r
- for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
- this.thresholdOut[k] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
- for (int j = 0; j < SimpleNNApplet.HIDDEN; j++) {\r
- this.weightHiddenOut[j][k] = Double.parseDouble(reader.readLine()) - SimpleNNApplet.VALUE_HALF;\r
- }\r
- }\r
- } finally {\r
- reader.close();\r
- }\r
- } catch (IOException e) {\r
- e.printStackTrace();\r
- }\r
- }\r
-\r
- /**\r
- * <p>\r
- * 順方向演算の結果と教師信号とのずれを表す二乗誤差を算出します。\r
- * </p>\r
- * \r
- * @param output\r
- * 順方向演算の結果。\r
- * @param teach\r
- * 教師信号。\r
- * @return 二乗誤差。\r
- */\r
- public double calcError(double[] output, double[] teach) {\r
- double error = 0;\r
-\r
- for (int i = 0; i < output.length; i++) {\r
- error += (teach[i] - output[i]) * (teach[i] - output[i]);\r
- }\r
-\r
- return error;\r
- }\r
-\r
- /**\r
- * <p>\r
- * シグモイド関数です。\r
- * </p>\r
- * \r
- * @param x\r
- * 引数。\r
- * @return 計算結果。\r
- */\r
- public double sigmoid(double x) {\r
- return 1.0 / (1.0 + Math.exp(-SimpleNNApplet.BETA * x));\r
- }\r
-\r
- /**\r
- * <p>\r
* 手書き入力された文字を認識します。\r
* </p>\r
*/\r
Graphics g = this.getGraphics();\r
\r
// 順方向演算\r
- this.recognizeOut = this.forwardNeuralNet(this.writtenIn);\r
+ this.recognizeOut = this.simpleNN.forwardNeuralNet(this.writtenIn);\r
\r
// 結果の表示\r
for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {\r
}\r
}\r
\r
- private void resultOutput() throws IOException {\r
- BufferedWriter w = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("C:/result.txt")));\r
- try {\r
- for (int i = 0; i < this.weightInHidden.length; i++) {\r
- for (int j = 0; j < this.weightInHidden[i].length; j++) {\r
- w.write("weightInHidden[" + i + "][" + j + "]=" + this.weightInHidden[i][j]);\r
- w.newLine();\r
- }\r
- }\r
- for (int i = 0; i < this.thresholdHidden.length; i++) {\r
- w.write("thresholdHidden[" + i + "]=" + this.thresholdHidden[i]);\r
- w.newLine();\r
- }\r
- for (int i = 0; i < this.weightHiddenOut.length; i++) {\r
- for (int j = 0; j < this.weightHiddenOut[i].length; j++) {\r
- w.write("weightHiddenOut[" + i + "][" + j + "]=" + this.weightHiddenOut[i][j]);\r
- w.newLine();\r
- }\r
- }\r
- for (int i = 0; i < this.thresholdOut.length; i++) {\r
- w.write("thresholdOut[" + i + "]=" + this.thresholdOut[i]);\r
- w.newLine();\r
- }\r
- } finally {\r
- w.close();\r
- }\r
- }\r
-\r
}\r