|
@@ -41,6 +41,7 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
greater_is_better: bool,
|
|
greater_is_better: bool,
|
|
trials_path: str,
|
|
trials_path: str,
|
|
backup_trials_freq: Optional[int] = None,
|
|
backup_trials_freq: Optional[int] = None,
|
|
|
|
+ cross_validation_needs_scorer: bool = True,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
additional_metrics: Optional[Dict[str, Callable]] = None,
|
|
additional_metrics: Optional[Dict[str, Callable]] = None,
|
|
strategy_name: Optional[str] = None,
|
|
strategy_name: Optional[str] = None,
|
|
@@ -86,6 +87,8 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
greater_is_better=greater_is_better,
|
|
greater_is_better=greater_is_better,
|
|
trials_path=trials_path,
|
|
trials_path=trials_path,
|
|
backup_trials_freq=backup_trials_freq,
|
|
backup_trials_freq=backup_trials_freq,
|
|
|
|
+ cross_validation_needs_scorer=
|
|
|
|
+ cross_validation_needs_scorer,
|
|
cross_val_averaging_func=cross_val_averaging_func,
|
|
cross_val_averaging_func=cross_val_averaging_func,
|
|
additional_metrics=additional_metrics,
|
|
additional_metrics=additional_metrics,
|
|
strategy_name=strategy_name,
|
|
strategy_name=strategy_name,
|