OSDN Git Service

a51126b84e581d69e9f736c05ad89dfa9b2201ba
[simplenn/repo.git] / simplenn / src / main / java / Back.java
1 import java.applet.Applet;\r
2 import java.awt.*;\r
3 import java.awt.event.*;\r
4 \r
5 public class Back extends Applet implements MouseListener,MouseMotionListener,ActionListener{\r
6 \r
7    Button button1,button2,button3,button4;\r
8 \r
9    int X0=10,X1=125;\r
10    int Y0=55,Y1=70,Y2=160,Y3=240,Y4=305;\r
11 \r
12    int RX0=30,RX1=60,RX2=210,RX3=260;\r
13    int RY0=225,RY1=240;\r
14 \r
15    int WIDTH=7;              //入力データの幅\r
16    int HEIGHT=11;            //入力データの高さ\r
17    int INPUT=WIDTH*HEIGHT;   //入力層の数(入力データ数)\r
18    int HIDDEN=16;            //隠れ層の数\r
19    int PATTERN=10;           //パターンの種類\r
20    int OUTPUT=PATTERN;       //出力層の数(出力データ数)\r
21    int OUTER_CYCLES=200;     //外部サイクル(一連のパターンの繰返し学習)の回数\r
22    int INNER_CYCLES=200;     //内部サイクル(同一パターンの繰返し学習)の回数\r
23    float ALPHA=1.2f;         //学習の加速係数\r
24    float BETA=1.2f;          //シグモイド曲線の傾斜\r
25 \r
26    int[] sample_in=new int[INPUT];                  //学習用入力\r
27    int[] written_in=new int[INPUT];                 //認識用手書き入力\r
28 \r
29    float[][] weight_ih=new float[INPUT][HIDDEN];    //入力層と隠れ層の間の重み係数\r
30    float[] thresh_h=new float[HIDDEN];              //隠れ層の閾値\r
31    float[] hidden_out=new float[HIDDEN];            //隠れ層の出力\r
32 \r
33    float[][] weight_ho=new float[HIDDEN][OUTPUT];   //隠れ層と出力層の間の重み係数\r
34    float[] thresh_o=new float[OUTPUT];              //出力層の閾値\r
35    float[] recog_out=new float[OUTPUT];             //認識出力(出力層の出力)\r
36 \r
37    int[] teach=new int[PATTERN];                    //教師信号\r
38 \r
39 \r
40 \r
41 \r
42 \r
43    boolean learning_flag;  //「学習モード」フラグ\r
44 \r
45    //学習用入力データの基となるパターン\r
46    int[][] sample_array={{0,0,1,1,1,0,0,  //'0'\r
47                          0,1,0,0,0,1,0,\r
48                          1,0,0,0,0,0,1,\r
49                          1,0,0,0,0,0,1,\r
50                          1,0,0,0,0,0,1,\r
51                          1,0,0,0,0,0,1,\r
52                          1,0,0,0,0,0,1,\r
53                          1,0,0,0,0,0,1,\r
54                          1,0,0,0,0,0,1,\r
55                          0,1,0,0,0,1,0,\r
56                          0,0,1,1,1,0,0},\r
57   \r
58                         {0,0,0,1,0,0,0,  //'1'\r
59                          0,0,0,1,0,0,0,\r
60                          0,0,0,1,0,0,0,\r
61                          0,0,0,1,0,0,0,\r
62                          0,0,0,1,0,0,0,\r
63                          0,0,0,1,0,0,0,\r
64                          0,0,0,1,0,0,0,\r
65                          0,0,0,1,0,0,0,\r
66                          0,0,0,1,0,0,0,\r
67                          0,0,0,1,0,0,0,\r
68                          0,0,0,1,0,0,0},\r
69  \r
70                         {0,0,1,1,1,0,0,  //'2'\r
71                          0,1,0,0,0,1,0,\r
72                          1,0,0,0,0,0,1,\r
73                          0,0,0,0,0,0,1,\r
74                          0,0,0,0,0,0,1,\r
75                          0,0,0,0,0,1,0,\r
76                          0,0,0,0,1,0,0,\r
77                          0,0,0,1,0,0,0,\r
78                          0,0,1,0,0,0,0,\r
79                          0,1,0,0,0,0,0,\r
80                          1,1,1,1,1,1,1},\r
81   \r
82                         {0,0,1,1,1,0,0,  //'3'\r
83                          0,1,0,0,0,1,0,\r
84                          1,0,0,0,0,0,1,\r
85                          0,0,0,0,0,1,0,\r
86                          0,0,0,0,1,0,0,\r
87                          0,0,0,0,0,1,0,\r
88                          0,0,0,0,0,0,1,\r
89                          0,0,0,0,0,0,1,\r
90                          1,0,0,0,0,0,1,\r
91                          0,1,0,0,0,1,0,\r
92                          0,0,1,1,1,0,0},\r
93  \r
94                         {0,0,0,0,1,0,0,  //'4'\r
95                          0,0,0,1,1,0,0,\r
96                          0,0,1,0,1,0,0,\r
97                          0,0,1,0,1,0,0,\r
98                          0,1,0,0,1,0,0,\r
99                          0,1,0,0,1,0,0,\r
100                          1,0,0,0,1,0,0,\r
101                          1,1,1,1,1,1,1,\r
102                          0,0,0,0,1,0,0,\r
103                          0,0,0,0,1,0,0,\r
104                          0,0,0,0,1,0,0},\r
105  \r
106                         {1,1,1,1,1,1,1,  //'5'\r
107                          1,0,0,0,0,0,0,\r
108                          1,0,0,0,0,0,0,\r
109                          1,0,0,0,0,0,0,\r
110                          1,1,1,1,1,0,0,\r
111                          0,0,0,0,0,1,0,\r
112                          0,0,0,0,0,0,1,\r
113                          0,0,0,0,0,0,1,\r
114                          0,0,0,0,0,0,1,\r
115                          1,0,0,0,0,1,0,\r
116                          0,1,1,1,1,1,0},\r
117  \r
118                         {0,0,0,0,1,1,0,  //'6'\r
119                          0,0,0,1,0,0,0,\r
120                          0,0,1,0,0,0,0,\r
121                          0,1,0,0,0,0,0,\r
122                          0,1,0,0,0,0,0,\r
123                          1,0,0,0,0,0,0,\r
124                          1,0,1,1,1,0,0,\r
125                          1,1,0,0,0,1,0,\r
126                          1,0,0,0,0,0,1,\r
127                          0,1,0,0,0,1,0,\r
128                          0,0,1,1,1,0,0},\r
129  \r
130                         {1,1,1,1,1,1,1,  //'7'\r
131                          0,0,0,0,0,0,1,\r
132                          0,0,0,0,0,0,1,\r
133                          0,0,0,0,0,1,0,\r
134                          0,0,0,0,0,1,0,\r
135                          0,0,0,0,1,0,0,\r
136                          0,0,0,0,1,0,0,\r
137                          0,0,0,1,0,0,0,\r
138                          0,0,0,1,0,0,0,\r
139                          0,0,1,0,0,0,0,\r
140                          0,0,1,0,0,0,0},\r
141  \r
142                         {0,0,1,1,1,0,0,  //'8'\r
143                          0,1,0,0,0,1,0,\r
144                          1,0,0,0,0,0,1,\r
145                          1,0,0,0,0,0,1,\r
146                          0,1,0,0,0,1,0,\r
147                          0,0,1,1,1,0,0,\r
148                          0,1,0,0,0,1,0,\r
149                          1,0,0,0,0,0,1,\r
150                          1,0,0,0,0,0,1,\r
151                          1,0,0,0,0,0,1,\r
152                          0,1,1,1,1,1,0},\r
153                  \r
154                         {0,1,1,1,1,1,0,  //'9'\r
155                          1,0,0,0,0,0,1,\r
156                          1,0,0,0,0,0,1,\r
157                          1,0,0,0,0,0,1,\r
158                          0,1,1,1,1,1,1,\r
159                          0,0,0,0,0,0,1,\r
160                          0,0,0,0,0,0,1,\r
161                          0,0,0,0,0,0,1,\r
162                          0,0,0,0,0,0,1,\r
163                          1,0,0,0,0,0,1,\r
164                          0,1,1,1,1,1,0}};\r
165 \r
166    int[][] teach_array=new int[PATTERN][OUTPUT];  //パターンと出力すべき教師信号の比較表\r
167 \r
168    int x_new,y_new,x_old,y_old;           //手書き文字入力用座標\r
169 \r
170 \r
171    public void init(){\r
172 \r
173       setBackground(Color.gray);\r
174 \r
175       //ボタンの設定\r
176       add(button1=new Button("  再学習  "));\r
177       add(button2=new Button(" 学習終了 "));\r
178       add(button3=new Button("入力クリヤ"));\r
179       add(button4=new Button("  認  識  "));\r
180       button1.addActionListener(this);\r
181       button2.addActionListener(this);\r
182       button3.addActionListener(this);\r
183       button4.addActionListener(this);\r
184 \r
185       //マウスの設定\r
186       addMouseListener(this);\r
187       addMouseMotionListener(this);\r
188 \r
189       //教師信号の設定\r
190       for(int q=0;q<PATTERN;q++)\r
191          for(int k=0;k<OUTPUT;k++){\r
192             if(q==k) teach_array[q][k]=1;\r
193             else     teach_array[q][k]=0;\r
194          }\r
195 \r
196       //モードの初期設定\r
197       learning_flag=true;\r
198 \r
199    }\r
200 \r
201    //------------------- ボタン関係のメソッド ------------------\r
202 \r
203    public void actionPerformed(ActionEvent ae){\r
204 \r
205       if(ae.getSource()==button1){      //「再学習」\r
206          learning_flag=true;\r
207          repaint();\r
208       }\r
209       if(ae.getSource()==button2){      //「学習終了」\r
210          learning_flag=false;\r
211          repaint();\r
212       }\r
213       if(ae.getSource()==button3){      //「入力クリヤ」\r
214          if(!learning_flag)\r
215             repaint();\r
216       }\r
217       if(ae.getSource()==button4){      //「認識」\r
218          if(!learning_flag)\r
219             recognizeCharacter();\r
220       }\r
221 \r
222    }\r
223 \r
224    //---------- マウス関係のメソッド(手書き文字入力)----------\r
225 \r
226    public void mousePressed(MouseEvent me){\r
227       int x=me.getX();\r
228       int y=me.getY();\r
229       if(!learning_flag && x>=RX1 && x<=RX1+WIDTH*10 && y>=RY1 && y<=RY1+HEIGHT*10){\r
230          x_old=me.getX();\r
231          y_old=me.getY();\r
232          written_in[(y_old-RY1)/10*WIDTH+(x_old-RX1)/10]=1;\r
233       }\r
234    }\r
235 \r
236    public void mouseClicked(MouseEvent me){}\r
237    public void mouseEntered(MouseEvent me){}\r
238    public void mouseExited(MouseEvent me){}\r
239    public void mouseReleased(MouseEvent me){}\r
240 \r
241    public void mouseDragged(MouseEvent me){\r
242       int x=me.getX();\r
243       int y=me.getY();\r
244       if(!learning_flag && x>=RX1 && x<=RX1+WIDTH*10 && y>=RY1 && y<=RY1+HEIGHT*10){\r
245          Graphics g=getGraphics(); \r
246          x_new=me.getX();\r
247          y_new=me.getY();\r
248          g.drawLine(x_old,y_old,x_new,y_new);\r
249          x_old=x_new;\r
250          y_old=y_new;\r
251          written_in[(y_old-RY1)/10*WIDTH+(x_old-RX1)/10]=1;\r
252       }\r
253  \r
254    }\r
255 \r
256    public void mouseMoved(MouseEvent me){}\r
257  \r
258 \r
259 \r
260    //---------- 起動時およびrepaint()で呼び出されるメソッド ----------\r
261  \r
262    public void paint(Graphics g){\r
263 \r
264       int i,j,k,p,q,r,x;\r
265 \r
266       String string;\r
267 \r
268       float outer_error;          //外部サイクルエラー累計\r
269       float inner_error;          //内部サイクルエラー累計\r
270       float temp_error;           //隠れ層の誤差の累計 \r
271 \r
272       //学習モードの背景\r
273       if(learning_flag){\r
274          g.setColor(new Color(255,255,192));\r
275          g.fillRect(5,35,590,460);\r
276          g.setColor(Color.black);\r
277          g.drawString("学習モード",500,55);\r
278       }\r
279 \r
280       //認識モードの背景\r
281       else{\r
282          g.setColor(new Color(192,255,255));\r
283          g.fillRect(5,35,590,460);\r
284          g.setColor(Color.black);\r
285          g.drawString("認識モード",500,55);\r
286       }\r
287 \r
288       //学習用パターンの表示\r
289       g.drawString("使用している学習用パターン",X0,Y0);\r
290       for(q=0;q<PATTERN;q++){\r
291          x=56*q;\r
292          for(j=0;j<HEIGHT;j++)\r
293             for(i=0;i<WIDTH;i++){\r
294                if(sample_array[q][WIDTH*j+i]==1)     g.setColor(Color.red);\r
295                else                                  g.setColor(Color.cyan);\r
296                g.fillRect(X0+x+6*i,Y1+6*j,5,5);\r
297             }\r
298       }\r
299       g.setColor(Color.black);\r
300 \r
301       //-------------------------------------------------------------------\r
302       //--------------------------- 学習モード ----------------------------\r
303       //-------------------------------------------------------------------\r
304       if(learning_flag){\r
305 \r
306          //閾値と重みの乱数設定\r
307          for(j=0;j<HIDDEN;j++){\r
308             thresh_h[j]=(float)Math.random()-0.5f;\r
309             for(i=0;i<INPUT;i++)\r
310                weight_ih[i][j]=(float)Math.random()-0.5f;\r
311          }\r
312          for(k=0;k<OUTPUT;k++){\r
313             thresh_o[k]=(float)Math.random()-0.5f;\r
314             for(j=0;j<HIDDEN;j++)\r
315                weight_ho[j][k]=(float)Math.random()-0.5f;\r
316          }\r
317 \r
318          //-------------------------- 学習 --------------------------\r
319 \r
320          for(p=0;p<OUTER_CYCLES;p++){     //外部サイクル\r
321 \r
322             outer_error=0.0f;         //外部二乗誤差のクリヤー\r
323 \r
324             for(q=0;q<PATTERN;q++){   //パターンの切り替え\r
325 \r
326                //パターンに対応した入力と教師信号の設定\r
327                sample_in=sample_array[q];\r
328                teach=teach_array[q];\r
329 \r
330                for(r=0;r<INNER_CYCLES;r++){   //内部サイクル\r
331 \r
332                   //順方向演算\r
333                   forwardNeuralNet(sample_in,recog_out);       \r
334 \r
335                   //逆方向演算(バックプロパゲーション)\r
336                   backwardNeuralNet();\r
337 \r
338                }\r
339 \r
340                //内部二乗誤差の計算\r
341                inner_error=0.0f;   //内部二乗誤差のクリヤー\r
342                for(k=0;k<OUTPUT;k++)\r
343                   inner_error+=(teach[k]-recog_out[k])*(teach[k]-recog_out[k]);\r
344 \r
345                outer_error+=inner_error;   //外部二乗誤差への累加算\r
346 \r
347             }\r
348 \r
349             //外部サイクルの回数と外部二乗誤差の表示\r
350             g.drawString("実行中の外部サイクルの回数と二乗誤差",X0,Y2);\r
351             g.setColor(new Color(255,255,192));\r
352             g.fillRect(X0+5,Y2+10,200,50);   //以前の表示を消去\r
353             g.setColor(Color.black);\r
354             g.drawString("OuterCycles="+String.valueOf(p),X0+10,Y2+25);\r
355             g.drawString("TotalSquaredError="+String.valueOf(outer_error),X0+10,Y2+45);\r
356 \r
357          } \r
358 \r
359 \r
360          //--------------------- 学習結果の確認 ---------------------\r
361 \r
362          g.drawString("学習結果の確認",X0,Y3);\r
363          for(k=0;k<OUTPUT;k++){\r
364             g.drawString("Output",X1+45*k,Y3+25);\r
365             g.drawString("  ["+String.valueOf(k)+"]",X1+5+45*k,Y3+40);\r
366          }      \r
367 \r
368          for(q=0;q<PATTERN;q++){\r
369 \r
370             //入力パターンの設定\r
371             sample_in=sample_array[q];\r
372 \r
373             //順方向演算\r
374             forwardNeuralNet(sample_in,recog_out);\r
375 \r
376             //結果の表示\r
377             g.setColor(Color.black);\r
378             g.drawString("TestPattern["+String.valueOf(q)+"]",X0+10,Y4+20*q);\r
379             for(k=0;k<OUTPUT;k++){\r
380                if(recog_out[k]>0.99){        //99% より大は、赤で YES と表示\r
381                   g.setColor(Color.red);\r
382                   string="YES";\r
383                }\r
384                else if(recog_out[k]<0.01){   // 1% より小は、青で NO と表示\r
385                   g.setColor(Color.blue);\r
386                   string="NO ";\r
387                }\r
388                else{                         // 1% 以上 99% 以下は、黒で ? と表示\r
389                   g.setColor(Color.black);\r
390                   string=" ? ";\r
391                }\r
392                g.drawString(string,X1+10+45*k,Y4+20*q);\r
393             }\r
394 \r
395          }\r
396       }\r
397 \r
398       //-------------------------------------------------------------------\r
399       //--------------------------- 認識モード ----------------------------\r
400       //-------------------------------------------------------------------\r
401       else{\r
402          g.setColor(Color.black);\r
403          g.drawString("マウスで数字を描いて下さい",RX0,RY0);\r
404          g.drawRect(RX1-1,RY1-1,WIDTH*10+2,HEIGHT*10+2);     //外枠\r
405          g.setColor(Color.gray);\r
406          for(j=1;j<HEIGHT;j++)\r
407             g.drawLine(RX1,RY1+10*j,RX1+WIDTH*10,RY1+10*j);  //横方向区切り\r
408          for(i=1;i<WIDTH;i++)\r
409             g.drawLine(RX1+10*i,RY1,RX1+10*i,RY1+HEIGHT*10);  //縦方向区切り\r
410          for(i=0;i<INPUT;i++)\r
411             written_in[i]=0;     //手書き入力データのクリヤ\r
412       }\r
413 \r
414    }\r
415 \r
416    //順方向演算のメソッド\r
417    public void forwardNeuralNet(int[] input,float[] output){\r
418 \r
419       float[] out=new float[OUTPUT];\r
420       float[] hidden=new float[HIDDEN];\r
421 \r
422       //隠れ層出力の計算\r
423       for(int j=0;j<HIDDEN;j++){\r
424          hidden[j]=-thresh_h[j];\r
425          for(int i=0;i<INPUT;i++)\r
426             hidden[j]+=input[i]*weight_ih[i][j];\r
427          hidden_out[j]=sigmoid(hidden[j]);\r
428       }\r
429 \r
430       //出力層出力の計算\r
431       for(int k=0;k<OUTPUT;k++){\r
432          out[k]=-thresh_o[k];\r
433          for(int j=0;j<HIDDEN;j++)\r
434             out[k]+=hidden_out[j]*weight_ho[j][k];\r
435          output[k]=sigmoid(out[k]);\r
436       }\r
437 \r
438    }\r
439 \r
440    //逆方向演算のメソッド\r
441    public void backwardNeuralNet(){\r
442 \r
443       int i,j,k;\r
444 \r
445       float[] output_error=new float[OUTPUT];       //出力層の誤差\r
446       float[] hidden_error=new float[HIDDEN];       //隠れ層の誤差\r
447 \r
448       float temp_error;\r
449 \r
450       //出力層の誤差の計算\r
451       for(k=0;k<OUTPUT;k++)\r
452          output_error[k]=(teach[k]-recog_out[k])*recog_out[k]*(1.0f-recog_out[k]);\r
453 \r
454       //隠れ層の誤差の計算\r
455       for(j=0;j<HIDDEN;j++){\r
456          temp_error=0.0f;\r
457          for(k=0;k<OUTPUT;k++)\r
458             temp_error+=output_error[k]*weight_ho[j][k];\r
459          hidden_error[j]=hidden_out[j]*(1.0f-hidden_out[j])*temp_error;\r
460       }\r
461 \r
462       //重みの補正\r
463       for(k=0;k<OUTPUT;k++)\r
464          for(j=0;j<HIDDEN;j++)\r
465             weight_ho[j][k]+=ALPHA*output_error[k]*hidden_out[j];\r
466       for(j=0;j<HIDDEN;j++)\r
467          for(i=0;i<INPUT;i++)\r
468             weight_ih[i][j]+=ALPHA*hidden_error[j]*sample_in[i];\r
469 \r
470       //閾値の補正\r
471       for(k=0;k<OUTPUT;k++)\r
472          thresh_o[k]-=ALPHA*output_error[k];\r
473       for(j=0;j<HIDDEN;j++)\r
474          thresh_h[j]-=ALPHA*hidden_error[j];\r
475 \r
476    }\r
477   \r
478    //Sigmoid関数を計算するメソッド\r
479    public float sigmoid(float x){\r
480 \r
481       return 1.0f/(1.0f+(float)Math.exp(-BETA*x));\r
482 \r
483    }\r
484 \r
485    //入力文字を認識するメソッド\r
486    public void recognizeCharacter(){\r
487 \r
488       Graphics g=getGraphics();\r
489       String string;\r
490 \r
491       //順方向演算\r
492       forwardNeuralNet(written_in,recog_out);\r
493 \r
494       //結果の表示\r
495       for(int k=0;k<OUTPUT;k++){\r
496           g.setColor(Color.black);\r
497           g.drawString(String.valueOf(k)+"である",RX2,RY1+20*k);\r
498           if(recog_out[k]>0.8f)  g.setColor(Color.red);\r
499           else                   g.setColor(Color.black);\r
500 \r
501           g.fillRect(RX3,RY1-10+20*k,(int)(200*recog_out[k]),10);\r
502           g.drawString(String.valueOf((int)(100*recog_out[k]+0.5f))+"%",RX3+(int)(200*recog_out[k])+10,RY1+20*k);\r
503        }\r
504 \r
505    }\r
506 \r
507 }\r
508 \r
509 \r
510 \r