X-Git-Url: http://git.osdn.net/view?a=blobdiff_plain;f=network.py;fp=network.py;h=f52c9433b245d71bd57eb54d9c989d02650e8d63;hb=24b4c153c1bfb909713c494d3d2fddb5d0281e37;hp=e6377aa0f165cf4c7f15ca122ce74ba77d5de470;hpb=403f5b69271ec9bb2cbd8853d8b476e50a0dfab9;p=deep-learning%2Flearning.git diff --git a/network.py b/network.py index e6377aa..f52c943 100644 --- a/network.py +++ b/network.py @@ -50,18 +50,18 @@ class Comp(): Y = np.reshape(Y,[1,64]) for i in range(10): self.model1.fit(X,Y) - res = self.model1.predict(X,1) - while True: - s = np.argmax(res) - if Y[0][s] == -1: - res[0][s] = -1 - print('miss') - continue - else: - print('hit!') - break - else: - s = np.argmax(Y) + res = self.model1.predict(X,1) + while True: + s = np.argmax(res) + if res[0][s] == 0: + s = np.argmax(Y) + print('miss') + elif Y[0][s] == 0: + res[0][s] = 0 + continue + else: + print('hit!') + break print(Y,res) self.model1.save_weights(hdf5_file) return [s // 8, s % 8] @@ -75,18 +75,18 @@ class Comp(): Y = np.reshape(Y,[1,64]) for i in range(10): self.model2.fit(X,Y) - res = self.model2.predict(X,1) - while True: - s = np.argmax(res) - if Y[0][s] == -1: - res[0][s] = -1 - print('miss') - continue - else: - print('hit!') - break - else: - s = np.argmax(Y) + res = self.model2.predict(X,1) + while True: + s = np.argmax(res) + if res[0][s] == 0: + s = np.argmax(Y) + print('miss') + elif Y[0][s] == 0: + res[0][s] = 0 + continue + else: + print('hit!') + break print(Y,res) hdf5_file ='./gote-model.hdf5' self.model2.save_weights(hdf5_file)