tanja 3 years ago
parent
commit
27a40c41a8

+ 3 - 0
cdplib/gridsearch/GridSearchPipelineSelector.py

@@ -41,6 +41,7 @@ class GridSearchPipelineSelector(PipelineSelector):
                  greater_is_better: bool,
                  trials_path: str,
                  backup_trials_freq: Optional[int] = None,
+                 cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
                  additional_metrics: Optional[Dict[str, Callable]] = None,
                  strategy_name: Optional[str] = None,
@@ -86,6 +87,8 @@ class GridSearchPipelineSelector(PipelineSelector):
                              greater_is_better=greater_is_better,
                              trials_path=trials_path,
                              backup_trials_freq=backup_trials_freq,
+                             cross_validation_needs_scorer=
+                                 cross_validation_needs_scorer,
                              cross_val_averaging_func=cross_val_averaging_func,
                              additional_metrics=additional_metrics,
                              strategy_name=strategy_name,

+ 3 - 0
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -57,6 +57,7 @@ class HyperoptPipelineSelector(PipelineSelector):
                  greater_is_better: bool,
                  trials_path: str,
                  backup_trials_freq: Optional[int] = None,
+                 cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
                  additional_metrics: Optional[Dict[str, Callable]] = None,
                  strategy_name: Optional[str] = None,
@@ -102,6 +103,8 @@ class HyperoptPipelineSelector(PipelineSelector):
                              greater_is_better=greater_is_better,
                              trials_path=trials_path,
                              backup_trials_freq=backup_trials_freq,
+                             cross_validation_needs_scorer=
+                                 cross_validation_needs_scorer,
                              cross_val_averaging_func=cross_val_averaging_func,
                              additional_metrics=additional_metrics,
                              strategy_name=strategy_name,