OSDN Git Service

379d53d175bb7a8b20801d0149464ae4363076ae
[deep-learning/learning.git] / network.py
1 '''
2 Created on 2017/12/25
3
4 @author: fukemasashi
5 '''
6 from keras.models import Sequential
7 from keras.layers import Dense,Dropout,Activation    
8 import numpy as np
9
10 class  Comp():
11     def __init__(self):
12         self.model1,self.model2 = Sequential(),Sequential()
13
14         self.model1.add(Dense(50,input_shape=(64,)))
15         self.model1.add(Activation('relu'))
16     
17         self.model1.add(Dense(100))
18         self.model1.add(Activation('relu'))
19     
20         self.model1.add(Dense(64))
21         self.model1.add(Activation('softmax'))
22     
23         self.model1.compile(
24             loss='categorical_crossentropy',
25             optimizer='adam',
26             metrics=['accuracy'])
27
28         self.model2.add(Dense(50,input_shape=(64,)))
29         self.model2.add(Activation('relu'))
30         self.model2.add(Dense(64))    
31         self.model2.add(Activation('softmax'))
32         self.model2.compile(
33             loss='categorical_crossentropy',
34             optimizer='adam',
35             metrics=['accuracy'])
36
37     def sente_stone(self,X_train,Y_train):
38         X,Y = np.array(X_train),np.array(Y_train) 
39         X = np.reshape(X,[1,64])
40         Y = np.reshape(Y,[1,64])
41         self.model1.fit(X,Y)
42         hdf5_file = './sente-model.hdf5'
43         #self.model1.save_weights(hdf5_file)
44         res = self.model1.predict(X,Y):
45         i = 0
46         for j in res:
47             if j != 0:
48                 return [i % 8, i // 8]
49             i += 1
50
51     def gote_stone(self,X_train,Y_train):
52         X,Y=np.array(X_train),np.array(Y_train)
53         X = np.reshape(X,[1,64])
54         Y = np.reshape(Y,[1,64])
55         self.model2.fit(X,Y)
56         hdf5_file ='./gote-model.hdf5'
57         #self.model2.save_weights(hdf5_file)
58         res = self.model2.predict(X,Y)
59         i = 0
60         for j in res:
61             if j != 0:
62                 return [i % 8, i // 8]
63             i += 1