|
@@ -33,7 +33,7 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
|
|
|
else:
|
|
|
|
|
|
print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
|
|
|
- from typing_extensions import Callable, TypedDict
|
|
|
+ from typing_extensions import TypedDict
|
|
|
|
|
|
|
|
|
import functools
|
|
@@ -79,7 +79,8 @@ class PipelineSelector(ABC):
|
|
|
trials_path: str,
|
|
|
backup_trials_freq: int = None,
|
|
|
cross_validation_needs_scorer: bool = True,
|
|
|
- cross_val_averaging_func: Callable = np.mean,
|
|
|
+
|
|
|
+ cross_val_averaging_func = np.mean,
|
|
|
|
|
|
additional_metrics = None,
|
|
|
|
|
@@ -230,7 +231,8 @@ class PipelineSelector(ABC):
|
|
|
self._logger.log_and_raise_error(err)
|
|
|
|
|
|
def configure_cross_validation(self,
|
|
|
- cross_validation: Callable,
|
|
|
+
|
|
|
+ cross_validation,
|
|
|
kwargs: dict = None) -> None:
|
|
|
"""
|
|
|
Method for attaching a custom cross-validation function
|
|
@@ -504,7 +506,8 @@ class PipelineSelector(ABC):
|
|
|
return summary
|
|
|
|
|
|
def configer_summary_saving(self,
|
|
|
- save_method: Callable
|
|
|
+
|
|
|
+ save_method
|
|
|
= functools.partial(
|
|
|
pd.DataFrame.to_excel,
|
|
|
**{"path_or_buf": "result.csv"}),
|