OSDN Git Service

temp committ
[deep-learning/learning.git] / network.py
index f52c943..6acb15f 100644 (file)
@@ -12,7 +12,7 @@ class  Comp():
     def __init__(self):
         self.model1,self.model2 = Sequential(),Sequential()
 
-        self.model1.add(Dense(50,input_shape=(64,)))
+        self.model1.add(Dense(100,input_shape=(64,)))
         self.model1.add(Activation('relu'))
         self.model1.add(Dropout(0.25))
     
@@ -28,7 +28,7 @@ class  Comp():
             optimizer='adam',
             metrics=['accuracy'])
 
-        self.model2.add(Dense(50,input_shape=(64,)))
+        self.model2.add(Dense(100,input_shape=(64,)))
         self.model2.add(Activation('sigmoid'))
         self.model2.add(Dropout(0.25))
         self.model2.add(Dense(100))
@@ -46,11 +46,10 @@ class  Comp():
         if os.path.exists(hdf5_file):
             self.model1.load_weights(hdf5_file)
         X,Y = np.array(X_train),np.array(Y_train) 
-        X = np.reshape(X,[1,64])
-        Y = np.reshape(Y,[1,64])
-        for i in range(10):
-            self.model1.fit(X,Y)
-        res = self.model1.predict(X,1)
+        X = np.reshape(np.float32(X),(1,64))
+        Y = np.reshape(np.float32(Y),(1,64))
+        self.model1.fit(X,Y)
+        res = self.model1.predict(X,0)
         while True:
             s = np.argmax(res)
             if res[0][s] == 0:
@@ -59,10 +58,7 @@ class  Comp():
             elif Y[0][s] == 0:
                 res[0][s] = 0
                 continue
-            else:
-                print('hit!')
             break
-        print(Y,res)
         self.model1.save_weights(hdf5_file)
         return [s // 8, s % 8]
         
@@ -71,11 +67,10 @@ class  Comp():
         if os.path.exists(hdf5_file):
             self.model2.load_weights(hdf5_file) 
         X,Y=np.array(X_train),np.array(Y_train)
-        X = np.reshape(X,[1,64])
-        Y = np.reshape(Y,[1,64])
-        for i in range(10):
-            self.model2.fit(X,Y)
-        res = self.model2.predict(X,1)
+        X = np.reshape(np.float32(X),(1,64))
+        Y = np.reshape(np.float32(Y),(1,64))
+        self.model2.fit(X,Y)
+        res = self.model2.predict(X,0)
         while True:
             s = np.argmax(res)
             if res[0][s] == 0:
@@ -84,10 +79,6 @@ class  Comp():
             elif Y[0][s] == 0:
                 res[0][s] = 0
                 continue
-            else:
-                print('hit!')
             break
-        print(Y,res)
-        hdf5_file ='./gote-model.hdf5'
         self.model2.save_weights(hdf5_file)
         return [s // 8, s % 8]