Selaa lähdekoodia

remove callable

ogert 3 vuotta sitten
vanhempi
commit
34a3b9089e

+ 2 - 1
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -69,7 +69,8 @@ class HyperoptPipelineSelector(PipelineSelector):
                 #  backup_trials_freq: Optional[int] = None,
                  backup_trials_freq = None,
                  cross_validation_needs_scorer: bool = True,
-                 cross_val_averaging_func: Callable = np.mean,
+                #  cross_val_averaging_func: Callable = np.mean,
+                 cross_val_averaging_func = np.mean,
                 #  additional_metrics: Optional[Dict[str, Callable]] = None,
                 additional_metrics = None,
                 #  strategy_name: Optional[str] = None,

+ 7 - 4
cdplib/pipeline_selector/PipelineSelector.py

@@ -33,7 +33,7 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
 else:
     # from typing_extensions import *
     print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
-    from typing_extensions import Callable, TypedDict
+    from typing_extensions import TypedDict
     
 
 import functools
@@ -79,7 +79,8 @@ class PipelineSelector(ABC):
                  trials_path: str,
                  backup_trials_freq: int = None,
                  cross_validation_needs_scorer: bool = True,
-                 cross_val_averaging_func: Callable = np.mean,
+                #  cross_val_averaging_func: Callable = np.mean,
+                cross_val_averaging_func = np.mean,
                 #  additional_metrics: Dict[str, Callable] = None,
                  additional_metrics = None,
                 #  additional_averaging_funcs: Dict[str, Callable] = None,
@@ -230,7 +231,8 @@ class PipelineSelector(ABC):
             self._logger.log_and_raise_error(err)
 
     def configure_cross_validation(self,
-                                   cross_validation: Callable,
+                                #    cross_validation: Callable,
+                                   cross_validation,
                                    kwargs: dict = None) -> None:
         """
         Method for attaching a custom cross-validation function
@@ -504,7 +506,8 @@ class PipelineSelector(ABC):
         return summary
 
     def configer_summary_saving(self,
-                                save_method: Callable
+                                # save_method: Callable
+                                save_method
                                 = functools.partial(
                                         pd.DataFrame.to_excel,
                                         **{"path_or_buf": "result.csv"}),