package jp.gr.java_conf.u6k.simplenn;\r
\r
import java.io.BufferedReader;\r
-import java.io.BufferedWriter;\r
-import java.io.FileOutputStream;\r
import java.io.IOException;\r
import java.io.InputStreamReader;\r
-import java.io.OutputStreamWriter;\r
\r
/**\r
* <p>\r
* 入力層と隠れ層の間の重み係数。\r
* </p>\r
*/\r
- private double[][] weightInHidden;\r
+ private double[] weightInHidden;\r
\r
/**\r
* <p>\r
* 隠れ層と出力層の間の重み係数。\r
* </p>\r
*/\r
- private double[][] weightHiddenOut;\r
+ private double[] weightHiddenOut;\r
\r
/**\r
* <p>\r
* ニューラル・ネットワークの状態を初期化します。\r
*/\r
this.thresholdHidden = new double[hiddenNumber];\r
- this.weightInHidden = new double[inputNumber][hiddenNumber];\r
+ this.weightInHidden = new double[inputNumber * hiddenNumber];\r
this.thresholdOut = new double[outputNumber];\r
- this.weightHiddenOut = new double[hiddenNumber][outputNumber];\r
+ this.weightHiddenOut = new double[hiddenNumber * outputNumber];\r
// TODO ちゃんとランダムに初期化する。\r
try {\r
BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));\r
for (int i = 0; i < hiddenNumber; i++) {\r
this.thresholdHidden[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
for (int j = 0; j < inputNumber; j++) {\r
- this.weightInHidden[j][i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
+ this.weightInHidden[j * this.hiddenNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
}\r
}\r
for (int i = 0; i < outputNumber; i++) {\r
this.thresholdOut[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
for (int j = 0; j < hiddenNumber; j++) {\r
- this.weightHiddenOut[j][i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
+ this.weightHiddenOut[j * this.outputNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
}\r
}\r
} finally {\r
for (int i = 0; i < hiddenOutput.length; i++) {\r
hiddenOutput[i] = -this.thresholdHidden[i];\r
for (int j = 0; j < input.length; j++) {\r
- hiddenOutput[i] += input[j] * this.weightInHidden[j][i];\r
+ hiddenOutput[i] += input[j] * this.weightInHidden[j * this.hiddenNumber + i];\r
}\r
hiddenOutput[i] = this.sigmoid(hiddenOutput[i]);\r
}\r
for (int i = 0; i < output.length; i++) {\r
output[i] = -this.thresholdOut[i];\r
for (int j = 0; j < hiddenOutput.length; j++) {\r
- output[i] += hiddenOutput[j] * this.weightHiddenOut[j][i];\r
+ output[i] += hiddenOutput[j] * this.weightHiddenOut[j * this.outputNumber + i];\r
}\r
output[i] = this.sigmoid(output[i]);\r
}\r
for (int i = 0; i < hiddenError.length; i++) {\r
double err = 0;\r
for (int j = 0; j < output.length; j++) {\r
- err += outputError[j] * this.weightHiddenOut[i][j];\r
+ err += outputError[j] * this.weightHiddenOut[i * this.outputNumber + j];\r
}\r
hiddenError[i] = hiddenOutput[i] * (1.0 - hiddenOutput[i]) * err;\r
}\r
*/\r
for (int i = 0; i < outputError.length; i++) {\r
for (int j = 0; j < hiddenOutput.length; j++) {\r
- this.weightHiddenOut[j][i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];\r
+ this.weightHiddenOut[j * this.outputNumber + i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];\r
}\r
}\r
for (int i = 0; i < hiddenError.length; i++) {\r
for (int j = 0; j < input.length; j++) {\r
- this.weightInHidden[j][i] += this.learningCoefficient * hiddenError[i] * input[j];\r
+ this.weightInHidden[j * this.hiddenNumber + i] += this.learningCoefficient * hiddenError[i] * input[j];\r
}\r
}\r
\r
return 1.0 / (1.0 + Math.exp(-x));\r
}\r
\r
- /**\r
- * <p>\r
- * インスタンスの状態を「C:/result.txt」に出力します。\r
- * </p>\r
- * \r
- * @throws IOException\r
- * 出力に失敗した場合。\r
- */\r
- public void outputState() throws IOException {\r
- BufferedWriter w = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("C:/result.txt")));\r
- try {\r
- for (int i = 0; i < this.weightInHidden.length; i++) {\r
- for (int j = 0; j < this.weightInHidden[i].length; j++) {\r
- w.write("weightInHidden[" + i + "][" + j + "]=" + this.weightInHidden[i][j]);\r
- w.newLine();\r
- }\r
- }\r
- for (int i = 0; i < this.thresholdHidden.length; i++) {\r
- w.write("thresholdHidden[" + i + "]=" + this.thresholdHidden[i]);\r
- w.newLine();\r
- }\r
- for (int i = 0; i < this.weightHiddenOut.length; i++) {\r
- for (int j = 0; j < this.weightHiddenOut[i].length; j++) {\r
- w.write("weightHiddenOut[" + i + "][" + j + "]=" + this.weightHiddenOut[i][j]);\r
- w.newLine();\r
- }\r
- }\r
- for (int i = 0; i < this.thresholdOut.length; i++) {\r
- w.write("thresholdOut[" + i + "]=" + this.thresholdOut[i]);\r
- w.newLine();\r
- }\r
- } finally {\r
- w.close();\r
- }\r
- }\r
-\r
}\r