OSDN Git Service

細かい見直し 仕上げ
[deep-learning/learning.git] / network.py
1 '''
2 Created on 2017/12/25
3
4 @author: fukemasashi
5 '''
6 from keras.models import Sequential
7 from keras.layers import Dense,Dropout,Activation    
8 import numpy as np
9 import os
10
11 class  Comp():
12     def __init__(self):
13         self.model1,self.model2 = Sequential(),Sequential()
14
15         self.model1.add(Dense(100,input_shape=(64,)))
16         self.model1.add(Activation('relu'))
17         self.model1.add(Dropout(0.25))
18     
19         self.model1.add(Dense(100))
20         self.model1.add(Activation('relu'))
21         self.model1.add(Dropout(0.25))
22     
23         self.model1.add(Dense(64))
24         self.model1.add(Activation('softmax'))
25     
26         self.model1.compile(
27             loss='categorical_crossentropy',
28             optimizer='adam',
29             metrics=['accuracy'])
30
31         self.model2.add(Dense(100,input_shape=(64,)))
32         self.model2.add(Activation('sigmoid'))
33         self.model2.add(Dropout(0.25))
34         self.model2.add(Dense(100))
35         self.model2.add(Activation('sigmoid'))
36         self.model2.add(Dropout(0.25))
37         self.model2.add(Dense(64))    
38         self.model2.add(Activation('softmax'))
39         self.model2.compile(
40             loss='categorical_crossentropy',
41             optimizer='adam',
42             metrics=['accuracy'])
43         
44     def sente_stone(self,X_train,Y_train):
45         hdf5_file = 'sente-model.hdf5'
46         if os.path.exists(hdf5_file):
47             self.model1.load_weights(hdf5_file)
48         X,Y = np.array(X_train),np.array(Y_train) 
49         X = np.reshape(np.float32(X),(1,64))
50         Y = np.reshape(np.float32(Y),(1,64))
51         self.model1.fit(X,Y)
52         res = self.model1.predict(X,None,0)
53         while True:
54             s = np.argmax(res)
55             if res[0][s] == 0:
56                 s = np.argmax(Y)
57                 print('miss')
58             elif Y[0][s] == 0:
59                 res[0][s] = 0
60                 continue
61             break
62         self.model1.save_weights(hdf5_file)
63         return [s // 8, s % 8]
64         
65     def gote_stone(self,X_train,Y_train):
66         hdf5_file = 'gote-model.hdf5'
67         if os.path.exists(hdf5_file):
68             self.model2.load_weights(hdf5_file) 
69         X,Y=np.array(X_train),np.array(Y_train)
70         X = np.reshape(np.float32(X),(1,64))
71         Y = np.reshape(np.float32(Y),(1,64))
72         self.model2.fit(X,Y)
73         res = self.model2.predict(X,None,0)
74         while True:
75             s = np.argmax(res)
76             if res[0][s] == 0:
77                 s = np.argmax(Y)
78                 print('miss')
79             elif Y[0][s] == 0:
80                 res[0][s] = 0
81                 continue
82             break
83         self.model2.save_weights(hdf5_file)
84         return [s // 8, s % 8]