|
@@ -37,8 +37,8 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
|
|
|
Literal, Dict, List, Union
|
|
|
else:
|
|
|
# from typing_extensions import *
|
|
|
- from typing_extensions import Callable, Optional,\
|
|
|
- Literal
|
|
|
+ # from typing_extensions import Callable, Optional,\
|
|
|
+ # Literal
|
|
|
|
|
|
from cdplib.log import Log
|
|
|
|
|
@@ -65,14 +65,17 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
|
cost_func,
|
|
|
greater_is_better: bool,
|
|
|
trials_path: str,
|
|
|
- backup_trials_freq: Optional[int] = None,
|
|
|
+ # backup_trials_freq: Optional[int] = None,
|
|
|
+ backup_trials_freq = None,
|
|
|
cross_validation_needs_scorer: bool = True,
|
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
|
# additional_metrics: Optional[Dict[str, Callable]] = None,
|
|
|
additional_metrics = None,
|
|
|
- strategy_name: Optional[str] = None,
|
|
|
- stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
|
|
|
- = "INFO"):
|
|
|
+ # strategy_name: Optional[str] = None,
|
|
|
+ strategy_name = None,
|
|
|
+ # stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
|
|
|
+ # = "INFO")
|
|
|
+ stdout_log_level = "INFO"):
|
|
|
"""
|
|
|
param Callable cost_func: function to minimize or maximize
|
|
|
over the elements of a given (pipeline/hyperparameter) space
|
|
@@ -131,7 +134,8 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
|
|
|
|
def run_trials(self,
|
|
|
niter: int,
|
|
|
- algo: Literal[tpe.suggest, rand.suggest] = tpe.suggest)\
|
|
|
+ # algo: Literal[tpe.suggest, rand.suggest] = tpe.suggest)\
|
|
|
+ algo = tpe.suggest)\
|
|
|
-> None:
|
|
|
'''
|
|
|
Method performing the search of the best pipeline in the given space.
|