Browse Source

fixed bugs in pipeline selection

tanja 3 years ago
parent
commit
d7a2089cce

+ 5 - 3
cdplib/gridsearch/GridSearchPipelineSelector.py

@@ -16,13 +16,14 @@ Created on Wed Sep 30 14:15:17 2020
 import os
 import datetime
 import numpy as np
+from copy import deepcopy
 from itertools import product
 from collections import ChainMap
 from sklearn.pipeline import Pipeline
 from typing import Callable, Optional, Literal, Dict, Union, List
 from cdplib.log import Log
 
-from cdplib.pipeline_selector.PipelineSelector import PipelineSelector
+# from cdplib.pipeline_selector.PipelineSelector import PipelineSelector
 
 
 class GridSearchPipelineSelector(PipelineSelector):
@@ -140,8 +141,9 @@ class GridSearchPipelineSelector(PipelineSelector):
 
                 result = self._objective(space_element)
 
-                pipeline = space_element["pipeline"].set_params(
-                        **space_element["params"])
+                pipeline = deepcopy(space_element["pipeline"])
+                
+                pipeline = pipeline.set_params(**space_element["params"])
 
                 trial = {"name": space_element["name"],
                          "params": space_element["params"],

+ 8 - 4
cdplib/pipeline_selector/PipelineSelector.py

@@ -133,12 +133,16 @@ class PipelineSelector(ABC):
             # score factor is 1 when cost_func is minimized,
             # -1 when cost func is maximized
             self._score_factor = (not greater_is_better) - greater_is_better
+            self._cross_val_averaging_func = cross_val_averaging_func
+            self._additional_metrics = additional_metrics
+            self._additional_averaging_funcs = additional_averaging_funcs or {}
+            
             self.trials_path = trials_path
             self._backup_trials_freq = backup_trials_freq
+
             self._strategy_name = strategy_name
             self._data_path = None
             self._cv_path = None
-
             self._X = None
             self._y = None
             self._cv = None
@@ -188,9 +192,6 @@ class PipelineSelector(ABC):
             self.total_tuning_time = None
             self.finished_tuning = False
             
-            self._additional_metrics = additional_metrics
-            self._additional_averaging_funcs = additional_averaging_funcs or {}
-
         except Exception as e:
             err = ("Failed to initialize the class. "
                    "Exit with error: {}".format(e))
@@ -350,6 +351,9 @@ class PipelineSelector(ABC):
 
                 # Here we create a trivial cv object
                 # with one validation split.
+                
+                # XXX Tanya finish here
+                
                 cv = CVComposer.dummy_cv()
 
                 train_inds = list(range(len(X_train)))