2 * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.
\r
4 * Redistribution and use in source and binary forms, with or without
\r
5 * modification, are permitted provided that the following conditions
\r
8 * 1. Redistributions of source code must retain the above copyright
\r
9 * notice, this list of conditions and the following disclaimer.
\r
11 * 2. Redistributions in binary form must reproduce the above copyright
\r
12 * notice, this list of conditions and the following disclaimer in the
\r
13 * documentation and/or other materials provided with the distribution.
\r
15 * 3. Neither the name of Clarkware Consulting, Inc. nor the names of its
\r
16 * contributors may be used to endorse or promote products derived
\r
17 * from this software without prior written permission. For written
\r
18 * permission, please contact clarkware@clarkware.com.
\r
20 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES,
\r
21 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
\r
22 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
\r
23 * CLARKWARE CONSULTING OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
\r
24 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
\r
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
\r
26 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
\r
27 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
\r
28 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
\r
29 * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\r
32 package jp.gr.java_conf.u6k.simplenn;
\r
34 import java.io.BufferedReader;
\r
35 import java.io.IOException;
\r
36 import java.io.InputStreamReader;
\r
40 * 簡単に使用できるニューラル・ネットワークの実装クラスです。学習方法はバックプロパゲーション法です。
\r
43 * このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。
\r
47 * @see http://codezine.jp/a/article/aid/372.aspx
\r
49 public final class SimpleNN {
\r
53 * ネットワークを流れる値の上限(1)の半分。
\r
56 private static final double VALUE_HALF = 0.5;
\r
63 private int inputNumber;
\r
70 private int hiddenNumber;
\r
77 private int outputNumber;
\r
84 private double[] weightInHidden;
\r
91 private double[] thresholdHidden;
\r
98 private double[] weightHiddenOut;
\r
105 private double[] thresholdOut;
\r
109 * 学習係数(learning coefficient)。
\r
112 private double learningCoefficient;
\r
116 * ニューラル・ネットワークの状態(閾値、重み)を初期化します。
\r
119 * @param inputNumber
\r
121 * @param hiddenNumber
\r
123 * @param outputNumber
\r
125 * @param learningCoefficient
\r
127 * @throws IllegalArgumentException
\r
128 * inputNumber引数、hiddenNumber引数、outputNumber引数、learningCoefficient引数が0以下の場合。
\r
130 public SimpleNN(int inputNumber, int hiddenNumber, int outputNumber, double learningCoefficient) {
\r
134 if (inputNumber <= 0) {
\r
135 throw new IllegalArgumentException("inputNumber <= 0");
\r
137 if (hiddenNumber <= 0) {
\r
138 throw new IllegalArgumentException("hiddenNumber <= 0");
\r
140 if (outputNumber <= 0) {
\r
141 throw new IllegalArgumentException("outputNumber <= 0");
\r
143 if (learningCoefficient <= 0) {
\r
144 throw new IllegalArgumentException("learningCoefficient <= 0");
\r
148 * ニューラル・ネットワークの状態を初期化します。
\r
150 this.thresholdHidden = new double[hiddenNumber];
\r
151 this.weightInHidden = new double[inputNumber * hiddenNumber];
\r
152 this.thresholdOut = new double[outputNumber];
\r
153 this.weightHiddenOut = new double[hiddenNumber * outputNumber];
\r
154 // TODO ちゃんとランダムに初期化する。
\r
156 BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));
\r
158 for (int i = 0; i < hiddenNumber; i++) {
\r
159 this.thresholdHidden[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;
\r
160 for (int j = 0; j < inputNumber; j++) {
\r
161 this.weightInHidden[j * this.hiddenNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;
\r
164 for (int i = 0; i < outputNumber; i++) {
\r
165 this.thresholdOut[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;
\r
166 for (int j = 0; j < hiddenNumber; j++) {
\r
167 this.weightHiddenOut[j * this.outputNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;
\r
173 } catch (IOException e) {
\r
174 e.printStackTrace();
\r
177 this.inputNumber = inputNumber;
\r
178 this.hiddenNumber = hiddenNumber;
\r
179 this.outputNumber = outputNumber;
\r
180 this.learningCoefficient = learningCoefficient;
\r
185 * 入力データと教師信号を用いて、ニューラル・ネットワークの状態を更新します(学習します)。
\r
192 * @throws NullPointerException
\r
193 * input引数、teach引数がnullの場合。
\r
194 * @throws IllegalArgumentException
\r
195 * input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
\r
197 public void learn(double[] input, double[] teach) {
\r
198 double[] output = new double[this.outputNumber];
\r
199 double[] hiddenOutput = new double[this.hiddenNumber];
\r
201 this.calcForward(input, output, hiddenOutput);
\r
202 this.calcBackward(input, output, hiddenOutput, teach);
\r
207 * 入力データをニューラル・ネットワークを用いて計算します。
\r
212 * @return 計算結果の出力データ。
\r
213 * @throws NullPointerException
\r
215 * @throws IllegalArgumentException
\r
216 * input引数の配列要素数が入力層のニューロン数と異なる場合。
\r
218 public double[] calculate(double[] input) {
\r
219 double[] output = new double[this.outputNumber];
\r
221 this.calcForward(input, output, new double[this.hiddenNumber]);
\r
228 * 入力データから導き出される出力データと教師信号とのずれを表す、二乗誤差を算出します。
\r
236 * @throws NullPointerException
\r
237 * input引数、teach引数がnullの場合。
\r
238 * @throws IllegalArgumentException
\r
239 * input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
\r
241 public double reportError(double[] input, double[] teach) {
\r
242 double[] output = this.calculate(input);
\r
243 double err = this.calcError(output, teach);
\r
257 * @param hiddenOutput
\r
258 * 順方向演算の過程の隠れ層出力を格納する配列。
\r
259 * @throws NullPointerException
\r
260 * input引数、output引数、hiddenOutput引数がnullの場合。
\r
261 * @throws IllegalArgumentException
\r
262 * input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。
\r
264 private void calcForward(double[] input, double[] output, double[] hiddenOutput) {
\r
268 if (input == null) {
\r
269 throw new NullPointerException("input == null");
\r
271 if (output == null) {
\r
272 throw new NullPointerException("output == null");
\r
274 if (hiddenOutput == null) {
\r
275 throw new NullPointerException("hiddenOutput == null");
\r
277 if (input.length != this.inputNumber) {
\r
278 throw new IllegalArgumentException("input.length != inputNumber");
\r
280 if (output.length != this.outputNumber) {
\r
281 throw new IllegalArgumentException("output.length != outputNumber");
\r
283 if (hiddenOutput.length != this.hiddenNumber) {
\r
284 throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");
\r
290 for (int i = 0; i < hiddenOutput.length; i++) {
\r
291 hiddenOutput[i] = -this.thresholdHidden[i];
\r
292 for (int j = 0; j < input.length; j++) {
\r
293 hiddenOutput[i] += input[j] * this.weightInHidden[j * this.hiddenNumber + i];
\r
295 hiddenOutput[i] = this.sigmoid(hiddenOutput[i]);
\r
301 for (int i = 0; i < output.length; i++) {
\r
302 output[i] = -this.thresholdOut[i];
\r
303 for (int j = 0; j < hiddenOutput.length; j++) {
\r
304 output[i] += hiddenOutput[j] * this.weightHiddenOut[j * this.outputNumber + i];
\r
306 output[i] = this.sigmoid(output[i]);
\r
319 * @param hiddenOutput
\r
320 * 順方向演算の過程の隠れ層出力を格納する配列。
\r
323 * @throws NullPointerException
\r
324 * input引数、output引数、hiddenOutput引数、teach引数がnullの場合。
\r
325 * @throws IllegalArgumentException
\r
326 * input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
\r
328 private void calcBackward(double[] input, double[] output, double[] hiddenOutput, double[] teach) {
\r
332 if (input == null) {
\r
333 throw new NullPointerException("input == null");
\r
335 if (output == null) {
\r
336 throw new NullPointerException("output == null");
\r
338 if (hiddenOutput == null) {
\r
339 throw new NullPointerException("hiddenOutput == null");
\r
341 if (teach == null) {
\r
342 throw new NullPointerException("teach == null");
\r
344 if (input.length != this.inputNumber) {
\r
345 throw new IllegalArgumentException("input.length != inputNumber");
\r
347 if (output.length != this.outputNumber) {
\r
348 throw new IllegalArgumentException("output.length != outputNumber");
\r
350 if (hiddenOutput.length != this.hiddenNumber) {
\r
351 throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");
\r
353 if (teach.length != this.outputNumber) {
\r
354 throw new IllegalArgumentException("teach.length != outputNumber");
\r
360 double[] outputError = new double[output.length];
\r
361 for (int i = 0; i < outputError.length; i++) {
\r
362 outputError[i] = (teach[i] - output[i]) * output[i] * (1.0 - output[i]);
\r
368 double[] hiddenError = new double[hiddenOutput.length];
\r
369 for (int i = 0; i < hiddenError.length; i++) {
\r
371 for (int j = 0; j < output.length; j++) {
\r
372 err += outputError[j] * this.weightHiddenOut[i * this.outputNumber + j];
\r
374 hiddenError[i] = hiddenOutput[i] * (1.0 - hiddenOutput[i]) * err;
\r
380 for (int i = 0; i < outputError.length; i++) {
\r
381 for (int j = 0; j < hiddenOutput.length; j++) {
\r
382 this.weightHiddenOut[j * this.outputNumber + i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];
\r
385 for (int i = 0; i < hiddenError.length; i++) {
\r
386 for (int j = 0; j < input.length; j++) {
\r
387 this.weightInHidden[j * this.hiddenNumber + i] += this.learningCoefficient * hiddenError[i] * input[j];
\r
394 for (int i = 0; i < this.thresholdOut.length; i++) {
\r
395 this.thresholdOut[i] -= this.learningCoefficient * outputError[i];
\r
397 for (int i = 0; i < this.thresholdHidden.length; i++) {
\r
398 this.thresholdHidden[i] -= this.learningCoefficient * hiddenError[i];
\r
404 * 順方向演算の結果と教師信号とのずれを表す二乗誤差を計算します。
\r
412 * @throws NullPointerException
\r
413 * output引数、teach引数がnullの場合。
\r
415 private double calcError(double[] output, double[] teach) {
\r
419 if (output == null) {
\r
420 throw new NullPointerException("output == null");
\r
422 if (teach == null) {
\r
423 throw new NullPointerException("teach == null");
\r
430 for (int i = 0; i < output.length; i++) {
\r
431 error += (teach[i] - output[i]) * (teach[i] - output[i]);
\r
446 private double sigmoid(double x) {
\r
447 return 1.0 / (1.0 + Math.exp(-x));
\r