|
@@ -75,6 +75,7 @@ class PipelineSelector(ABC):
|
|
greater_is_better: bool,
|
|
greater_is_better: bool,
|
|
trials_path: str,
|
|
trials_path: str,
|
|
backup_trials_freq: int = None,
|
|
backup_trials_freq: int = None,
|
|
|
|
+ cross_validation_needs_scorer: bool = True,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
additional_metrics: Dict[str, Callable] = None,
|
|
additional_metrics: Dict[str, Callable] = None,
|
|
additional_averaging_funcs: Dict[str, Callable] = None,
|
|
additional_averaging_funcs: Dict[str, Callable] = None,
|
|
@@ -139,6 +140,7 @@ class PipelineSelector(ABC):
|
|
self.configured_summary_saving = False
|
|
self.configured_summary_saving = False
|
|
|
|
|
|
self._cost_func = cost_func
|
|
self._cost_func = cost_func
|
|
|
|
+ self._greater_is_better = greater_is_better
|
|
# score factor is 1 when cost_func is minimized,
|
|
# score factor is 1 when cost_func is minimized,
|
|
# -1 when cost func is maximized
|
|
# -1 when cost func is maximized
|
|
self._score_factor = (not greater_is_better) - greater_is_better
|
|
self._score_factor = (not greater_is_better) - greater_is_better
|
|
@@ -160,6 +162,8 @@ class PipelineSelector(ABC):
|
|
# if cross-valition is not configured,
|
|
# if cross-valition is not configured,
|
|
# sklearn cross-validation method is taken by default
|
|
# sklearn cross-validation method is taken by default
|
|
self._cross_validation = sklearn_cross_validation
|
|
self._cross_validation = sklearn_cross_validation
|
|
|
|
+
|
|
|
|
+ self._cross_validation_needs_scorer = cross_validation_needs_scorer
|
|
|
|
|
|
# if a trials object already exists at the given path,
|
|
# if a trials object already exists at the given path,
|
|
# it is loaded and the search is continued. Else,
|
|
# it is loaded and the search is continued. Else,
|
|
@@ -589,11 +593,13 @@ class PipelineSelector(ABC):
|
|
over the folds.
|
|
over the folds.
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
- scoring = {"score": make_scorer(self._cost_func)}
|
|
|
|
-
|
|
|
|
- scoring.update({metric_name: make_scorer(metric)
|
|
|
|
- for metric_name, metric
|
|
|
|
- in self._additional_metrics.items()})
|
|
|
|
|
|
+
|
|
|
|
+ scoring = {"score": self._cost_func} | self._additional_metrics
|
|
|
|
+
|
|
|
|
+ if self._cross_validation_needs_scorer:
|
|
|
|
+ for metric_name, metric in scoring.itmes():
|
|
|
|
+ scoring[metric_name] = make_scorer(
|
|
|
|
+ metric, greater_is_better=self._greater_is_better)
|
|
|
|
|
|
scores = self._cross_validation(
|
|
scores = self._cross_validation(
|
|
estimator=pipeline,
|
|
estimator=pipeline,
|