OSDN Git Service

モデル作成、データ読み込みを完了
authorshupeluter@hotmail.com <shupeluter@hotmail.com>
Sat, 10 Mar 2018 14:30:56 +0000 (23:30 +0900)
committershupeluter@hotmail.com <shupeluter@hotmail.com>
Sat, 10 Mar 2018 14:30:56 +0000 (23:30 +0900)
src/main/Python/Lern.py

index e69de29..2e5df25 100644 (file)
@@ -0,0 +1,50 @@
+from chainer import Link,Chain,ChainList,report,optimizers
+import chainer.functions as F
+import chainer.links as L
+from DataReader import DataReader
+
+
+class MyChain(Chain):
+    def __init__(self):
+        super(Chain,self).__init__(
+            l1=L.Linear(50,30),
+            l2=L.Linear(30,9)
+        )
+
+    def __call__(self,x):
+        h = F.sigmoid(self.l1(x))
+        o = self.l2(h)
+        return o
+
+class MyClassifer(Chain):
+    def __init__(self,predictor):
+        super(MyClassifer,self).__init__()
+        with self.init_scope():
+            self.predictor = predictor
+    def __call__(self,x,t):
+        y = self.predictor(x)
+        loss = F.softmax_cross_entropy(y,t)
+        accuracy = F.accuracy(y,t)
+        report({'loss': loss, 'accuracy': accuracy}, self)
+        return loss;
+
+def main():
+#    try:
+        #モデルを準備
+        model = L.Classifier(MyChain)
+
+        #オプティマイザを準備
+        optimizer = optimizers.Adam
+        optimizer(model)
+
+        #元データ生成
+        reader = DataReader() #type DataReader
+        dataList = []
+        dataList = reader.createLearningData()
+        print(len(dataList))
+
+#    except:
+#        print("an error occured")
+
+main()
+