|
@@ -12,8 +12,6 @@ Created on Tue Oct 6 15:04:25 2020
|
|
the best pipeline so far as well as the entire tuning history
|
|
the best pipeline so far as well as the entire tuning history
|
|
if needed.
|
|
if needed.
|
|
"""
|
|
"""
|
|
-from __future__ import annotations
|
|
|
|
-
|
|
|
|
import os
|
|
import os
|
|
import sys
|
|
import sys
|
|
|
|
|
|
@@ -35,14 +33,12 @@ from cdplib.pipeline_selector.PipelineSelector import PipelineSelector,\
|
|
|
|
|
|
# from typing import Callable, Optional, Literal, Dict, Union, List
|
|
# from typing import Callable, Optional, Literal, Dict, Union, List
|
|
if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
|
|
if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
|
|
- from typing import Callable, \
|
|
|
|
|
|
+ from typing import Callable, Optional,\
|
|
Literal, Dict, List, Union
|
|
Literal, Dict, List, Union
|
|
else:
|
|
else:
|
|
# from typing_extensions import *
|
|
# from typing_extensions import *
|
|
- Dict = dict
|
|
|
|
- List = list
|
|
|
|
- from typing_extensions import Callable, \
|
|
|
|
- Literal, Dict, List, Union
|
|
|
|
|
|
+ from typing_extensions import Callable, Optional,\
|
|
|
|
+ Literal, Union
|
|
|
|
|
|
from cdplib.log import Log
|
|
from cdplib.log import Log
|
|
|
|
|
|
@@ -71,7 +67,8 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
backup_trials_freq: Optional[int] = None,
|
|
backup_trials_freq: Optional[int] = None,
|
|
cross_validation_needs_scorer: bool = True,
|
|
cross_validation_needs_scorer: bool = True,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
cross_val_averaging_func: Callable = np.mean,
|
|
- additional_metrics: Optional[Dict[str, Callable]] = None,
|
|
|
|
|
|
+ # additional_metrics: Optional[Dict[str, Callable]] = None,
|
|
|
|
+ additional_metrics = None,
|
|
strategy_name: Optional[str] = None,
|
|
strategy_name: Optional[str] = None,
|
|
stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
|
|
stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
|
|
= "INFO"):
|
|
= "INFO"):
|
|
@@ -198,8 +195,8 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
|
|
|
|
self._logger.log_and_raise_error(err)
|
|
self._logger.log_and_raise_error(err)
|
|
|
|
|
|
- def _get_space_element_from_trial(self, trial: dict)\
|
|
|
|
- -> Union[Dict[str, SpaceElementType], None]:
|
|
|
|
|
|
+ def _get_space_element_from_trial(self, trial: dict):#\
|
|
|
|
+ # -> Union[Dict[str, SpaceElementType], None]:
|
|
"""
|
|
"""
|
|
Hyperopt trials object does not contain the space
|
|
Hyperopt trials object does not contain the space
|
|
elements that result in the corresponding trials.
|
|
elements that result in the corresponding trials.
|
|
@@ -234,8 +231,8 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
|
|
|
|
self._logger.log_and_raise_error(err)
|
|
self._logger.log_and_raise_error(err)
|
|
|
|
|
|
- def _get_space_element_from_index(self, i: int)\
|
|
|
|
- -> Union[Dict[str, SpaceElementType], None]:
|
|
|
|
|
|
+ def _get_space_element_from_index(self, i: int): #\
|
|
|
|
+ # -> Union[Dict[str, SpaceElementType], None]:
|
|
"""
|
|
"""
|
|
Gets the space element of shape
|
|
Gets the space element of shape
|
|
{"name": NAME, "params": PARAMS, "pipeline": PIPELINE}
|
|
{"name": NAME, "params": PARAMS, "pipeline": PIPELINE}
|
|
@@ -402,8 +399,8 @@ class HyperoptPipelineSelector(PipelineSelector):
|
|
|
|
|
|
self._logger.log_and_raise_error(err)
|
|
self._logger.log_and_raise_error(err)
|
|
|
|
|
|
- def get_n_best_trial_pipelines_of_each_type(self, n: int)\
|
|
|
|
- -> Union[Dict[str, List[Pipeline]], None]:
|
|
|
|
|
|
+ def get_n_best_trial_pipelines_of_each_type(self, n: int): #\
|
|
|
|
+ # -> Union[Dict[str, List[Pipeline]], None]:
|
|
"""
|
|
"""
|
|
:return: a dictiionry where keys are pipeline names,
|
|
:return: a dictiionry where keys are pipeline names,
|
|
and values are lists of best pipelines with this name
|
|
and values are lists of best pipelines with this name
|