Przeglądaj źródła

fixed bug in pipeline selector

tanja 3 lat temu
rodzic
commit
0b1c98ed1a
1 zmienionych plików z 11 dodań i 7 usunięć
  1. 11 7
      cdplib/pipeline_selector/PipelineSelector.py

+ 11 - 7
cdplib/pipeline_selector/PipelineSelector.py

@@ -179,6 +179,9 @@ class PipelineSelector(ABC):
             self.start_tuning_time = datetime.datetime.today()
             self.total_tuning_time = None
             self.finished_tuning = False
+            
+            self._additional_metrics = additional_metrics
+            self._additional_averaging_funcs = additional_averaging_funcs
 
         except Exception as e:
             err = ("Failed to initialize the class. "
@@ -341,10 +344,6 @@ class PipelineSelector(ABC):
                 # with one validation split.
                 cv = CVComposer.dummy_cv()
 
-
-
-
-
                 train_inds = list(range(len(X_train)))
                 val_inds = list(range(len(X_train),
                                       len(X_train) + len(X_val)))
@@ -398,12 +397,17 @@ class PipelineSelector(ABC):
         try:
             assert(os.path.isfile(data_hdf5_store_path)),\
                 "Parameter hdf5_store_path is not a file"
+                
+            # close all opened files, because hdf5 will 
+            # fail to reopen an opened (for some reason) file
+            import tables
+            tables.file._open_files.close_all()
 
             store = pd.HDFStore(data_hdf5_store_path)
 
             self._data_path = data_hdf5_store_path
 
-            data_input = {key: store["key"] if key in store else None
+            data_input = {key: store[key] if key in store else None
                           for key in ["X_train", "y_train", "X_val", "y_val"]}
 
             if cv_pickle_path is not None:
@@ -564,7 +568,7 @@ class PipelineSelector(ABC):
             over the folds.
         """
         try:
-            scoring = {"score": make_scorer(self.cost_func)}
+            scoring = {"score": make_scorer(self._cost_func)}
 
             scoring.update({metric_name: make_scorer(metric)
                             for metric_name, metric
@@ -575,7 +579,7 @@ class PipelineSelector(ABC):
                     X=self._X,
                     y=self._y,
                     cv=self._cv,
-                    scoring=self._scoring,
+                    scoring=scoring,
                     error_score=np.nan)
 
             averaging_funcs = {