OSDN Git Service

networkモジュールも変更する必要があったのを忘れていました
[deep-learning/learning.git] / network.py
index e6377aa..f52c943 100644 (file)
@@ -50,18 +50,18 @@ class  Comp():
         Y = np.reshape(Y,[1,64])
         for i in range(10):
             self.model1.fit(X,Y)
-            res = self.model1.predict(X,1)
-            while True:
-                s = np.argmax(res)
-                if Y[0][s] == -1:
-                    res[0][s] = -1
-                    print('miss')
-                    continue
-                else:
-                    print('hit!')
-                break
-        else:
-            s = np.argmax(Y)
+        res = self.model1.predict(X,1)
+        while True:
+            s = np.argmax(res)
+            if res[0][s] == 0:
+                s = np.argmax(Y)
+                print('miss')
+            elif Y[0][s] == 0:
+                res[0][s] = 0
+                continue
+            else:
+                print('hit!')
+            break
         print(Y,res)
         self.model1.save_weights(hdf5_file)
         return [s // 8, s % 8]
@@ -75,18 +75,18 @@ class  Comp():
         Y = np.reshape(Y,[1,64])
         for i in range(10):
             self.model2.fit(X,Y)
-            res = self.model2.predict(X,1)
-            while True:
-                s = np.argmax(res)
-                if Y[0][s] == -1:
-                    res[0][s] = -1
-                    print('miss')
-                    continue
-                else:
-                    print('hit!')
-                break
-        else:
-            s = np.argmax(Y)
+        res = self.model2.predict(X,1)
+        while True:
+            s = np.argmax(res)
+            if res[0][s] == 0:
+                s = np.argmax(Y)
+                print('miss')
+            elif Y[0][s] == 0:
+                res[0][s] = 0
+                continue
+            else:
+                print('hit!')
+            break
         print(Y,res)
         hdf5_file ='./gote-model.hdf5'
         self.model2.save_weights(hdf5_file)