OSDN Git Service

519e29bbe50542741900b540089dde0b2e0e7068
[simplenn/repo.git] / simplenn / src / main / java / jp / gr / java_conf / u6k / simplenn / SimpleNN.java
1 /*\r
2  * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.\r
3  *\r
4  * Redistribution and use in source and binary forms, with or without\r
5  * modification, are permitted provided that the following conditions\r
6  * are met:\r
7  *\r
8  *    1. Redistributions of source code must retain the above copyright\r
9  *       notice, this list of conditions and the following disclaimer.\r
10  *\r
11  *    2. Redistributions in binary form must reproduce the above copyright\r
12  *       notice, this list of conditions and the following disclaimer in the\r
13  *       documentation and/or other materials provided with the distribution.\r
14  *\r
15  *    3. Neither the name of Clarkware Consulting, Inc. nor the names of its\r
16  *       contributors may be used to endorse or promote products derived\r
17  *       from this software without prior written permission. For written\r
18  *       permission, please contact clarkware@clarkware.com.\r
19  *\r
20  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES,\r
21  * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\r
22  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL\r
23  * CLARKWARE CONSULTING OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\r
24  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\r
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,\r
26  * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF\r
27  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\r
28  * NEGLIGENCE OR OTHERWISE) ARISING IN  ANY WAY OUT OF THE USE OF THIS SOFTWARE,\r
29  * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r
30  */\r
31 \r
32 package jp.gr.java_conf.u6k.simplenn;\r
33 \r
34 import java.io.BufferedReader;\r
35 import java.io.IOException;\r
36 import java.io.InputStreamReader;\r
37 \r
38 /**\r
39  * <p>\r
40  * 簡単に使用できるニューラル・ネットワークの実装クラスです。学習方法はバックプロパゲーション法です。\r
41  * </p>\r
42  * <p>\r
43  * このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。\r
44  * </p>\r
45  * \r
46  * @version $Id$\r
47  * @see http://codezine.jp/a/article/aid/372.aspx\r
48  */\r
49 public final class SimpleNN {\r
50 \r
51     /**\r
52      * <p>\r
53      * ネットワークを流れる値の上限(1)の半分。\r
54      * </p>\r
55      */\r
56     private static final double VALUE_HALF = 0.5;\r
57 \r
58     /**\r
59      * <p>\r
60      * 入力層のニューロン数。\r
61      * </p>\r
62      */\r
63     private int                 inputNumber;\r
64 \r
65     /**\r
66      * <p>\r
67      * 隠れ層のニューロン数。\r
68      * </p>\r
69      */\r
70     private int                 hiddenNumber;\r
71 \r
72     /**\r
73      * <p>\r
74      * 出力層のニューロン数。\r
75      * </p>\r
76      */\r
77     private int                 outputNumber;\r
78 \r
79     /**\r
80      * <p>\r
81      * 入力層と隠れ層の間の重み係数。\r
82      * </p>\r
83      */\r
84     private double[]            weightInHidden;\r
85 \r
86     /**\r
87      * <p>\r
88      * 隠れ層の閾値。\r
89      * </p>\r
90      */\r
91     private double[]            thresholdHidden;\r
92 \r
93     /**\r
94      * <p>\r
95      * 隠れ層と出力層の間の重み係数。\r
96      * </p>\r
97      */\r
98     private double[]            weightHiddenOut;\r
99 \r
100     /**\r
101      * <p>\r
102      * 出力層の閾値。\r
103      * </p>\r
104      */\r
105     private double[]            thresholdOut;\r
106 \r
107     /**\r
108      * <p>\r
109      * 学習係数(learning coefficient)。\r
110      * </p>\r
111      */\r
112     private double              learningCoefficient;\r
113 \r
114     /**\r
115      * <p>\r
116      * ニューラル・ネットワークの状態(閾値、重み)を初期化します。\r
117      * </p>\r
118      * \r
119      * @param inputNumber\r
120      *            入力層のニューロン数。\r
121      * @param hiddenNumber\r
122      *            隠れ層のニューロン数。\r
123      * @param outputNumber\r
124      *            出力層のニューロン数。\r
125      * @param learningCoefficient\r
126      *            学習係数。\r
127      * @throws IllegalArgumentException\r
128      *             inputNumber引数、hiddenNumber引数、outputNumber引数、learningCoefficient引数が0以下の場合。\r
129      */\r
130     public SimpleNN(int inputNumber, int hiddenNumber, int outputNumber, double learningCoefficient) {\r
131         /*\r
132          * 引数を確認します。\r
133          */\r
134         if (inputNumber <= 0) {\r
135             throw new IllegalArgumentException("inputNumber <= 0");\r
136         }\r
137         if (hiddenNumber <= 0) {\r
138             throw new IllegalArgumentException("hiddenNumber <= 0");\r
139         }\r
140         if (outputNumber <= 0) {\r
141             throw new IllegalArgumentException("outputNumber <= 0");\r
142         }\r
143         if (learningCoefficient <= 0) {\r
144             throw new IllegalArgumentException("learningCoefficient <= 0");\r
145         }\r
146 \r
147         /*\r
148          * ニューラル・ネットワークの状態を初期化します。\r
149          */\r
150         this.thresholdHidden = new double[hiddenNumber];\r
151         this.weightInHidden = new double[inputNumber * hiddenNumber];\r
152         this.thresholdOut = new double[outputNumber];\r
153         this.weightHiddenOut = new double[hiddenNumber * outputNumber];\r
154         // TODO ちゃんとランダムに初期化する。\r
155         try {\r
156             BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream("random.txt")));\r
157             try {\r
158                 for (int i = 0; i < hiddenNumber; i++) {\r
159                     this.thresholdHidden[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
160                     for (int j = 0; j < inputNumber; j++) {\r
161                         this.weightInHidden[j * this.hiddenNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
162                     }\r
163                 }\r
164                 for (int i = 0; i < outputNumber; i++) {\r
165                     this.thresholdOut[i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
166                     for (int j = 0; j < hiddenNumber; j++) {\r
167                         this.weightHiddenOut[j * this.outputNumber + i] = Double.parseDouble(r.readLine()) - SimpleNN.VALUE_HALF;\r
168                     }\r
169                 }\r
170             } finally {\r
171                 r.close();\r
172             }\r
173         } catch (IOException e) {\r
174             e.printStackTrace();\r
175         }\r
176 \r
177         this.inputNumber = inputNumber;\r
178         this.hiddenNumber = hiddenNumber;\r
179         this.outputNumber = outputNumber;\r
180         this.learningCoefficient = learningCoefficient;\r
181     }\r
182 \r
183     /**\r
184      * <p>\r
185      * 入力データと教師信号を用いて、ニューラル・ネットワークの状態を更新します(学習します)。\r
186      * </p>\r
187      * \r
188      * @param input\r
189      *            入力データ。\r
190      * @param teach\r
191      *            教師信号。\r
192      * @throws NullPointerException\r
193      *             input引数、teach引数がnullの場合。\r
194      * @throws IllegalArgumentException\r
195      *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
196      */\r
197     public void learn(double[] input, double[] teach) {\r
198         double[] output = new double[this.outputNumber];\r
199         double[] hiddenOutput = new double[this.hiddenNumber];\r
200 \r
201         this.calcForward(input, output, hiddenOutput);\r
202         this.calcBackward(input, output, hiddenOutput, teach);\r
203     }\r
204 \r
205     /**\r
206      * <p>\r
207      * 入力データをニューラル・ネットワークを用いて計算します。\r
208      * </p>\r
209      * \r
210      * @param input\r
211      *            入力データ。\r
212      * @return 計算結果の出力データ。\r
213      * @throws NullPointerException\r
214      *             input引数がnullの場合。\r
215      * @throws IllegalArgumentException\r
216      *             input引数の配列要素数が入力層のニューロン数と異なる場合。\r
217      */\r
218     public double[] calculate(double[] input) {\r
219         double[] output = new double[this.outputNumber];\r
220 \r
221         this.calcForward(input, output, new double[this.hiddenNumber]);\r
222 \r
223         return output;\r
224     }\r
225 \r
226     /**\r
227      * <p>\r
228      * 入力データから導き出される出力データと教師信号とのずれを表す、二乗誤差を算出します。\r
229      * </p>\r
230      * \r
231      * @param input\r
232      *            入力データ。\r
233      * @param teach\r
234      *            教師信号。\r
235      * @return 二乗誤差。\r
236      * @throws NullPointerException\r
237      *             input引数、teach引数がnullの場合。\r
238      * @throws IllegalArgumentException\r
239      *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
240      */\r
241     public double reportError(double[] input, double[] teach) {\r
242         double[] output = this.calculate(input);\r
243         double err = this.calcError(output, teach);\r
244 \r
245         return err;\r
246     }\r
247 \r
248     /**\r
249      * <p>\r
250      * 順方向演算を行います。\r
251      * </p>\r
252      * \r
253      * @param input\r
254      *            入力データ。\r
255      * @param output\r
256      *            順方向演算の結果を格納する配列。\r
257      * @param hiddenOutput\r
258      *            順方向演算の過程の隠れ層出力を格納する配列。\r
259      * @throws NullPointerException\r
260      *             input引数、output引数、hiddenOutput引数がnullの場合。\r
261      * @throws IllegalArgumentException\r
262      *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。\r
263      */\r
264     private void calcForward(double[] input, double[] output, double[] hiddenOutput) {\r
265         /*\r
266          * 引数を確認します。\r
267          */\r
268         if (input == null) {\r
269             throw new NullPointerException("input == null");\r
270         }\r
271         if (output == null) {\r
272             throw new NullPointerException("output == null");\r
273         }\r
274         if (hiddenOutput == null) {\r
275             throw new NullPointerException("hiddenOutput == null");\r
276         }\r
277         if (input.length != this.inputNumber) {\r
278             throw new IllegalArgumentException("input.length != inputNumber");\r
279         }\r
280         if (output.length != this.outputNumber) {\r
281             throw new IllegalArgumentException("output.length != outputNumber");\r
282         }\r
283         if (hiddenOutput.length != this.hiddenNumber) {\r
284             throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");\r
285         }\r
286 \r
287         /*\r
288          * 隠れ層の出力を計算します。\r
289          */\r
290         for (int i = 0; i < hiddenOutput.length; i++) {\r
291             hiddenOutput[i] = -this.thresholdHidden[i];\r
292             for (int j = 0; j < input.length; j++) {\r
293                 hiddenOutput[i] += input[j] * this.weightInHidden[j * this.hiddenNumber + i];\r
294             }\r
295             hiddenOutput[i] = this.sigmoid(hiddenOutput[i]);\r
296         }\r
297 \r
298         /*\r
299          * 出力層の出力を計算します。\r
300          */\r
301         for (int i = 0; i < output.length; i++) {\r
302             output[i] = -this.thresholdOut[i];\r
303             for (int j = 0; j < hiddenOutput.length; j++) {\r
304                 output[i] += hiddenOutput[j] * this.weightHiddenOut[j * this.outputNumber + i];\r
305             }\r
306             output[i] = this.sigmoid(output[i]);\r
307         }\r
308     }\r
309 \r
310     /**\r
311      * <p>\r
312      * 逆方向演算を行います。\r
313      * </p>\r
314      * \r
315      * @param input\r
316      *            順方向演算の入力データ。\r
317      * @param output\r
318      *            順方向演算の結果。\r
319      * @param hiddenOutput\r
320      *            順方向演算の過程の隠れ層出力を格納する配列。\r
321      * @param teach\r
322      *            教師信号。\r
323      * @throws NullPointerException\r
324      *             input引数、output引数、hiddenOutput引数、teach引数がnullの場合。\r
325      * @throws IllegalArgumentException\r
326      *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。\r
327      */\r
328     private void calcBackward(double[] input, double[] output, double[] hiddenOutput, double[] teach) {\r
329         /*\r
330          * 引数を確認します。\r
331          */\r
332         if (input == null) {\r
333             throw new NullPointerException("input == null");\r
334         }\r
335         if (output == null) {\r
336             throw new NullPointerException("output == null");\r
337         }\r
338         if (hiddenOutput == null) {\r
339             throw new NullPointerException("hiddenOutput == null");\r
340         }\r
341         if (teach == null) {\r
342             throw new NullPointerException("teach == null");\r
343         }\r
344         if (input.length != this.inputNumber) {\r
345             throw new IllegalArgumentException("input.length != inputNumber");\r
346         }\r
347         if (output.length != this.outputNumber) {\r
348             throw new IllegalArgumentException("output.length != outputNumber");\r
349         }\r
350         if (hiddenOutput.length != this.hiddenNumber) {\r
351             throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");\r
352         }\r
353         if (teach.length != this.outputNumber) {\r
354             throw new IllegalArgumentException("teach.length != outputNumber");\r
355         }\r
356 \r
357         /*\r
358          * 出力層の誤差を計算します。\r
359          */\r
360         double[] outputError = new double[output.length];\r
361         for (int i = 0; i < outputError.length; i++) {\r
362             outputError[i] = (teach[i] - output[i]) * output[i] * (1.0 - output[i]);\r
363         }\r
364 \r
365         /*\r
366          * 隠れ層の誤差を計算します。\r
367          */\r
368         double[] hiddenError = new double[hiddenOutput.length];\r
369         for (int i = 0; i < hiddenError.length; i++) {\r
370             double err = 0;\r
371             for (int j = 0; j < output.length; j++) {\r
372                 err += outputError[j] * this.weightHiddenOut[i * this.outputNumber + j];\r
373             }\r
374             hiddenError[i] = hiddenOutput[i] * (1.0 - hiddenOutput[i]) * err;\r
375         }\r
376 \r
377         /*\r
378          * 重みを補正します。\r
379          */\r
380         for (int i = 0; i < outputError.length; i++) {\r
381             for (int j = 0; j < hiddenOutput.length; j++) {\r
382                 this.weightHiddenOut[j * this.outputNumber + i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];\r
383             }\r
384         }\r
385         for (int i = 0; i < hiddenError.length; i++) {\r
386             for (int j = 0; j < input.length; j++) {\r
387                 this.weightInHidden[j * this.hiddenNumber + i] += this.learningCoefficient * hiddenError[i] * input[j];\r
388             }\r
389         }\r
390 \r
391         /*\r
392          * 閾値を補正します。\r
393          */\r
394         for (int i = 0; i < this.thresholdOut.length; i++) {\r
395             this.thresholdOut[i] -= this.learningCoefficient * outputError[i];\r
396         }\r
397         for (int i = 0; i < this.thresholdHidden.length; i++) {\r
398             this.thresholdHidden[i] -= this.learningCoefficient * hiddenError[i];\r
399         }\r
400     }\r
401 \r
402     /**\r
403      * <p>\r
404      * 順方向演算の結果と教師信号とのずれを表す二乗誤差を計算します。\r
405      * </p>\r
406      * \r
407      * @param output\r
408      *            順方向演算の結果。\r
409      * @param teach\r
410      *            教師信号。\r
411      * @return 二乗誤差。\r
412      * @throws NullPointerException\r
413      *             output引数、teach引数がnullの場合。\r
414      */\r
415     private double calcError(double[] output, double[] teach) {\r
416         /*\r
417          * 引数を確認します。\r
418          */\r
419         if (output == null) {\r
420             throw new NullPointerException("output == null");\r
421         }\r
422         if (teach == null) {\r
423             throw new NullPointerException("teach == null");\r
424         }\r
425 \r
426         /*\r
427          * 二乗誤差を計算します。\r
428          */\r
429         double error = 0;\r
430         for (int i = 0; i < output.length; i++) {\r
431             error += (teach[i] - output[i]) * (teach[i] - output[i]);\r
432         }\r
433 \r
434         return error;\r
435     }\r
436 \r
437     /**\r
438      * <p>\r
439      * シグモイド関数です。\r
440      * </p>\r
441      * \r
442      * @param x\r
443      *            引数。\r
444      * @return 計算結果。\r
445      */\r
446     private double sigmoid(double x) {\r
447         return 1.0 / (1.0 + Math.exp(-x));\r
448     }\r
449 \r
450 }\r