|
@@ -179,6 +179,9 @@ class PipelineSelector(ABC):
|
|
self.start_tuning_time = datetime.datetime.today()
|
|
self.start_tuning_time = datetime.datetime.today()
|
|
self.total_tuning_time = None
|
|
self.total_tuning_time = None
|
|
self.finished_tuning = False
|
|
self.finished_tuning = False
|
|
|
|
+
|
|
|
|
+ self._additional_metrics = additional_metrics
|
|
|
|
+ self._additional_averaging_funcs = additional_averaging_funcs
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
err = ("Failed to initialize the class. "
|
|
err = ("Failed to initialize the class. "
|
|
@@ -341,10 +344,6 @@ class PipelineSelector(ABC):
|
|
# with one validation split.
|
|
# with one validation split.
|
|
cv = CVComposer.dummy_cv()
|
|
cv = CVComposer.dummy_cv()
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
train_inds = list(range(len(X_train)))
|
|
train_inds = list(range(len(X_train)))
|
|
val_inds = list(range(len(X_train),
|
|
val_inds = list(range(len(X_train),
|
|
len(X_train) + len(X_val)))
|
|
len(X_train) + len(X_val)))
|
|
@@ -398,12 +397,17 @@ class PipelineSelector(ABC):
|
|
try:
|
|
try:
|
|
assert(os.path.isfile(data_hdf5_store_path)),\
|
|
assert(os.path.isfile(data_hdf5_store_path)),\
|
|
"Parameter hdf5_store_path is not a file"
|
|
"Parameter hdf5_store_path is not a file"
|
|
|
|
+
|
|
|
|
+ # close all opened files, because hdf5 will
|
|
|
|
+ # fail to reopen an opened (for some reason) file
|
|
|
|
+ import tables
|
|
|
|
+ tables.file._open_files.close_all()
|
|
|
|
|
|
store = pd.HDFStore(data_hdf5_store_path)
|
|
store = pd.HDFStore(data_hdf5_store_path)
|
|
|
|
|
|
self._data_path = data_hdf5_store_path
|
|
self._data_path = data_hdf5_store_path
|
|
|
|
|
|
- data_input = {key: store["key"] if key in store else None
|
|
|
|
|
|
+ data_input = {key: store[key] if key in store else None
|
|
for key in ["X_train", "y_train", "X_val", "y_val"]}
|
|
for key in ["X_train", "y_train", "X_val", "y_val"]}
|
|
|
|
|
|
if cv_pickle_path is not None:
|
|
if cv_pickle_path is not None:
|
|
@@ -564,7 +568,7 @@ class PipelineSelector(ABC):
|
|
over the folds.
|
|
over the folds.
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
- scoring = {"score": make_scorer(self.cost_func)}
|
|
|
|
|
|
+ scoring = {"score": make_scorer(self._cost_func)}
|
|
|
|
|
|
scoring.update({metric_name: make_scorer(metric)
|
|
scoring.update({metric_name: make_scorer(metric)
|
|
for metric_name, metric
|
|
for metric_name, metric
|
|
@@ -575,7 +579,7 @@ class PipelineSelector(ABC):
|
|
X=self._X,
|
|
X=self._X,
|
|
y=self._y,
|
|
y=self._y,
|
|
cv=self._cv,
|
|
cv=self._cv,
|
|
- scoring=self._scoring,
|
|
|
|
|
|
+ scoring=scoring,
|
|
error_score=np.nan)
|
|
error_score=np.nan)
|
|
|
|
|
|
averaging_funcs = {
|
|
averaging_funcs = {
|