|
@@ -133,12 +133,16 @@ class PipelineSelector(ABC):
|
|
# score factor is 1 when cost_func is minimized,
|
|
# score factor is 1 when cost_func is minimized,
|
|
# -1 when cost func is maximized
|
|
# -1 when cost func is maximized
|
|
self._score_factor = (not greater_is_better) - greater_is_better
|
|
self._score_factor = (not greater_is_better) - greater_is_better
|
|
|
|
+ self._cross_val_averaging_func = cross_val_averaging_func
|
|
|
|
+ self._additional_metrics = additional_metrics
|
|
|
|
+ self._additional_averaging_funcs = additional_averaging_funcs or {}
|
|
|
|
+
|
|
self.trials_path = trials_path
|
|
self.trials_path = trials_path
|
|
self._backup_trials_freq = backup_trials_freq
|
|
self._backup_trials_freq = backup_trials_freq
|
|
|
|
+
|
|
self._strategy_name = strategy_name
|
|
self._strategy_name = strategy_name
|
|
self._data_path = None
|
|
self._data_path = None
|
|
self._cv_path = None
|
|
self._cv_path = None
|
|
-
|
|
|
|
self._X = None
|
|
self._X = None
|
|
self._y = None
|
|
self._y = None
|
|
self._cv = None
|
|
self._cv = None
|
|
@@ -188,9 +192,6 @@ class PipelineSelector(ABC):
|
|
self.total_tuning_time = None
|
|
self.total_tuning_time = None
|
|
self.finished_tuning = False
|
|
self.finished_tuning = False
|
|
|
|
|
|
- self._additional_metrics = additional_metrics
|
|
|
|
- self._additional_averaging_funcs = additional_averaging_funcs or {}
|
|
|
|
-
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
err = ("Failed to initialize the class. "
|
|
err = ("Failed to initialize the class. "
|
|
"Exit with error: {}".format(e))
|
|
"Exit with error: {}".format(e))
|
|
@@ -350,6 +351,9 @@ class PipelineSelector(ABC):
|
|
|
|
|
|
# Here we create a trivial cv object
|
|
# Here we create a trivial cv object
|
|
# with one validation split.
|
|
# with one validation split.
|
|
|
|
+
|
|
|
|
+ # XXX Tanya finish here
|
|
|
|
+
|
|
cv = CVComposer.dummy_cv()
|
|
cv = CVComposer.dummy_cv()
|
|
|
|
|
|
train_inds = list(range(len(X_train)))
|
|
train_inds = list(range(len(X_train)))
|