OSDN Git Service

改行コードの削除を追加
authorunknown <shupeluter@hotmail.com>
Sun, 10 Jun 2018 04:56:36 +0000 (13:56 +0900)
committerunknown <shupeluter@hotmail.com>
Sun, 10 Jun 2018 04:56:36 +0000 (13:56 +0900)
src/main/Python/DataReader.py

index 3975be8..f273dc7 100644 (file)
@@ -1,31 +1,63 @@
 import os
 import Data
 import yaml
+import numpy
+from typing import List
+from Exceptions import IleagalDataException
+
+data_size = 10
+data_length = 50
 
 
 class DataReader:
-    DATA_SOURCE  = ""
+    DATA_SOURCE = ""
 
     def __init__(self):
         with open('config\\toolconf.yml', 'r') as config:
             confdata = yaml.load(config)
             self.DATA_SOURCE = confdata['dataPath']
 
+    def get_learning_data(self):
+        target_data = self.parse_data_files()
+        datasets: List(numpy.ndarray) = []
+        labels: List(str) =[]
+
+        for cdata in target_data:
+            datasets.append(cdata.get_array_data())
+            labels.append(cdata.getLable())
+
+        return numpy.array(datasets), numpy.array(labels)
+
+    def __check_data(self, data: Data):
+
+        # データ元ファイルのパスを保持していること
+        if data.get_org_file() == "":
+            raise IleagalDataException(data, "データファイルパスが設定されちません。")
+        if data.getLable() == "":
+            raise IleagalDataException(data, "ラベルが設定されてません。")
+        if len(data.getData()) != data_size :
+            raise IleagalDataException(data, "要素数(=行数)が想定と異なります。")
+
+        for current in data.get_array_data():
+            if len(current) != data_length:
+                raise IleagalDataException(data, "データのサイズが想定と異なります。")
+        return True
+
     def parse_data_files(self):
         data = []
 
-        for file in self.walkDataDirectory(self.DATA_SOURCE):
-            data.append(self.read(file))
+        for file in self.__walkDataDirectory(self.DATA_SOURCE):
+            data.append(self._read(file))
 
         return data
 
-    def walkDataDirectory(self, directory):
+    def __walkDataDirectory(self, directory):
         for root, dirs, files in os.walk(directory):
 
             for file in files:
-                yield os.path.join(root,file)
+                yield os.path.join(root, file)
 
-    def read(self, file: str):
+    def _read(self, file: str):
         result = Data.Data()
         if os.path.isfile(file):
             datafile = open(file)
@@ -35,8 +67,8 @@ class DataReader:
             datafile.close()
             result.setLabel(int(orgdata[0].strip()))
 
-            for i in range(1,len(orgdata)):
-                line_data = orgdata[i].replace("\n","")
+            for i in range(1, len(orgdata)):
+                line_data = orgdata[i].replace("\n", "")
                 data.append(line_data.split(','))
             result.setData(data)
             result.set_org_data(file)